관리 메뉴

클라이언트/ 서버/ 엔지니어 " 게임 개발자"를 향한 매일의 공부일지

회귀 알고리즘과 모델 규제 7 - k-최근접 이웃 회귀 심화 학습 본문

인공지능/머신러닝

회귀 알고리즘과 모델 규제 7 - k-최근접 이웃 회귀 심화 학습

huenuri 2024. 10. 1. 21:36

지도학습의 한 종류인 회귀 모델을 학습해볼 것이다. 코드 하나 하나 제대로 읽으면서 진짜 공부를 시작해본다.


 

 

 

k-최근접 이웃 회귀 실습해보기

 

 

데이터가 어떤 형태를 띄고 있는지 산점도를 그렸다. 하나의 특성을 사용하기 대문에 특성 데이터를 x에 놓고, 타깃 데이터를 y축에 놓는다. 농어의 길이에 따라 무게도 늘어나는 것을 볼 수 있다. 이 실습은 농어의 무게 단위로 가격을 책정하는 모델을 만드는 문제이다.

 

 

훈련 세트와 테스트 세트로 나누었다. 사이킷런에 사용할 훈련 세트는 2차원 배열이어야 한다. perch_length가 1차원 배열이기 때문에 이를 나눈 train_input과 test_input도 1차원 배열이다. 이것을 2차원 배열로 바꾸어야 한다.

지난 장에서는 2개의 특성을 사용했기 때문에 자연스럽게 열이 2개인 2차원 배열을 사용했다. 이번 예제에서는 특성을 1개만 사용하므로 수동으로 2차원 배열을 만들 것이다.


 

 

원본 배열의 원소는 4개인데 6개로 바꾸려고 하자 오류가 발생한다. 이제 reshape() 메서드를 사용해 train_input과 test_input을 2차원 배열로 바꾸겠다.

 

 

 

(42,)인 1차원 배열에서 (42, 1)인 2차원 배열로 변경되었다. 넘파이는 배열의 크기를 자동으로 지정하는 기능도 제공하는데 크기에 -1을 지정하면 나머지 원소 개수로 모두 채운다. 예를 들어 첫번째 크기를 나머지 원소로 채우고, 두번째 크기를 1로 하려면 (-1, 1)로 사용한다.


 

 

 

 

회귀에서는 정확한 숫자를 맞힌다는 것이 거의 불가능하다. 예측하는 값이나 타깃 모두 임의의 수치이기 때문이다. 회귀의 경우 조금 다른 값으로 평가하는데 이 점수를 결정계수라고 부른다.

만약 타깃의 평균 정도를 예측하는 수준이라면 R²은 0에 가까워지고, 예측이 타깃에 아주 가까워지면 1에 가까운 값이 된다. 0.99면 좋은 값이다.

타깃과 예측 한 값 사이의 차이를 구해보면 어느 정도 예측이 벗어났는지 가늠할 수 있다. 

 

 

 

결과에서 예측이 평균적으로 19g 정도 타깃값과 다르다는 것을 알 수 있다. 지금까지는 훈련 세트를 사용해 모델을 훈련하고 테스트 세트로 모델을 평가했다. 그런데 훈련 세트를 사용해 평가해 보면 어떨까?


 

 

 

 

훈련 세트에 적합하도록 만들어진 게 과대적합이고, 테스트 세트나 둘다 적합하지 않게 만들어진 게 과소적합이다. 훈련 세트보다 테스트세트의 점수가 높으니 과소적합이다. 이 문제를 해결하기 위해 이웃의 개수 k를 줄인다.


 

 

 

k값을 줄였더니 훈련 세트의 R² 점수가 높아졌다. 이제 과소적합의 문제를 해결했다.


 

 

 

 

5에서 45까지의 길이에 대한 예측 값을 선으로 표시한다. prediction은 n 값에 따라 예측한 결과이다.

 

 

 

이 코드는 KNN 회귀 모델을 이용하여 농어의 길이(5에서 45까지)를 입력받고, 그에 따른 무게를 예측한다. k 값을 1, 5, 10으로 바꿔가며 각각의 예측 결과를 시각화한다. 예측 값은 x 범위에 따라 선으로 나타낸다.

이 과정은 KNN 회귀 모델의 성능을 시각적으로 확인하는 데 유용하며, k 값에 따라 모델의 예측이 어떻게 변하는지 보여줍니다. 이제 k값이 1일 때와 5일 때도 확인해볼 것이다.


 

 


 

 

 

산점도 그래프를 확인해보니 k값이 커질수록 뭔가 달라지는 변화가 보인다. k값이 작을 때는 과대적합 가능성이 크다. 모델이 훈련 데이터에 지나치게 반응하는 것이다. 하지만 k값이 커질수록 모델의 복잡도가 감소한다. 즉, 모델이 훈련 데이터의 세부적인 변동에 덜 민감해진다.

더 많은 이웃의 평균을 사용하여 예측을 수행하므로, 전체적인 패턴을 반영하게 된다. 이웃이 많아질수록 더 많은 데이터 포인트를 고려하므로, 더 부드럽고 안정적인 예측이 가능하다는 것을 이 그래프를 통해 볼 수 있다.

 


 

 

 

학습을 마치고

확실히 2차 학습을 해보니 이전에는 보이지 않던 부분들이 잘 보였다. 그냥 대충 이런 게 있구나 하고 넘어갔던 것들도 보였고 이해를 하고 학습을 할 수 있었다. 그리고 확인 문제에서도 그냥 코드만 찍고 말았는데 코드를 분석할 수 있는 능력이 길러지는 것 같다.

다시 학습을 해보길 정말 잘했다는 생각이 든다.