안녕하세요.

이번 포스팅에서 소개해드릴 내용은 신경망 모델의 warm-starting 방법입니다.

Model warm-starting from old one after model reconstruction


warm-starting이란 처음부터 시작하지 않는다는 거죠, 완전하게 task에 맞추어 훈련이 되지는 않았지만 모델의 파라미터의 일부가 유의미한 정보(trained)를 가지고 있는 채로 훈련을 시작하는 것을 의미합니다. 일반적으로 기존 모델에 추가할 수 있는 새로운 정보가 있다면, 새로운 모듈을 통해 그 정보를 기존 모델에 통합해야합니다. 모델에 새로운 모델을 덧붙인다고 표현할 수 있습니다. 예를 들어, 기존 모델 A에 새로운 모듈 α를 추가하여 새로운 모델 B ( = A + α) 구축한 후, B 훈련 시 기존 모델 A의 parameter는 A 모델의 checkpoint로 로드하고, α 부분은 랜덤 초기화시키는 방법입니다.


새로운 모델을 추가한 후! 만약 모델이 너무 크거나, 훈련 시간이 오래 걸리거나, 추가한 모듈이 크면 처음부터(from scratch) 훈련하는 것이 부담스럽습니다. 시간, GPU 제한 등등 (하지만 요즘은 무료로 GPU 자원을 제공해주는 곳도 많기에 시간을 절약한다고 보는 것이 더 맞을 수 있습니다.) 예를 들어, Transformer 에 BERT 모델을 덧붙인다(?)고 가정하면 Transformer에 BERT 모델의 final hidden state를 끼워넣기 위해 새로운 모듈을 추가해야 합니다. 생각해보면 base Transformer (6 encoder layer + 6 decoder layer, model dim 512)와 base BERT (12 encoder layer, model dim 768)을 합쳐서 fine-tuning 한다는 말인데 제가 사용할 수 있는 GPU(V100, P100, GTX 1080 TI)에서는 바로 OOM 에러가 발생합니다. 그럼 BERT를 freezing한다고 가정하면, OOM 에러는 없이 training이 가능합니다. 하지만 Transformer에 BERT가 붙었다면 최소 2배 이상 훈련 속도가 떨어집니다.
그럼 대략적으로 훈련 시간을 계산해보겠습니다. vanillar Transformer라면 1일 정도면 50 epoch까지 훈련할 수 있지만, (Transformer + BERT)를 epoch 0부터 모델을 훈련하여 모델이 수렴할 정도 (epoch 50 ~100)까지 훈련한다하면 최소 2~3일(epoch 100 훈련하려면 5일 이상) 정도가 걸립니다. (실제 모델 훈련 경험에서 기다림으로 체득한 정보입니다.. GPU P100, 훈련 데이터 160M, 참고로 BERT 모델의 parameter는 freezing)


그럼 어떻게 warm-starting을 할 수 있을까요?

pytorch 공식 홈페이지에서 제공하는 모델 save/load는 다음과 같습니다. 

#Save:
torch.save(model.state_dict(), PATH)

#Load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

간단합니다. 모델의 Parameter를 dict으로 저장하고 model의 내부 network 선언명에 따라 Parameter를 load합니다.

예제를 통한 설명 전 실행 환경 정보입니다.


실행 환경 (Google colab pro)

Pytorch 1.7.1 cuda
Python 3.6.9
NAME="Ubuntu"
VERSION="18.04.5 LTS (Bionic Beaver)"
ID=ubuntu
ID_LIKE=debian
PRETTY_NAME="Ubuntu 18.04.5 LTS"
VERSION_ID="18.04"

그럼 작고 아름다운 모델을 구축하겠습니다.

[Code]

import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from  torch.nn import Parameter


