mindcv
套件可以通过python的argparse
库和pyyaml
库解析模型的yaml文件来进行参数的配置,下面我们以squeezenet_1.0模型为例,解释如何配置相应的参数。
mode:使用静态图模式(0)或动态图模式(1)。
distribute:是否使用分布式。
mode: 0
distribute: True
...
python train.py --mode 0 --distribute False ...
args.model
代表参数mode
,args.distribute
代表参数distribute
。
def train(args):
ms.set_context(mode=args.mode)
if args.distribute:
init()
device_num = get_group_size()
rank_id = get_rank()
ms.set_auto_parallel_context(device_num=device_num,
parallel_mode='data_parallel',
gradients_mean=True)
else:
device_num = None
rank_id = None
...
dataset:数据集名称。
data_dir:数据集文件所在路径。
shuffle:是否进行数据混洗。
dataset_download:是否下载数据集。
batch_size:每个批处理数据包含的数据条目。
drop_remainder:当最后一个批处理数据包含的数据条目小于 batch_size 时,是否将该批处理丢弃。
num_parallel_workers:读取数据的工作线程数。
dataset: 'imagenet'
data_dir: './imagenet2012'
shuffle: True
dataset_download: False
batch_size: 32
drop_remainder: True
num_parallel_workers: 8
...
python train.py ... --dataset imagenet --data_dir ./imagenet2012 --shuffle True \
--dataset_download False --batch_size 32 --drop_remainder True \
--num_parallel_workers 8 ...
def train(args):
...
dataset_train = create_dataset(
name=args.dataset,
root=args.data_dir,
split='train',
shuffle=args.shuffle,
num_samples=args.num_samples,
num_shards=device_num,
shard_id=rank_id,
num_parallel_workers=args.num_parallel_workers,
download=args.dataset_download,
num_aug_repeats=args.aug_repeats)
...
target_transform = transforms.OneHot(num_classes) if args.loss == 'BCE' else None
loader_train = create_loader(
dataset=dataset_train,
batch_size=args.batch_size,
drop_remainder=args.drop_remainder,
is_training=True,
mixup=args.mixup,
cutmix=args.cutmix,
cutmix_prob=args.cutmix_prob,
num_classes=args.num_classes,
transform=transform_list,
target_transform=target_transform,
num_parallel_workers=args.num_parallel_workers,
)
...
image_resize:图像的输出尺寸大小。
scale:要裁剪的原始尺寸大小的各个尺寸的范围。
ratio:裁剪宽高比的范围。
hfilp:图像被翻转的概率。
interpolation:图像插值方式。
crop_pct:输入图像中心裁剪百分比。
color_jitter:颜色抖动因子(亮度调整因子,对比度调整因子,饱和度调整因子)。
re_prob:执行随机擦除的概率。
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
interpolation: 'bilinear'
crop_pct: 0.875
color_jitter: [0.4, 0.4, 0.4]
re_prob: 0.5
...
python train.py ... --image_resize 224 --scale [0.08, 1.0] --ratio [0.75, 1.333] \
--hflip 0.5 --interpolation "bilinear" --crop_pct 0.875 \
--color_jitter [0.4, 0.4, 0.4] --re_prob 0.5 ...
def train(args):
...
transform_list = create_transforms(
dataset_name=args.dataset,
is_training=True,
image_resize=args.image_resize,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
interpolation=args.interpolation,
auto_augment=args.auto_augment,
mean=args.mean,
std=args.std,
re_prob=args.re_prob,
re_scale=args.re_scale,
re_ratio=args.re_ratio,
re_value=args.re_value,
re_max_attempts=args.re_max_attempts
)
...
model:模型名称。
num_classes:分类的类别数。
pretrained:是否加载预训练模型。
ckpt_path:参数文件所在的路径。
keep_checkpoint_max:最多保存多少个checkpoint文件。
ckpt_save_dir:保存参数文件的路径。
epoch_size:训练执行轮次。
dataset_sink_mode:数据是否直接下沉至处理器进行处理。
amp_level:混合精度等级。
model: 'squeezenet1_0'
num_classes: 1000
pretrained: False
ckpt_path: './squeezenet1_0_gpu.ckpt'
keep_checkpoint_max: 10
ckpt_save_dir: './ckpt/'
epoch_size: 200
dataset_sink_mode: True
amp_level: 'O0'
...
python train.py ... --model squeezenet1_0 --num_classes 1000 --pretrained False \
--ckpt_path ./squeezenet1_0_gpu.ckpt --keep_checkpoint_max 10 \
--ckpt_save_path ./ckpt/ --epoch_size 200 --dataset_sink_mode True \
--amp_level O0 ...
def train(args):
...
network = create_model(model_name=args.model,
num_classes=args.num_classes,
in_channels=args.in_channels,
drop_rate=args.drop_rate,
drop_path_rate=args.drop_path_rate,
pretrained=args.pretrained,
checkpoint_path=args.ckpt_path,
ema=args.ema)
...
loss:损失函数的简称。
label_smoothing:标签平滑值,用于计算Loss时防止模型过拟合的正则化手段。
loss: 'CE'
label_smoothing: 0.1
...
python train.py ... --loss CE --label_smoothing 0.1 ...
def train(args):
...
loss = create_loss(name=args.loss,
reduction=args.reduction,
label_smoothing=args.label_smoothing,
aux_factor=args.aux_factor)
...
scheduler:学习率策略的名称。
min_lr:学习率的最小值。
lr:学习率的最大值。
warmup_epochs:学习率warmup的轮次。
decay_epochs:进行衰减的step数。
scheduler: 'cosine_decay'
min_lr: 0.0
lr: 0.01
warmup_epochs: 0
decay_epochs: 200
...
python train.py ... --scheduler cosine_decay --min_lr 0.0 --lr 0.01 \
--warmup_epochs 0 --decay_epochs 200 ...
def train(args):
...
lr_scheduler = create_scheduler(num_batches,
scheduler=args.scheduler,
lr=args.lr,
min_lr=args.min_lr,
warmup_epochs=args.warmup_epochs,
warmup_factor=args.warmup_factor,
decay_epochs=args.decay_epochs,
decay_rate=args.decay_rate,
milestones=args.multi_step_decay_milestones,
num_epochs=args.epoch_size,
lr_epoch_stair=args.lr_epoch_stair)
...
opt:优化器名称。
filter_bias_and_bn:参数中是否包含bias,gamma或者beta。
momentum:移动平均的动量。
weight_decay:权重衰减(L2 penalty)。
loss_scale:梯度缩放系数
use_nesterov:是否使用Nesterov Accelerated Gradient (NAG)算法更新梯度。
opt: 'momentum'
filter_bias_and_bn: True
momentum: 0.9
weight_decay: 0.00007
loss_scale: 1024
use_nesterov: False
...
python train.py ... --opt momentum --filter_bias_and_bn True --weight_decay 0.00007 \
--loss_scale 1024 --use_nesterov False ...
def train(args):
...
if args.ema:
optimizer = create_optimizer(network.trainable_params(),
opt=args.opt,
lr=lr_scheduler,
weight_decay=args.weight_decay,
momentum=args.momentum,
nesterov=args.use_nesterov,
filter_bias_and_bn=args.filter_bias_and_bn,
loss_scale=args.loss_scale,
checkpoint_path=opt_ckpt_path,
eps=args.eps)
else:
optimizer = create_optimizer(network.trainable_params(),
opt=args.opt,
lr=lr_scheduler,
weight_decay=args.weight_decay,
momentum=args.momentum,
nesterov=args.use_nesterov,
filter_bias_and_bn=args.filter_bias_and_bn,
checkpoint_path=opt_ckpt_path,
eps=args.eps)
...
使用parse设置参数可以覆盖yaml文件中的参数设置。以下面的shell命令为例,
python train.py -c ./configs/squeezenet/squeezenet_1.0_gpu.yaml --data_dir ./data
上面的命令将args.data_dir
参数的值由yaml文件中的 ./imagenet2012 覆盖为 ./data。
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》