|
- from itertools import permutations
- import mindspore
- import mindspore.ops as ops
- from mindspore import nn
- import numpy as np
- from mindspore import Tensor
- from mindspore.ops import constexpr
-
- EPS = 1e-8
-
- class loss(nn.Cell):
- def __init__(self):
- super(loss, self).__init__()
- self.mean = ops.ReduceMean()
-
- def construct(self, source, estimate_source, source_lengths):
- return self.cal_loss(source, estimate_source, source_lengths)
-
- def cal_loss(self, source, estimate_source, source_lengths):
- """
- Args:
- source: [B, C, T], B is batch size
- estimate_source: [B, C, T]
- source_lengths: [B]
- """
- mean = ops.ReduceMean()
- max_snr, perms, max_snr_idx = self.cal_si_snr_with_pit(self, source, estimate_source, source_lengths)
- loss = 0 - mean(max_snr)
- reorder_estimate_source = self.reorder_source(estimate_source, perms, max_snr_idx)
- return loss, max_snr, estimate_source, reorder_estimate_source
-
- @constexpr
- def cal_si_snr_with_pit(self, source, estimate_source, source_lengths):
- """Calculate SI-SNR with PIT training.
- Args:
- source: [B, C, T], B is batch size
- estimate_source: [B, C, T]
- source_lengths: [B], each item is between [0, T]
- """
- cast = ops.Cast()
- sum = ops.ReduceSum(keep_dims=True)
- _sum = ops.ReduceSum(keep_dims=False)
- B, C, T = source.shape
- # mask padding position along T
- mask = self.get_mask(source, source_lengths)
- print(estimate_source.shape)
- print(mask.shape)
- estimate_source *= mask
-
- # Step 1. Zero-mean norm
- num_samples = cast(source_lengths.view(-1, 1, 1), mindspore.float32) # [B, 1, 1]
- mean_target = sum(source, 2) / num_samples
- mean_estimate = sum(estimate_source, 2) / num_samples
- zero_mean_target = source - mean_target
- zero_mean_estimate = estimate_source - mean_estimate
- # mask padding position along T
- zero_mean_target *= mask
- zero_mean_estimate *= mask
-
- # Step 2. SI-SNR with PIT
- # reshape to use broadcast
- expand_dims = ops.ExpandDims()
- s_target = expand_dims(zero_mean_target, 1) # [B, 1, C, T]
- s_estimate = expand_dims(zero_mean_estimate, 2) # [B, C, 1, T]
- # s_target = <s', s>s / ||s||^2
- pair_wise_dot = sum(s_estimate * s_target, 3) # [B, C, C, 1]
- s_target_energy = sum(s_target ** 2, 3) + EPS # [B, 1, C, 1]
- pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T]
- # e_noise = s' - s_target
- e_noise = s_estimate - pair_wise_proj # [B, C, C, T]
- # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
- pair_wise_si_snr = _sum(pair_wise_proj ** 2, 3) / (_sum(e_noise ** 2, 3) + EPS)
- log = ops.Log()
- pair_wise_si_snr = 10 * log(pair_wise_si_snr + EPS) / log(
- Tensor(np.array([10.0]), mindspore.float32)) # [B, C, C]
-
- # Get max_snr of each utterance
- # permutations, [C!, C]
- perms = Tensor(list(permutations(range(C))), dtype=mindspore.int64)
- # one-hot, [C!, C, C]
- scatter = ops.ScatterNd()
- indices = Tensor(np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0]]), mindspore.int32)
- updates = Tensor(np.array([1, 1, 1, 1]), mindspore.float32)
- print(indices.shape)
- print(updates.shape)
- shape = (2, 2, 2)
- perms_one_hot = scatter(indices, updates, shape)
- # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
- matmul = ops.MatMul()
- transpose = ops.Transpose()
- perms_one_hot = transpose(perms_one_hot.view(C, -1), (1, 0))
- snr_set = matmul(pair_wise_si_snr.view(B, -1), perms_one_hot)
- max_snr_idx = ops.Argmax(axis=1, output_type=mindspore.int32)(snr_set) # [B]
- # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1)) # [B, 1]
- argmax = ops.ArgMaxWithValue(axis=1, keep_dims=True)
- _, max_snr = argmax(snr_set)
- max_snr /= C
- return max_snr, perms, max_snr_idx
-
- def reorder_source(self, source, perms, max_snr_idx):
- """
- Args:
- source: [B, C, T]
- perms: [C!, C], permutations
- max_snr_idx: [B], each item is between [0, C!)
- Returns:
- reorder_source: [B, C, T]
- """
- B, C, T = source.shape
- # [B, C], permutation whose SI-SNR is max of each utterance
- # for each utterance, reorder estimate source according this permutation
- max_snr_perm = perms[max_snr_idx, :]
- # Print = ops.Print()
- # a = Tensor([0, 1], dtype=mindspore.int64)
- # if(max_snr_perm[0][1] == a[1]):
- # Print("(*)" * 100)
- zeros_like = ops.ZerosLike()
- reorder_source = zeros_like(source)
- for b in range(B):
- for c in range(C):
- temp1 = int(max_snr_perm[b][c])
- reorder_source[b, c] = source[b, temp1]
- return reorder_source
-
-
- # def get_mask(self, source, source_lengths):
- # """
- # Args:
- # source: [B, C, T]
- # source_lengths: [B]
- # Returns:
- # mask: [B, 1, T]
- # """
- # B, _, T = source.shape
- # ones = ops.Ones()
- # mask = ones((B, 1, T), mindspore.float32)
- # for i in range(B):
- # mask[i, :, source_lengths[i]:] = 0
- # return mask
-
- if __name__ == "__main__":
- from mindspore import context
- context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
- my_loss = loss()
- print("______________________ test cal_loss _______________________")
- padded_source = Tensor(np.random.randn(1, 2, 46400), dtype=mindspore.float32)
- mixture_lengths = Tensor(np.random.randn(1), dtype=mindspore.int32)
- estimate_source = Tensor(np.random.randn(1, 2, 46400), dtype=mindspore.float32)
- print("*" * 100)
- loss, max_snr, estimate_source, reorder_estimate_source = \
- my_loss(padded_source, estimate_source, mixture_lengths)
- print("_" * 100)
- print(loss.shape)
- print(max_snr.shape)
- print(estimate_source.shape)
- print(reorder_estimate_source.shape)
|