|
- """TrainOnestepGen network"""
-
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore import context
- from mindspore.communication.management import get_group_size
- from mindspore.context import ParallelMode
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
-
-
- class TrainOnestepGen(nn.Cell):
- """TrainOnestepGen
- Encapsulation class of DBPN network training.
- Append an optimizer to the training network after that the construct
- function can be called to create the backward graph.
- Args:
- network(Cell): Generator with loss Cell. Note that loss function should have been added
- optimizer(Cell):Optimizer for updating the weights.
- sens (Number): The adjust parameter. Default: 1.0.
- Outputs:
- Tensor
- """
-
- def __init__(self, network, optimizer, sens=1.0):
- super(TrainOnestepGen, self).__init__(auto_prefix=False)
-
- self.network = network # 定义前向网络
- self.network.set_grad() # 构建反向网络
- self.optimizer = optimizer # 定义优化器
- self.weights = self.optimizer.parameters # 待更新参数
- self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
-
- def construct(self, target, input, neigbor, flow):
- """Defines the computation performed."""
- loss = self.network(target, input, neigbor, flow)
- # sens_g = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
- grads = self.grad(self.network, self.weights)(target, input, neigbor, flow, loss)
- # grads = self.grad_reducer(grads)
- loss = ops.depend(loss, self.optimizer(grads))
- return loss
|