|
- #!/usr/bin/env python
-
- # Edited on 2020/12
- # Reference: Kaituo XU
- # Author: yoonsanghyu
-
- import argparse
- import os
-
- from collections import OrderedDict
- import librosa
- from mir_eval.separation import bss_eval_sources
- import numpy as np
- import torch
-
- from data import AudioDataLoader, AudioDataset
- from pit_criterion import cal_loss
- # from dptnet import DPTNet
- from models import DPTNet_base
- from utils import remove_pad
-
- #os.environ['CUDA_VISIBLE_DEVICES'] = '4'
-
- parser = argparse.ArgumentParser('Evaluate separation performance using DPTNet')
-
-
- # Network architecture
- parser.add_argument('--N', default=64, type=int,
- help='Number of filters in autoencoder')
- parser.add_argument('--C', default=2, type=int,
- help='Maximum number of speakers')
- parser.add_argument('--L', default=4, type=int,
- help='Length of window in autoencoder') # L=2 in paper
- parser.add_argument('--H', default=4, type=int,
- help='Number of head in Multi-head attention')
- parser.add_argument('--K', default=250, type=int,
- help='segment size')
- parser.add_argument('--B', default=6, type=int,
- help='Number of repeats')
-
- parser.add_argument('--model_path', type=str, default='/home/amax/frb/transformer/exp/temp/temp_best.pth.tar',
- help='Path to model file created by training')
- parser.add_argument('--data_dir', type=str, default='/home/amax/frb/out/tt',
- help='directory including mix.json, s1.json and s2.json')
- parser.add_argument('--cal_sdr', type=int, default=1,
- help='Whether calculate SDR, add this option because calculation of SDR is very slow')
- parser.add_argument('--use_cuda', type=int, default=1,
- help='Whether use GPU')
- parser.add_argument('--sample_rate', default=8000, type=int,
- help='Sample rate')
- parser.add_argument('--batch_size', default=1, type=int,
- help='Batch size')
-
-
- def evaluate(args):
- total_SISNRi = 0
- total_SDRi = 0
- total_cnt = 0
- #torch.cuda.set_device(4)
- # Load model
- model = DPTNet_base(enc_dim=256, feature_dim=64, hidden_dim=128, layer=4, segment_size=250, nspk=2, win_len=2)
- # model = DPTNet(args.N, args.C, args.L, args.H, args.K, args.B)
- #model = model.cuda(4)
- if args.use_cuda:
- # model = torch.nn.DataParallel(model)
- model.cuda(4)
-
- # model.load_state_dict(torch.load(args.model_path, map_location='cpu'))
-
- model_info = torch.load(args.model_path)
-
- state_dict = OrderedDict()
- for k, v in model_info['model_state_dict'].items():
- name = k.replace("module.", "") # remove 'module.'
- state_dict[name] = v
- model.load_state_dict(state_dict)
-
- #print(model)
-
- # Load data
- dataset = AudioDataset(args.data_dir, args.batch_size,
- sample_rate=args.sample_rate, segment=-1)
- data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)
-
- with torch.no_grad():
- for i, (data) in enumerate(data_loader):
- # Get batch data
- padded_mixture, mixture_lengths, padded_source = data
- if args.use_cuda:
- padded_mixture = padded_mixture.cuda()
- mixture_lengths = mixture_lengths.cuda()
- padded_source = padded_source.cuda()
- # Forward
- estimate_source = model(padded_mixture) # [B, C, T]
- loss, max_snr, estimate_source, reorder_estimate_source = \
- cal_loss(padded_source, estimate_source, mixture_lengths)
- # Remove padding and flat
- mixture = remove_pad(padded_mixture, mixture_lengths)
- source = remove_pad(padded_source, mixture_lengths)
- # NOTE: use reorder estimate source
- estimate_source = remove_pad(reorder_estimate_source,
- mixture_lengths)
- # for each utterance
- for mix, src_ref, src_est in zip(mixture, source, estimate_source):
- print("Utt", total_cnt + 1)
- # Compute SDRi
- if args.cal_sdr:
- avg_SDRi = cal_SDRi(src_ref, src_est, mix)
- total_SDRi += avg_SDRi
- print("\tSDRi={0:.2f}".format(avg_SDRi))
- # Compute SI-SNRi
- avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
- print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
- total_SISNRi += avg_SISNRi
- total_cnt += 1
- if args.cal_sdr:
- print("Average SDR improvement: {0:.2f}".format(total_SDRi / total_cnt))
- print("Average SISNR improvement: {0:.2f}".format(total_SISNRi / total_cnt))
-
-
- def cal_SDRi(src_ref, src_est, mix):
- """Calculate Source-to-Distortion Ratio improvement (SDRi).
- NOTE: bss_eval_sources is very very slow.
- Args:
- src_ref: numpy.ndarray, [C, T]
- src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
- mix: numpy.ndarray, [T]
- Returns:
- average_SDRi
- """
- src_anchor = np.stack([mix, mix], axis=0)
- sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
- sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
- avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2
- # print("SDRi1: {0:.2f}, SDRi2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[1]))
- return avg_SDRi
-
-
- def cal_SISNRi(src_ref, src_est, mix):
- """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
- Args:
- src_ref: numpy.ndarray, [C, T]
- src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
- mix: numpy.ndarray, [T]
- Returns:
- average_SISNRi
- """
- sisnr1 = cal_SISNR(src_ref[0], src_est[0])
- sisnr2 = cal_SISNR(src_ref[1], src_est[1])
- sisnr1b = cal_SISNR(src_ref[0], mix)
- sisnr2b = cal_SISNR(src_ref[1], mix)
- # print("SISNR base1 {0:.2f} SISNR base2 {1:.2f}, avg {2:.2f}".format(
- # sisnr1b, sisnr2b, (sisnr1b+sisnr2b)/2))
- # print("SISNRi1: {0:.2f}, SISNRi2: {1:.2f}".format(sisnr1, sisnr2))
- avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2
- return avg_SISNRi
-
-
- def cal_SISNR(ref_sig, out_sig, eps=1e-8):
- """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
- Args:
- ref_sig: numpy.ndarray, [T]
- out_sig: numpy.ndarray, [T]
- Returns:
- SISNR
- """
- assert len(ref_sig) == len(out_sig)
- ref_sig = ref_sig - np.mean(ref_sig)
- out_sig = out_sig - np.mean(out_sig)
- ref_energy = np.sum(ref_sig ** 2) + eps
- proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
- noise = out_sig - proj
- ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
- sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
- return sisnr
-
-
- if __name__ == '__main__':
- os.environ['CUDA_VISIBLE_DEVICES'] = '7'
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
- args = parser.parse_args()
- print(args)
- evaluate(args)
|