|
- import torch
- from torch.nn import functional as F
-
- from dassl.optim import build_optimizer, build_lr_scheduler
- from dassl.utils import count_num_param
- from dassl.engine import TRAINER_REGISTRY, TrainerX
- from dassl.modeling import build_network
- from dassl.engine.trainer import SimpleNet
-
-
- @TRAINER_REGISTRY.register()
- class DDAIG(TrainerX):
- """Deep Domain-Adversarial Image Generation.
-
- https://arxiv.org/abs/2003.06054.
- """
-
- def __init__(self, cfg):
- super().__init__(cfg)
- self.lmda = cfg.TRAINER.DDAIG.LMDA
- self.clamp = cfg.TRAINER.DDAIG.CLAMP
- self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
- self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
- self.warmup = cfg.TRAINER.DDAIG.WARMUP
- self.alpha = cfg.TRAINER.DDAIG.ALPHA
-
- def build_model(self):
- cfg = self.cfg
-
- print("Building F")
- self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
- self.F.to(self.device)
- print("# params: {:,}".format(count_num_param(self.F)))
- self.optim_F = build_optimizer(self.F, cfg.OPTIM)
- self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
- self.register_model("F", self.F, self.optim_F, self.sched_F)
-
- print("Building D")
- self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
- self.D.to(self.device)
- print("# params: {:,}".format(count_num_param(self.D)))
- self.optim_D = build_optimizer(self.D, cfg.OPTIM)
- self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
- self.register_model("D", self.D, self.optim_D, self.sched_D)
-
- print("Building G")
- self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
- self.G.to(self.device)
- print("# params: {:,}".format(count_num_param(self.G)))
- self.optim_G = build_optimizer(self.G, cfg.OPTIM)
- self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
- self.register_model("G", self.G, self.optim_G, self.sched_G)
-
- def forward_backward(self, batch):
- input, label, domain = self.parse_batch_train(batch)
-
- #############
- # Update G
- #############
- input_p = self.G(input, lmda=self.lmda)
- if self.clamp:
- input_p = torch.clamp(
- input_p, min=self.clamp_min, max=self.clamp_max
- )
- loss_g = 0
- # Minimize label loss
- loss_g += F.cross_entropy(self.F(input_p), label)
- # Maximize domain loss
- loss_g -= F.cross_entropy(self.D(input_p), domain)
- self.model_backward_and_update(loss_g, "G")
-
- # Perturb data with new G
- with torch.no_grad():
- input_p = self.G(input, lmda=self.lmda)
- if self.clamp:
- input_p = torch.clamp(
- input_p, min=self.clamp_min, max=self.clamp_max
- )
-
- #############
- # Update F
- #############
- loss_f = F.cross_entropy(self.F(input), label)
- if (self.epoch + 1) > self.warmup:
- loss_fp = F.cross_entropy(self.F(input_p), label)
- loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
- self.model_backward_and_update(loss_f, "F")
-
- #############
- # Update D
- #############
- loss_d = F.cross_entropy(self.D(input), domain)
- self.model_backward_and_update(loss_d, "D")
-
- loss_summary = {
- "loss_g": loss_g.item(),
- "loss_f": loss_f.item(),
- "loss_d": loss_d.item(),
- }
-
- if (self.batch_idx + 1) == self.num_batches:
- self.update_lr()
-
- return loss_summary
-
- def model_inference(self, input):
- return self.F(input)
|