|
-
- <!DOCTYPE html>
-
- <html lang="zh">
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />
-
- <title>3.9 Pytorch优化器 — 深入浅出PyTorch</title>
-
- <!-- Loaded before other Sphinx assets -->
- <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
- <link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
-
-
- <link rel="stylesheet"
- href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
- <link rel="preload" as="font" type="font/woff2" crossorigin
- href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
- <link rel="preload" as="font" type="font/woff2" crossorigin
- href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
-
- <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
- <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
- <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
- <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
- <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
-
- <!-- Pre-loaded scripts that we'll load fully later -->
- <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">
-
- <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
- <script src="../_static/jquery.js"></script>
- <script src="../_static/underscore.js"></script>
- <script src="../_static/doctools.js"></script>
- <script>let toggleHintShow = 'Click to show';</script>
- <script>let toggleHintHide = 'Click to hide';</script>
- <script>let toggleOpenOnPrint = 'true';</script>
- <script src="../_static/togglebutton.js"></script>
- <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
- <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
- <link rel="index" title="索引" href="../genindex.html" />
- <link rel="search" title="搜索" href="../search.html" />
- <link rel="next" title="第四章:PyTorch基础实战" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html" />
- <link rel="prev" title="3.8 可视化" href="3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html" />
- <meta name="viewport" content="width=device-width, initial-scale=1" />
- <meta name="docsearch:language" content="zh">
-
-
- <!-- Google Analytics -->
-
- </head>
- <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
- <!-- Checkboxes to toggle the left sidebar -->
- <input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
- <label class="overlay overlay-navbar" for="__navigation">
- <div class="visually-hidden">Toggle navigation sidebar</div>
- </label>
- <!-- Checkboxes to toggle the in-page toc -->
- <input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
- <label class="overlay overlay-pagetoc" for="__page-toc">
- <div class="visually-hidden">Toggle in-page Table of Contents</div>
- </label>
- <!-- Headers at the top -->
- <div class="announcement header-item noprint"></div>
- <div class="header header-item noprint"></div>
-
-
- <div class="container-fluid" id="banner"></div>
-
-
-
- <div class="container-xl">
- <div class="row">
-
- <!-- Sidebar -->
- <div class="bd-sidebar noprint" id="site-navigation">
- <div class="bd-sidebar__content">
- <div class="bd-sidebar__top"><div class="navbar-brand-box">
- <a class="navbar-brand text-wrap" href="../index.html">
-
-
-
- <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
-
- </a>
- </div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
- <i class="icon fas fa-search"></i>
- <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
- </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
- <div class="bd-toc-item active">
- <p aria-level="2" class="caption" role="heading">
- <span class="caption-text">
- 目录
- </span>
- </p>
- <ul class="current nav bd-sidenav">
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
- 第一章:PyTorch的简介和安装
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
- <label for="toctree-checkbox-1">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
- 1.1 PyTorch简介
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
- 1.2 PyTorch的安装
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
- 1.3 PyTorch相关资源
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
- 第二章:PyTorch基础知识
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
- <label for="toctree-checkbox-2">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
- 2.1 张量
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
- 2.2 自动求导
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
- 2.3 并行计算简介
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 current active has-children">
- <a class="reference internal" href="index.html">
- 第三章:PyTorch的主要组成模块
- </a>
- <input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
- <label for="toctree-checkbox-3">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul class="current">
- <li class="toctree-l2">
- <a class="reference internal" href="3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
- 3.1 思考:完成深度学习的必要部分
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
- 3.2 基本配置
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
- 3.3 数据读入
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="3.4%20%E6%A8%A1%E5%9E%8B%E6%9E%84%E5%BB%BA.html">
- 3.4 模型构建
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
- 3.5 模型初始化
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
- 3.6 损失函数
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
- 3.7 训练和评估
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
- 3.8 可视化
- </a>
- </li>
- <li class="toctree-l2 current active">
- <a class="current reference internal" href="#">
- 3.9 Pytorch优化器
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html">
- 第四章:PyTorch基础实战
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
- <label for="toctree-checkbox-4">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/%E5%9F%BA%E7%A1%80%E5%AE%9E%E6%88%98%E2%80%94%E2%80%94FashionMNIST%E6%97%B6%E8%A3%85%E5%88%86%E7%B1%BB.html">
- 基础实战——FashionMNIST时装分类
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/index.html">
- 第五章:PyTorch模型定义
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
- <label for="toctree-checkbox-5">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
- 5.1 PyTorch模型定义的方式
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
- 5.2 利用模型块快速搭建复杂网络
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
- 5.3 PyTorch修改模型
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.4%20PyTorh%E6%A8%A1%E5%9E%8B%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96.html">
- 5.4 PyTorch模型保存与读取
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
- 第六章:PyTorch进阶训练技巧
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
- <label for="toctree-checkbox-6">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
- 6.1 自定义损失函数
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
- 6.2 动态调整学习率
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
- 6.3 模型微调-torchvision
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
- 6.3 模型微调 - timm
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
- 6.4 半精度训练
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
- 6.5 数据增强-imgaug
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
- 6.6 使用argparse进行调参
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.7%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E4%B8%8E%E8%BF%9B%E9%98%B6%E8%AE%AD%E7%BB%83%E6%8A%80%E5%B7%A7.html">
- PyTorch模型定义与进阶训练技巧
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
- 第七章:PyTorch可视化
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
- <label for="toctree-checkbox-7">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
- 7.1 可视化网络结构
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
- 7.2 CNN可视化
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
- 7.3 使用TensorBoard可视化训练过程
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
- 第八章:PyTorch生态简介
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
- <label for="toctree-checkbox-8">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
- 8.1 本章简介
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
- 8.2 torchvision
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
- 8.3 PyTorchVideo简介
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
- 8.4 torchtext简介
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/transforms%E5%AE%9E%E6%93%8D.html">
- transforms实战
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/index.html">
- 第九章:PyTorch的模型部署
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/>
- <label for="toctree-checkbox-9">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/9.1%20%E4%BD%BF%E7%94%A8ONNX%E8%BF%9B%E8%A1%8C%E9%83%A8%E7%BD%B2%E5%B9%B6%E6%8E%A8%E7%90%86.html">
- 9.1 使用ONNX进行部署并推理
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/index.html">
- 第十章:常见代码解读
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/>
- <label for="toctree-checkbox-10">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/10.1%20%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
- 9.1 图像分类(补充中)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/10.2%20%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B.html">
- 目标检测
- </a>
- </li>
- </ul>
- </li>
- </ul>
-
- </div>
- </nav></div>
- <div class="bd-sidebar__bottom">
- <!-- To handle the deprecated key -->
-
- <div class="navbar_extra_footer">
- Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
- </div>
-
- </div>
- </div>
- <div id="rtd-footer-container"></div>
- </div>
-
-
-
-
-
-
- <!-- A tiny helper pixel to detect if we've scrolled -->
- <div class="sbt-scroll-pixel-helper"></div>
- <!-- Main content -->
- <div class="col py-0 content-container">
-
- <div class="header-article row sticky-top noprint">
-
-
-
-
- <div class="col py-1 d-flex header-article-main">
- <div class="header-article__left">
-
- <label for="__navigation"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="right"
- title="Toggle navigation"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-bars"></i>
- </span>
-
- </label>
-
-
- </div>
- <div class="header-article__right">
- <button onclick="toggleFullScreen()"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="bottom"
- title="Fullscreen mode"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-expand"></i>
- </span>
-
- </button>
-
- <div class="menu-dropdown menu-dropdown-repository-buttons">
- <button class="headerbtn menu-dropdown__trigger"
- aria-label="Source repositories">
- <i class="fab fa-github"></i>
- </button>
- <div class="menu-dropdown__content">
- <ul>
- <li>
- <a href="https://github.com/datawhalechina/thorough-pytorch"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Source repository"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fab fa-github"></i>
- </span>
- <span class="headerbtn__text-container">repository</span>
- </a>
-
- </li>
-
- <li>
- <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第三章/3.9 优化器.html&body=Your%20issue%20content%20here."
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Open an issue"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-lightbulb"></i>
- </span>
- <span class="headerbtn__text-container">open issue</span>
- </a>
-
- </li>
-
- <li>
- <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第三章/3.9 优化器.md"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Edit this page"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-pencil-alt"></i>
- </span>
- <span class="headerbtn__text-container">suggest edit</span>
- </a>
-
- </li>
-
- </ul>
- </div>
- </div>
-
- <div class="menu-dropdown menu-dropdown-download-buttons">
- <button class="headerbtn menu-dropdown__trigger"
- aria-label="Download this page">
- <i class="fas fa-download"></i>
- </button>
- <div class="menu-dropdown__content">
- <ul>
- <li>
- <a href="../_sources/第三章/3.9 优化器.md.txt"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Download source file"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-file"></i>
- </span>
- <span class="headerbtn__text-container">.md</span>
- </a>
-
- </li>
-
- <li>
-
- <button onclick="printPdf(this)"
- class="headerbtn"
- data-toggle="tooltip"
- data-placement="left"
- title="Print to PDF"
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-file-pdf"></i>
- </span>
- <span class="headerbtn__text-container">.pdf</span>
- </button>
-
- </li>
-
- </ul>
- </div>
- </div>
- <label for="__page-toc"
- class="headerbtn headerbtn-page-toc"
-
- >
-
-
- <span class="headerbtn__icon-container">
- <i class="fas fa-list"></i>
- </span>
-
- </label>
-
- </div>
- </div>
-
- <!-- Table of contents -->
- <div class="col-md-3 bd-toc show noprint">
- <div class="tocsection onthispage pt-5 pb-3">
- <i class="fas fa-list"></i> Contents
- </div>
- <nav id="bd-toc-nav" aria-label="Page">
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id1">
- 3.9.1 Pytorch提供的优化器
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id2">
- 3.9.2 实际操作
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id3">
- 3.9.3 输出结果
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id4">
- 3.9.4 实验
- </a>
- </li>
- </ul>
-
- </nav>
- </div>
- </div>
- <div class="article row">
- <div class="col pl-md-3 pl-lg-5 content-container">
- <!-- Table of contents that is only displayed when printing the page -->
- <div id="jb-print-docs-body" class="onlyprint">
- <h1>3.9 Pytorch优化器</h1>
- <!-- Table of contents -->
- <div id="print-main-content">
- <div id="jb-print-toc">
-
- <div>
- <h2> Contents </h2>
- </div>
- <nav aria-label="Page">
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id1">
- 3.9.1 Pytorch提供的优化器
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id2">
- 3.9.2 实际操作
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id3">
- 3.9.3 输出结果
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#id4">
- 3.9.4 实验
- </a>
- </li>
- </ul>
-
- </nav>
- </div>
- </div>
- </div>
- <main id="main-content" role="main">
-
- <div>
-
- <section class="tex2jax_ignore mathjax_ignore" id="pytorch">
- <h1>3.9 Pytorch优化器<a class="headerlink" href="#pytorch" title="永久链接至标题">#</a></h1>
- <p>深度学习的目标是通过不断改变网络参数,使得参数能够对输入做各种非线性变换拟合输出,本质上就是一个函数去寻找最优解,只不过这个最优解是一个矩阵,而如何快速求得这个最优解是深度学习研究的一个重点,以经典的resnet-50为例,它大约有2000万个系数需要进行计算,那么我们如何计算出这么多系数,有以下两种方法:</p>
- <ol class="simple">
- <li><p>第一种是直接暴力穷举一遍参数,这种方法实施可能性基本为0,堪比愚公移山plus的难度。</p></li>
- <li><p>为了使求解参数过程更快,人们提出了第二种办法,即BP+优化器逼近求解。</p></li>
- </ol>
- <p>因此,优化器是根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值,使得模型输出更加接近真实标签。</p>
- <p>经过本节的学习,你将收获:</p>
- <ul class="simple">
- <li><p>了解PyTorch的优化器</p></li>
- <li><p>学会使用PyTorch提供的优化器进行优化</p></li>
- <li><p>优化器的属性和构造</p></li>
- <li><p>优化器的对比</p></li>
- </ul>
- <section id="id1">
- <h2>3.9.1 Pytorch提供的优化器<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h2>
- <p>Pytorch很人性化的给我们提供了一个优化器的库torch.optim,在这里面提供了十种优化器。</p>
- <ul class="simple">
- <li><p>torch.optim.ASGD</p></li>
- <li><p>torch.optim.Adadelta</p></li>
- <li><p>torch.optim.Adagrad</p></li>
- <li><p>torch.optim.Adam</p></li>
- <li><p>torch.optim.AdamW</p></li>
- <li><p>torch.optim.Adamax</p></li>
- <li><p>torch.optim.LBFGS</p></li>
- <li><p>torch.optim.RMSprop</p></li>
- <li><p>torch.optim.Rprop</p></li>
- <li><p>torch.optim.SGD</p></li>
- <li><p>torch.optim.SparseAdam</p></li>
- </ul>
- <p>而以上这些优化算法均继承于<code class="docutils literal notranslate"><span class="pre">Optimizer</span></code>,下面我们先来看下所有优化器的基类<code class="docutils literal notranslate"><span class="pre">Optimizer</span></code>。定义如下:</p>
- <div class="highlight-Python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Optimizer</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
- <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">):</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">defaults</span> <span class="o">=</span> <span class="n">defaults</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">state</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">dict</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span> <span class="o">=</span> <span class="p">[]</span>
- </pre></div>
- </div>
- <p><strong><code class="docutils literal notranslate"><span class="pre">Optimizer</span></code>有三个属性:</strong></p>
- <ul class="simple">
- <li><p><code class="docutils literal notranslate"><span class="pre">defaults</span></code>:存储的是优化器的超参数,例子如下:</p></li>
- </ul>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="p">{</span><span class="s1">'lr'</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span> <span class="s1">'momentum'</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">,</span> <span class="s1">'dampening'</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">'weight_decay'</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">'nesterov'</span><span class="p">:</span> <span class="kc">False</span><span class="p">}</span>
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><code class="docutils literal notranslate"><span class="pre">state</span></code>:参数的缓存,例子如下:</p></li>
- </ul>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span>defaultdict(<class 'dict'>, {tensor([[ 0.3864, -0.0131],
- [-0.1911, -0.4511]], requires_grad=True): {'momentum_buffer': tensor([[0.0052, 0.0052],
- [0.0052, 0.0052]])}})
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><code class="docutils literal notranslate"><span class="pre">param_groups</span></code>:管理的参数组,是一个list,其中每个元素是一个字典,顺序是params,lr,momentum,dampening,weight_decay,nesterov,例子如下:</p></li>
- </ul>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="p">[{</span><span class="s1">'params'</span><span class="p">:</span> <span class="p">[</span><span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">0.1022</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.6890</span><span class="p">],[</span><span class="o">-</span><span class="mf">1.5116</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.7846</span><span class="p">]],</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)],</span> <span class="s1">'lr'</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">'momentum'</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">'dampening'</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">'weight_decay'</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s1">'nesterov'</span><span class="p">:</span> <span class="kc">False</span><span class="p">}]</span>
- </pre></div>
- </div>
- <p><strong><code class="docutils literal notranslate"><span class="pre">Optimizer</span></code>还有以下的方法:</strong></p>
- <ul class="simple">
- <li><p><code class="docutils literal notranslate"><span class="pre">zero_grad()</span></code>:清空所管理参数的梯度,PyTorch的特性是张量的梯度不自动清零,因此每次反向传播后都需要清空梯度。</p></li>
- </ul>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">zero_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">set_to_none</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
- <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
- <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]:</span>
- <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1">#梯度不为空</span>
- <span class="k">if</span> <span class="n">set_to_none</span><span class="p">:</span>
- <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="kc">None</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">grad_fn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
- <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">detach_</span><span class="p">()</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
- <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">zero_</span><span class="p">()</span><span class="c1"># 梯度设置为0</span>
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><code class="docutils literal notranslate"><span class="pre">step()</span></code>:执行一步梯度更新,参数更新</p></li>
- </ul>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">closure</span><span class="p">):</span>
- <span class="k">raise</span> <span class="ne">NotImplementedError</span>
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><code class="docutils literal notranslate"><span class="pre">add_param_group()</span></code>:添加参数组</p></li>
- </ul>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">add_param_group</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param_group</span><span class="p">):</span>
- <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">param_group</span><span class="p">,</span> <span class="nb">dict</span><span class="p">),</span> <span class="s2">"param group must be a dict"</span>
- <span class="c1"># 检查类型是否为tensor</span>
- <span class="n">params</span> <span class="o">=</span> <span class="n">param_group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span>
- <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
- <span class="n">param_group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="n">params</span><span class="p">]</span>
- <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="nb">set</span><span class="p">):</span>
- <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s1">'optimizer parameters need to be organized in ordered collections, but '</span>
- <span class="s1">'the ordering of tensors in sets will change between runs. Please use a list instead.'</span><span class="p">)</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">param_group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
- <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">param_group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]:</span>
- <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
- <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"optimizer can only optimize Tensors, "</span>
- <span class="s2">"but one of the params is "</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">typename</span><span class="p">(</span><span class="n">param</span><span class="p">))</span>
- <span class="k">if</span> <span class="ow">not</span> <span class="n">param</span><span class="o">.</span><span class="n">is_leaf</span><span class="p">:</span>
- <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"can't optimize a non-leaf Tensor"</span><span class="p">)</span>
-
- <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">default</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">defaults</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
- <span class="k">if</span> <span class="n">default</span> <span class="ow">is</span> <span class="n">required</span> <span class="ow">and</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">param_group</span><span class="p">:</span>
- <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"parameter group didn't specify a value of required optimization parameter "</span> <span class="o">+</span>
- <span class="n">name</span><span class="p">)</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">param_group</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">default</span><span class="p">)</span>
-
- <span class="n">params</span> <span class="o">=</span> <span class="n">param_group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span>
- <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">params</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">params</span><span class="p">)):</span>
- <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">"optimizer contains a parameter group with duplicate parameters; "</span>
- <span class="s2">"in future, this will cause an error; "</span>
- <span class="s2">"see github.com/pytorch/pytorch/issues/40967 for more information"</span><span class="p">,</span> <span class="n">stacklevel</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
- <span class="c1"># 上面好像都在进行一些类的检测,报Warning和Error</span>
- <span class="n">param_set</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
- <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
- <span class="n">param_set</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]))</span>
-
- <span class="k">if</span> <span class="ow">not</span> <span class="n">param_set</span><span class="o">.</span><span class="n">isdisjoint</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">param_group</span><span class="p">[</span><span class="s1">'params'</span><span class="p">])):</span>
- <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"some parameters appear in more than one parameter group"</span><span class="p">)</span>
- <span class="c1"># 添加参数</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">param_group</span><span class="p">)</span>
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><code class="docutils literal notranslate"><span class="pre">load_state_dict()</span></code> :加载状态参数字典,可以用来进行模型的断点续训练,继续上次的参数进行训练</p></li>
- </ul>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">load_state_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">):</span>
- <span class="sa">r</span><span class="sd">"""Loads the optimizer state.</span>
-
- <span class="sd"> Arguments:</span>
- <span class="sd"> state_dict (dict): optimizer state. Should be an object returned</span>
- <span class="sd"> from a call to :meth:`state_dict`.</span>
- <span class="sd"> """</span>
- <span class="c1"># deepcopy, to be consistent with module API</span>
- <span class="n">state_dict</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span>
- <span class="c1"># Validate the state_dict</span>
- <span class="n">groups</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span>
- <span class="n">saved_groups</span> <span class="o">=</span> <span class="n">state_dict</span><span class="p">[</span><span class="s1">'param_groups'</span><span class="p">]</span>
-
- <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">groups</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="n">saved_groups</span><span class="p">):</span>
- <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"loaded state dict has a different number of "</span>
- <span class="s2">"parameter groups"</span><span class="p">)</span>
- <span class="n">param_lens</span> <span class="o">=</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">g</span><span class="p">[</span><span class="s1">'params'</span><span class="p">])</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">groups</span><span class="p">)</span>
- <span class="n">saved_lens</span> <span class="o">=</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">g</span><span class="p">[</span><span class="s1">'params'</span><span class="p">])</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">saved_groups</span><span class="p">)</span>
- <span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">p_len</span> <span class="o">!=</span> <span class="n">s_len</span> <span class="k">for</span> <span class="n">p_len</span><span class="p">,</span> <span class="n">s_len</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">param_lens</span><span class="p">,</span> <span class="n">saved_lens</span><span class="p">)):</span>
- <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"loaded state dict contains a parameter group "</span>
- <span class="s2">"that doesn't match the size of optimizer's group"</span><span class="p">)</span>
-
- <span class="c1"># Update the state</span>
- <span class="n">id_map</span> <span class="o">=</span> <span class="p">{</span><span class="n">old_id</span><span class="p">:</span> <span class="n">p</span> <span class="k">for</span> <span class="n">old_id</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span>
- <span class="nb">zip</span><span class="p">(</span><span class="n">chain</span><span class="o">.</span><span class="n">from_iterable</span><span class="p">((</span><span class="n">g</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">saved_groups</span><span class="p">)),</span>
- <span class="n">chain</span><span class="o">.</span><span class="n">from_iterable</span><span class="p">((</span><span class="n">g</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="n">groups</span><span class="p">)))}</span>
-
- <span class="k">def</span> <span class="nf">cast</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
- <span class="sa">r</span><span class="sd">"""Make a deep copy of value, casting all tensors to device of param."""</span>
- <span class="o">.....</span>
-
- <span class="c1"># Copy state assigned to params (and cast tensors to appropriate types).</span>
- <span class="c1"># State that is not assigned to params is copied as is (needed for</span>
- <span class="c1"># backward compatibility).</span>
- <span class="n">state</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">dict</span><span class="p">)</span>
- <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">state_dict</span><span class="p">[</span><span class="s1">'state'</span><span class="p">]</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
- <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">id_map</span><span class="p">:</span>
- <span class="n">param</span> <span class="o">=</span> <span class="n">id_map</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
- <span class="n">state</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">state</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span>
-
- <span class="c1"># Update parameter groups, setting their 'params' value</span>
- <span class="k">def</span> <span class="nf">update_group</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">new_group</span><span class="p">):</span>
- <span class="o">...</span>
- <span class="n">param_groups</span> <span class="o">=</span> <span class="p">[</span>
- <span class="n">update_group</span><span class="p">(</span><span class="n">g</span><span class="p">,</span> <span class="n">ng</span><span class="p">)</span> <span class="k">for</span> <span class="n">g</span><span class="p">,</span> <span class="n">ng</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">groups</span><span class="p">,</span> <span class="n">saved_groups</span><span class="p">)]</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">__setstate__</span><span class="p">({</span><span class="s1">'state'</span><span class="p">:</span> <span class="n">state</span><span class="p">,</span> <span class="s1">'param_groups'</span><span class="p">:</span> <span class="n">param_groups</span><span class="p">})</span>
- </pre></div>
- </div>
- <ul class="simple">
- <li><p><code class="docutils literal notranslate"><span class="pre">state_dict()</span></code>:获取优化器当前状态信息字典</p></li>
- </ul>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">state_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
- <span class="sa">r</span><span class="sd">"""Returns the state of the optimizer as a :class:`dict`.</span>
-
- <span class="sd"> It contains two entries:</span>
-
- <span class="sd"> * state - a dict holding current optimization state. Its content</span>
- <span class="sd"> differs between optimizer classes.</span>
- <span class="sd"> * param_groups - a dict containing all parameter groups</span>
- <span class="sd"> """</span>
- <span class="c1"># Save order indices instead of Tensors</span>
- <span class="n">param_mappings</span> <span class="o">=</span> <span class="p">{}</span>
- <span class="n">start_index</span> <span class="o">=</span> <span class="mi">0</span>
-
- <span class="k">def</span> <span class="nf">pack_group</span><span class="p">(</span><span class="n">group</span><span class="p">):</span>
- <span class="o">......</span>
- <span class="n">param_groups</span> <span class="o">=</span> <span class="p">[</span><span class="n">pack_group</span><span class="p">(</span><span class="n">g</span><span class="p">)</span> <span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">]</span>
- <span class="c1"># Remap state to use order indices as keys</span>
- <span class="n">packed_state</span> <span class="o">=</span> <span class="p">{(</span><span class="n">param_mappings</span><span class="p">[</span><span class="nb">id</span><span class="p">(</span><span class="n">k</span><span class="p">)]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">k</span><span class="p">):</span> <span class="n">v</span>
- <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">state</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
- <span class="k">return</span> <span class="p">{</span>
- <span class="s1">'state'</span><span class="p">:</span> <span class="n">packed_state</span><span class="p">,</span>
- <span class="s1">'param_groups'</span><span class="p">:</span> <span class="n">param_groups</span><span class="p">,</span>
- <span class="p">}</span>
- </pre></div>
- </div>
- </section>
- <section id="id2">
- <h2>3.9.2 实际操作<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h2>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
- <span class="kn">import</span> <span class="nn">torch</span>
-
- <span class="c1"># 设置权重,服从正态分布 --> 2 x 2</span>
- <span class="n">weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="c1"># 设置梯度为全1矩阵 --> 2 x 2</span>
- <span class="n">weight</span><span class="o">.</span><span class="n">grad</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
- <span class="c1"># 输出现有的weight和data</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"The data of weight before step:</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">))</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"The grad of weight before step:</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">grad</span><span class="p">))</span>
- <span class="c1"># 实例化优化器</span>
- <span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">([</span><span class="n">weight</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">)</span>
- <span class="c1"># 进行一步操作</span>
- <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
- <span class="c1"># 查看进行一步后的值,梯度</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"The data of weight after step:</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">))</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"The grad of weight after step:</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">grad</span><span class="p">))</span>
- <span class="c1"># 权重清零</span>
- <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
- <span class="c1"># 检验权重是否为0</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"The grad of weight after optimizer.zero_grad():</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">grad</span><span class="p">))</span>
- <span class="c1"># 输出参数</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"optimizer.params_group is </span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">))</span>
- <span class="c1"># 查看参数位置,optimizer和weight的位置一样,我觉得这里可以参考Python是基于值管理</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"weight in optimizer:</span><span class="si">{}</span><span class="se">\n</span><span class="s2">weight in weight:</span><span class="si">{}</span><span class="se">\n</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">id</span><span class="p">(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s1">'params'</span><span class="p">][</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">id</span><span class="p">(</span><span class="n">weight</span><span class="p">)))</span>
- <span class="c1"># 添加参数:weight2</span>
- <span class="n">weight2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="n">optimizer</span><span class="o">.</span><span class="n">add_param_group</span><span class="p">({</span><span class="s2">"params"</span><span class="p">:</span> <span class="n">weight2</span><span class="p">,</span> <span class="s1">'lr'</span><span class="p">:</span> <span class="mf">0.0001</span><span class="p">,</span> <span class="s1">'nesterov'</span><span class="p">:</span> <span class="kc">True</span><span class="p">})</span>
- <span class="c1"># 查看现有的参数</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"optimizer.param_groups is</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">))</span>
- <span class="c1"># 查看当前状态信息</span>
- <span class="n">opt_state_dict</span> <span class="o">=</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"state_dict before step:</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="n">opt_state_dict</span><span class="p">)</span>
- <span class="c1"># 进行5次step操作</span>
- <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">50</span><span class="p">):</span>
- <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
- <span class="c1"># 输出现有状态信息</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"state_dict after step:</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">())</span>
- <span class="c1"># 保存参数信息</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="sa">r</span><span class="s2">"D:\pythonProject\Attention_Unet"</span><span class="p">,</span> <span class="s2">"optimizer_state_dict.pkl"</span><span class="p">))</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"----------done-----------"</span><span class="p">)</span>
- <span class="c1"># 加载参数信息</span>
- <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="sa">r</span><span class="s2">"D:\pythonProject\Attention_Unet\optimizer_state_dict.pkl"</span><span class="p">)</span> <span class="c1"># 需要修改为你自己的路径</span>
- <span class="n">optimizer</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"load state_dict successfully</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">state_dict</span><span class="p">))</span>
- <span class="c1"># 输出最后属性信息</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">defaults</span><span class="p">))</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">state</span><span class="p">))</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">))</span>
- </pre></div>
- </div>
- </section>
- <section id="id3">
- <h2>3.9.3 输出结果<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h2>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span># 进行更新前的数据,梯度
- The data of weight before step:
- tensor([[-0.3077, -0.1808],
- [-0.7462, -1.5556]])
- The grad of weight before step:
- tensor([[1., 1.],
- [1., 1.]])
- # 进行更新后的数据,梯度
- The data of weight after step:
- tensor([[-0.4077, -0.2808],
- [-0.8462, -1.6556]])
- The grad of weight after step:
- tensor([[1., 1.],
- [1., 1.]])
- # 进行梯度清零的梯度
- The grad of weight after optimizer.zero_grad():
- tensor([[0., 0.],
- [0., 0.]])
- # 输出信息
- optimizer.params_group is
- [{'params': [tensor([[-0.4077, -0.2808],
- [-0.8462, -1.6556]], requires_grad=True)], 'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
-
- # 证明了优化器的和weight的储存是在一个地方,Python基于值管理
- weight in optimizer:1841923407424
- weight in weight:1841923407424
-
- # 输出参数
- optimizer.param_groups is
- [{'params': [tensor([[-0.4077, -0.2808],
- [-0.8462, -1.6556]], requires_grad=True)], 'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[ 0.4539, -2.1901, -0.6662],
- [ 0.6630, -1.5178, -0.8708],
- [-2.0222, 1.4573, 0.8657]], requires_grad=True)], 'lr': 0.0001, 'nesterov': True, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0}]
-
- # 进行更新前的参数查看,用state_dict
- state_dict before step:
- {'state': {0: {'momentum_buffer': tensor([[1., 1.],
- [1., 1.]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0]}, {'lr': 0.0001, 'nesterov': True, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'params': [1]}]}
- # 进行更新后的参数查看,用state_dict
- state_dict after step:
- {'state': {0: {'momentum_buffer': tensor([[0.0052, 0.0052],
- [0.0052, 0.0052]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0]}, {'lr': 0.0001, 'nesterov': True, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'params': [1]}]}
-
- # 存储信息完毕
- ----------done-----------
- # 加载参数信息成功
- load state_dict successfully
- # 加载参数信息
- {'state': {0: {'momentum_buffer': tensor([[0.0052, 0.0052],
- [0.0052, 0.0052]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0]}, {'lr': 0.0001, 'nesterov': True, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'params': [1]}]}
-
- # defaults的属性输出
- {'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}
-
- # state属性输出
- defaultdict(<class 'dict'>, {tensor([[-1.3031, -1.1761],
- [-1.7415, -2.5510]], requires_grad=True): {'momentum_buffer': tensor([[0.0052, 0.0052],
- [0.0052, 0.0052]])}})
-
- # param_groups属性输出
- [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [tensor([[-1.3031, -1.1761],
- [-1.7415, -2.5510]], requires_grad=True)]}, {'lr': 0.0001, 'nesterov': True, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'params': [tensor([[ 0.4539, -2.1901, -0.6662],
- [ 0.6630, -1.5178, -0.8708],
- [-2.0222, 1.4573, 0.8657]], requires_grad=True)]}]
-
- </pre></div>
- </div>
- <p><strong>注意:</strong></p>
- <ol class="simple">
- <li><p>每个优化器都是一个类,我们一定要进行实例化才能使用,比如下方实现:</p></li>
- </ol>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span>class Net(nn.Moddule):
- ···
- net = Net()
- optim = torch.optim.SGD(net.parameters(),lr=lr)
- optim.step()
- </pre></div>
- </div>
- <ol class="simple">
- <li><p>optimizer在一个神经网络的epoch中需要实现下面两个步骤:</p>
- <ol class="simple">
- <li><p>梯度置零</p></li>
- <li><p>梯度更新</p></li>
- </ol>
- </li>
- </ol>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>
- <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">EPOCH</span><span class="p">):</span>
- <span class="o">...</span>
- <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="c1">#梯度置零</span>
- <span class="n">loss</span> <span class="o">=</span> <span class="o">...</span> <span class="c1">#计算loss</span>
- <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="c1">#BP反向传播</span>
- <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="c1">#梯度更新</span>
- </pre></div>
- </div>
- <ol class="simple">
- <li><p>给网络不同的层赋予不同的优化器参数。</p></li>
- </ol>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">optim</span>
- <span class="kn">from</span> <span class="nn">torchvision.models</span> <span class="kn">import</span> <span class="n">resnet18</span>
-
- <span class="n">net</span> <span class="o">=</span> <span class="n">resnet18</span><span class="p">()</span>
-
- <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">([</span>
- <span class="p">{</span><span class="s1">'params'</span><span class="p">:</span><span class="n">net</span><span class="o">.</span><span class="n">fc</span><span class="o">.</span><span class="n">parameters</span><span class="p">()},</span><span class="c1">#fc的lr使用默认的1e-5</span>
- <span class="p">{</span><span class="s1">'params'</span><span class="p">:</span><span class="n">net</span><span class="o">.</span><span class="n">layer4</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">conv1</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span><span class="s1">'lr'</span><span class="p">:</span><span class="mf">1e-2</span><span class="p">}],</span><span class="n">lr</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span>
-
- <span class="c1"># 可以使用param_groups查看属性</span>
- </pre></div>
- </div>
- </section>
- <section id="id4">
- <h2>3.9.4 实验<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h2>
- <p>为了更好的帮大家了解优化器,我们对PyTorch中的优化器进行了一个小测试</p>
- <p><strong>数据生成</strong>:</p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1000</span><span class="p">)</span>
- <span class="c1"># 升维操作</span>
- <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">y</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">()))</span>
- </pre></div>
- </div>
- <p><strong>数据分布曲线</strong>:</p>
- <p><img alt="" src="../_images/3.6.1.png" /></p>
- <p><strong>网络结构</strong></p>
- <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Net</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
- <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
- <span class="nb">super</span><span class="p">(</span><span class="n">Net</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">hidden</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">predict</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
-
- <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
- <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
- <span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
- <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
- <span class="k">return</span> <span class="n">x</span>
-
- </pre></div>
- </div>
- <p>下面这部分是测试图,纵坐标代表Loss,横坐标代表的是Step:</p>
- <p><img alt="" src="../_images/3.6.2.png" /></p>
- <p>在上面的图片上,曲线下降的趋势和对应的steps代表了在这轮数据,模型下的收敛速度</p>
- <p><strong>注意:</strong></p>
- <p>优化器的选择是需要根据模型进行改变的,不存在绝对的好坏之分,我们需要多进行一些测试。</p>
- <p>后续会添加SparseAdam,LBFGS这两个优化器的可视化结果</p>
- </section>
- </section>
-
-
- </div>
-
- </main>
- <footer class="footer-article noprint">
-
- <!-- Previous / next buttons -->
- <div class='prev-next-area'>
- <a class='left-prev' id="prev-link" href="3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html" title="上一页 页">
- <i class="fas fa-angle-left"></i>
- <div class="prev-next-info">
- <p class="prev-next-subtitle">上一页</p>
- <p class="prev-next-title">3.8 可视化</p>
- </div>
- </a>
- <a class='right-next' id="next-link" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html" title="下一页 页">
- <div class="prev-next-info">
- <p class="prev-next-subtitle">下一页</p>
- <p class="prev-next-title">第四章:PyTorch基础实战</p>
- </div>
- <i class="fas fa-angle-right"></i>
- </a>
- </div>
- </footer>
- </div>
- </div>
- <div class="footer-content row">
- <footer class="col footer"><p>
-
- By ZhikangNiu<br/>
-
- © Copyright 2022, ZhikangNiu.<br/>
- </p>
- </footer>
- </div>
-
- </div>
-
-
- </div>
- </div>
-
- <!-- Scripts loaded after <body> so the DOM is not blocked -->
- <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>
-
-
- </body>
- </html>
|