Pytroch模型的加载
Pytorch
模型的加载流程:
1 2 3 4 5 6
| model = HRNet() model_file = 'state_8.pth' checkpoint_dict = torch.load(model_state_file) model.load_state_dict(checkpoint_dict) model.cuda() out = model(input)
|
这样加载模型一般用于测试,会加载模型的所有参数,但是如果需要加载某一层的参数就行不通了,需要后续操作。
Pytorch模型部分参数加载
1 2 3 4 5 6 7 8 9 10 11
| model_file = 'state_8.pth' if os.path.exists(model_feature): pre_dict = torch.load(model_feature) pre_dict = pre_dict['state_dict'] need_param = {} for name,m in pre_dict.items(): if name.split('.')[0] == 'feature_global': need_param[name] = m self.load_state_dict(need_param,strict=False)
|
总结来说,就是根据层名称来过滤,这里有一个trick就是可以把自己需要改的模型层写到容器nn.Sequence()
就可以批量加载,例如上述feature_global
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| self.feature_global = nn.Sequential( nn.Conv2d(256, 512, 6, 2), nn.BatchNorm2d(512), nn.MaxPool2d(3, 2), nn.ReLU(inplace=True), nn.Conv2d(512, 384, 3, 1), nn.BatchNorm2d(384), nn.ReLU(inplace=True), nn.Conv2d(384, 384, 3, 1), nn.BatchNorm2d(384), nn.ReLU(inplace=True), nn.Conv2d(384, 256, 3, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3, 1))
|
这样便可加载任意层参数,根据层名称加载就有个要求不能重名。
以上就是Pytorch模型的加载,后续有其他问题再更新。