# toy feed-forward net
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(10, 3)
        self.fc2 = nn.Linear(3, 3)
        # self.fc3 = nn.Linear(3, 3)        
        # self.fc4 = nn.Linear(3, 3)
        self.fc5 = nn.Linear(3, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        # x = self.fc3(x)
        # x = self.fc4(x)
        x = self.fc5(x)
        return x


net = Net()
net.cuda()


# define new random data (input&label)
random_input = Variable(torch.randn(10,)).cuda()
random_target = Variable(torch.randn(1,)).cuda()

# loss
criterion = nn.MSELoss()

# optimizer
optimizer = optim.Adam(net.parameters(), lr=0.1)


for epoch in range(1,10):
    net.zero_grad()

    if epoch == 1:
      print('fc1 parameter before model training (initializing)')
      print("net.fc1.weight: ", net.fc1.weight)
      print("net.fc1.bias: ", net.fc1.bias) 
                
    output = net(random_input)
    loss = criterion(output, random_target)
    loss.backward()
    optimizer.step()

print("*"*89)
print('fc1 parameter after model training (trained)')
print("net.fc1.weight: ", net.fc1.weight)
print("net.fc1.bias: ", net.fc1.bias) 


torch.save(net.state_dict(), "/content/net.pt")

3개의 fully connected layer로 구성된 모델입니다. 랜덤 데이터를 통해 10 epoch 동안 학습하였고, 출력된 것은 모델의 첫번째 fully connected layer (fc1)의 weight과 bias 입니다. (주의하실 점은 fc3, fc4를 추가하여 새로운 모델로 사용할 것이기 때문에 fc3과 fc4를 모델에서 주석 처리하여 뺏습니다.) 그리고 모델을 훈련 및 저장합니다. 그대로 코드를 실행하신다면 저장 루트를 바꿔주세요.

[Output]

fc1 parameter before model training (initializing)
net.fc1.weight:  Parameter containing:
tensor([[ 0.2413,  0.1984, -0.1557,  0.2183, -0.0026,  0.2063,  0.0356,  0.2401,
          0.0841,  0.0411],
        [ 0.2648,  0.0499,  0.1177,  0.0172, -0.0913, -0.1910,  0.1510,  0.0115,
         -0.2350, -0.0122],
        [-0.0457,  0.0920, -0.1096,  0.2145,  0.1477, -0.0229, -0.2680, -0.1000,
          0.2053, -0.2606]], device='cuda:0', requires_grad=True)
net.fc1.bias:  Parameter containing:
tensor([-0.0475, -0.2888,  0.0440], device='cuda:0', requires_grad=True)
*****************************************************************************************
fc1 parameter after model training (trained)
net.fc1.weight:  Parameter containing:
tensor([[-0.0752, -0.1181, -0.4722,  0.5348, -0.3191, -0.1102,  0.3521, -0.0764,
         -0.2324,  0.3576],
        [ 0.5753,  0.3604,  0.4282, -0.2933,  0.2192,  0.1195, -0.1595,  0.3220,
          0.0755, -0.3227],
        [ 0.1871,  0.3248,  0.1233, -0.0184,  0.3806,  0.2100, -0.5009,  0.1329,
          0.4382, -0.4934]], device='cuda:0', requires_grad=True)
net.fc1.bias:  Parameter containing:
tensor([ 0.2690, -0.5993, -0.1888], device='cuda:0', requires_grad=True)

fc1의 parameter를 출력했습니다.


저장된 모델을 출력해보겠습니다.

[Code]

torch.load("/content/net.pt")

[Output]

OrderedDict([('fc1.weight',
              tensor([[-0.0752, -0.1181, -0.4722,  0.5348, -0.3191, -0.1102,  0.3521, -0.0764,
                       -0.2324,  0.3576],
                      [ 0.5753,  0.3604,  0.4282, -0.2933,  0.2192,  0.1195, -0.1595,  0.3220,
                        0.0755, -0.3227],
                      [ 0.1871,  0.3248,  0.1233, -0.0184,  0.3806,  0.2100, -0.5009,  0.1329,
                        0.4382, -0.4934]], device='cuda:0')),
             ('fc1.bias',
              tensor([ 0.2690, -0.5993, -0.1888], device='cuda:0')),
             ('fc2.weight', tensor([[ 0.6975,  0.2921,  0.2655],
                      [-0.3196, -0.6429,  0.2746],
                      [ 0.2035,  0.1762,  0.4912]], device='cuda:0')),
             ('fc2.bias',
              tensor([-0.2143, -0.2884, -0.0638], device='cuda:0')),
             ('fc5.weight',
              tensor([[-0.0023, -0.2436, -0.1727]], device='cuda:0')),
             ('fc5.bias', tensor([0.8086], device='cuda:0'))])

fc1의 Parameter를 확인한 결과 잘 저장된 것을 알 수 있습니다.


그럼 본격적으로 Net 모델에 새로운 모델을 추가해보겠습니다. 아까 fc3과 fc4에 걸어뒀던 주석을 풀고 새로운 모델 new_net을 생성 후 아까 저장해둔 net 모델을 load 합니다.

[Code]

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(10, 3)
        self.fc2 = nn.Linear(3, 3)
        self.fc3 = nn.Linear(3, 3)        
        self.fc4 = nn.Linear(3, 3)
        self.fc5 = nn.Linear(3, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        x = self.fc5(x)
        return x

new_net = Net()
new_net.load_state_dict(torch.load("/content/net.pt"))

[Output]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-eeb026f923bc> in <module>()
     18 
     19 new_net = Net()
---> 20 new_net.load_state_dict(torch.load("/content/net.pt"))

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1050         if len(error_msgs) > 0:
   1051             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1052                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1053         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1054 

RuntimeError: Error(s) in loading state_dict for Net:
    Missing key(s) in state_dict: "fc3.weight", "fc3.bias", "fc4.weight", "fc4.bias". 

출력을 확인해보니, 

load시, 저장된 모델에서 fc3, fc4과 일치하는 파라미터를 찾을 수 없다는 에러가 뜹니다. (load_state_dict 함수는 업로드할 모델과 저장해둔 모델의 keys가 완전히 일치해야 load할 수 있습니다.)


그럼 fc3과 fc4는 그냥 알아서하고 fc1, fc4, fc5만 받아오면 안될까?

load_state_dict의 strict를 False로 설정하면, 이 문제를 해결할 수 있습니다. (error 없이 실행됩니다.) 저장된 모델에서 keys가 일치하는 부분만 받아오고 나머지는 랜덤 초기화를 진행합니다. load_state_dict의 strict를 False로 설정합니다.

[Code]

new_net.load_state_dict(torch.load("/content/net.pt"), strict=False)

[Output]

_IncompatibleKeys(missing_keys=['fc3.weight', 'fc3.bias', 'fc4.weight', 'fc4.bias'], unexpected_keys=[])

에러는 아니지만 호환되지 않는 key가 있다는 메시지가 출력됩니다.

그럼 new_net의 fc1과 fc3의 파라미터를 확인해보겠습니다.

[Code]

print("fc1 parameter after loading model from the old one")
print("net.fc1.weight: ", new_net.fc1.weight)
print("new_net.fc1.bias: ", new_net.fc1.bias) 
print("fc3 parameter after loading model from the old one (random initializing)")
print("new_net.fc3.weight: ", new_net.fc3.weight)
print("new_net.fc3.bias: ", new_net.fc3.bias) 

[Output]

fc1 parameter after loading model from the old one
net.fc1.weight:  Parameter containing:
tensor([[-0.0752, -0.1181, -0.4722,  0.5348, -0.3191, -0.1102,  0.3521, -0.0764,
         -0.2324,  0.3576],
        [ 0.5753,  0.3604,  0.4282, -0.2933,  0.2192,  0.1195, -0.1595,  0.3220,
          0.0755, -0.3227],
        [ 0.1871,  0.3248,  0.1233, -0.0184,  0.3806,  0.2100, -0.5009,  0.1329,
          0.4382, -0.4934]], requires_grad=True)
new_net.fc1.bias:  Parameter containing:
tensor([ 0.2690, -0.5993, -0.1888], requires_grad=True)
fc3 parameter after loading model from the old one (random initializing)
new_net.fc3.weight:  Parameter containing:
tensor([[-0.1801,  0.1856, -0.4097],
        [ 0.4498, -0.3422, -0.3599],
        [-0.4008, -0.2709, -0.4433]], requires_grad=True)
new_net.fc3.bias:  Parameter containing:
tensor([-0.0126,  0.0906, -0.1108], requires_grad=True)

보시다시피 fc1은 checkpoint로부터 load되었고, fc3는 초기화된 것을 확인할 수 있습니다. 원하는 대로 warm-starting을 할 수 있습니다!


만약 checkpoint에 저장되어있는 모델들의 keys ( 선언명 )이 다른데 그 파라미터로 모델을 초기화하고 싶다면?

방법은 두가지가 있습니다.

  1. checkpoint의 keys를 로드할 모델의 선언명들로 바꾸어준다.

  2. checkpoint를 읽어와 모델의 파라미터와 바이어스를 직접 수정한다.


오늘 포스팅은 여기까지입니다.

감사합니다.