본문 바로가기

머신러닝공부

PyTorch_Tutorial_Save and Load The Model 파이토치 공식사이트

728x90
반응형

Save and Load the Model (모델 저장하고 불러오기)

In this section we will look at how the persist model state with saving, loading and running model predictions.

import torch
improt torchvision.models as models

Saving and Loading Model Weights

PyTorch models store the learned parameters in an internal state dictionary, called state_dict.

These can be persisted via the torch.save method:

model = models.vgg16(pretrained=True)
torch.svae(model.state_dict(), 'model_weights.pth')
/opt/conda/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning:

The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.

/opt/conda/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning:

Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  1%|1         | 6.94M/528M [00:00<00:07, 72.7MB/s]
  5%|4         | 23.8M/528M [00:00<00:03, 134MB/s]
  9%|9         | 48.9M/528M [00:00<00:02, 193MB/s]
 15%|#4        | 76.6M/528M [00:00<00:02, 231MB/s]
 20%|#9        | 104M/528M [00:00<00:01, 251MB/s]
 25%|##4       | 131M/528M [00:00<00:01, 263MB/s]
 30%|###       | 159M/528M [00:00<00:01, 273MB/s]
 35%|###5      | 187M/528M [00:00<00:01, 278MB/s]
 40%|####      | 214M/528M [00:00<00:01, 279MB/s]
 46%|####5     | 241M/528M [00:01<00:01, 283MB/s]
 51%|#####     | 269M/528M [00:01<00:00, 285MB/s]
 56%|#####6    | 296M/528M [00:01<00:00, 285MB/s]
 61%|######1   | 324M/528M [00:01<00:00, 286MB/s]
 67%|######6   | 351M/528M [00:01<00:00, 287MB/s]
 72%|#######1  | 379M/528M [00:01<00:00, 287MB/s]
 77%|#######6  | 406M/528M [00:01<00:00, 286MB/s]
 82%|########2 | 433M/528M [00:01<00:00, 286MB/s]
 87%|########7 | 461M/528M [00:01<00:00, 287MB/s]
 92%|#########2| 488M/528M [00:01<00:00, 287MB/s]
 98%|#########7| 516M/528M [00:02<00:00, 287MB/s]
100%|##########| 528M/528M [00:02<00:00, 270MB/s]

To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.

model = models.vgg16 # we do not specify pretained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Note:

be sure to call model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Falling to do this will yield inconsistent inference result.

 

Saving and Loading Models with Shapes

When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. We might want to save the structrue of this class together with the model, in which case we can pass model(and not model.state_dict()) to the saving function:

torch.save(model, 'model.pth')

We can then load the model like this:

model = torch.load('model.pth')

NOTE:

This approach uses Pytho pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.

반응형