AI/AI노하우2 Pytorch(tutorial)- MNIST 데이터 분류 CNN을 하면서 MNIST 데이터셋은 거즘 튜토리얼 같은 것이다. 이전에도 MNIST로 몇 번 공부해 봤는데, 현재 공부하고 있는 책인 '파이토치 딥러닝 프로젝트 모음집'에서도 튜토리얼로 나와있어 복습할 겸 다시 따라해봤다. 1. 모듈을 불러온다. import torch import torch.nn as nn # 딥러닝 네트워크의 기본구성 요소를 포함한 모듈 import torch.nn.functional as F import torch.optim as optim # 가중치 추정에 필요한 최적화 알고리즘 from torchvision import datasets, transforms from matplotlib import pyplot as plt %matplotlib inline 2. 분석환경을 설정한.. 2021. 10. 4. pretrained된 가중치 불러오기 보통 pretrained된 가중치를 불러올 때 checkpoint로 불러온다. model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train() 1. 먼저 인스턴스 모델을 가져와준다. 2. checkpo.. 2021. 9. 1. 이전 1 다음