본문 바로가기

머신러닝공부

Callback 함수 간단 정리

728x90
반응형

개념

Callback 함수는 1. 특정 상황에서 실행되는 함수를 시스템에 먼저 등록해두면 2. 해당 상황이 발생했을 때 등록되어 있는 함수가 실행되고 3. 시스템에서의 결과를 통해서 개발자는 등록된 콜백 함수가 실행된 것을 알 수 있다.

즉, 콜백 함수라는 것은 특정상황에서 실행될 함수를 시스템에 등록해두면, 그러한 상황이 발생했을때 시스템이 자동으로 실행해주는 함수라고 볼 수 있다. 그래서 TensorFlow에서도 다양한 콜백 함수가 있는데, TensorFlow 콜백 함수는 모델의 학습 방향, 저장 시점 그리고 학습 정지 시점 등에 관한 상황을 모니터링 하기 위해 주로 사용됨

예를들어, 학습 도중에 학습율(learning_rate)을 변화시키거나 일정시간이 지나도 검증데이터(validation data), 손실값(val_loss)이 개선되지 않으면 학습을 멈추게 하는 등의 작업을 할 수 있는데, 이러한 기능을 수행하는 TensorFlow의 대표적인 콜백 함수는

1. 학습 중에 학습율을 변화시킬 수 있는 ReduceLROnPlateau

2. 모델의 가중치(Weight) 값을 중간에 저장할 수 있는 ModelCheckpoint

3. 모델 성능 지표가 일정 시간 동안 개선되지 않을 때 조기 종료 할 수 있게 하는 EarlyStopping 등이 있다.

 

학습율 조정 : ReduceLROnPlateau

모델의 성능 개선이 없을 경우, 학습율(Learning Rate)를 조절해서 모델의 개선을 유도하는 콜백 함수로서 factor 파라미터를 통해 학습율을 조정함(factor<1.0)

from tensorflow.keras.callbacks import ReduceLROnPlateau
reduceLR = ReduceLROnPlateau(monitor = 'val_loss', # val_loss 기준으로 callback 호출
			factor = 0.5, # callback 호출시 학습률을 1/2로 줄임
                     	patience = 5, # epoch 5동안 계선되지 않으면 callback 호출
                       	verbose = 1) # 로그 출력
hist = model.fit(x_train, t_train, epochs=50, validation_split=0.2, callbacks=[reduceLR]) # 콜백함수를 시스템에 등록

모델 가중치 중간 저장 - ModelCheckpoint

모델이 학습 도중에 조건을 만족했을 때 현재의 가중치(weight)를 중간 저장함

학습 시간이 오래 걸린다면, 혹시 중간에 memory overflow나 crash가 나더라도 다시 가중치를 불러와서 학습을 이어나갈 수 있기 때문에, 이러한 중간 저장 기능은 학습시간을 단축시킬 수 있음

from tensorflow.keras.callbacks import ModelCheckpoint

file_path = '.modelchpoint_test.h5' # 저장할 file path
checkpoint = ModelCheckpoint(file_path, # 저장할 file path
			monitor = 'val_loss', # val_loss값이 개선 되었을때 호출
			verbose = 1, # log 출력
                     	save_best_only = True # best 값만 저장
                       	mode = 'auto') # auto는 자동으로 best를 찾음
hist = model.fit(x_train, t_train, epochs=50, validation_split=0.2, callbacks=[checkpoint]) # 콜백함수를 시스템에 등록

학습 조기 종료 : EarlyStopping

모델 성능 지표가 우리가 설정한 epoch동안 개선되지 않았을 때 조기종료할 수 있음

일반적으로 EarlyStopping과 ModelCheckpoint 조항을 통해서, 개선되지 않는 학습에 대한 조기 종료를 실행하고, ModelCheckpoint로 부터 가장 best model을 다시 로드하여 학습을 재게하는 경우가 일반적임

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

file_path = '.modelchpoint_test.h5' # 저장할 file path
checkpoint = ModelCheckpoint(file_path, # 저장할 file path
			monitor = 'val_loss', # val_loss값이 개선 되었을때 호출
			verbose = 1, # log 출력
                     	save_best_only = True # best 값만 저장
                       	mode = 'auto') # auto는 자동으로 best를 찾음
                        
# 콜백함수 정의
stopping = EarlyStopping(monitor = 'val_loss', # 관찰대상은 val_loss
			patience = 5) # 5 epoch 동안 개선되지 않으면 조기종료
                    
hist = model.fit(x_train, t_train, epochs=50, validation_split=0.2, callbacks=[checkpoint, stopping]) # 콜백함수를 시스템에 등록

반응형