일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- 혼공머신
- 중학1-1
- 자바
- 연습문제
- 컴퓨터비전
- 텍스트마이닝
- 정수와유리수
- 딥러닝
- pandas
- CNN
- C++
- JSP/Servlet
- 컴퓨터구조
- CSS
- 머신러닝
- 자바 실습
- 데이터입출력구현
- 정보처리기사필기
- 파이썬라이브러리
- 데이터분석
- 중학수학
- html/css
- 정보처리기사실기
- 운영체제
- 코딩테스트
- numpy/pandas
- SQL
- 데이터베이스
- 파이썬
- 영어공부
- Today
- Total
클라이언트/ 서버/ 엔지니어 "게임 개발자"를 향한 매일의 공부일지
합성곱 신경망 4 - 복잡한 모델 생성 1 : Functional API로 다중 입력, 다중 출력 레이어 생성 및 다중 출력 분류 모델 본문
합성곱 신경망 4 - 복잡한 모델 생성 1 : Functional API로 다중 입력, 다중 출력 레이어 생성 및 다중 출력 분류 모델
huenuri 2024. 10. 29. 06:02이번 단원도 무척 길어서 두 개의 포스트로 나누어 작성할 예정이다. 단순한 모델 생성도 많이 어려웠는데 이번에는 복잡한 모델을 생성해 볼 것이다.
텐서플로 케라스의 Functional API를 사용하면 Sequential API로 구현할 수 없는 복잡한 구조의 모델을 정의할 수 있다. 예를 들면, 각 레이어를 기준으로 입력이 2개 이상이거나 출력이 2개 이상인 모델을 만들 수 있다. 또는 중간에 있는 레이어들을 건너뛰고 뒤쪽에 있는 레이어로 출력 텐서를 전달하는 방식으로 직접 연결하는 방법도 가능하다.
데이터셋 준비
필요한 라이브러리와 mnist 데이터셋을 불러와서 훈련 데이터셋과 검증 데이터셋으로 구분하여 저장한다.
원본 mnist 데이터셋은 0~9까지 숫자를 손글씨로 쓴 이미지(x)와 정답(y) 숫자로 구성되어 있다. 우리는 0~9까지 숫자를 맞추는 정답 외에 홀수인지 짝수인지를 판단하는 정답(y_odd)을 새로 만들어서 추가할 것이다. 다음과 같이 y_train 값이 홀수이면 1, 짝수이면 0으로 하는 y_train_odd 배열을 새로 만든다.
원본 y_train 배열과 홀짝으로 분류한 y_train_odd 배열을 동시에 출력하여 비교해 본다. 홀수 짝수 여부가 잘 정리된 것을 확인할 수 있다.
검증 데이터셋에 대해서도 홀수, 짝수를 나타내는 y_valid_odd 배열을 만든다.
입력 이미지 데이터를 255로 나눠서 정규화한다. 색상이 하나인 모노컬러 이미지이므로, 새로운 축을 추가하고 채널 개수를 1개로 지정한다. 텐서플로 expand_dims() 함수에 원본 배열을 입력하고, 새롭게 추가하려는 축의 인덱스를 지정하면 해당 인덱스에 새로운 축이 추가된다. 여기서는 축의 인덱스로 -1을 지정했기 때문에 새로운 축은 끝에 추가된다.
즉 (60000, 28, 28) 배열이 (60000, 28, 28, 1) 배열이 된다.
Functional API로 다중 입력, 다중 출력 레이어 생성
Functional API를 사용하면 사용자가 원하는 복잡한 구조의 모델을 만들 수 있다. Sequential API의 경우 레이어를 층층이 한 줄로 연결할 수밖에 없지만, Functional API를 사용하면 다중 입력 또는 다중 출력 같은 구조를 만들 수 있다.
다음 코드에서 입력 레이어 inputs의 경우, Conv2D 레이어와 Flatten 레이어의 입력으로 사용된다. 따라서 입력 레이어는 2개의 출력을 갖게 된다. 서로 다른 2개의 출력은 각각 다른 레이어의 입력으로 사용되고, 최종적으로 Concatenate 레이어에서 합쳐진 다음에 Dense 레이어를 통과한다. 이처럼 함수의 입력과 출력으로 표현할 수 있어 자유롭게 모델 구조를 정의할 수 있는 장점이 있다.
앞의 정의한 모델 구조를 그래프로 출력하면 다음과 같다. 케라스 utils 모듈의 plot_model 함수를 사용하고, 모델을 입력하면 모델 구조를 그려준다. 별도의 파일로 저장할 수도 있다.
입력 레이어가 2개의 출력으로 나누어지고, 마지막 Dense 레이어를 통과하기 전에 Concatenate 레이어에서 하나로 합쳐지는 구조를 시각적으로 확인할 수 있다.
다중 분류 모델에 맞게 손실함수와 평가지표를 지정하고, 옵티마이저로는 adam을 사용한다. 10 epoch에 대한 모델 훈련을 마치고 검증 데이터에 대한 모델의 예측 성능을 평가한다. 10개의 숫자 레이블을 맞추는 모델의 정확도는 약 98%로서 앞에서 학습했던 Sequential 모델과 큰 차이는 없다.
다중 출력 분류 모델
이번에는 두 가지 서로 다른 분류 문제를 예측하는 다중 출력 모델을 만들어 본다. 앞에서 데이터셋을 불러온 후에 홀수, 짝수 정답 배열을 추가해 주었다. 다음은 0~9까지 숫자를 맞추는 분류 문제와 홀수, 짝수를 맞추는 분류 문제를 동시에 풀어내는 모델을 정의하는 코드다. 각기 다른 문제에 맞도록 최종 분류기의 출력 레이어를 2개 만드는 것이 핵심이다.
즉, 다음 모델은 입력 이미지를 하나 받아서 해당 손글씨가 어떤 숫자인지를 분류하고 홀수인지 여부도 함께 분류하는 문제가 된다.
모델의 입력 텐서와 출력 텐서를 화면에 표시한다. 입력은 하나이고, 출력은 10개인 텐서와 1개인 텐서로 두 개가 확인된다.
모델 구조를 그림으로 그리면 다음과 같다.
모델을 컴파일할 때 주의할 내용이 있다. 앞서 모델을 정의할 때 각각의 출력 Dense 레이어에 name 속성을 digit_dense와 같은 이름을 지정한 것을 떠올리자. 여기서 지정한 이름을 key로 하고, 해당 key에 해당하는 레이어에 적용할 손실함수와 가중치를 딕셔너리 형태로 지정한다.
모델을 훈련시키기 위해 fit() 메서드를 적용할 때도 출력 값을 2개 지정하고, 레이어 이름 속성을 key로 하는 딕셔너리 형태로 각각의 출력에 맞는 정답 배열을 입력해야 한다.
이번에도 책에 틀린 코드가 있어서 수정해주었다. 실수 형태인 float으로 전달하는 것과 평가 지표를 리스트 형태로 만드는 것으로 수정했다.
모델의 성능을 평가해보자. 숫자를 맞추는 문제는 98%의 정확도를 보이는 반면, 홀수 여부를 판단하는 문제는 약 89%의 정확도를 보인다.
검증 데이터셋의 인덱스에 해당하는 이미지를 출력하면 숫자 7에 대한 손글씨 이미지다.
검증 데이터셋의 모든 이미지 데이터를 입력해서 2개의 분류 문제에 대한 예측 확률을 구하면 다음과 같다. 첫 번째 예측 값은 10개 분류 레이블에 대한 확률을 담고 있고, 두 번째 예측 값은 홀수일 확률을 담고 있다. 숫자 8 이미지에 대한 예측 확률을 print 함수로 출력하면 다음과 같다.
넘파이 argmax 함수를 이용하여 예측 확률을 실제 정답 레이블로 변환한다. 검증 데이터셋의 첫 10개 이미지에 대한 예측 레이블은 다음과 같고 첫 번째 이미지를 7로 예측하고 있다.
홀수, 짝수 여부에 대한 예측 레이블을 출력하면 다음과 같다. 임계값으로 0.5를 지정했는데, 홀수일 확률이 0.5보다 큰 경우 홀수로 분류하기로 한다. 첫 번째 이미지 7에 대하여 홀수(4)로 잘 분류하고 있는 것을 확인할 수 있다. 두 번째 샘플 이미지에 대해서는 숫자 2에 해당하는 짝수(0)로 정확하게 분류하고 있다.
학습을 마치고
홀수와 짝수를 판별하는 숫자 이미지에 대해서 모델을 생성해서 출력하는 학습을 진행했다. 이번에는 코드 오류가 많지 않아서 그런지 이전의 단순 모델보다 훨씬 더 이해하기 괜찮았다. 다음 포스트에서는 전이 학습 부분을 학습해 볼 것이다.
'인공지능 > 딥러닝' 카테고리의 다른 글
합성곱 신경망 6 - 위성 이미지 분류 1 : 데이터셋 로드 및 전처리와 모델 훈련 및 검증 (0) | 2024.10.29 |
---|---|
합성곱 신경망 5 - 복잡한 모델 생성 2 : 전이 학습 (0) | 2024.10.29 |
합성곱 신경망 3 - 간단한 모델 생성 2 : 모델 구조 파악 (0) | 2024.10.29 |
합성곱 신경망 2 - 간단한 모델 생성 1 : 데이터 로드 및 전처리와 Sequential API로 모델 생성 및 학습 (0) | 2024.10.28 |
합성곱 신경망 1 - 합성곱 신경망에 대하여 (2) | 2024.10.28 |