|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Bert for pretraining."""
- import numpy as np
-
- import mindspore.nn as nn
- from mindspore.common.initializer import initializer, TruncatedNormal
- from mindspore.ops import operations as P
- from mindspore.ops import functional as F
- from mindspore.ops import composite as C
- from mindspore.common.tensor import Tensor
- from mindspore.common.parameter import Parameter
- from mindspore.common import dtype as mstype
- from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
- from mindspore.context import ParallelMode
- from mindspore.communication.management import get_group_size
- from mindspore import context
- from .bert_model import BertModel
-
- GRADIENT_CLIP_TYPE = 1
- GRADIENT_CLIP_VALUE = 1.0
-
- clip_grad = C.MultitypeFuncGraph("clip_grad")
-
-
- @clip_grad.register("Number", "Number", "Tensor")
- def _clip_grad(clip_type, clip_value, grad):
- """
- Clip gradients.
-
- Inputs:
- clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
- clip_value (float): Specifies how much to clip.
- grad (tuple[Tensor]): Gradients.
-
- Outputs:
- tuple[Tensor], clipped gradients.
- """
- if clip_type not in (0, 1):
- return grad
- dt = F.dtype(grad)
- if clip_type == 0:
- new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
- F.cast(F.tuple_to_array((clip_value,)), dt))
- else:
- new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
- return new_grad
-
-
- class GetMaskedLMOutput(nn.Cell):
- """
- Get masked lm output.
-
- Args:
- config (BertConfig): The config of BertModel.
-
- Returns:
- Tensor, masked lm output.
- """
-
- def __init__(self, config):
- super(GetMaskedLMOutput, self).__init__()
- self.width = config.hidden_size
- self.reshape = P.Reshape()
- self.gather = P.Gather()
-
- weight_init = TruncatedNormal(config.initializer_range)
- self.dense = nn.Dense(self.width,
- config.hidden_size,
- weight_init=weight_init,
- activation=config.hidden_act).to_float(config.compute_type)
- self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
- self.output_bias = Parameter(
- initializer(
- 'zero',
- config.vocab_size))
- self.matmul = P.MatMul(transpose_b=True)
- self.log_softmax = nn.LogSoftmax(axis=-1)
- self.shape_flat_offsets = (-1, 1)
- self.last_idx = (-1,)
- self.shape_flat_sequence_tensor = (-1, self.width)
- self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32))
- self.cast = P.Cast()
- self.compute_type = config.compute_type
- self.dtype = config.dtype
-
- def construct(self,
- input_tensor,
- output_weights,
- positions):
- """Get output log_probs"""
- rng = F.tuple_to_array(F.make_range(P.Shape()(input_tensor)[0]))
- flat_offsets = self.reshape(rng * self.seq_length_tensor, self.shape_flat_offsets)
- flat_position = self.reshape(positions + flat_offsets, self.last_idx)
- flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
- input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
- input_tensor = self.cast(input_tensor, self.compute_type)
- output_weights = self.cast(output_weights, self.compute_type)
- input_tensor = self.dense(input_tensor)
- input_tensor = self.layernorm(input_tensor)
- logits = self.matmul(input_tensor, output_weights)
- logits = self.cast(logits, self.dtype)
- logits = logits + self.output_bias
- log_probs = self.log_softmax(logits)
- return log_probs
-
-
- class GetNextSentenceOutput(nn.Cell):
- """
- Get next sentence output.
-
- Args:
- config (BertConfig): The config of Bert.
-
- Returns:
- Tensor, next sentence output.
- """
-
- def __init__(self, config):
- super(GetNextSentenceOutput, self).__init__()
- self.log_softmax = P.LogSoftmax()
- weight_init = TruncatedNormal(config.initializer_range)
- self.dense = nn.Dense(config.hidden_size, 2,
- weight_init=weight_init, has_bias=True).to_float(config.compute_type)
- self.dtype = config.dtype
- self.cast = P.Cast()
-
- def construct(self, input_tensor):
- logits = self.dense(input_tensor)
- logits = self.cast(logits, self.dtype)
- log_prob = self.log_softmax(logits)
- return log_prob
-
-
- class BertPreTraining(nn.Cell):
- """
- Bert pretraining network.
-
- Args:
- config (BertConfig): The config of BertModel.
- is_training (bool): Specifies whether to use the training mode.
- use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
-
- Returns:
- Tensor, prediction_scores, seq_relationship_score.
- """
-
- def __init__(self, config, is_training, use_one_hot_embeddings):
- super(BertPreTraining, self).__init__()
- self.bert = BertModel(config, is_training, use_one_hot_embeddings)
- self.cls1 = GetMaskedLMOutput(config)
- self.cls2 = GetNextSentenceOutput(config)
-
- def construct(self, input_ids, input_mask, token_type_id,
- masked_lm_positions):
- sequence_output, pooled_output, embedding_table = \
- self.bert(input_ids, token_type_id, input_mask)
- prediction_scores = self.cls1(sequence_output,
- embedding_table,
- masked_lm_positions)
- seq_relationship_score = self.cls2(pooled_output)
- return prediction_scores, seq_relationship_score
-
-
- class BertPretrainingLoss(nn.Cell):
- """
- Provide bert pre-training loss.
-
- Args:
- config (BertConfig): The config of BertModel.
-
- Returns:
- Tensor, total loss.
- """
-
- def __init__(self, config):
- super(BertPretrainingLoss, self).__init__()
- self.vocab_size = config.vocab_size
- self.onehot = P.OneHot()
- self.on_value = Tensor(1.0, mstype.float32)
- self.off_value = Tensor(0.0, mstype.float32)
- self.reduce_sum = P.ReduceSum()
- self.reduce_mean = P.ReduceMean()
- self.reshape = P.Reshape()
- self.last_idx = (-1,)
- self.neg = P.Neg()
- self.cast = P.Cast()
-
- def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids,
- masked_lm_weights, next_sentence_labels):
- """Defines the computation performed."""
- label_ids = self.reshape(masked_lm_ids, self.last_idx)
- label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32)
- one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
-
- per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
- numerator = self.reduce_sum(label_weights * per_example_loss, ())
- denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
- masked_lm_loss = numerator / denominator
-
- # next_sentence_loss
- labels = self.reshape(next_sentence_labels, self.last_idx)
- one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value)
- per_example_loss = self.neg(self.reduce_sum(
- one_hot_labels * seq_relationship_score, self.last_idx))
- next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx)
-
- # total_loss
- total_loss = masked_lm_loss + next_sentence_loss
-
- return total_loss
-
-
- class BertNetworkWithLoss(nn.Cell):
- """
- Provide bert pre-training loss through network.
-
- Args:
- config (BertConfig): The config of BertModel.
- is_training (bool): Specifies whether to use the training mode.
- use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
-
- Returns:
- Tensor, the loss of the network.
- """
-
- def __init__(self, config, is_training, use_one_hot_embeddings=False):
- super(BertNetworkWithLoss, self).__init__()
- self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings)
- self.loss = BertPretrainingLoss(config)
- self.cast = P.Cast()
-
- def construct(self,
- input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights):
- """Get pre-training loss"""
- prediction_scores, seq_relationship_score = \
- self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
- total_loss = self.loss(prediction_scores, seq_relationship_score,
- masked_lm_ids, masked_lm_weights, next_sentence_labels)
- return self.cast(total_loss, mstype.float32)
-
-
- class BertTrainOneStepCell(nn.TrainOneStepCell):
- """
- Encapsulation class of bert network training.
-
- Append an optimizer to the training network after that the construct
- function can be called to create the backward graph.
-
- Args:
- network (Cell): The training network. Note that loss function should have been added.
- optimizer (Optimizer): Optimizer for updating the weights.
- sens (Number): The adjust parameter. Default: 1.0.
- """
-
- def __init__(self, network, optimizer, sens=1.0):
- super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
- self.cast = P.Cast()
- self.hyper_map = C.HyperMap()
-
- def set_sens(self, value):
- self.sens = value
-
- def construct(self,
- input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights):
- """Defines the computation performed."""
- weights = self.weights
-
- loss = self.network(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights)
- grads = self.grad(self.network, weights)(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- self.cast(F.tuple_to_array((self.sens,)),
- mstype.float32))
- grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
- grads = self.grad_reducer(grads)
- succ = self.optimizer(grads)
- return F.depend(loss, succ)
-
-
- grad_scale = C.MultitypeFuncGraph("grad_scale")
- reciprocal = P.Reciprocal()
-
-
- @grad_scale.register("Tensor", "Tensor")
- def tensor_grad_scale(scale, grad):
- return grad * reciprocal(scale)
-
-
- _grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
- grad_overflow = P.FloatStatus()
-
-
- @_grad_overflow.register("Tensor")
- def _tensor_grad_overflow(grad):
- return grad_overflow(grad)
-
-
- class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
- """
- Encapsulation class of bert network training.
-
- Append an optimizer to the training network after that the construct
- function can be called to create the backward graph.
-
- Args:
- network (Cell): The training network. Note that loss function should have been added.
- optimizer (Optimizer): Optimizer for updating the weights.
- scale_update_cell (Cell): Cell to do the loss scale. Default: None.
- """
-
- def __init__(self, network, optimizer, scale_update_cell=None):
- super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
- self.cast = P.Cast()
- self.degree = 1
- if self.reducer_flag:
- self.degree = get_group_size()
- self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
-
- self.loss_scale = None
- self.loss_scaling_manager = scale_update_cell
- if scale_update_cell:
- self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
-
- def construct(self,
- input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- sens=None):
- """Defines the computation performed."""
- weights = self.weights
- loss = self.network(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights)
- if sens is None:
- scaling_sens = self.loss_scale
- else:
- scaling_sens = sens
- status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
- grads = self.grad(self.network, weights)(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- self.cast(scaling_sens,
- mstype.float32))
- # apply grad reducer on grads
- grads = self.grad_reducer(grads)
- grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
- grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
-
- cond = self.get_overflow_status(status, grads)
- overflow = cond
- if sens is None:
- overflow = self.loss_scaling_manager(self.loss_scale, cond)
- if overflow:
- succ = False
- else:
- succ = self.optimizer(grads)
- ret = (loss, cond, scaling_sens)
- return F.depend(ret, succ)
-
-
- class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
- """
- Encapsulation class of bert network training.
-
- Append an optimizer to the training network after that the construct
- function can be called to create the backward graph.
- Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow
- condition as input.
-
- Args:
- network (Cell): The training network. Note that loss function should have been added.
- optimizer (Optimizer): Optimizer for updating the weights.
- scale_update_cell (Cell): Cell to do the loss scale. Default: None.
- """
- def __init__(self, network, optimizer, scale_update_cell=None):
- super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell)
- self.cast = P.Cast()
- self.degree = 1
- if self.reducer_flag:
- self.degree = get_group_size()
- self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
- self.loss_scale = None
- self.loss_scaling_manager = scale_update_cell
- if scale_update_cell:
- self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
-
- def construct(self,
- input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- sens=None):
- """Defines the computation performed."""
- weights = self.weights
- loss = self.network(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights)
- if sens is None:
- scaling_sens = self.loss_scale
- else:
- scaling_sens = sens
-
- status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
- grads = self.grad(self.network, weights)(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- self.cast(scaling_sens,
- mstype.float32))
- # apply grad reducer on grads
- grads = self.grad_reducer(grads)
- grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
- grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
- cond = self.get_overflow_status(status, grads)
- overflow = cond
- if self.loss_scaling_manager is not None:
- overflow = self.loss_scaling_manager(scaling_sens, cond)
- succ = self.optimizer(grads, overflow)
- ret = (loss, cond, scaling_sens)
- return F.depend(ret, succ)
-
- cast = P.Cast()
- add_grads = C.MultitypeFuncGraph("add_grads")
-
-
- @add_grads.register("Tensor", "Tensor")
- def _add_grads(accu_grad, grad):
- return accu_grad + cast(grad, mstype.float32)
-
- update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
-
- @update_accu_grads.register("Tensor", "Tensor")
- def _update_accu_grads(accu_grad, grad):
- succ = True
- return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
-
- accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
-
- @accumulate_accu_grads.register("Tensor", "Tensor")
- def _accumulate_accu_grads(accu_grad, grad):
- succ = True
- return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
-
-
- zeroslike = P.ZerosLike()
- reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
-
-
- @reset_accu_grads.register("Tensor")
- def _reset_accu_grads(accu_grad):
- succ = True
- return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
-
-
- class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
- """
- Encapsulation class of bert network training.
-
- Append an optimizer to the training network after that the construct
- function can be called to create the backward graph.
-
- To mimic higher batch size, gradients are accumulated N times before weight update.
-
- For distribution mode, allreduce will only be implemented in the weight updated step,
- i.e. the sub-step after gradients accumulated N times.
-
- Args:
- network (Cell): The training network. Note that loss function should have been added.
- optimizer (Optimizer): Optimizer for updating the weights.
- scale_update_cell (Cell): Cell to do the loss scale. Default: None.
- accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
- batch_size * accumulation_steps. Default: 1.
- """
-
- def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
- super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
- self.network = network
- self.network.set_grad()
- self.weights = optimizer.parameters
- self.optimizer = optimizer
- self.accumulation_steps = accumulation_steps
- self.enable_global_norm = enable_global_norm
- self.one = Tensor(np.array([1]).astype(np.int32))
- self.zero = Tensor(np.array([0]).astype(np.int32))
- self.local_step = Parameter(initializer(0, [1], mstype.int32))
- self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
- self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
- self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
-
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.reducer_flag = False
- self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- self.reducer_flag = True
- self.grad_reducer = F.identity
- self.degree = 1
- if self.reducer_flag:
- self.degree = get_group_size()
- self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
- self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
- self.overflow_reducer = F.identity
- if self.is_distributed:
- self.overflow_reducer = P.AllReduce()
- self.cast = P.Cast()
- self.alloc_status = P.NPUAllocFloatStatus()
- self.get_status = P.NPUGetFloatStatus()
- self.clear_status = P.NPUClearFloatStatus()
- self.reduce_sum = P.ReduceSum(keep_dims=False)
- self.base = Tensor(1, mstype.float32)
- self.less_equal = P.LessEqual()
- self.logical_or = P.LogicalOr()
- self.not_equal = P.NotEqual()
- self.select = P.Select()
- self.reshape = P.Reshape()
- self.hyper_map = C.HyperMap()
- self.loss_scale = None
- self.loss_scaling_manager = scale_update_cell
- if scale_update_cell:
- self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
-
- def construct(self,
- input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- sens=None):
- """Defines the computation performed."""
- weights = self.weights
- loss = self.network(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights)
- if sens is None:
- scaling_sens = self.loss_scale
- else:
- scaling_sens = sens
- # alloc status and clear should be right before gradoperation
- init = self.alloc_status()
- init = F.depend(init, loss)
- clear_status = self.clear_status(init)
- scaling_sens = F.depend(scaling_sens, clear_status)
- # update accumulation parameters
- is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
- self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
- self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
- mean_loss = self.accu_loss / self.local_step
- is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
-
- grads = self.grad(self.network, weights)(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- self.cast(scaling_sens,
- mstype.float32))
-
- accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
- mean_loss = F.depend(mean_loss, accu_succ)
-
- init = F.depend(init, mean_loss)
- get_status = self.get_status(init)
- init = F.depend(init, get_status)
- flag_sum = self.reduce_sum(init, (0,))
- overflow = self.less_equal(self.base, flag_sum)
- overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
- accu_overflow = self.select(overflow, self.one, self.zero)
- self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
-
- if is_accu_step:
- succ = False
- else:
- # apply grad reducer on grads
- grads = self.grad_reducer(self.accu_grads)
- scaling = scaling_sens * self.degree * self.accumulation_steps
- grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
- if self.enable_global_norm:
- grads = C.clip_by_global_norm(grads, 1.0, None)
- else:
- grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
- accu_overflow = F.depend(accu_overflow, grads)
- accu_overflow = self.overflow_reducer(accu_overflow)
- overflow = self.less_equal(self.base, accu_overflow)
- accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
- overflow = F.depend(overflow, accu_succ)
- overflow = self.reshape(overflow, (()))
- if sens is None:
- overflow = self.loss_scaling_manager(self.loss_scale, overflow)
- if overflow:
- succ = False
- else:
- succ = self.optimizer(grads)
-
- ret = (mean_loss, overflow, scaling_sens)
- return F.depend(ret, succ)
-
-
- class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell):
- """
- Encapsulation class of bert network training.
-
- Append an optimizer to the training network after that the construct
- function can be called to create the backward graph.
-
- To mimic higher batch size, gradients are accumulated N times before weight update.
-
- For distribution mode, allreduce will be implemented after each sub-step and the trailing time
- will be overided by backend optimization pass.
-
- Args:
- network (Cell): The training network. Note that loss function should have been added.
- optimizer (Optimizer): Optimizer for updating the weights.
- scale_update_cell (Cell): Cell to do the loss scale. Default: None.
- accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
- batch_size * accumulation_steps. Default: 1.
- """
- def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
- super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False)
- self.network = network
- self.network.set_grad()
- self.weights = optimizer.parameters
- self.optimizer = optimizer
- self.accumulation_steps = accumulation_steps
- self.enable_global_norm = enable_global_norm
- self.one = Tensor(np.array([1]).astype(np.int32))
- self.zero = Tensor(np.array([0]).astype(np.int32))
- self.local_step = Parameter(initializer(0, [1], mstype.int32))
- self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
- self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
- self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
-
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.reducer_flag = False
- self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- self.reducer_flag = True
- self.grad_reducer = F.identity
- self.degree = 1
- if self.reducer_flag:
- self.degree = get_group_size()
- self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
- self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
- self.overflow_reducer = F.identity
- if self.is_distributed:
- self.overflow_reducer = P.AllReduce()
- self.cast = P.Cast()
- self.alloc_status = P.NPUAllocFloatStatus()
- self.get_status = P.NPUGetFloatStatus()
- self.clear_before_grad = P.NPUClearFloatStatus()
- self.reduce_sum = P.ReduceSum(keep_dims=False)
- self.base = Tensor(1, mstype.float32)
- self.less_equal = P.LessEqual()
- self.logical_or = P.LogicalOr()
- self.not_equal = P.NotEqual()
- self.select = P.Select()
- self.reshape = P.Reshape()
- self.hyper_map = C.HyperMap()
- self.loss_scale = None
- self.loss_scaling_manager = scale_update_cell
- if scale_update_cell:
- self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
-
- @C.add_flags(has_effect=True)
- def construct(self,
- input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- sens=None):
- """Defines the computation performed."""
- weights = self.weights
- loss = self.network(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights)
- if sens is None:
- scaling_sens = self.loss_scale
- else:
- scaling_sens = sens
-
- # update accumulation parameters
- is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
- self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
- self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
- mean_loss = self.accu_loss / self.local_step
- is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
-
- # alloc status and clear should be right before gradoperation
- init = self.alloc_status()
- self.clear_before_grad(init)
- grads = self.grad(self.network, weights)(input_ids,
- input_mask,
- token_type_id,
- next_sentence_labels,
- masked_lm_positions,
- masked_lm_ids,
- masked_lm_weights,
- self.cast(scaling_sens,
- mstype.float32))
-
-
- accu_grads = self.hyper_map(add_grads, self.accu_grads, grads)
- scaling = scaling_sens * self.degree * self.accumulation_steps
- grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads)
- grads = self.grad_reducer(grads)
-
- self.get_status(init)
- flag_sum = self.reduce_sum(init, (0,))
- flag_reduce = self.overflow_reducer(flag_sum)
- overflow = self.less_equal(self.base, flag_reduce)
- overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
- accu_overflow = self.select(overflow, self.one, self.zero)
- self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
- overflow = self.reshape(overflow, (()))
-
- if is_accu_step:
- succ = False
- accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads)
- succ = F.depend(succ, accu_succ)
- else:
- if sens is None:
- overflow = self.loss_scaling_manager(self.loss_scale, overflow)
- if overflow:
- succ = False
- else:
- if self.enable_global_norm:
- grads = C.clip_by_global_norm(grads, 1.0, None)
- else:
- grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
-
- succ = self.optimizer(grads)
-
- accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
- succ = F.depend(succ, accu_succ)
-
- ret = (mean_loss, overflow, scaling_sens)
- return F.depend(ret, succ)
|