python怎么使用预训练的模型_PyTorch使用预训练模型-程序员宅基地

技术标签: python怎么使用预训练的模型  

PyTorch模型加载的时候,有预训练模型,通过使用预训练模型可以给模型使用带来很多的便捷,对于模型的使用以下给出了一些总结,如有错误恳请指正。

一、直接加载预训练模型进行训练

1、加载保存的整个模型

torch.save(model,'model.pkl')

...

model = torch.load('model.pkl')

2、加载保存的模型参数

torch.save(model.state_dict(),'model_state_dict.pkl')

...

model.load_state_dict(torch.load('model_state_dict.pkl'))

关于模型的保存和加载,可以详细参照我的这篇文章:HUST小菜鸡:Pytorch搭建简单神经网络(三)——快速搭建、保存与提取​zhuanlan.zhihu.comv2-5aed9b4858ee329f1de0b9d5ff33ce4a_180x120.jpg

通过对模型参数的保存的解析,我们可以深入的了解

load_dict = torch.load('models/cifar10_statedict.pkl')

print(load_dict.keys())

print(type(load_dict))

输出的结果如下所示:

odict_keys(['conv1.0.weight', 'conv1.0.bias', 'conv2.0.weight', 'conv2.0.bias', 'conv3.0.weight', 'conv3.0.bias', 'conv4.0.weight', 'conv4.0.bias', 'conv5.0.weight', 'conv5.0.bias', 'conv6.0.weight', 'conv6.0.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.5.weight', 'classifier.5.bias'])

可以看出保存的state_dict其实是一个collections.OrderedDict的Object,和普通的dict不同的是,该类别是有着严格的顺序,而dict中的元素是没有严格的顺序。

但是有一个问题值得深入考量——两个网络的结构是一样的,但是结构的命名是不一样的,那么对于这种模型的加载,如果不一样的话会出现报错,该如何解决

参照以上结果的输出,state_dict中key就是网络结构的名称,所以当网络结构一样的时候,只需要修改索引key,就可以解决以上的问题,至于如何修改可以参照如下方式:https://stackoverflow.com/questions/12150872/change-key-in-ordereddict-without-losing-order​stackoverflow.com

二、加载部分预训练模型

我们经常对现有的经典网络进行如下操作,我们不修改网络的主体部分,我们只修改网络的输出,或者在最后加上一些网络层来达到我们想要的输出结果,虽然很难保证网络模型和某些公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

model = cifar10_cnn.CIFAR10_Nettest()

pretrained_dict = torch.load('models/cifar10_statedict.pkl')

model_dict = model.state_dict()

print('随机初始化权重第一层:',model_dict['conv1.0.weight'])

# 将pretrained_dict里不属于model_dict的键剔除掉

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

print('预训练权重第一层:',pretrained_dict['conv1.0.weight'])

# 更新现有的model_dict

model_dict.update(pretrained_dict) #利用预训练模型的参数,更新模型

model.load_state_dict(model_dict)

print('更新后权重第一层:',model_dict['conv1.0.weight'])

输出的部分结果如下所示,为了直观显示我只截取了中间的某一部分

随机初始化权重第一层: tensor([[[[ 0.0142, 0.1039, 0.1260],

[ 0.1805, -0.0533, 0.0007],

[-0.1032, -0.1039, -0.0633]],

[[ 0.0714, -0.0053, 0.0059],

[-0.0528, 0.0438, -0.1108],

[ 0.0544, 0.0157, 0.1265]],

预训练权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],

[-2.3942e-01, -1.5474e-01, 1.3142e-01],

[-9.4602e-02, 6.4120e-02, -9.4336e-02]],

[[ 9.7318e-02, 1.0526e-01, 2.3400e-03],

[-5.8471e-02, -8.8146e-02, -1.6053e-01],

[-1.0788e-01, -5.9083e-02, -9.0651e-02]],

更新后权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],

