Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
xiahb 7bdfbe0754 | 1 year ago | |
---|---|---|
utils | 1 year ago | |
README.md | 1 year ago | |
test.py | 1 year ago | |
train.py | 1 year ago |
Pytorch实现:使用ResNet18网络训练Cifar10数据集,测试集准确率达到95.46%(从0开始,不使用预训练模型)
Pytorch
实现:使用ResNet18
网络训练Cifar10
数据集,测试集准确率达到95.46%(从0开始,不使用预训练模型)作者:ZOMIN
:ZOMIN28 (github.com)
本文将介绍如何使用数据增强和模型修改的方式,在不使用任何预训练模型参数的情况下,在ResNet18
网络上对Cifar10
数据集进行分类任务。在测试集上,我们的模型准确率可以达到95.46%。在Kaggle
的Cifar10
比赛上,我训练的模型在300,000的超大Cifar10
数据集上依然可以达到95.46%的准确率。
Cifar10
数据集Cifar10
数据集由10个类的60000个尺寸为32x32
的RGB
彩色图像组成,每个类有6000个图像, 有50000个训练图像和10000个测试图像。
在使用Pytorch
时,我们可以直接使用torchvision.datasets.CIFAR10()
方法获取该数据集。
为了提高模型的泛化性,防止训练时在训练集上过拟合,往往在训练的过程中会对训练集进行数据增强操作,例如随机翻转、遮挡、填充后裁剪等操作。我们这里对训练集做如下三种处理:
代码如下:
transforms.RandomHorizontalFlip()
我们可以将尺寸为32x32
的图像填充为40x40
,然后随机裁剪成32x32
。
transforms.RandomCrop(32, padding=4)
Cutout操作会随机遮挡图片的若干尺寸的若干块,尺寸和块可以根据自己的需要设置。
调用代码如下,这里我们设置块为1,尺寸长度为16个像素。cutout的完整操作将在后面给出。Github链接:https://github.com/uoguelph-mlrg/Cutout
Cutout(n_holes=1, length=16)
ResNet18
模型考虑到CIFAR10
数据集的图片尺寸太小,ResNet18
网络的7x7
降采样卷积和池化操作容易丢失一部分信息,所以在实验中我们将7x7
的降采样层和最大池化层去掉,替换为一个3x3
的降采样卷积,同时减小该卷积层的步长和填充大小,这样可以尽可能保留原始图像的信息。
修改卷积层如下:
model.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
删去最大池化层:
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
#x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
在模型的训练上,我们采用的策略是:设置初始学习率为0.1,每当经过10个epoch训练的验证集损失没有下降时,学习率变为原来的0.5,共训练250个epoch。在训练中,我们的batch_size大小为128,优化器为SGD
:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
完整的代码已上传
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》