2 분 소요

목표

머신러닝을 이용하여 도미와 빙어의 분류 기준을 알려주지 않고 데이터와 정답만 알려주고 스스로 분류기준을 찾아 생선의 길이와 무게 데이터가 들어왔을 때 이것이 도미인지 빙어인지 분류하게 한다.

사용 데이터

https://www.kaggle.com/aungpyaeap/fish-market

이진분류

머신러닝에서 여러개의 종류(클래스) 중 하나를 구별해 내는 문제를 분류라고 부른다. 2개의 클래스 중 하나를 분류한는 문제를 이진 분류라고 한다.

도미와 빙어 분류 프로그램 만들기

도미 데이터 입력

저울로 잰 도미의 길이와 무게를 파이썬 리스트로 만든다. 아래 코드는 일일이 타이핑할 필요 없이 bream_list.py 에서 가져올 수 있다.

bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]

도미 데이터 산점도

산점도는 x,y축 좌표계에 두 변수(x,y)의 관계를 나타내는 방법이다. 여기서는 길이를 x축으로 하고 무게를 y축으로 한다.

import matplotlib.pyplot as plt # 그래프를 그리는 대표적인 패키지

plt.scatter(bream_length, bream_weight) #bream_length를 x축, bream_weight를 y축으로하여 산점도를 그림
plt.xlabel('length') # x축의 이름을 length로 정함
plt.ylabel('weight') # y축의 이름을 weight로 정함
plt.show() # 그래프를 화면에 출력

산점도는 아래와 같다. pic1

생선의 길이가 길수록 대체로 무게도 많이 나간다. 산점도 그래프가 직선 형태에 가깝게 나타나는 경우를 선형 적이라고 한다.

빙어 데이터 입력

저울로 잰 빙어의 길이와 무게를 파이썬 리스트로 만든다. 아래 코드는 일일이 타이핑할 필요 없이 smelt_list.py 에서 가져올 수 있다.

smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

도미와 빙어 데이터 산점도

scatter 함수를 연속으로 사용하면 여러 산점도를 한번에 그릴 수 있따.

plt.scatter(bream_length,bream_weight)
plt.scatter(smelt_length,smelt_weight)
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

pic2

matplotlib가 알아서 빙어와 도미의 데이터를 색깔로 구분시켜준다.

그러나 이 그래프를 처음 보면 어떤 것이 도미이고 빙어인지 구분하기 어렵다. 따라서 책에는 없지만 legend를 추가시켜주었다.

plt.scatter(bream_length,bream_weight,label='bream')
plt.scatter(smelt_length,smelt_weight,label='smelt')
plt.legend()
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

pic5

두 리스트를 하나로 합치기

length = bream_length + smelt_length
weight = bream_weight + smelt_weight

이제 length와 weight 앞쪽 35개에는 도미의 데이터가, 뒤쪽 14개는 빙어의 데이터가 들어있다.

머신러닝 패키치 사이킷런에서 학습 데이터를 넣으려면 2차원 리스트가 필요하다.

[[L1,W1],[L2,W2],[L3,W3]...[Ln,Wn]]
#L1,L2,L3... ->길이
#W1,W2,W3... ->무게

이것을 만드는 가장 쉬운 방법은 zip 함수와 리스트 내포 구문을 사용하는 것이다.

fish_data = [[l,w] for l,w in zip(length,weight)]

fish_data의 값은 아래 사진과 같다.

pic3

정답 데이터 만들기

fish_data의 앞쪽 35개의 데이터는 도미, 뒤쪽 14개의 데이터는 빙어이다. 머신러닝 알고리즘에 어떤 것이 도미이고 빙어인지 알려줘야한다. 도미는 1 빙어는 0으로 하면

fish_target = [1]*35 +[0]*14
print(fish_target)

위와 같이 fish_target이라는 정답 데이터를 만들 수 있다.

학습시키기

사이킷런에서 k-최근접 이웃 알고리즘 클래스인 KNeighborsClassifier를 임포트하고 클래스의 객체를 만든다.

from sklearn.neighbors import KNeighborsClassifier
kn = KNeighborsClassifier() 

이제 fit 함수로 데이터를 학습시킨다.

kn.fit(fish_data,fish_target)

아래의 score함수로 학습의 정확도를 보면 1.0이다.

kn.score(fish_data,fish_target)

아마 도미와 빙어의 겹치는 부분이 없어서 그런 것 같다. 위의 그래프를 보면 파란점과 주황색 점이 겹치지 않는다.

KNeighborsClassifier

이 알고리즘은 주위의 다른 데이터를 보고 다수를 차지하는 것을 정답으로 한다. 사이킷런에서는 기본 주변 5개를 보고 판단한다. 즉 k최근접 알고리즘의 준비과정은 데이터를 모두 가지고 있는 것이다

새로운 데이터가 들어오면 가지고 있는 데이터와의 직선 거리를 비교하여 가장 가까운 n개의 데이터를 보고 어떤 데이터인지 판단하면 된다.

따라서 데이터를 다 메모리에 저장하고 있어야하므로, 데이터가 크다면 이 알고리즘은 사용하기 어렵다. 실제로 데이터를 저장하고 있는지 다음 함수를 통해 확인할 수 있다.

print(kn._fit_X)
print(kn._y)

pic4

머신러닝 프로그램을 이용하여 예측하기

kn.predict([[30,600]])

위의 코드를 돌리면 도미(1)로 예측한다. 실제로 산점도에서 (30,600) 근처에는 도미가 많다.

근처 49개의 데이터로 판단하기

kn49 = KNeighborsClassifier(n_neighbors=49)

위의 코드로 가장 가까운 데이터 49개를 사용하는 모델을 만들 수 있다. 즉 근처 49개의 데이터를 보고 도미인지 빙어인지 판단한다. 그러나 입력된 데이터가 총 49개이고, 도미가 35개 빙어가 14개 이므로 항상 도미로 예측할 것이다.

kn49.fit(fish_data,fish_target)
kn49.score(fish_data,fish_target)

실제로 위 코드를 돌리면 정확도가 약 0.714(넷째자리에서 반올림)인데, 이는 35/49의 값과 같다. 따라서 적절한 n_neighbors 값을 찾는 것이 중요하다.

Reference

박해선 저,<혼자 공부하는 머신러닝 + 딥러닝>, 한빛미디어, 2020

댓글남기기