|
- import argparse
- import json
-
- import torch
-
- import utils
- from onnxexport.model_onnx_speaker_mix import SynthesizerTrn
-
- parser = argparse.ArgumentParser(description='SoVitsSvc OnnxExport')
-
- def OnnxExport(path=None):
- device = torch.device("cpu")
- hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
- SVCVITS = SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- **hps.model)
- _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
- _ = SVCVITS.eval().to(device)
- for i in SVCVITS.parameters():
- i.requires_grad = False
-
- num_frames = 200
-
- test_hidden_unit = torch.rand(1, num_frames, SVCVITS.gin_channels)
- test_pitch = torch.rand(1, num_frames)
- test_vol = torch.rand(1, num_frames)
- test_mel2ph = torch.LongTensor(torch.arange(0, num_frames)).unsqueeze(0)
- test_uv = torch.ones(1, num_frames, dtype=torch.float32)
- test_noise = torch.randn(1, 192, num_frames)
- test_sid = torch.LongTensor([0])
- export_mix = True
- if len(hps.spk) < 2:
- export_mix = False
-
- if export_mix:
- spk_mix = []
- n_spk = len(hps.spk)
- for i in range(n_spk):
- spk_mix.append(1.0/float(n_spk))
- test_sid = torch.tensor(spk_mix)
- SVCVITS.export_chara_mix(hps.spk)
- test_sid = test_sid.unsqueeze(0)
- test_sid = test_sid.repeat(num_frames, 1)
-
- SVCVITS.eval()
-
- if export_mix:
- daxes = {
- "c": [0, 1],
- "f0": [1],
- "mel2ph": [1],
- "uv": [1],
- "noise": [2],
- "sid":[0]
- }
- else:
- daxes = {
- "c": [0, 1],
- "f0": [1],
- "mel2ph": [1],
- "uv": [1],
- "noise": [2]
- }
-
- input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
- output_names = ["audio", ]
-
- if SVCVITS.vol_embedding:
- input_names.append("vol")
- vol_dadict = {"vol" : [1]}
- daxes.update(vol_dadict)
- test_inputs = (
- test_hidden_unit.to(device),
- test_pitch.to(device),
- test_mel2ph.to(device),
- test_uv.to(device),
- test_noise.to(device),
- test_sid.to(device),
- test_vol.to(device)
- )
- else:
- test_inputs = (
- test_hidden_unit.to(device),
- test_pitch.to(device),
- test_mel2ph.to(device),
- test_uv.to(device),
- test_noise.to(device),
- test_sid.to(device)
- )
-
- # SVCVITS = torch.jit.script(SVCVITS)
- SVCVITS(test_hidden_unit.to(device),
- test_pitch.to(device),
- test_mel2ph.to(device),
- test_uv.to(device),
- test_noise.to(device),
- test_sid.to(device),
- test_vol.to(device))
-
- SVCVITS.dec.OnnxExport()
-
- torch.onnx.export(
- SVCVITS,
- test_inputs,
- f"checkpoints/{path}/{path}_SoVits.onnx",
- dynamic_axes=daxes,
- do_constant_folding=False,
- opset_version=16,
- verbose=False,
- input_names=input_names,
- output_names=output_names
- )
-
- vec_lay = "layer-12" if SVCVITS.gin_channels == 768 else "layer-9"
- spklist = []
- for key in hps.spk.keys():
- spklist.append(key)
-
- MoeVSConf = {
- "Folder" : f"{path}",
- "Name" : f"{path}",
- "Type" : "SoVits",
- "Rate" : hps.data.sampling_rate,
- "Hop" : hps.data.hop_length,
- "Hubert": f"vec-{SVCVITS.gin_channels}-{vec_lay}",
- "SoVits4": True,
- "SoVits3": False,
- "CharaMix": export_mix,
- "Volume": SVCVITS.vol_embedding,
- "HiddenSize": SVCVITS.gin_channels,
- "Characters": spklist,
- "Cluster": ""
- }
-
- with open(f"checkpoints/{path}.json", 'w') as MoeVsConfFile:
- json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
-
-
- if __name__ == '__main__':
- parser.add_argument('-n', '--model_name', type=str, default="TransformerFlow", help='模型文件夹名(根目录下新建ckeckpoints文件夹,在此文件夹下建立一个新的文件夹,放置模型,该文件夹名即为此项)')
- args = parser.parse_args()
- path = args.model_name
- OnnxExport(path)
|