|
- import numpy as np
- from torchvision.utils import save_image
- import random
- from torch.autograd import Variable
-
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
-
- from src.dataset import DataLoader
- from src.config import train_cfg as cfg
- from src.fid import calculate_fid_given_paths
-
-
- img_shape = (cfg.channels, cfg.img_size, cfg.img_size)
-
- cuda = True if torch.cuda.is_available() else False
- print("cuda is_available: ",cuda,"----------------------------------------------------")
-
- # 初始化网络权重
- def weights_init_normal(m):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
- if hasattr(m, "bias") and m.bias is not None:
- torch.nn.init.constant_(m.bias.data, 0.0)
- elif classname.find("BatchNorm2d") != -1:
- torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
- torch.nn.init.constant_(m.bias.data, 0.0)
-
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
-
- self.init_size = cfg.img_size // 4
- self.l1 = nn.Sequential(nn.Linear(cfg.latent_dim, 128 * self.init_size ** 2))
-
- self.conv_blocks = nn.Sequential(
- nn.Upsample(scale_factor=2),
- nn.Conv2d(128, 128, 3, stride=1, padding=1),
- nn.BatchNorm2d(128, 0.8),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Upsample(scale_factor=2),
- nn.Conv2d(128, 64, 3, stride=1, padding=1),
- nn.BatchNorm2d(64, 0.8),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(64, cfg.channels, 3, stride=1, padding=1),
- nn.Tanh(),
- )
-
- def forward(self, z):
- out = self.l1(z)
- out = out.view(out.shape[0], 128, self.init_size, self.init_size)
- img = self.conv_blocks(out)
- return img
-
-
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
-
- def discriminator_block(in_filters, out_filters, bn=True):
- block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
- if bn:
- block.append(nn.BatchNorm2d(out_filters, 0.8))
- return block
-
- self.model = nn.Sequential(
- *discriminator_block(cfg.channels, 16, bn=False),
- *discriminator_block(16, 32),
- *discriminator_block(32, 64),
- *discriminator_block(64, 128),
- )
-
- # The height and width of downsampled image
- ds_size = cfg.img_size // 2 ** 4
- self.adv_layer = nn.Linear(128 * ds_size ** 2, 1)
-
- def forward(self, img):
- out = self.model(img)
- out = out.view(out.shape[0], -1)
- validity = self.adv_layer(out)
-
- return validity
-
- # 训练 函数
- def train(dataloader,optimizer_G,optimizer_D,generator,discriminator,adversarial_loss):
- '''
- :param dataloader: 数据加载
- :param optimizer_G: 用于训练生成器的优化器
- :param optimizer_D: 用于训练判别器的优化器
- :param generator: 生成器
- :param discriminator: 判别器
- :param adversarial_loss: 损失函数
- :return: 生成器,判别器
- '''
-
- generator.apply(weights_init_normal)
- discriminator.apply(weights_init_normal)
-
- for epoch in range(cfg.epoch_size):
- for batch, (imgs, _) in enumerate(dataloader):
-
- # Adversarial ground truths
- valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
- fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
-
- # Configure input
- real_imgs = Variable(imgs.type(Tensor))
-
- # -----------------
- # Train Generator
- # -----------------
-
- optimizer_G.zero_grad()
-
- # Sample noise as generator input
- z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], cfg.latent_dim))))
-
- # Generate a batch of images
- gen_imgs = generator(z)
-
- # 生成器损失
- g_loss = adversarial_loss(discriminator(gen_imgs), valid)
-
- g_loss.backward()
- optimizer_G.step()
-
- # ---------------------
- # 训练判别器 Train Discriminator
- # ---------------------
-
- optimizer_D.zero_grad()
-
- # 判别器损失
- real_loss = adversarial_loss(discriminator(real_imgs), valid)
- fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
- d_loss = (real_loss + fake_loss) / 2
-
- d_loss.backward()
- optimizer_D.step()
-
- batches_done = epoch * len(dataloader) + batch
- if batches_done and batches_done % cfg.sample_interval == 0:
- print(
- "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
- % (epoch, cfg.epoch_size, batch, len(dataloader), d_loss.item(), g_loss.item())
- )
- i = 1
- for data in gen_imgs.data:
- save_image(data, cfg.save_img_path+"batch%d_%d.png" % (batches_done,i), normalize=True)
- i+=1
- fid = calculate_fid_given_paths([cfg.data_path, cfg.save_img_path], cfg.batch_size, device=device, dims=64)
- print("训练过程 average fid_score: ", fid)
- pass
-
- return generator,discriminator
-
- # 验证 函数
- def eval(generator,ouput_path=cfg.ouput_path):
-
- dataloader = DataLoader(batch_size=1)
- if cuda:
- device = 'cuda'
- generator.cuda()
- generator.to(device)
-
- generator.eval()
- for batch,(img,_) in enumerate(dataloader):
- if batch > 1000:
- pass
- else:
- img = Variable(img.type(Tensor))
- z = Variable(Tensor(np.random.normal(0, 1, (img.shape[0], cfg.latent_dim))))
- gen_img = generator(z)
- save_image(gen_img, ouput_path+"%d.png" % batch, normalize=True)
- fid = calculate_fid_given_paths([cfg.data_path, cfg.ouput_path], cfg.batch_size, device=device, dims=64)
- print("最终生成的1k张图片 fid_score: ", fid,"------------------------------")
-
-
- if __name__ =="__main__":
- # Loss function
- adversarial_loss = torch.nn.MSELoss()
-
- # 初始化 generator 和 discriminator
- generator = Generator()
- discriminator = Discriminator()
-
- if cuda:
- device = 'cuda'
- generator.cuda()
- discriminator.cuda()
- adversarial_loss.cuda()
-
- # 数据加载
- dataloader = DataLoader()
-
- # 优化器
- optimizer_G = torch.optim.Adam(generator.parameters(), lr=cfg.learning_rate, betas=(cfg.b1, cfg.b2))
- optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=cfg.learning_rate, betas=(cfg.b1, cfg.b2))
-
- Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
-
- # Training
- generator,discriminator = train(dataloader,optimizer_G,optimizer_D,generator,discriminator,adversarial_loss)
- eval(generator)
|