보통 pretrained된 가중치를 불러올 때 checkpoint로 불러온다.
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
1. 먼저 인스턴스 모델을 가져와준다.
2. checkpoint로 pretrained된 가중치 = state_dict을 넣어준다.
3. 이제 이 state_dict을 인스턴스 모델에 넣어준다.(load_state_dict)
그런데 이 때, 오류가 날 때가 있다.
내 추측으론 pretrained에 있는 가중치나 하이퍼 파라미터가 인스턴스 모델의 state_dict엔 없는데 그대로 가져다 넣을려니 오류가 생기는 것 같다.
그래서
import torch
checkpoint=torch.load('학습된 매개변수 경로')
p = checkpoint['state_dict']
instance_dict=model.state_dict()
new_model_dict={}
for k,v in p.items():
if k in instance_dict:
new_model_dict[k] = v
instance_dict.update(new_model_dict)
model.load_state_dict(instance_dict)
로 해결한다.
학습된 state_dict에서 key를 하나씩 꺼내다가 그것이 인스턴스 모델 안에 있으면, 그 값을 인스턴스 모델 state_dict에 넣는다.
그 후 model에 load해주면 해결된다.
'AI > AI노하우' 카테고리의 다른 글
Pytorch(tutorial)- MNIST 데이터 분류 (0) | 2021.10.04 |
---|
댓글