#3 master

Open
JKsaigo wants to merge 41 commits from Wan_/Bi_Real_Net:master into master
  1. +1
    -1
      src/args.py
  2. +3
    -3
      src/configs/birealnet34.yaml
  3. +6
    -6
      src/models/birealnet/birealnet.py
  4. +1
    -1
      src/tools/cell.py
  5. +24
    -12
      src/tools/optimizer.py

+ 1
- 1
src/args.py View File

@@ -75,7 +75,7 @@ def parse_arguments():
parser.add_argument("--num_classes", default=1000, type=int) parser.add_argument("--num_classes", default=1000, type=int)
parser.add_argument("--pretrained", dest="pretrained", default=None, type=str, help="use pre-trained model") parser.add_argument("--pretrained", dest="pretrained", default=None, type=str, help="use pre-trained model")
parser.add_argument("--config_file", help="Config file to use (see configs dir)", default=None, required=False) parser.add_argument("--config_file", help="Config file to use (see configs dir)", default=None, required=False)
parser.add_argument("--seed", default=0, type=int, help="seed for initializing training. ")
parser.add_argument("--seed", default=42, type=int, help="seed for initializing training. ")
parser.add_argument("--save_every", default=10, type=int, help="Save every ___ epochs(default:10)") parser.add_argument("--save_every", default=10, type=int, help="Save every ___ epochs(default:10)")
parser.add_argument("--label_smoothing", type=float, help="Label smoothing to use, default 0.0", default=0.1) parser.add_argument("--label_smoothing", type=float, help="Label smoothing to use, default 0.0", default=0.1)
parser.add_argument("--image_size", default=224, help="Image Size.", type=int) parser.add_argument("--image_size", default=224, help="Image Size.", type=int)


+ 3
- 3
src/configs/birealnet34.yaml View File

@@ -15,9 +15,9 @@ cutmix: 0.


# ===== Learning Rate Policy ======== # # ===== Learning Rate Policy ======== #
optimizer: adam optimizer: adam
base_lr: 0.002
base_lr: 0.001
warmup_lr: 0.000006 warmup_lr: 0.000006
min_lr: 0.
min_lr: 0.000001
lr_scheduler: lambda_lr lr_scheduler: lambda_lr
warmup_length: 0 warmup_length: 0


@@ -36,7 +36,7 @@ momentum: 0.9
batch_size: 128 batch_size: 128


# ===== Hardware setup ===== # # ===== Hardware setup ===== #
num_parallel_workers: 16
num_parallel_workers: 32
device_target: Ascend device_target: Ascend


# ===== Model config ===== # # ===== Model config ===== #

+ 6
- 6
src/models/birealnet/birealnet.py View File

@@ -5,10 +5,10 @@ from mindspore.common import dtype as mstype
__all__ = ['birealnet18', 'birealnet34'] __all__ = ['birealnet18', 'birealnet34']




if os.getenv("DEVICE_TARGET") == "Ascend" and int(os.getenv("DEVICE_NUM")) > 1:
BatchNorm2d = nn.SyncBatchNorm
else:
BatchNorm2d = nn.BatchNorm2d
# if os.getenv("DEVICE_TARGET") == "Ascend" and int(os.getenv("DEVICE_NUM")) > 1:
# BatchNorm2d = nn.SyncBatchNorm
# else:
BatchNorm2d = nn.BatchNorm2d




class AdaptiveAvgPool2d(nn.Cell): class AdaptiveAvgPool2d(nn.Cell):
@@ -18,7 +18,7 @@ class AdaptiveAvgPool2d(nn.Cell):
self.mean = ops.ReduceMean(True) self.mean = ops.ReduceMean(True)


def construct(self, x): def construct(self, x):
x = self.mean(x, (2, 3))
x = self.mean(x[:, :, 0:7, 0:7 ], (-2,-1))
return x return x




@@ -63,7 +63,7 @@ class HardBinaryConv(nn.Cell):
self.padding = padding self.padding = padding
self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size
self.shape = (out_chn, in_chn, kernel_size, kernel_size) self.shape = (out_chn, in_chn, kernel_size, kernel_size)
self.weights = Parameter(ops.UniformReal()((self.number_of_weights,1)), requires_grad=True)
self.weights = Parameter(ops.UniformReal()((self.number_of_weights,1)) * 0.001, requires_grad=True)
self.conv2d = ops.Conv2D(out_channel=out_chn, kernel_size=3, stride=self.stride, pad=self.padding, pad_mode="pad") self.conv2d = ops.Conv2D(out_channel=out_chn, kernel_size=3, stride=self.stride, pad=self.padding, pad_mode="pad")
self.mean = ops.ReduceMean(keep_dims=True) self.mean = ops.ReduceMean(keep_dims=True)




