|
- import torch
- import torchaudio
- from transformers import AutoModel
-
-
- def feature_loss(fmap_r, fmap_g):
- loss = 0
- for dr, dg in zip(fmap_r, fmap_g):
- for rl, gl in zip(dr, dg):
- rl = rl.float().detach()
- gl = gl.float()
- loss += torch.mean(torch.abs(rl - gl))
-
- return loss * 2
-
-
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
- loss = 0
- r_losses = []
- g_losses = []
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
- dr = dr.float()
- dg = dg.float()
- r_loss = torch.mean((1 - dr) ** 2)
- g_loss = torch.mean(dg**2)
- loss += r_loss + g_loss
- r_losses.append(r_loss.item())
- g_losses.append(g_loss.item())
-
- return loss, r_losses, g_losses
-
-
- def generator_loss(disc_outputs):
- loss = 0
- gen_losses = []
- for dg in disc_outputs:
- dg = dg.float()
- l = torch.mean((1 - dg) ** 2)
- gen_losses.append(l)
- loss += l
-
- return loss, gen_losses
-
-
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
- """
- z_p, logs_q: [b, h, t_t]
- m_p, logs_p: [b, h, t_t]
- """
- z_p = z_p.float()
- logs_q = logs_q.float()
- m_p = m_p.float()
- logs_p = logs_p.float()
- z_mask = z_mask.float()
-
- kl = logs_p - logs_q - 0.5
- kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
- kl = torch.sum(kl * z_mask)
- l = kl / torch.sum(z_mask)
- return l
-
-
- class WavLMLoss(torch.nn.Module):
- def __init__(self, model, wd, model_sr, slm_sr=16000):
- super(WavLMLoss, self).__init__()
- self.wavlm = AutoModel.from_pretrained(model)
- self.wd = wd
- self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
- self.wavlm.eval()
- for param in self.wavlm.parameters():
- param.requires_grad = False
-
- def forward(self, wav, y_rec):
- with torch.no_grad():
- wav_16 = self.resample(wav)
- wav_embeddings = self.wavlm(
- input_values=wav_16, output_hidden_states=True
- ).hidden_states
- y_rec_16 = self.resample(y_rec)
- y_rec_embeddings = self.wavlm(
- input_values=y_rec_16.squeeze(), output_hidden_states=True
- ).hidden_states
-
- floss = 0
- for er, eg in zip(wav_embeddings, y_rec_embeddings):
- floss += torch.mean(torch.abs(er - eg))
-
- return floss.mean()
-
- def generator(self, y_rec):
- y_rec_16 = self.resample(y_rec)
- y_rec_embeddings = self.wavlm(
- input_values=y_rec_16, output_hidden_states=True
- ).hidden_states
- y_rec_embeddings = (
- torch.stack(y_rec_embeddings, dim=1)
- .transpose(-1, -2)
- .flatten(start_dim=1, end_dim=2)
- )
- y_df_hat_g = self.wd(y_rec_embeddings)
- loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
-
- return loss_gen
-
- def discriminator(self, wav, y_rec):
- with torch.no_grad():
- wav_16 = self.resample(wav)
- wav_embeddings = self.wavlm(
- input_values=wav_16, output_hidden_states=True
- ).hidden_states
- y_rec_16 = self.resample(y_rec)
- y_rec_embeddings = self.wavlm(
- input_values=y_rec_16, output_hidden_states=True
- ).hidden_states
-
- y_embeddings = (
- torch.stack(wav_embeddings, dim=1)
- .transpose(-1, -2)
- .flatten(start_dim=1, end_dim=2)
- )
- y_rec_embeddings = (
- torch.stack(y_rec_embeddings, dim=1)
- .transpose(-1, -2)
- .flatten(start_dim=1, end_dim=2)
- )
-
- y_d_rs = self.wd(y_embeddings)
- y_d_gs = self.wd(y_rec_embeddings)
-
- y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
-
- r_loss = torch.mean((1 - y_df_hat_r) ** 2)
- g_loss = torch.mean((y_df_hat_g) ** 2)
-
- loss_disc_f = r_loss + g_loss
-
- return loss_disc_f.mean()
-
- def discriminator_forward(self, wav):
- with torch.no_grad():
- wav_16 = self.resample(wav)
- wav_embeddings = self.wavlm(
- input_values=wav_16, output_hidden_states=True
- ).hidden_states
- y_embeddings = (
- torch.stack(wav_embeddings, dim=1)
- .transpose(-1, -2)
- .flatten(start_dim=1, end_dim=2)
- )
-
- y_d_rs = self.wd(y_embeddings)
-
- return y_d_rs
|