@@ -14,6 +14,7 @@
# ============================================================================
"""Functions of optimizer"""
import os
import numpy as np
from mindspore.nn.optim import AdamWeightDecay, Adam
from mindspore.nn.optim.momentum import Momentum
@@ -33,6 +34,17 @@ def get_optimizer(args, model, batch_num):
optim_type = args.optimizer.lower()
params = get_param_groups(model)
learning_rate = get_learning_rate(args, batch_num)
learning_rate = learning_rate * args.batch_size * int(os.getenv("DEVICE_NUM", args.device_num)) / 512.
additional_list = []
# additional_lr = 0.0000078125
for additional_epoch in range(44):
for additional_step in range(batch_num):
additional_list.append(args.min_lr)
# additional_list.append(additional_lr * (1.0 - additional_epoch / 44))
learning_rate = np.append(learning_rate, additional_list) # add 44 epochs
args.epochs += 44
step = int(args.start_epoch * batch_num)
accumulation_step = int(args.accumulation_step)
learning_rate = learning_rate[step::accumulation_step]
@@ -41,7 +53,7 @@ def get_optimizer(args, model, batch_num):
f"=> Start step: {step}\n"
f"=> Total step: {train_step}\n"
f"=> Accumulation step:{accumulation_step}")
learning_rate = learning_rate * args.batch_size * int(os.getenv("DEVICE_NUM", args.device_num)) / 512.
if accumulation_step > 1:
learning_rate = learning_rate * accumulation_step
@@ -77,20 +89,20 @@ def get_param_groups(network):
decay_params = []
no_decay_params = []
for x in network.trainable_params():
if len(x.shape) == 4 or x.name=='classifier.0.weight' or x.name == 'classifier.0.bias':
decay_params.append(x)
else:
no_decay_params.append(x)
# for x in network.trainable_params():
# parameter_name = x.name
# if parameter_name.endswith(".weight"):
# # Dense or Conv's weight using weight decay
# if len(x.shape) == 4 or x.name=='classifier.0.weight' or x.name == 'classifier.0.bias':
# decay_params.append(x)
# else:
# # all bias not using weight decay
# # bn weight bias not using weight decay, be carefully for now x not include LN
# no_decay_params.append(x)
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith(".weight"):
# Dense or Conv's weight using weight decay
decay_params.append(x)
else:
# all bias not using weight decay
# bn weight bias not using weight decay, be carefully for now x not include LN
no_decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]