|
|
@@ -0,0 +1,165 @@ |
|
|
|
# Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""train midas.""" |
|
|
|
import os |
|
|
|
import json |
|
|
|
from mindspore import dtype as mstype |
|
|
|
from mindspore import context |
|
|
|
from mindspore import nn |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore.context import ParallelMode |
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore.common import set_seed |
|
|
|
from mindspore.train.serialization import load_checkpoint |
|
|
|
from mindspore.train.model import Model |
|
|
|
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig |
|
|
|
from mindspore.communication.management import init, get_rank |
|
|
|
from src.midas_net import MidasNet, Loss, NetwithCell |
|
|
|
from src.utils import loadImgDepth |
|
|
|
from src.config import config |
|
|
|
|
|
|
|
set_seed(1) |
|
|
|
ds.config.set_seed(1) |
|
|
|
|
|
|
|
|
|
|
|
def dynamic_lr(num_epoch_per_decay, total_epochs, steps_per_epoch, lr, end_lr): |
|
|
|
""" |
|
|
|
dynamic learning rate generator |
|
|
|
Return the value, lr_each_step. |
|
|
|
""" |
|
|
|
lr_each_step = [] |
|
|
|
total_steps = steps_per_epoch * total_epochs |
|
|
|
decay_steps = steps_per_epoch * num_epoch_per_decay |
|
|
|
lr = nn.PolynomialDecayLR(lr, end_lr, decay_steps, 0.5) |
|
|
|
for i in range(total_steps): |
|
|
|
if i < decay_steps: |
|
|
|
i = Tensor(i, mstype.int32) |
|
|
|
lr_each_step.append(lr(i).asnumpy()) |
|
|
|
else: |
|
|
|
lr_each_step.append(end_lr) |
|
|
|
return lr_each_step |
|
|
|
|
|
|
|
|
|
|
|
def train(mixdata_path): |
|
|
|
"""train""" |
|
|
|
epoch_number_total = config.epoch_size |
|
|
|
batch_size = config.batch_size |
|
|
|
if config.is_modelarts: |
|
|
|
import moxing as mox |
|
|
|
device_id = int(os.getenv('DEVICE_ID')) |
|
|
|
device_num = int(os.getenv('RANK_SIZE')) |
|
|
|
local_data_path = '/cache/data' |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, max_call_depth=10000) |
|
|
|
context.set_context(device_id=device_id) |
|
|
|
# define distributed local data path |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True) |
|
|
|
init() |
|
|
|
local_data_path = os.path.join(local_data_path, str(device_id)) |
|
|
|
mixdata_path = os.path.join(local_data_path, mixdata_path) |
|
|
|
load_path = os.path.join(local_data_path, 'midas_resnext_101_WSL.ckpt') |
|
|
|
output_path = config.train_url |
|
|
|
# data download |
|
|
|
mox.file.copy_parallel(src_url=config.data_url, dst_url=local_data_path) |
|
|
|
elif config.run_distribute: |
|
|
|
if config.device_target == 'GPU': |
|
|
|
device_num = int(os.getenv('RANK_SIZE', '1')) |
|
|
|
if device_num > 1: |
|
|
|
init("nccl") |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True) |
|
|
|
device_id = get_rank() |
|
|
|
context.set_context(device_id=device_id, enable_graph_kernel=True) |
|
|
|
else: |
|
|
|
device_id = int(os.getenv('DEVICE_ID')) |
|
|
|
device_num = int(os.getenv('RANK_SIZE')) |
|
|
|
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, |
|
|
|
max_call_depth=10000) |
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True, |
|
|
|
device_num=device_num |
|
|
|
) |
|
|
|
init() |
|
|
|
local_data_path = config.train_data_dir |
|
|
|
mixdata_path = config.train_json_data_dir |
|
|
|
load_path = config.model_weights |
|
|
|
else: |
|
|
|
local_data_path = config.train_data_dir |
|
|
|
mixdata_path = config.train_json_data_dir |
|
|
|
load_path = config.model_weights |
|
|
|
device_id = config.device_id |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, |
|
|
|
save_graphs=False, device_id=device_id, |
|
|
|
max_call_depth=10000) |
|
|
|
# load data |
|
|
|
f = open(mixdata_path) |
|
|
|
data_config = json.load(f) |
|
|
|
img_paths = data_config['img'] |
|
|
|
f.close() |
|
|
|
mix_dataset = loadImgDepth.LoadImagesDepth(local_path=local_data_path, img_paths=img_paths) |
|
|
|
ds.config.set_enable_shared_mem(False) |
|
|
|
if config.is_modelarts or config.run_distribute: |
|
|
|
mix_dataset = ds.GeneratorDataset(mix_dataset, ['img', 'mask', 'depth'], shuffle=True, num_parallel_workers=8, |
|
|
|
num_shards=device_num, shard_id=device_id) |
|
|
|
else: |
|
|
|
mix_dataset = ds.GeneratorDataset(mix_dataset, ['img', 'mask', 'depth'], shuffle=True) |
|
|
|
mix_dataset = mix_dataset.batch(8, drop_remainder=True) |
|
|
|
per_step_size = mix_dataset.get_dataset_size() |
|
|
|
# define net_loss_opt |
|
|
|
net = MidasNet() |
|
|
|
net = net.set_train() |
|
|
|
loss = Loss() |
|
|
|
load_checkpoint(load_path, net=net) |
|
|
|
backbone_params = list(filter(lambda x: 'backbone' in x.name, net.trainable_params())) |
|
|
|
no_backbone_params = list(filter(lambda x: 'backbone' not in x.name, net.trainable_params())) |
|
|
|
if config.lr_decay: |
|
|
|
group_params = [{'params': backbone_params, |
|
|
|
'lr': nn.PolynomialDecayLR(config.backbone_params_lr |
|
|
|
, config.backbone_params_end_lr, |
|
|
|
epoch_number_total * per_step_size, config.power)}, |
|
|
|
{'params': no_backbone_params, |
|
|
|
'lr': nn.PolynomialDecayLR(config.no_backbone_params_lr, |
|
|
|
config.no_backbone_params_end_lr, |
|
|
|
epoch_number_total * per_step_size, config.power)}, |
|
|
|
{'order_params': net.trainable_params()}] |
|
|
|
else: |
|
|
|
group_params = [{'params': backbone_params, 'lr': 1e-5}, |
|
|
|
{'params': no_backbone_params, 'lr': 1e-4}, |
|
|
|
{'order_params': net.trainable_params()}] |
|
|
|
optim = nn.Adam(group_params) |
|
|
|
netwithLoss = NetwithCell(net, loss) |
|
|
|
midas_net = nn.TrainOneStepCell(netwithLoss, optim) |
|
|
|
model = Model(midas_net) |
|
|
|
# define callback |
|
|
|
loss_cb = LossMonitor() |
|
|
|
time_cb = TimeMonitor() |
|
|
|
checkpointconfig = CheckpointConfig(saved_network=net, save_checkpoint_steps=5, keep_checkpoint_max=2) |
|
|
|
if config.is_modelarts: |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='Midas_{}'.format(device_id), directory=local_data_path + '/output/ckpt', |
|
|
|
config=checkpointconfig) |
|
|
|
else: |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='Midas_{}'.format(device_id), directory='./ckpt/', config=checkpointconfig) |
|
|
|
callbacks = [loss_cb, time_cb, ckpoint_cb] |
|
|
|
# train |
|
|
|
print("Starting Training:per_step_size={},batchsize={},epoch={}".format(per_step_size, batch_size, |
|
|
|
epoch_number_total)) |
|
|
|
model.train(epoch_number_total, mix_dataset, callbacks=callbacks) |
|
|
|
if config.is_modelarts: |
|
|
|
mox.file.copy_parallel(local_data_path + "/output", output_path) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
train(mixdata_path="mixdata.json") |