avatar

Catalog
Pytorch加载部分模型参数

Pytroch模型的加载

Pytorch 模型的加载流程:

python
1
2
3
4
5
6
model = HRNet() # init model
model_file = 'state_8.pth'
checkpoint_dict = torch.load(model_state_file) # load model file
model.load_state_dict(checkpoint_dict) # load parameters
model.cuda()
out = model(input) # calculate results

这样加载模型一般用于测试,会加载模型的所有参数,但是如果需要加载某一层的参数就行不通了,需要后续操作。


Pytorch模型部分参数加载

python
1
2
3
4
5
6
7
8
9
10
11
model_file = 'state_8.pth' 
if os.path.exists(model_feature):
pre_dict = torch.load(model_feature)
pre_dict = pre_dict['state_dict'] # 这里与模型的保存有关系,['state_dict']表示参数
need_param = {} # 待加载的参数集合
for name,m in pre_dict.items(): # name为参数名称, m为参数
if name.split('.')[0] == 'feature_global': # 找到名称为feature_global的层加载
need_param[name] = m

# 加载过滤后的model strict表示不需要一一对应,找到对应的加载即可
self.load_state_dict(need_param,strict=False)

总结来说,就是根据层名称来过滤,这里有一个trick就是可以把自己需要改的模型层写到容器nn.Sequence()就可以批量加载,例如上述feature_global:

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
self.feature_global = nn.Sequential(
nn.Conv2d(256, 512, 6, 2),
nn.BatchNorm2d(512),
nn.MaxPool2d(3, 2),
nn.ReLU(inplace=True),
nn.Conv2d(512, 384, 3, 1),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.Conv2d(384, 384, 3, 1),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, 3, 1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1))

这样便可加载任意层参数,根据层名称加载就有个要求不能重名


以上就是Pytorch模型的加载,后续有其他问题再更新。

Author: 星星泡饭
Link: https://luoyou.art/2020/06/16/Pytorch%E5%8A%A0%E8%BD%BD%E9%83%A8%E5%88%86%E6%A8%A1%E5%9E%8B%E5%8F%82%E6%95%B0/
Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.
Donate
  • 微信
    微信
  • 支付寶
    支付寶

Comment