[-2.3942e-01,

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_39567169/article/details/109940576

智能推荐

mysql load character_mysql load data Invalid utf8mb4 character string: ”-程序员宅基地

文章浏览阅读3.2k次。使用mysql的 load data 导入数据到 数据库中:LOAD DATA LOCAL INFILE '/tmp/2982/20200424/user.csv'INTO TABLE t_user CHARACTER SET utf8mb4 FIELDS TERMINATED BY ','LINES TERMINATED BY '\r\n'IGNORE 1 LINES(userName, use..._invalid utf8mb4 character string:

Java中的Map解析:探索键值对的奇妙世界_map键值对-程序员宅基地

文章浏览阅读543次。本文将深入解析Java中的Map接口及其常见实现类,详细介绍基本概念、常用方法、不同实现类的特点以及适用场景,帮助读者全面了解和灵活应用Map在Java编程中的威力。通过了解Map的基本概念、常用方法和不同实现类的特点,读者可以更加全面地理解和应用Map在Java编程中的重要性。因为TreeMap是有序的,所以在需要按照键的顺序访问的场景下非常有用。在Java中,有多个常见的Map实现类,每个实现类都有自己的特点和适用场景。根据具体的需求和场景,可以选择合适的Map实现类和相应的操作方法。_map键值对

编程实现下列功能:假设以两个元素依值非递减有序排列的顺序表A和B 分别表示两个集合(同一表中的元素值各不相同),求一个新的集合C=A-B,且表C中的元素也是依值递增有序排列。_建立两个按数据元素值非递减有序排列的线性表a和b,均以顺序表作为存储结构-程序员宅基地

文章浏览阅读3.1k次。编程实现下列功能:假设以两个元素依值非递减有序排列的顺序表A和B 分别表示两个集合(同一表中的元素值各不相同),求一个新的集合C=A-B,且表C中的元素也是依值递增有序排列。# include <stdio.h># include <stdlib.h># define initsize 20//初始分配量# define LISTINCREMENT 5//分配增量..._建立两个按数据元素值非递减有序排列的线性表a和b,均以顺序表作为存储结构

java readutf()方法,java socket writeUTF()和readUTF()-程序员宅基地

文章浏览阅读207次。I've been reading some Java socket code snippet and fonund out a fact that in socket communication, to send messages in sequence, you don't have to seperate them by hand, the writer/reader stream do t..._java socket writeutf

ssh远程登录_ssh登录命令-程序员宅基地

文章浏览阅读3.3k次,点赞2次,收藏7次。l 指定登录名称;基于口令认证时,使用强密码策略,比如:tr -dc A-Za-z0-9_ < /dev/urandom | head -c 12| xargs。服务端得到客户端的请求后,会到authorized_keys()中查找,如果有响应的IP和用户,就会随机生成一个字符串,例如:kgc。3)客户端得到服务端的信息后,通过算法生成密钥,结合自己的公钥生成密钥对,然后将密钥对发送给服务端。5)最后,客户端拥有自己的公钥和私钥以及服务端的公钥,服务端拥有自己的公钥和私钥以及客户端的公钥。_ssh登录命令

C神奇操作(1)--static inline_static inline的优点-程序员宅基地

文章浏览阅读258次。C神奇操作(1)_static inline的优点

随便推点

国产化服务器内网安装onlyoffice_内网部署onlyoffice社区版-程序员宅基地

文章浏览阅读6k次,点赞4次,收藏30次。国产化服务器安装onlyoffice_内网部署onlyoffice社区版

基于Ardupilot/PX4固件,APM/PIXhawk硬件的VTOL垂直起降固定翼软硬件参数调试(第三篇)故障保护及问题诊断_ekf故障安全-程序员宅基地

文章浏览阅读1.4w次,点赞12次,收藏134次。基于Ardupilot/PX4固件,APM/PIXhawk硬件的VTOL垂直起降固定翼软硬件参数调试(第三篇)故障保护及问题诊断PIX无法安装驱动双击下载的px4_driver_installer_v10_win.exe驱动安装文件,Pixhawk驱动下载(点击即可下载):http://www.inf.ethz.ch/personal/lomeier/downloads/px4_driver..._ekf故障安全

python opencv 实现图片,视频 转 字符/字符画/字符视频_python opencv字模转换-程序员宅基地

文章浏览阅读1.5k次,点赞2次,收藏19次。基于python 3.8图片转字符 升级版,顺便加了个 GUI地址:待定界面及使用图片转换时,界面会卡顿,表现为按钮按下去不会回弹,正常现象转换完成的图片视频在软件根目录视频转换请勿使用高分辨率,速度太慢,当然输入分辨率越高,转换后的分辨率也高参考 : i5-7200u 实测 500x300 mp4, 每秒只能处理 1.5 帧效果(这里只能上图,视频就不做展示,效果参考图片)- 原图-图片转 txt图片转 txt文本- 图片转指定字符-jpg彩色同一个字符,通过颜色的变换展_python opencv字模转换

零基础入门微信小程序开发-程序员宅基地

文章浏览阅读10w+次,点赞763次,收藏5.8k次。本课程是一个系列入门教程,目标是从 0 开始带领读者上手实战,课程以微信小程序的核心概念作为主线,介绍配置文件、页面样式文件、JavaScript 的基本知识并以指南针为例对基本知识进行扩展,另外加上开发工具的安装、小程序发布等内容,共 9 篇文章。_零基础学习微信小程序

SSM框架反向自动生成Mapper等_反向生成mapper-程序员宅基地

文章浏览阅读1.5k次。一、在pom文件中加入插件 web-ssm org.apache.tomcat.maven_反向生成mapper

使用Ionic.Zip压缩、分卷压缩、解压文件-程序员宅基地

文章浏览阅读1.8k次。分卷压缩使用方法: int iMinLength = 100;//最小压缩包单位(100M) ZipHelper.Compress("../../../xxx.xx", "../../../xx.zip", ZipDataUnit.MB, iMinLength);//分卷压缩解压使用方法 : string strUnZipPath = Environment.CurrentDirectory + "\\DeCompress"; _ionic.zip

推荐文章

热门文章

相关标签