|
- import numpy as np
- from torchvision.utils import save_image
- from matplotlib import pyplot as plt
- import os
- from torch.autograd import Variable
-
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
-
- from src.dataset import DataLoader
- from src.config import train_cfg as cfg
- from src.fid import calculate_fid_given_paths
-
-
- img_shape = (cfg.channels, cfg.img_size, cfg.img_size)
-
- cuda = True if torch.cuda.is_available() else False
- #print("cuda is_available: ",cuda,"----------------------------------------------------")
-
- fid_list = []
- D_loss_list = []
- G_loss_list = []
- batch_list = []
- # 生成器
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
-
- def block(in_feat, out_feat, normalize=True):
- layers = [nn.Linear(in_feat, out_feat)]
- if normalize:
- layers.append(nn.BatchNorm1d(out_feat, 0.8))
- layers.append(nn.LeakyReLU(0.2, inplace=True))
- return layers
-
- self.model = nn.Sequential(
- *block(cfg.latent_dim, 128, normalize=False),
- *block(128, 256),
- *block(256, 512),
- *block(512, 1024),
- nn.Linear(1024, int(np.prod(img_shape))),
- nn.Tanh()
- )
-
- def forward(self, z):
- img = self.model(z)
- img = img.view(img.size(0), *img_shape)
- return img
- # 判别器
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
-
- self.model = nn.Sequential(
- nn.Linear(int(np.prod(img_shape)), 512),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(512, 256),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(256, 1),
- nn.Sigmoid(),
- )
-
- def forward(self, img):
- img_flat = img.view(img.size(0), -1)
- validity = self.model(img_flat)
-
- return validity
-
- # 训练 函数
- def train(dataloader,optimizer_G,optimizer_D,generator,discriminator,adversarial_loss):
- '''
- :param dataloader: 数据加载
- :param optimizer_G: 用于训练生成器的优化器
- :param optimizer_D: 用于训练判别器的优化器
- :param generator: 生成器
- :param discriminator: 判别器
- :param adversarial_loss: 损失函数
- :return: 生成器,判别器
- '''
- for epoch in range(cfg.epoch_size):
- for batch, (imgs, _) in enumerate(dataloader):
-
- # Adversarial ground truths
- valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
- fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
-
- # Configure input
- real_imgs = Variable(imgs.type(Tensor))
-
- # -----------------
- # Train Generator
- # -----------------
-
- optimizer_G.zero_grad()
-
- # Sample noise as generator input
- z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], cfg.latent_dim))))
-
- # Generate a batch of images
- gen_imgs = generator(z)
-
- # 生成器损失
- g_loss = adversarial_loss(discriminator(gen_imgs), valid)
-
- g_loss.backward()
- optimizer_G.step()
-
- # ---------------------
- # 训练判别器 Train Discriminator
- # ---------------------
-
- optimizer_D.zero_grad()
-
- # 判别器损失
- real_loss = adversarial_loss(discriminator(real_imgs), valid)
- fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
- d_loss = (real_loss + fake_loss) / 2
-
- d_loss.backward()
- optimizer_D.step()
-
- batches_done = epoch * len(dataloader) + batch
- if batches_done and batches_done % cfg.sample_interval == 0:
- print(
- "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
- % (epoch, cfg.epoch_size, batch, len(dataloader), d_loss.item(), g_loss.item())
- )
-
- i = 1
- if not os.path.exists(cfg.save_img_path+"batch%d"%batches_done):
- os.mkdir(cfg.save_img_path+"batch%d"%batches_done)
- for data in gen_imgs.data:
- save_image(data, cfg.save_img_path+"batch%d/%d.png" % (batches_done,i), normalize=True)
- i+=1
- fid = calculate_fid_given_paths([cfg.data_path, cfg.save_img_path+"batch%d/"% batches_done], cfg.batch_size, device=device, dims=64)
- print("fid_score: ", fid)
-
- fid_list.append(float(fid))
- D_loss_list.append(float(d_loss))
- G_loss_list.append(float(g_loss))
- batch_list.append(float(batches_done))
- pass
-
- return generator,discriminator
-
- # 验证 函数
- def eval(generator,ouput_path=cfg.ouput_path):
-
- dataloader = DataLoader(batch_size=1)
- if cuda:
- device = 'cuda'
- generator.cuda()
- generator.to(device)
-
- generator.eval()
- for batch,(img,_) in enumerate(dataloader):
- if batch > 1000:
- pass
- else:
- img = Variable(img.type(Tensor))
- z = Variable(Tensor(np.random.normal(0, 1, (img.shape[0], cfg.latent_dim))))
- gen_img = generator(z)
- save_image(gen_img, ouput_path+"%d.png" % batch, normalize=True)
- fid = calculate_fid_given_paths([cfg.data_path, cfg.ouput_path], cfg.batch_size, device=device, dims=64)
- print("最终生成的1k张图片 fid_score: ", fid,"------------------------------")
-
- def plot(fid_list, batch_list, D_loss_list,G_loss_list):
- fig = plt.figure()
- ax1 = fig.add_subplot(2, 2, 1)
- ax2 = fig.add_subplot(2, 2, 2)
- ax3 = fig.add_subplot(2, 2, 3)
- ax4 = fig.add_subplot(2, 2, 4)
-
- ax1.plot(batch_list,fid_list,label='ax1')
- #ax1.legend()
- ax1.set_title('fid score')
-
- ax2.plot(batch_list, G_loss_list)
- #ax2.legend()
- ax2.set_title('G_loss')
-
- ax3.plot(batch_list, D_loss_list)
- #ax3.legend()
- ax3.set_title('D_loss')
-
- plt.show()
-
-
- if __name__ =="__main__":
- if not os.path.exists(cfg.save_img_path):
- os.mkdir(cfg.save_img_path)
- if not os.path.exists(cfg.ouput_path):
- os.mkdir(cfg.ouput_path)
-
- # Loss function
- adversarial_loss = torch.nn.BCELoss()
-
- # 初始化 generator 和 discriminator
- generator = Generator()
- discriminator = Discriminator()
-
- if cuda:
- device = 'cuda'
- generator.cuda()
- discriminator.cuda()
- adversarial_loss.cuda()
-
- # 数据加载
- dataloader = DataLoader()
-
- # 优化器
- optimizer_G = torch.optim.Adam(generator.parameters(), lr=cfg.learning_rate, betas=(cfg.b1, cfg.b2))
- optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=cfg.learning_rate, betas=(cfg.b1, cfg.b2))
-
- Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
-
- # Training
- generator,discriminator = train(dataloader,optimizer_G,optimizer_D,generator,discriminator,adversarial_loss)
- eval(generator)
- plot(fid_list, batch_list, D_loss_list,G_loss_list)
|