English | 简体中文
最新版本 v0.10.3 在 2024.1.24 发布。
亮点:
支持安装不依赖于 opencv 的 mmengine-lite 版本。可阅读安装文档了解用法。
支持使用 ColossalAI 进行训练。可阅读大模型训练了解用法。
支持梯度检查点。详见用法。
支持多种可视化后端,包括NeptuneVisBackend
、DVCLiveVisBackend
和 AimVisBackend
。可阅读可视化后端了解用法。
如果想了解更多版本更新细节和历史信息,请阅读更新日志。
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库。它为开发人员提供了坚实的工程基础,以此避免在工作流上编写冗余代码。作为 OpenMMLab 所有代码库的训练引擎,其在不同研究领域支持了上百个算法。此外,MMEngine 也可以用于非 OpenMMLab 项目中。
主要特性:
通用且强大的执行器:
接口统一的开放架构:
可定制的训练流程:
在安装 MMEngine 之前,请确保 PyTorch 已成功安装在环境中,可以参考 PyTorch 官方安装文档。
安装 MMEngine
pip install -U openmim
mim install mmengine
验证是否安装成功
python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'
更多安装方式请阅读安装文档。
以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、可配置的训练和验证流程。
首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel
,并且其 forward
方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode
。
mode
接受字符串 "loss",并返回一个包含 "loss" 字段的字典。mode
接受字符串 "predict",并返回同时包含预测信息和真实信息的结果。import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
其次,我们需要构建训练和验证所需要的数据集(Dataset)和数据加载器(DataLoader)。在该示例中,我们使用 TorchVision 支持的方式构建数据集。
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric
,并实现 process
和 compute_metrics
方法。
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# 将一个批次的中间结果保存至 `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# 返回保存有评测指标结果的字典,其中键为指标名称
return dict(accuracy=100 * total_correct / total_size)
最后,我们利用构建好的模型
,数据加载器
,评测指标
构建一个执行器(Runner),并伴随其他的配置信息,如下所示。
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
# 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
# 训练配置,例如 epoch 等
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
runner.train()
我们感谢所有的贡献者为改进和提升 MMEngine 所作出的努力。请参考贡献指南来了解参与项目贡献的相关指引。
如果您觉得 MMEngine 对您的研究有所帮助,请考虑引用它:
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
该项目采用 Apache 2.0 license 开源许可证。
扫描下方的二维码可关注 OpenMMLab 团队的 知乎官方账号,扫描下方微信二维码添加喵喵好友,进入 MMEngine 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】
我们会在 OpenMMLab 社区为大家
干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》