English | 简体中文
最新版本 v0.8.4 在 2023.08.03 发布。
亮点:
支持使用 efficient_conv_bn_eval
参数开启更高效的 ConvBN
推理模式。详见节省显存文档
新增微调 Llama2 的示例。
引入纯 Python 风格的配置文件:
请参考教程以获取更详细的用法说明。
如果想了解更多版本更新细节和历史信息,请阅读更新日志
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库。它为开发人员提供了坚实的工程基础,以此避免在工作流上编写冗余代码。作为 OpenMMLab 所有代码库的训练引擎,其在不同研究领域支持了上百个算法。此外,MMEngine 也可以用于非 OpenMMLab 项目中。
主要特性:
通用且强大的执行器:
接口统一的开放架构:
可定制的训练流程:
在安装 MMEngine 之前,请确保 PyTorch 已成功安装在环境中,可以参考 PyTorch 官方安装文档。
安装 MMEngine
pip install -U openmim
mim install mmengine
验证是否安装成功
python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'
更多安装方式请阅读安装文档。
以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、可配置的训练和验证流程。
首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel
,并且其 forward
方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode
。
mode
接受字符串 "loss",并返回一个包含 "loss" 字段的字典。mode
接受字符串 "predict",并返回同时包含预测信息和真实信息的结果。import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
其次,我们需要构建训练和验证所需要的数据集(Dataset)和数据加载器(DataLoader)。在该示例中,我们使用 TorchVision 支持的方式构建数据集。
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric
,并实现 process
和 compute_metrics
方法。
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# 将一个批次的中间结果保存至 `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# 返回保存有评测指标结果的字典,其中键为指标名称
return dict(accuracy=100 * total_correct / total_size)
最后,我们利用构建好的模型
,数据加载器
,评测指标
构建一个执行器(Runner),并伴随其他的配置信息,如下所示。
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
# 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
# 训练配置,例如 epoch 等
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
runner.train()
我们感谢所有的贡献者为改进和提升 MMEngine 所作出的努力。请参考贡献指南来了解参与项目贡献的相关指引。
如果您觉得 MMEngine 对您的研究有所帮助,请考虑引用它:
@article{mmengine2022,
title = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
author = {MMEngine Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmengine}},
year={2022}
}
该项目采用 Apache 2.0 license 开源许可证。
扫描下方的二维码可关注 OpenMMLab 团队的 知乎官方账号,加入 OpenMMLab 团队的 官方交流 QQ 群,或通过添加微信“Open小喵Lab”加入官方交流微信群。
我们会在 OpenMMLab 社区为大家
干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》