본문 바로가기
AI/AI노하우

pretrained된 가중치 불러오기

by lucian 2021. 9. 1.

보통 pretrained된 가중치를 불러올 때 checkpoint로 불러온다.

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

1. 먼저 인스턴스 모델을 가져와준다.

2. checkpoint로 pretrained된 가중치 = state_dict을 넣어준다.

3. 이제 이 state_dict을 인스턴스 모델에 넣어준다.(load_state_dict)

 

그런데 이 때, 오류가 날 때가 있다.

내 추측으론 pretrained에 있는 가중치나 하이퍼 파라미터가 인스턴스 모델의 state_dict엔 없는데 그대로 가져다 넣을려니 오류가 생기는 것 같다.

 

그래서 

import torch
checkpoint=torch.load('학습된 매개변수 경로')
p = checkpoint['state_dict']
instance_dict=model.state_dict()

new_model_dict={}

for k,v in p.items():
    if k in instance_dict:
        new_model_dict[k] = v
instance_dict.update(new_model_dict)
model.load_state_dict(instance_dict)

로 해결한다.

학습된 state_dict에서 key를 하나씩 꺼내다가 그것이 인스턴스 모델 안에 있으면, 그 값을 인스턴스 모델 state_dict에 넣는다.

그 후 model에 load해주면 해결된다.

'AI > AI노하우' 카테고리의 다른 글

Pytorch(tutorial)- MNIST 데이터 분류  (0) 2021.10.04

댓글