본문 바로가기

머신러닝공부

Fashion MNIST 딥러닝 스크립트 예제

728x90
반응형

Fashion MNIST

손글씨가 아닌 옷과 신발 등의 흑백 이미지로서 MNIST보다는 좀 더 어려운 문제로 평가되고 있음

데이터 정의 -> 데이터 전처리 -> 모델 구축 -> 모델 컴파일 -> 학습 -> 모델 평가

# 데이터정의 (데이터 불러오기 및 확인)
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import fashion_mnist

(x_train, t_train), (x_test, t_test) = fashion_mnist.load_data()
print(학습데이터 정답)
print(테스트 데이터 정답)

import matplotlib.pyplot as plt
plt.figure(figsize = (6,6))
for index in range(25):
	plot.subplot(5, 5, index+1)
    plt.imshow(x_train[index], cmp='gray')
    plt.axis('off')
plt.show()

# 데이터 전처리
x_train = (x_train - 0.0) / (255.0 - 0.0) # 학습 데이터 정규화
x_test = (x_test - 0.0) / (255.0 - 0.0) # 테스트 데이터 정규화

# 모델 구축
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28))) # 입력층
model.add(tf.keras.layers.Dense(100, activation = 'relu')) # 은닉층
model.add(tf.keras.layers.Dense(10, activation = 'softmax')) # 출력층, softmax는 이미지 분류 활성화 함수

# 모델 컴파일
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3), loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 정답이 원핫 오프닝 방식으로 설정하지 않았으므로 loss에는 sparse_categorical_crossentropy가 온다.
model.summary()

# 모델 학습
hist = model.fit(x_train, t_train, epochs=30, validation_split=0.3)

# 모델 (정확도) 평가
model.evaluate(x_test, t_text)

# 손실 및 정확도 시각화
plt.title('Loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.grid()

plt.plot(hist.history['loss'], label='train loss')
plt.plot(hist.history['val_loss'], label = 'validation loss')
plt.legend(loc='best')
plt.show()

plt.title('Accuracy')
plt.xlabel('epchos')
plt.ylabel('accuracy')
plt.grid()

plt.plot(hist.history['accuracy'], label = 'train accuracy')
plt.plot(hist.history['val_accuracy'], label='validation accuracy')
plt.legend(loc='best')
plt.show()

# 혼동 행렬
from sklearn, metrics import confusion_matrix
import seaborn as sns

plt.figure(figsize=(6,6))
predicted_value = model.predict(x_test)
cm = confusion_matrix(t_test, np.argmax(predicted_value, axis=-1))
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
반응형