일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- CSS
- 순환신경망
- html/css
- 개발일기
- 디버깅
- c언어
- 자바
- 컴퓨터구조
- 상속
- 정보처리기사실기
- JSP
- 중학수학
- 연습문제
- 머신러닝
- 데이터베이스
- rnn
- ChatGPT
- 자바 실습
- 중학1-1
- 딥러닝
- 데이터분석
- 자바스크립트
- JSP/Servlet
- 정보처리기사필기
- 컴퓨터비전
- SQL
- JDBC
- 혼공머신
- 파이썬
- 자바스크립트심화
- Today
- Total
클라이언트/ 서버/ 엔지니어 "게임 개발자"를 향한 매일의 공부일지
다양한 분류 알고리즘 5 - 로지스틱 회귀 심화 학습 본문
조금 전까지는 어제치 학습 분량을 정리했었고 지금부터 쓰는 학습일지는 오늘 분량이다. 오늘도 12개 이상의 학습일지를 올리며 공부를 진행해 볼 것이다.
4장의 첫번째 단원인 로지스틱 회귀에 대해서 공부해 볼 것이다.
로지스틱 회귀 실습해보기
인터넷에서 직접 CSV 데이터를 읽어 들인다. 판다스의 read_csv() 함수로 CSV 파일을 데이터프레임으로 변환한 다음 head() 메서드로 처음 5개 행을 출력했다.
어떤 종류의 생선이 있는지 알기 위해 Species 열에서 고유한 값을 추출한다. 여기에는 생선의 종류가 담겨있다. 이 데이터프레임에서 Species 열을 타깃으로 만들고 나머지 5개 열은 입력 데이터로 사용한다.
이제 데이터를 훈련 세트와 데이터 세트로 나눈다. 그리고 훈련 세트와 테스트 세트를 표준화 전처리한다. 여기서도 훈련 세트의 통계값으로 테스트 세트를 변환해야 한다.
필요한 데이터를 모두 준비했다. k-최근접 이웃 분류기로 테스트 세트에 들어있는 확률을 예측해 보겠다.
최근접 이웃 개수인 k를 3으로 지정하여 사용했다. 타깃 데이터를 만들 때 fish['Species']를 사용해 만들었기 때문에 훈련 세트와 테스트 세트의 타깃 데이터에도 7개의 생선 종류가 들어있다. 깃 데이터에 2개 이상의 클래스가 포함된 문제를 다중 분류라고 한다.
테스트 세트에 있는 처음 5개의 샘플의 타깃값을 예측해 보았다. 5개 샘플에 대한 예측은 어떤 확률로 만들어졌을까? 테스트 세트에 있는 5개 샘플에 대한 확률을 출력해 보면 이와 같다. round() 함수는 기본적으로 소수점 첫째 자리에서 반올림을 하는데, decimals 매개변수로 유지할 소수점 아래 자릿수를 지정할 수 있다. 이렇게 5개의 샘플이 7개 생선에 대한 확률로 출력되었다.
즉, 첫 번째 열이 'Bream'에 대한 확률, 두 번째 열이 'Parkki'에 대한 확률이다. 첫 번째 행에 세 번째 열의 확률을 1로 예측했는데, 이 값은 Perch 값으로 정확하게 예측했음을 알 수 있다.
이 모델이 계산한 확률이 가장 가까운 이웃의 비율과 맞는지 확인해 볼 것이다.
출력 결과 위의 3개의 최근접 이웃을 사용하기 때문에 가능한 확률은 0/3, 1/3, 2/3, 3/3이 전부이다. 뭔가 더 좋을 방법을 찾아보자.
넘파이를 사용하면 로지스틱 회귀 그래프를 간단히 그릴 수 있다. -5와 5 사이에 0.1 간격으로 배열 z를 만든 다음 z 위치마다 시그모이드 함수를 계산한다. 지수 함 수 계산은 np.exp() 함수를 사용한다. 사이킷런에는 로지스틱 회귀 모델인 LogisticRegression 클래스가 준비되어 있다. 시그모이드 함수는 정말로 1에서 1까지 변한다.
훈련하기 전에 이진 분류를 수행해 보겠다. 이진 분류일 경우 시그모이드 함수의 출력이 0.5보다 크면 양성 클래스, 0.5보다 작으면 음성 클래스로 분류한다.
훈련한 모델을 사용해 train_bream_smelt에 있는 처음 5개의 샘플을 예측해 보았다. 두 번째 샘플을 제외하고는 모두 도미로 예측했다. 샘플마다 2개의 확률이 출력되었다. 첫 번째 열이 음성 클래스(0)에 대한 확률이고 두 번째 열이 양성 클래스(1)에 대한 확률이다. 이 둘 중 어떤 것이 양성 클래스일까?
classes_ 속성을 통해 타깃값을 알파벳순으로 정렬했다. predict_proba() 메서드가 반환한 배열값을 보면 두 번째 샘플만 양성 클래스인 빙어의 확률이 높다. 나머지는 모두 도미(Bream)로 예측할 것이다. 이제 선형 회귀에서처럼 로지스틱 회귀가 학습한 계수를 확인해 보자.
이 로지스틱 회귀 모델이 학습한 방정식은 다음과 같다.
z = -0.404 x (weight) - 0.576 x (Length) - 0.663 x (Diagonal) - 1.013 x (Height) - 0.732 x (width) - 2.161
확실히 로지스틱 회귀는 선형 회귀와 매우 비슷하다. 이 모델로 z값도 계산해 보았다. LogisticRegression 클래스는 decision_function()에서 메서드로 z값을 출력할 수 있다. train_bream_smelt의 처음 5개 샘플의 z값을 출력했다. 이 z 값을 시그모이드 함수에 통과시키면 확률을 얻을 수 있다.
이제 이진 분류의 경험을 바탕으로 7개의 생선을 분류하는 다중 분류 문제로 넘어가 보겠다.
훈련 세트와 테스트 세트에 대한 점수가 높고 과대적합이나 과소적합으로 치우친 것 같지 않다. 이제 5개의 샘플에 대한 예측 확률을 출력했다. 이진 분류일 경우 2개의 열만 있었다는 것을 기억하자.
첫 번째 샘플을 보면 세 번째 열의 확률이 가장 높다. 84.1%나 된다. 세 번째 열이 농어(Perch)에 대한 확률인지 확인해 보자. classes_ 속성에서 정보를 확인할 수 있다.
농어가 맞다. 첫 번째 샘플은 Perch를 가장 높은 확률로 예측했다. 두 번째 샘플은 여섯 번째 열인 Semlt를 가장 높은 확률로 예측했다. 이진 분류는 샘플마다 2개의 확률을 출력하고 다중 분류는 샘플마다 클래스 개수만큼 확률을 출력한다.
coef_와 intercept_의 크기도 출력해 보았다. 이 데이터는 5개의 특성을 사용하므로 coef_ 배열의 행렬은 5개이다. intercept_도 7개나 있다. 이 말은 z를 7개나 계산한다는 의미이다. 다중 분류는 클래스마다 z 값을 하나씩 계산한다. 가장 높은 z 값을 출력하는 클래스가 예측 클래스가 된다. 그럼 확률을 어떻게 계산한 것일까?
이진 분류에서는 시그모이드 함수를 사용해 z를 0과 1 사이의 값으로 변환했다. 다중 분류는 이와 달리 소프트맥스 함수를 사용하여 7개의 z값을 확률로 변환한다.
먼저 7개의 z값의 이름을 z1에서 z7이라고 붙이겠다. 이를 지수함수를 사용해 모두 더한다. 그런 다음 e_sum으로 나누어주면 된다. s1에서 s7까지 모두 더하면 분자와 분모가 같아지므로 1이 된다. 7개 생선에 대한 확률의 합은 1이 되어야 하므로 잘 맞다.
이진 분류에서처럼 decision_funciton() 메서드로 z1~z7까지의 값을 구한 다음 소프트맥스 함수를 사용해 확률로 바꾸어 본다. 테스트 세트의 처음 5개 샘플에 대한 z1~z7의 값을 구한다.
앞서 구한 decision 배열을 softmax() 함수에 전달했다. softmax()의 axis 매개변수는 소프트맥스를 계산할 축을 지정한다. 여기서는 axis=1로 지정하여 각 행, 즉 각 샘플에 대해 소프트맥스를 계산했다.
학습을 마치고
마지막에 소프트맥스 함수를 구하는 건 많이 어려웠다. 그래도 처음 학습할 때보다는 이해를 많이 한 편이다. 너무 깊이 있는 건 알 필요가 없고 이 정도만 해도 충분한 것 같다. 뒤에 딥러닝 학습을 할 때 이 내용이 똑같이 등장한다고 한다. 그때 더 자세히 이해하기로 했다.
로지스틱 회귀는 회귀 모델이 아닌 분류 모델이기에 선형 회귀처럼 선형 방정식을 사용한다. 다중 분류일 경우에는 클래스 개수만큼 방정식을 훈련하고.. 아무튼 여러 가지를 많이 배웠다.
'인공지능 > 머신러닝' 카테고리의 다른 글
트리 알고리즘 7 - 결정 트리 심화 학습 (1) | 2024.10.02 |
---|---|
다양한 분류 알고리즘 6 - 확률적 경사 하강법 심화 학습 (0) | 2024.10.02 |
회귀 알고리즘과 규제 모델 9 - 특성 공학과 규제 심화 학습 (0) | 2024.10.01 |
회귀 알고리즘과 모델 규제 8 - 선형 회귀 심화 학습 (0) | 2024.10.01 |
회귀 알고리즘과 모델 규제 7 - k-최근접 이웃 회귀 심화 학습 (0) | 2024.10.01 |