사용의 이유

딥러닝에서 가장 흔하게 사용되는 단어인 과적합(Overfitting)을 해결하기위해 나온 방법입니다.

보통은 너무 Train 데이터 셋에 훈련되는 것을 말하는데, Loss를 시각화해서 살펴보면 확실하게 알 수 있습니다. 또한, 저희가 사용하는 YOLO와 같은 곳에서도 어느 특정 부분을 넘어가면 Loss가 일정하게 나오는 것을 볼 수 있는데 이럴경우 계속해서 훈련을 진행하는 것이 의미가 없기도하고, 지속되면 과적합이 발생하기 때문에 모델 성능을 오히려 저하시키는 원인이 될 수 있습니다. 따라서 제작자가 진행할 수 있는한 최대의 에폭시를 사용해서 훈련하라고 할때에도 Early Stopping과 같은 Regularization 기법을 계속해서 적용하여야 합니다.

사실 이러한 과적합을 제거하기 위해 엄청나게 많은 연구들이 지속되고 있는데 가장 대표적으로 Backbone Network 내에서 사용하는 방법 중에서는 Dropout이 존재하고, Network 밖에서는 보통 Early stopping을 사용합니다. 이는 네트워크에는 영향을 주지 않아서 성능향상에는 도움을 주지는 않으나, 성능이 나빠지는 것을 방지할 수 있습니다.

성능을 높였는데 다시 낮아지는 것을 방지하지 못한다면 의미가 없어지기때문에 적절하게 잘 사용하는 것이 중요합니다.

과적합? Overfitting?

과적합은 모델이 너무 과도하게 훈련데이터에만 집착하는 것입니다. 너무 적은 데이터를 사용해서 훈련을 진행하거나, 너무 많은 에폭시를 돌렸거나, 네트워크를 데이터에 비해 너무 과도하게 크게 사용하거나 등의 이유로 다양하게 발생합니다. 과적합이 발생했다는 것은 쉽게 발견할 수 있는데 이는 이미지를 통해서 이해할 수 있습니다.

Untitled

Overfitting의 비교

Overfitting의 비교

위의 이미지에서 특정 지점을 잡아 갑작스럽게 Test Error 율이 떨어지지 않고, 증가하는 것을 볼 수 있습니다. 아래의 이미지에서 최적의 해를 찾은 것과 오버피팅, 언더피팅의 차이를 이해할 수 있습니다.

이를 해결하는 것은 모델을 적합하게 사용하거나, 적절한 에폭시를 사용하거나, 엄청나게 다양하고 많은 데이터를 사용하거나 등등의 방법이 존재하는데 이 글에서는 적절한 에폭시를 사용하도록 만들어주는 Early Stopping에 대해 알아봅니다.

Early Stopping Code 분석

Early Stopping은 다양한 방법으로 사용될 수 있으나, YOLO에서는 손 쉽게 옵션을 설정해서 사용할 수 있습니다.

하지만 코드를 알아두면 어떤 Detector든지, 어떠한 모델이든지에서 사용이 가능하기 때문에 가져왔습니다.

기본적으로 Pytorchtools 라고하는 패키지를 가져와서 사용할 수 있습니다.

import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss