|
-
- import os
-
- import numpy as np
- import onnxruntime as rt
-
- from inference.svs.opencpop.map import cpop_pinyin2ph_func
-
- from utils.hparams import set_hparams, hparams
- from utils.text_encoder import TokenTextEncoder
- from pypinyin import pinyin, lazy_pinyin, Style
- from collections import deque
- from tqdm import tqdm
- import librosa
- import glob
- import re
-
-
- class Infer:
- def __init__(self, hparams):
-
- self.hparams = hparams
- providers = [
- p for p in rt.get_available_providers() if 'Tensorrt' not in p
- ]
- print('Using these as onnxruntime providers:', providers)
-
- phone_list = ["AP", "SP", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g",
- "h", "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iu", "j", "k", "l", "m", "n", "o",
- "ong", "ou", "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "ui", "un", "uo", "v",
- "van", "ve", "vn", "w", "x", "y", "z", "zh"]
- self.ph_encoder = TokenTextEncoder(
- None, vocab_list=phone_list, replace_oov=',')
- self.pinyin2phs = cpop_pinyin2ph_func()
- self.spk_map = {'opencpop': 0}
-
- options = rt.SessionOptions()
- for provider in providers:
- if 'dml' in provider.lower():
- options.enable_mem_pattern = False
- options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL
-
- self.fs2 = rt.InferenceSession('model/fs2.onnx', options, providers=providers)
- if os.path.exists('model/q_sample.onnx'):
- self.q_sample = rt.InferenceSession(
- 'model/q_sample.onnx', options, providers=providers)
- if os.path.exists('model/p_sample_plms.onnx'):
- self.p_sample_plms = rt.InferenceSession(
- 'model/p_sample_plms.onnx', options, providers=providers)
- else:
- self.p_sample = rt.InferenceSession(
- 'model/p_sample.onnx', options, providers=providers)
- if os.path.exists('model/pe.onnx'):
- self.pe = rt.InferenceSession('model/pe.onnx', options, providers=providers)
- self.vocoder = rt.InferenceSession(
- 'model/vocoder.onnx', options, providers=providers)
-
- self.K_step = hparams['K_step']
- self.spec_min = np.asarray(hparams['spec_min'], np.float32)[None, None, :hparams['keep_bins']]
- self.spec_max = np.asarray(hparams['spec_max'], np.float32)[None, None, :hparams['keep_bins']]
- self.mel_bins = hparams['audio_num_mel_bins']
- self.use_pe = hparams.get(
- 'pe_enable') is not None and hparams['pe_enable']
-
- def model(self, txt_tokens, **kwargs):
- fs_input_names = [node.name for node in self.fs2.get_inputs()]
- inputs = {
- 'txt_tokens': txt_tokens
- }
- inputs.update({k: v for k, v in kwargs.items() if isinstance(
- v, np.ndarray) and k in fs_input_names})
-
- io_binding = self.fs2.io_binding()
- for k, v in inputs.items():
- io_binding.bind_cpu_input(k, v)
- io_binding.bind_output('decoder_inp')
- io_binding.bind_output('mel_out')
- if not self.use_pe:
- io_binding.bind_output('f0_denorm')
- self.fs2.run_with_iobinding(io_binding)
- decoder_inp, mel_out = io_binding.get_outputs()[:2]
- self.device_name = mel_out.device_name()
- ret = {
- 'decoder_inp': decoder_inp,
- 'mel_out': mel_out
- }
- if not self.use_pe:
- ret.update({'f0_denorm': io_binding.get_outputs()[-1]})
- cond = decoder_inp.numpy().transpose([0, 2, 1])
-
- ret['fs2_mel'] = ret['mel_out']
- fs2_mels = mel_out.numpy()
- t = self.K_step
- fs2_mels = self.norm_spec(fs2_mels)
- fs2_mels = fs2_mels.transpose([0, 2, 1])[:, None, :, :]
-
- io_binding = self.q_sample.io_binding()
- io_binding.bind_cpu_input('x_start', fs2_mels)
- io_binding.bind_cpu_input('noise', np.random.randn(
- *fs2_mels.shape).astype(fs2_mels.dtype))
- io_binding.bind_cpu_input('t', np.asarray([t-1], dtype=np.int64))
- io_binding.bind_output('x_next')
- self.q_sample.run_with_iobinding(io_binding)
- x = io_binding.get_outputs()[0].numpy()
- if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
- print('===> gaussion start.')
- shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
- x = np.random.randn(*shape).astype(fs2_mels.dtype)
-
- cond = rt.OrtValue.ortvalue_from_numpy(cond, mel_out.device_name(), 0)
- x = rt.OrtValue.ortvalue_from_numpy(x, mel_out.device_name(), 0)
-
- if hparams.get('pndm_speedup'):
- self.noise_list = deque(maxlen=4)
- iteration_interval = hparams['pndm_speedup']
- interval = rt.OrtValue.ortvalue_from_numpy(
- np.asarray([iteration_interval], np.int64),
- mel_out.device_name(), 0
- )
- for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
- total=t // iteration_interval):
- io_binding = self.p_sample_plms.io_binding()
- io_binding.bind_ortvalue_input('x', x)
- io_binding.bind_cpu_input(
- 'noise', np.random.randn(*x.shape).astype(x.dtype))
- io_binding.bind_ortvalue_input(
- 'cond', cond)
- io_binding.bind_cpu_input(
- 't', np.asarray([i], dtype=np.int64)) # torch i-1 but here i
- io_binding.bind_ortvalue_input('interval', interval)
- io_binding.bind_output('x_next')
- self.p_sample_plms.run_with_iobinding(io_binding)
- x = io_binding.get_outputs()[0]
- else:
- for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
- io_binding = self.p_sample.io_binding()
- io_binding.bind_ortvalue_input('x', x)
- io_binding.bind_cpu_input(
- 'noise', np.random.randn(*x.shape()).astype(np.float32))
- io_binding.bind_ortvalue_input(
- 'cond', cond)
- io_binding.bind_cpu_input(
- 't', np.asarray([i], dtype=np.int64)) # torch i-1 but here i
- io_binding.bind_output('x_next')
- self.p_sample.run_with_iobinding(io_binding)
- x = io_binding.get_outputs()[0]
- x = x.numpy()[:, 0].transpose([0, 2, 1])
- mel2ph = kwargs.get('mel2ph', None)
- if mel2ph is not None: # for singing
- ret['mel_out'] = self.denorm_spec(
- x) * ((mel2ph > 0).astype(np.float32)[:, :, None])
- else:
- ret['mel_out'] = self.denorm_spec(x)
- return ret
-
- def norm_spec(self, x):
- return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
-
- def denorm_spec(self, x):
- return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
-
- def forward_model(self, inp):
- sample = self.input_to_batch(inp)
- txt_tokens = sample['txt_tokens'] # [B, T_t]
- spk_id = sample.get('spk_ids')
-
- output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
- pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
- is_slur=sample['is_slur'])
- mel_out = output['mel_out'] # [B, T,80]
- mel_out = rt.OrtValue.ortvalue_from_numpy(mel_out, self.device_name, 0)
- if hparams.get('pe_enable') is not None and hparams['pe_enable']:
- # pe predict from Pred mel
- io_binding = self.pe.io_binding()
- io_binding.bind_ortvalue_input('mel_input', mel_out)
- io_binding.bind_output('f0_denorm_pred')
- self.pe.run_with_iobinding(io_binding)
- f0_pred = io_binding.get_outputs()[0]
- else:
- f0_pred = output['f0_denorm']
- wav_out = self.run_vocoder(mel_out, f0=f0_pred.numpy())
-
- return wav_out[0]
-
- def run_vocoder(self, c, **kwargs):
- # c = c.transpose([0, 2, 1]) # [B, 80, T]
- f0 = kwargs.get('f0') # [B, T]
- if f0 is not None and hparams.get('use_nsf'):
- y = self.vocoder.run(['wav_out'], {
- 'mel_out': c,
- 'f0': f0,
- })[0] # .reshape([-1])
- else:
- y = self.vocoder.run(['wav_out'], {
- 'mel_out': c,
- })[0] # .reshape([-1])
- # [T]
- return y # [None]
-
- def preprocess_word_level_input(self, inp):
- # Pypinyin can't solve polyphonic words
- text_raw = inp['text'].replace('最长', '最常').replace('长睫毛', '常睫毛') \
- .replace('那么长', '那么常').replace('多长', '多常') \
- .replace('很长', '很常') # We hope someone could provide a better g2p module for us by opening pull requests.
-
- # lyric
- pinyins = lazy_pinyin(text_raw, strict=False)
- ph_per_word_lst = [self.pinyin2phs[pinyin.strip()]
- for pinyin in pinyins if pinyin.strip() in self.pinyin2phs]
-
- # Note
- note_per_word_lst = [x.strip()
- for x in inp['notes'].split('|') if x.strip() != '']
- mididur_per_word_lst = [
- x.strip() for x in inp['notes_duration'].split('|') if x.strip() != '']
-
- if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst):
- print('Pass word-notes check.')
- else:
- print('The number of words does\'t match the number of notes\' windows. ',
- 'You should split the note(s) for each word by | mark.')
- print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
- print(len(ph_per_word_lst), len(
- note_per_word_lst), len(mididur_per_word_lst))
- return None
-
- note_lst = []
- ph_lst = []
- midi_dur_lst = []
- is_slur = []
- for idx, ph_per_word in enumerate(ph_per_word_lst):
- # for phs in one word:
- # single ph like ['ai'] or multiple phs like ['n', 'i']
- ph_in_this_word = ph_per_word.split()
-
- # for notes in one word:
- # single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here.
- note_in_this_word = note_per_word_lst[idx].split()
- midi_dur_in_this_word = mididur_per_word_lst[idx].split()
- # process for the model input
- # Step 1.
- # Deal with note of 'not slur' case or the first note of 'slur' case
- # j ie
- # F#4/Gb4 F#4/Gb4
- # 0 0
- for ph in ph_in_this_word:
- ph_lst.append(ph)
- note_lst.append(note_in_this_word[0])
- midi_dur_lst.append(midi_dur_in_this_word[0])
- is_slur.append(0)
- # step 2.
- # Deal with the 2nd, 3rd... notes of 'slur' case
- # j ie ie
- # F#4/Gb4 F#4/Gb4 C#4/Db4
- # 0 0 1
- # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes.
- if len(note_in_this_word) > 1:
- for idx in range(1, len(note_in_this_word)):
- ph_lst.append(ph_in_this_word[-1])
- note_lst.append(note_in_this_word[idx])
- midi_dur_lst.append(midi_dur_in_this_word[idx])
- is_slur.append(1)
- ph_seq = ' '.join(ph_lst)
-
- if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
- print(len(ph_lst), len(note_lst), len(midi_dur_lst))
- print('Pass word-notes check.')
- else:
- print('The number of words does\'t match the number of notes\' windows. ',
- 'You should split the note(s) for each word by | mark.')
- return None
- return ph_seq, note_lst, midi_dur_lst, is_slur
-
- def preprocess_phoneme_level_input(self, inp):
- ph_seq = inp['ph_seq']
- note_lst = inp['note_seq'].split()
- midi_dur_lst = inp['note_dur_seq'].split()
- is_slur = [float(x) for x in inp['is_slur_seq'].split()]
- print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
- if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
- print('Pass word-notes check.')
- else:
- print('The number of words does\'t match the number of notes\' windows. ',
- 'You should split the note(s) for each word by | mark.')
- return None
- return ph_seq, note_lst, midi_dur_lst, is_slur
-
- def preprocess_input(self, inp, input_type='word'):
- """
-
- :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
- :return:
- """
-
- item_name = inp.get('item_name', '<ITEM_NAME>')
- spk_name = inp.get('spk_name', 'opencpop')
-
- # single spk
- spk_id = self.spk_map[spk_name]
-
- # get ph seq, note lst, midi dur lst, is slur lst.
- if input_type == 'word':
- ret = self.preprocess_word_level_input(inp)
- # like transcriptions.txt in Opencpop dataset.
- elif input_type == 'phoneme':
- ret = self.preprocess_phoneme_level_input(inp)
- else:
- print('Invalid input type.')
- return None
-
- if ret:
- ph_seq, note_lst, midi_dur_lst, is_slur = ret
- else:
- print('==========> Preprocess_word_level or phone_level input wrong.')
- return None
-
- # convert note lst to midi id; convert note dur lst to midi duration
- try:
- midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
- for x in note_lst]
- midi_dur_lst = [float(x) for x in midi_dur_lst]
- except Exception as e:
- print(e)
- print('Invalid Input Type.')
- return None
-
- ph_token = self.ph_encoder.encode(ph_seq)
- item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
- 'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
- 'is_slur': np.asarray(is_slur), }
- item['ph_len'] = len(item['ph_token'])
- return item
-
- def input_to_batch(self, item):
- item_names = [item['item_name']]
- text = [item['text']]
- ph = [item['ph']]
- txt_tokens = np.int64(item['ph_token'])[None, :]
- txt_lengths = np.int64([txt_tokens.shape[1]])
- spk_ids = np.asarray(item['spk_id'], np.int64)[None]
-
- pitch_midi = np.int64(item['pitch_midi'])[None, :hparams['max_frames']]
- midi_dur = np.float32(item['midi_dur'])[None, :hparams['max_frames']]
- is_slur = np.int64(item['is_slur'])[None, :hparams['max_frames']]
-
- batch = {
- 'item_name': item_names,
- 'text': text,
- 'ph': ph,
- 'txt_tokens': txt_tokens,
- 'txt_lengths': txt_lengths,
- 'spk_ids': spk_ids,
- 'pitch_midi': pitch_midi,
- 'midi_dur': midi_dur,
- 'is_slur': is_slur
- }
- return batch
-
- def postprocess_output(self, output):
- return output
-
- def infer_once(self, inp):
- inp = self.preprocess_input(
- inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
- output = self.forward_model(inp)
- output = self.postprocess_output(output)
- return output
-
- @classmethod
- def example_run(cls, inp):
- from utils.audio import save_wav
- set_hparams('model/config.yaml')
- infer_ins = cls(hparams)
- out = infer_ins.infer_once(inp)
- os.makedirs('infer_out', exist_ok=True)
- save_wav(out, f'infer_out/example_out.wav',
- hparams['audio_sample_rate'])
-
- # if __name__ == '__main__':
- # # debug
- # set_hparams('model/config.yaml')
- # a = Infer(hparams)
- # a.preprocess_input({'text': '你 说 你 不 SP 懂 为 何 在 这 时 牵 手 AP',
- # 'notes': 'D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | rest | D#4/Eb4 | D4 | D4 | D4 | D#4/Eb4 | F4 | D#4/Eb4 | D4 | rest',
- # 'notes_duration': '0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590'
- # })
-
- # b = {
- # 'text': '小酒窝长睫毛AP是你最美的记号',
- # 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
- # 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340'
- # }
- # c = {
- # 'text': '小酒窝长睫毛AP是你最美的记号',
- # 'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
- # 'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
- # 'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
- # 'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0'
- # } # input like Opencpop dataset.
- # a.preprocess_input(b)
- # a.preprocess_input(c, input_type='phoneme')
|