|
-
- <!DOCTYPE html>
-
- <html lang="zh">
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />
-
- <title>5.4 PyTorch模型保存与读取 — 深入浅出PyTorch</title>
-
- <!-- Loaded before other Sphinx assets -->
- <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
- <link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
-
-
- <link rel="stylesheet"
- href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
- <link rel="preload" as="font" type="font/woff2" crossorigin
- href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
- <link rel="preload" as="font" type="font/woff2" crossorigin
- href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
-
- <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
- <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
- <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
- <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
- <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
-
- <!-- Pre-loaded scripts that we'll load fully later -->
- <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">
-
- <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
- <script src="../_static/jquery.js"></script>
- <script src="../_static/underscore.js"></script>
- <script src="../_static/doctools.js"></script>
- <script>let toggleHintShow = 'Click to show';</script>
- <script>let toggleHintHide = 'Click to hide';</script>
- <script>let toggleOpenOnPrint = 'true';</script>
- <script src="../_static/togglebutton.js"></script>
- <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
- <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
- <link rel="index" title="索引" href="../genindex.html" />
- <link rel="search" title="搜索" href="../search.html" />
- <link rel="next" title="第六章:PyTorch进阶训练技巧" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html" />
- <link rel="prev" title="5.3 PyTorch修改模型" href="5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html" />
- <meta name="viewport" content="width=device-width, initial-scale=1" />
- <meta name="docsearch:language" content="zh">
-
-
- <!-- Google Analytics -->
-
- </head>
- <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
- <!-- Checkboxes to toggle the left sidebar -->
- <input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
- <label class="overlay overlay-navbar" for="__navigation">
- <div class="visually-hidden">Toggle navigation sidebar</div>
- </label>
- <!-- Checkboxes to toggle the in-page toc -->
- <input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
- <label class="overlay overlay-pagetoc" for="__page-toc">
- <div class="visually-hidden">Toggle in-page Table of Contents</div>
- </label>
- <!-- Headers at the top -->
- <div class="announcement header-item noprint"></div>
- <div class="header header-item noprint"></div>
-
-
- <div class="container-fluid" id="banner"></div>
-
-
-
- <div class="container-xl">
- <div class="row">
-
- <!-- Sidebar -->
- <div class="bd-sidebar noprint" id="site-navigation">
- <div class="bd-sidebar__content">
- <div class="bd-sidebar__top"><div class="navbar-brand-box">
- <a class="navbar-brand text-wrap" href="../index.html">
-
-
-
- <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
-
- </a>
- </div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
- <i class="icon fas fa-search"></i>
- <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
- </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
- <div class="bd-toc-item active">
- <p aria-level="2" class="caption" role="heading">
- <span class="caption-text">
- 目录
- </span>
- </p>
- <ul class="current nav bd-sidenav">
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
- 第一章:PyTorch的简介和安装
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
- <label for="toctree-checkbox-1">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
- 1.1 PyTorch简介
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
- 1.2 PyTorch的安装
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
- 1.3 PyTorch相关资源
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
- 第二章:PyTorch基础知识
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
- <label for="toctree-checkbox-2">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
- 2.1 张量
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
- 2.2 自动求导
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
- 2.3 并行计算简介
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/index.html">
- 第三章:PyTorch的主要组成模块
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
- <label for="toctree-checkbox-3">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
- 3.1 思考:完成深度学习的必要部分
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
- 3.2 基本配置
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
- 3.3 数据读入
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.4%20%E6%A8%A1%E5%9E%8B%E6%9E%84%E5%BB%BA.html">
- 3.4 模型构建
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
- 3.5 模型初始化
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
- 3.6 损失函数
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
- 3.7 训练和评估
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
- 3.8 可视化
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.9%20%E4%BC%98%E5%8C%96%E5%99%A8.html">
- 3.9 Pytorch优化器
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html">
- 第四章:PyTorch基础实战
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
- <label for="toctree-checkbox-4">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/%E5%9F%BA%E7%A1%80%E5%AE%9E%E6%88%98%E2%80%94%E2%80%94FashionMNIST%E6%97%B6%E8%A3%85%E5%88%86%E7%B1%BB.html">
- 基础实战——FashionMNIST时装分类
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 current active has-children">
- <a class="reference internal" href="index.html">
- 第五章:PyTorch模型定义
- </a>
- <input checked="" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
- <label for="toctree-checkbox-5">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul class="current">
- <li class="toctree-l2">
- <a class="reference internal" href="5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
- 5.1 PyTorch模型定义的方式
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
- 5.2 利用模型块快速搭建复杂网络
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
- 5.3 PyTorch修改模型
- </a>
- </li>
- <li class="toctree-l2 current active">
- <a class="current reference internal" href="#">
- 5.4 PyTorch模型保存与读取
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
- 第六章:PyTorch进阶训练技巧
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
- <label for="toctree-checkbox-6">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
- 6.1 自定义损失函数
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
- 6.2 动态调整学习率
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
- 6.3 模型微调-torchvision
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
- 6.3 模型微调 - timm
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
- 6.4 半精度训练
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
- 6.5 数据增强-imgaug
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
- 6.6 使用argparse进行调参
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.7%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E4%B8%8E%E8%BF%9B%E9%98%B6%E8%AE%AD%E7%BB%83%E6%8A%80%E5%B7%A7.html">
- PyTorch模型定义与进阶训练技巧
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
- 第七章:PyTorch可视化
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
- <label for="toctree-checkbox-7">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
- 7.1 可视化网络结构
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
- 7.2 CNN可视化
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
- 7.3 使用TensorBoard可视化训练过程
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
- 第八章:PyTorch生态简介
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
- <label for="toctree-checkbox-8">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
- 8.1 本章简介
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
- 8.2 torchvision
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
- 8.3 PyTorchVideo简介
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
- 8.4 torchtext简介
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/transforms%E5%AE%9E%E6%93%8D.html">
- transforms实战
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/index.html">
- 第九章:PyTorch的模型部署
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/>
- <label for="toctree-checkbox-9">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/9.1%20%E4%BD%BF%E7%94%A8ONNX%E8%BF%9B%E8%A1%8C%E9%83%A8%E7%BD%B2%E5%B9%B6%E6%8E%A8%E7%90%86.html">
- 9.1 使用ONNX进行部署并推理
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/index.html">
- 第十章:常见代码解读
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/>
- <label for="toctree-checkbox-10">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/10.1%20%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
- 9.1 图像分类(补充中)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/10.2%20%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B.html">
- 目标检测
- </a>
- </li>
- </ul>
- </li>
- </ul>
-
- </div>
- </nav></div>
- <div class="bd-sidebar__bottom">
- <!-- To handle the deprecated key -->
-
- <div class="navbar_extra_footer">
- Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
- </div>
-
- </div>
- </div>
- <div id="rtd-footer-container"></div>
- </div>
-
-
-
-
-
-
- <!-- A tiny helper pixel to detect if we've scrolled -->
- <div class="sbt-scroll-pixel-helper"></div>
- <!-- Main content -->
- <div class="col py-0 content-container">
-
- <div class="header-article row sticky-top noprint">
-
-
-
-
- <div class="col py-1 d-flex header-article-main">
- <div class="header-article__left">
-
- <label for="__navigation"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="right"
- title="Toggle navigation"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-bars"></i>
- </span>
-
- </label>
-
-
- </div>
- <div class="header-article__right">
- <button onclick="toggleFullScreen()"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="bottom"
- title="Fullscreen mode"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-expand"></i>
- </span>
-
- </button>
-
- <div class="menu-dropdown menu-dropdown-repository-buttons">
- <button class="headerbtn menu-dropdown__trigger"
- aria-label="Source repositories">
- <i class="fab fa-github"></i>
- </button>
- <div class="menu-dropdown__content">
- <ul>
- <li>
- <a href="https://github.com/datawhalechina/thorough-pytorch"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Source repository"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fab fa-github"></i>
- </span>
- <span class="headerbtn__text-container">repository</span>
- </a>
-
- </li>
-
- <li>
- <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第五章/5.4 PyTorh模型保存与读取.html&body=Your%20issue%20content%20here."
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Open an issue"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-lightbulb"></i>
- </span>
- <span class="headerbtn__text-container">open issue</span>
- </a>
-
- </li>
-
- <li>
- <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第五章/5.4 PyTorh模型保存与读取.md"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Edit this page"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-pencil-alt"></i>
- </span>
- <span class="headerbtn__text-container">suggest edit</span>
- </a>
-
- </li>
-
- </ul>
- </div>
- </div>
-
- <div class="menu-dropdown menu-dropdown-download-buttons">
- <button class="headerbtn menu-dropdown__trigger"
- aria-label="Download this page">
- <i class="fas fa-download"></i>
- </button>
- <div class="menu-dropdown__content">
- <ul>
- <li>
- <a href="../_sources/第五章/5.4 PyTorh模型保存与读取.md.txt"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Download source file"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-file"></i>
- </span>
- <span class="headerbtn__text-container">.md</span>
- </a>
-
- </li>
-
- <li>
-
- <button onclick="printPdf(this)"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Print to PDF"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-file-pdf"></i>
- </span>
- <span class="headerbtn__text-container">.pdf</span>
- </button>
-
- </li>
-
- </ul>
- </div>
- </div>
- <label for="__page-toc"
- class="headerbtn headerbtn-page-toc"
-
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-list"></i>
- </span>
-
- </label>
-
- </div>
- </div>
-
- <!-- Table of contents -->
- <div class="col-md-3 bd-toc show noprint">
- <div class="tocsection onthispage pt-5 pb-3">
- <i class="fas fa-list"></i> Contents
- </div>
- <nav id="bd-toc-nav" aria-label="Page">
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id1">
- 5.4.1 模型存储格式
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id2">
- 5.4.2 模型存储内容
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id3">
- 5.4.3 单卡和多卡模型存储的区别
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id4">
- 5.4.4 情况分类讨论
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id5">
- 附:测试环境
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id6">
- 本节参考
- </a>
- </li>
- </ul>
-
- </nav>
- </div>
- </div>
- <div class="article row">
- <div class="col pl-md-3 pl-lg-5 content-container">
- <!-- Table of contents that is only displayed when printing the page -->
- <div id="jb-print-docs-body" class="onlyprint">
- <h1>5.4 PyTorch模型保存与读取</h1>
- <!-- Table of contents -->
- <div id="print-main-content">
- <div id="jb-print-toc">
-
- <div>
- <h2> Contents </h2>
- </div>
- <nav aria-label="Page">
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id1">
- 5.4.1 模型存储格式
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id2">
- 5.4.2 模型存储内容
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id3">
- 5.4.3 单卡和多卡模型存储的区别
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id4">
- 5.4.4 情况分类讨论
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id5">
- 附:测试环境
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id6">
- 本节参考
- </a>
- </li>
- </ul>
-
- </nav>
- </div>
- </div>
- </div>
- <main id="main-content" role="main">
-
- <div>
-
- <section class="tex2jax_ignore mathjax_ignore" id="pytorch">
- <h1>5.4 PyTorch模型保存与读取<a class="headerlink" href="#pytorch" title="永久链接至标题">#</a></h1>
- <p>在前面几节的内容中,我们介绍了如何构建和修改PyTorch模型。本节我们来讨论PyTorch如何保存和读取训练好的模型。</p>
- <p>另外,在很多场景下我们都会使用多GPU训练。这种情况下,模型会分布于各个GPU上(参加2.3节分布数据式训练,这里暂不考虑分布模型式训练),模型的保存和读取与单GPU训练情景下是否有所不同?</p>
- <p>经过本节的学习,你将收获:</p>
- <ul class="simple">
- <li><p>PyTorch的模型的存储格式</p></li>
- <li><p>PyTorch如何存储模型</p></li>
- <li><p>单卡与多卡训练下模型的保存与加载方法</p></li>
- </ul>
- <section id="id1">
- <h2>5.4.1 模型存储格式<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h2>
- <p>PyTorch存储模型主要采用pkl,pt,pth三种格式。就使用层面来说没有区别,这里不做具体的讨论。本节最后的参考内容中列出了查阅到的一些资料,感兴趣的读者可以进一步研究,欢迎留言讨论。</p>
- </section>
- <section id="id2">
- <h2>5.4.2 模型存储内容<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h2>
- <p>一个PyTorch模型主要包含两个部分:模型结构和权重。其中模型是继承nn.Module的类,权重的数据结构是一个字典(key是层名,value是权重向量)。存储也由此分为两种形式:存储整个模型(包括结构和权重),和只存储模型权重。</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">models</span>
- <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">(</span><span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
-
- <span class="c1"># 保存整个模型</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">save_dir</span><span class="p">)</span>
- <span class="c1"># 保存模型权重</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">save_dir</span><span class="p">)</span>
- </pre></div>
- </div>
- <p>对于PyTorch而言,pt, pth和pkl<strong>三种数据格式均支持模型权重和整个模型的存储</strong>,因此使用上没有差别。</p>
- </section>
- <section id="id3">
- <h2>5.4.3 单卡和多卡模型存储的区别<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h2>
- <p>PyTorch中将模型和数据放到GPU上有两种方式——.cuda()和.to(device),本节后续内容针对前一种方式进行讨论。如果要使用多卡训练的话,需要对模型使用torch.nn.DataParallel。示例如下:</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'0'</span> <span class="c1"># 如果是多卡改成类似0,1,2</span>
- <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span> <span class="c1"># 单卡</span>
- <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span> <span class="c1"># 多卡</span>
- </pre></div>
- </div>
- <p>之后我们把model对应的layer名称打印出来看一下,可以观察到差别在于多卡并行的模型每层的名称前多了一个“module”。</p>
- <ul class="simple">
- <li><p>单卡模型的层名:</p></li>
- </ul>
- <p><img alt="img" src="https://pic3.zhimg.com/v2-3490f6ab8bc806274dd017e1a66e2486_b.png" /></p>
- <ul class="simple">
- <li><p>多卡模型的层名:</p></li>
- </ul>
- <p><img alt="img" src="https://pic3.zhimg.com/v2-4b611c24c2e702749cebbe65eaff7cde_b.png" /></p>
- <p>这种模型表示的不同可能会导致模型保存和加载过程中需要处理一些矛盾点,下面对各种可能的情况做分类讨论。</p>
- </section>
- <section id="id4">
- <h2>5.4.4 情况分类讨论<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h2>
- <p>由于训练和测试所使用的硬件条件不同,在模型的保存和加载过程中可能因为单GPU和多GPU环境的不同带来模型不匹配等问题。这里对PyTorch框架下单卡/多卡下模型的保存和加载问题进行排列组合(=4),样例模型是torchvision中预训练模型resnet152,不尽之处欢迎大家补充。</p>
- <ul class="simple">
- <li><p><strong>单卡保存+单卡加载</strong></p></li>
- </ul>
- <p>在使用os.envision命令指定使用的GPU后,即可进行模型保存和读取操作。注意这里即便保存和读取时使用的GPU不同也无妨。</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
- <span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">models</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'0'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
- <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">(</span><span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
-
- <span class="c1"># 保存+读取整个模型</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
-
- <span class="c1"># 保存+读取模型权重</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">()</span> <span class="c1">#注意这里需要对模型结构有定义</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">state_dict</span> <span class="o">=</span> <span class="n">loaded_dict</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><strong>单卡保存+多卡加载</strong></p></li>
- </ul>
- <p>这种情况的处理比较简单,读取单卡保存的模型后,使用nn.DataParallel函数进行分布式训练设置即可(相当于3.1代码中.cuda()替换一下):</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
- <span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">models</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'0'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
- <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">(</span><span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
-
- <span class="c1"># 保存+读取整个模型</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">save_dir</span><span class="p">)</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'1,2'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">loaded_model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
-
- <span class="c1"># 保存+读取模型权重</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">save_dir</span><span class="p">)</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'1,2'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
- <span class="n">loaded_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">()</span> <span class="c1">#注意这里需要对模型结构有定义</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">state_dict</span> <span class="o">=</span> <span class="n">loaded_dict</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">loaded_model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><strong>多卡保存+单卡加载</strong></p></li>
- </ul>
- <p>这种情况下的核心问题是:如何去掉权重字典键名中的"module",以保证模型的统一性。</p>
- <p>对于加载整个模型,直接提取模型的module属性即可:</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
- <span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">models</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'1,2'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
-
- <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">(</span><span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
-
- <span class="c1"># 保存+读取整个模型</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">save_dir</span><span class="p">)</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'0'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">loaded_model</span><span class="o">.</span><span class="n">module</span>
- </pre></div>
- </div>
- <p>对于加载模型权重,有以下几种思路:</p>
- <p><strong>去除字典里的module麻烦,往model里添加module简单(推荐)</strong></p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
- <span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">models</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'0,1,2'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
-
- <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">(</span><span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
-
- <span class="c1"># 保存+读取模型权重</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">save_dir</span><span class="p">)</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'0'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
- <span class="n">loaded_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">()</span> <span class="c1">#注意这里需要对模型结构有定义</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">loaded_model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">state_dict</span> <span class="o">=</span> <span class="n">loaded_dict</span>
- </pre></div>
- </div>
- <p>这样即便是单卡,也可以开始训练了(相当于分布到单卡上)</p>
- <p><strong>遍历字典去除module</strong></p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'0'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
-
- <span class="n">loaded_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
-
- <span class="n">new_state_dict</span> <span class="o">=</span> <span class="n">OrderedDict</span><span class="p">()</span>
- <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">loaded_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
- <span class="n">name</span> <span class="o">=</span> <span class="n">k</span><span class="p">[</span><span class="mi">7</span><span class="p">:]</span> <span class="c1"># module字段在最前面,从第7个字符开始就可以去掉module</span>
- <span class="n">new_state_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span> <span class="c1">#新字典的key值对应的value一一对应</span>
-
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">()</span> <span class="c1">#注意这里需要对模型结构有定义</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">state_dict</span> <span class="o">=</span> <span class="n">new_state_dict</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">loaded_model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
- </pre></div>
- </div>
- <p><strong>使用replace操作去除module</strong></p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">loaded_model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">()</span>
- <span class="n">loaded_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">({</span><span class="n">k</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">'module.'</span><span class="p">,</span> <span class="s1">''</span><span class="p">):</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">loaded_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()})</span>
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><strong>多卡保存+多卡加载</strong></p></li>
- </ul>
- <p>由于是模型保存和加载都使用的是多卡,因此不存在模型层名前缀不同的问题。但多卡状态下存在一个device(使用的GPU)匹配的问题,即<strong>保存整个模型</strong>时会同时保存所使用的GPU id等信息,读取时若这些信息和当前使用的GPU信息不符则可能会报错或者程序不按预定状态运行。具体表现为以下两点:</p>
- <p><strong>读取整个模型再使用nn.DataParallel进行分布式训练设置</strong></p>
- <p>这种情况很可能会造成保存的整个模型中GPU id和读取环境下设置的GPU id不符,训练时数据所在device和模型所在device不一致而报错。</p>
- <p><strong>读取整个模型而不使用nn.DataParallel进行分布式训练设置</strong></p>
- <p>这种情况可能不会报错,测试中发现程序会自动使用设备的前n个GPU进行训练(n是保存的模型使用的GPU个数)。此时如果指定的GPU个数少于n,则会报错。在这种情况下,只有保存模型时环境的device id和读取模型时环境的device id一致,程序才会按照预期在指定的GPU上进行分布式训练。</p>
- <p>相比之下,读取模型权重,之后再使用nn.DataParallel进行分布式训练设置则没有问题。因此<strong>多卡模式下建议使用权重的方式存储和读取模型</strong>:</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
- <span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">models</span>
-
- <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'CUDA_VISIBLE_DEVICES'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'0,1,2'</span> <span class="c1">#这里替换成希望使用的GPU编号</span>
-
- <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">(</span><span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
-
- <span class="c1"># 保存+读取模型权重,强烈建议!!</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">()</span> <span class="c1">#注意这里需要对模型结构有定义</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">loaded_model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">state_dict</span> <span class="o">=</span> <span class="n">loaded_dict</span>
- </pre></div>
- </div>
- <p>如果只有保存的整个模型,也可以采用提取权重的方式构建新的模型:</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># 读取整个模型</span>
- <span class="n">loaded_whole_model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">save_dir</span><span class="p">)</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet152</span><span class="p">()</span> <span class="c1">#注意这里需要对模型结构有定义</span>
- <span class="n">loaded_model</span><span class="o">.</span><span class="n">state_dict</span> <span class="o">=</span> <span class="n">loaded_whole_model</span><span class="o">.</span><span class="n">state_dict</span>
- <span class="n">loaded_model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">loaded_model</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
- </pre></div>
- </div>
- <p>另外,上面所有对于loaded_model修改权重字典的形式都是通过赋值来实现的,在PyTorch中还可以通过"load_state_dict"函数来实现:</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">loaded_model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">loaded_dict</span><span class="p">)</span>
- </pre></div>
- </div>
- </section>
- <section id="id5">
- <h2>附:测试环境<a class="headerlink" href="#id5" title="永久链接至标题">#</a></h2>
- <p>OS: Ubuntu 20.04 LTS GPU: GeForce RTX 2080 Ti (x3)</p>
- </section>
- <section id="id6">
- <h2>本节参考<a class="headerlink" href="#id6" title="永久链接至标题">#</a></h2>
- <p>本章内容同时发布于<a class="reference external" href="https://zhuanlan.zhihu.com/p/371090724">知乎</a>和<a class="reference external" href="https://blog.csdn.net/goodljq/article/details/117258032">CSDN</a></p>
- <p>【1】<a class="reference external" href="https://www.zhihu.com/question/274533811">pytorch 中pkl和pth的区别?</a><br />
- 【2】<a class="reference external" href="https://stackoverflow.com/questions/59095824/what-is-the-difference-between-pt-pth-and-pwf-extentions-in-pytorch">What is the difference between .pt, .pth and .pwf extentions in PyTorch?</a></p>
- </section>
- </section>
-
-
- </div>
-
- </main>
- <footer class="footer-article noprint">
-
- <!-- Previous / next buttons -->
- <div class='prev-next-area'>
- <a class='left-prev' id="prev-link" href="5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html" title="上一页 页">
- <i class="fas fa-angle-left"></i>
- <div class="prev-next-info">
- <p class="prev-next-subtitle">上一页</p>
- <p class="prev-next-title">5.3 PyTorch修改模型</p>
- </div>
- </a>
- <a class='right-next' id="next-link" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html" title="下一页 页">
- <div class="prev-next-info">
- <p class="prev-next-subtitle">下一页</p>
- <p class="prev-next-title">第六章:PyTorch进阶训练技巧</p>
- </div>
- <i class="fas fa-angle-right"></i>
- </a>
- </div>
- </footer>
- </div>
- </div>
- <div class="footer-content row">
- <footer class="col footer"><p>
-
- By ZhikangNiu<br/>
-
- © Copyright 2022, ZhikangNiu.<br/>
- </p>
- </footer>
- </div>
-
- </div>
-
-
- </div>
- </div>
-
- <!-- Scripts loaded after <body> so the DOM is not blocked -->
- <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>
-
-
- </body>
- </html>
|