+ 1
- 1
src/tools/cell.py View File

@@ -33,7 +33,7 @@ def cast_amp(net):
print(f"=> using amp_level {args.amp_level}\n" print(f"=> using amp_level {args.amp_level}\n"
f"=> change {args.arch} to fp16") f"=> change {args.arch} to fp16")
net.to_float(mstype.float16) net.to_float(mstype.float16)
cell_types = (nn.GELU, nn.Softmax, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d, nn.LayerNorm, nn.SyncBatchNorm)
cell_types = (nn.GELU, nn.Softmax, nn.BatchNorm2d, nn.LayerNorm, nn.SyncBatchNorm)
# cell_types = (nn.GELU, nn.Softmax, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d, nn.LayerNorm, nn.ReLU, nn.Dense) # cell_types = (nn.GELU, nn.Softmax, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d, nn.LayerNorm, nn.ReLU, nn.Dense)
print(f"=> cast {cell_types} to fp32 back") print(f"=> cast {cell_types} to fp32 back")
do_keep_fp32(net, cell_types) do_keep_fp32(net, cell_types)


+ 24
- 12
src/tools/optimizer.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Functions of optimizer""" """Functions of optimizer"""
import os import os
import numpy as np


from mindspore.nn.optim import AdamWeightDecay, Adam from mindspore.nn.optim import AdamWeightDecay, Adam
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
@@ -33,6 +34,17 @@ def get_optimizer(args, model, batch_num):
optim_type = args.optimizer.lower() optim_type = args.optimizer.lower()
params = get_param_groups(model) params = get_param_groups(model)
learning_rate = get_learning_rate(args, batch_num) learning_rate = get_learning_rate(args, batch_num)
learning_rate = learning_rate * args.batch_size * int(os.getenv("DEVICE_NUM", args.device_num)) / 512.

additional_list = []
# additional_lr = 0.0000078125
for additional_epoch in range(44):
for additional_step in range(batch_num):
additional_list.append(args.min_lr)
# additional_list.append(additional_lr * (1.0 - additional_epoch / 44))
learning_rate = np.append(learning_rate, additional_list) # add 44 epochs
args.epochs += 44

step = int(args.start_epoch * batch_num) step = int(args.start_epoch * batch_num)
accumulation_step = int(args.accumulation_step) accumulation_step = int(args.accumulation_step)
learning_rate = learning_rate[step::accumulation_step] learning_rate = learning_rate[step::accumulation_step]
@@ -41,7 +53,7 @@ def get_optimizer(args, model, batch_num):
f"=> Start step: {step}\n" f"=> Start step: {step}\n"
f"=> Total step: {train_step}\n" f"=> Total step: {train_step}\n"
f"=> Accumulation step:{accumulation_step}") f"=> Accumulation step:{accumulation_step}")
learning_rate = learning_rate * args.batch_size * int(os.getenv("DEVICE_NUM", args.device_num)) / 512.
if accumulation_step > 1: if accumulation_step > 1:
learning_rate = learning_rate * accumulation_step learning_rate = learning_rate * accumulation_step


@@ -77,20 +89,20 @@ def get_param_groups(network):
decay_params = [] decay_params = []
no_decay_params = [] no_decay_params = []


for x in network.trainable_params():
if len(x.shape) == 4 or x.name=='classifier.0.weight' or x.name == 'classifier.0.bias':
decay_params.append(x)
else:
no_decay_params.append(x)

# for x in network.trainable_params(): # for x in network.trainable_params():
# parameter_name = x.name
# if parameter_name.endswith(".weight"):
# # Dense or Conv's weight using weight decay
# if len(x.shape) == 4 or x.name=='classifier.0.weight' or x.name == 'classifier.0.bias':
# decay_params.append(x) # decay_params.append(x)
# else: # else:
# # all bias not using weight decay
# # bn weight bias not using weight decay, be carefully for now x not include LN
# no_decay_params.append(x) # no_decay_params.append(x)


for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith(".weight"):
# Dense or Conv's weight using weight decay
decay_params.append(x)
else:
# all bias not using weight decay
# bn weight bias not using weight decay, be carefully for now x not include LN
no_decay_params.append(x)

return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

Loading…
Cancel
Save