|
- """
- Implementation of Base SSD-GAN models.
- """
- import torch
-
- from torch_mimicry.nets.basemodel import basemodel
- from torch_mimicry.modules import losses
- import numpy as np
-
-
- class SSD_Generator(basemodel.BaseModel):
- r"""
- Base class for a generic unconditional generator model.
-
- Attributes:
- nz (int): Noise dimension for upsampling.
- ngf (int): Variable controlling generator feature map sizes.
- bottom_width (int): Starting width for upsampling generator output to an image.
- loss_type (str): Name of loss to use for GAN loss.
- """
- def __init__(self, nz, ngf, bottom_width, loss_type, **kwargs):
- super().__init__(**kwargs)
- self.nz = nz
- self.ngf = ngf
- self.bottom_width = bottom_width
- self.loss_type = loss_type
-
- # def generate_images(self, netG, num_images, device=None):
- def generate_images(self, num_images, device=None):
- r"""
- Generates num_images randomly.
-
- Args:
- num_images (int): Number of images to generate
- device (torch.device): Device to send images to.
-
- Returns:
- Tensor: A batch of generated images.
- """
- if device is None:
- device = self.device
-
- noise = torch.randn((num_images, self.nz), device=device)
- # fake_images = netG.forward(noise)
- fake_images = self.forward(noise)
-
- return fake_images
-
- def compute_gan_loss(self, output):
- r"""
- Computes GAN loss for generator.
-
- Args:
- output (Tensor): A batch of output logits from the discriminator of shape (N, 1).
-
- Returns:
- Tensor: A batch of GAN losses for the generator.
- """
- # Compute loss and backprop
- if self.loss_type == "gan":
- errG = losses.minimax_loss_gen(output)
-
- elif self.loss_type == "ns":
- errG = losses.ns_loss_gen(output)
-
- elif self.loss_type == "hinge":
- errG = losses.hinge_loss_gen(output)
-
- elif self.loss_type == "wasserstein":
- errG = losses.wasserstein_loss_gen(output)
-
- else:
- raise ValueError("Invalid loss_type {} selected.".format(
- self.loss_type))
-
- return errG
-
- def train_step(self,
- real_batch,
- netD,
- optG,
- log_data,
- device=None,
- global_step=None,
- **kwargs):
- r"""
- Takes one training step for G.
-
- Args:
- real_batch (Tensor): A batch of real images of shape (N, C, H, W).
- Used for obtaining current batch size.
- netD (nn.Module): Discriminator model for obtaining losses.
- optG (Optimizer): Optimizer for updating generator's parameters.
- log_data (dict): A dict mapping name to values for logging uses.
- device (torch.device): Device to use for running the model.
- global_step (int): Variable to sync training, logging and checkpointing.
- Useful for dynamic changes to model amidst training.
-
- Returns:
- Returns MetricLog object containing updated logging variables after 1 training step.
-
- """
- self.zero_grad()
-
- # Get only batch size from real batch
- batch_size = real_batch[0].shape[0]
-
- # Produce fake images
- fake_images = self.generate_images(num_images=batch_size,
- device=device)
-
- # Compute output logit of D thinking image real
- out_spectral, out_spatial = netD(fake_images)
-
- # Compute loss
- out = 0.5 * out_spectral.detach() + 0.5 * out_spatial
- errG = self.compute_gan_loss(out)
-
- # Backprop and update gradients
- errG.backward()
- optG.step()
-
- # Log statistics
- log_data.add_metric('errG', errG, group='loss')
-
- return log_data
-
-
- class SSD_Discriminator(basemodel.BaseModel):
- r"""
- Base class for a generic unconditional discriminator model.
-
- Attributes:
- ndf (int): Variable controlling discriminator feature map sizes.
- loss_type (str): Name of loss to use for GAN loss.
- """
- def __init__(self, ndf, loss_type, **kwargs):
- super().__init__(**kwargs)
- self.ndf = ndf
- self.loss_type = loss_type
-
- def compute_gan_loss(self, output_real, output_fake):
- r"""
- Computes GAN loss for discriminator.
-
- Args:
- output_real (Tensor): A batch of output logits of shape (N, 1) from real images.
- output_fake (Tensor): A batch of output logits of shape (N, 1) from fake images.
-
- Returns:
- errD (Tensor): A batch of GAN losses for the discriminator.
- """
- # Compute loss for D
- if self.loss_type == "gan" or self.loss_type == "ns":
- errD = losses.minimax_loss_dis(output_fake=output_fake,
- output_real=output_real)
-
- elif self.loss_type == "hinge":
- errD = losses.hinge_loss_dis(output_fake=output_fake,
- output_real=output_real)
-
- elif self.loss_type == "wasserstein":
- errD = losses.wasserstein_loss_dis(output_fake=output_fake,
- output_real=output_real)
-
- else:
- raise ValueError("Invalid loss_type selected.")
-
- return errD
-
- def compute_probs(self, output_real, output_fake):
- r"""
- Computes probabilities from real/fake images logits.
-
- Args:
- output_real (Tensor): A batch of output logits of shape (N, 1) from real images.
- output_fake (Tensor): A batch of output logits of shape (N, 1) from fake images.
-
- Returns:
- tuple: Average probabilities of real/fake image considered as real for the batch.
- """
- D_x = torch.sigmoid(output_real).mean().item()
- D_Gz = torch.sigmoid(output_fake).mean().item()
-
- return D_x, D_Gz
-
- def train_step(self,
- real_batch,
- netG,
- optD,
- log_data,
- device=None,
- global_step=None,
- **kwargs):
- r"""
- Takes one training step for D.
-
- Args:
- real_batch (Tensor): A batch of real images of shape (N, C, H, W).
- loss_type (str): Name of loss to use for GAN loss.
- netG (nn.Module): Generator model for obtaining fake images.
- optD (Optimizer): Optimizer for updating discriminator's parameters.
- device (torch.device): Device to use for running the model.
- log_data (dict): A dict mapping name to values for logging uses.
- global_step (int): Variable to sync training, logging and checkpointing.
- Useful for dynamic changes to model amidst training.
-
- Returns:
- MetricLog: Returns MetricLog object containing updated logging variables after 1 training step.
- """
- self.zero_grad()
- real_images, real_labels = real_batch
- batch_size = real_images.shape[0] # Match batch sizes for last iter
-
- # Produce logits for real images
- out_spectral_real, out_spatial_real = self.forward(real_images)
-
- # Produce fake images
- fake_images = netG.generate_images(num_images=batch_size,
- device=device).detach()
-
- # Produce logits for fake images
- out_spectral_fake, out_spatial_fake = self.forward(fake_images)
-
- # Compute loss for D
- errC = self.compute_gan_loss(output_real=out_spectral_real,
- output_fake=out_spectral_fake)
-
- out_real = 0.5 * out_spectral_real.detach() + 0.5 * out_spatial_real
- out_fake = 0.5 * out_spectral_fake.detach() + 0.5 * out_spatial_fake
- errD = self.compute_gan_loss(output_real=out_real,
- output_fake=out_fake)
-
- # Backprop and update gradients
- errD_total = errD + errC
- errD_total.backward()
- optD.step()
-
- # Compute probabilities
- D_x, D_Gz = out_real.mean().item(), out_fake.mean().item()
-
- # Log statistics for D once out of loop
- log_data.add_metric('errD', errD.item(), group='loss')
- log_data.add_metric('errC', errC.item(), group='loss')
- log_data.add_metric('D(x)', D_x, group='prob')
- log_data.add_metric('D(G(z))', D_Gz, group='prob')
-
- return log_data
|