|
- import numpy as np
- import itertools
- from torchvision.utils import save_image
-
- from torch.autograd import Variable
- import torch.nn as nn
- import torch
-
- from src.dataset import DataLoader
- from src.config import train_cfg as cfg
- from src.fid import calculate_fid_given_paths
-
-
- img_shape = (cfg.channels, cfg.img_size, cfg.img_size)
- cuda = True if torch.cuda.is_available() else False
- print("cuda is available: ",cuda,"----------------------------------------------------")
-
- # 生成隐含向量z
- def reparameterization(mu, logvar):
- std = torch.exp(logvar / 2)
- sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), cfg.latent_dim))))
- z = sampled_z * std + mu
- return z
-
- # 编码器
- class Encoder(nn.Cell):
- def __init__(self):
- super(Encoder, self).__init__()
-
- self.model = nn.Sequential(
- nn.Linear(int(np.prod(img_shape)), 512),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(512, 512),
- nn.BatchNorm1d(512),
- nn.LeakyReLU(0.2, inplace=True),
- )
-
- self.mu = nn.Linear(512, cfg.latent_dim)
- self.logvar = nn.Linear(512, cfg.latent_dim)
-
- def construct(self, img):
- img_flat = img.view(img.shape[0], -1)
- x = self.model(img_flat)
- mu = self.mu(x)
- logvar = self.logvar(x)
- z = reparameterization(mu, logvar)
- return z
-
- # 编码器
- class Decoder(nn.Cell):
- def __init__(self):
- super(Decoder, self).__init__()
-
- self.model = nn.Sequential(
- nn.Linear(cfg.latent_dim, 512),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(512, 512),
- nn.BatchNorm1d(512),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(512, int(np.prod(img_shape))),
- nn.Tanh(),
- )
-
- def construct(self, z):
- img_flat = self.model(z)
- img = img_flat.view(img_flat.shape[0], *img_shape)
- return img
-
- # 判别器
- class Discriminator(nn.Cell):
- def __init__(self):
- super(Discriminator, self).__init__()
-
- self.model = nn.Sequential(
- nn.Linear(cfg.latent_dim, 512),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(512, 256),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(256, 1),
- nn.Sigmoid(),
- )
-
- def construct(self, z):
- validity = self.model(z)
- return validity
-
- # 训练 函数
- def train(dataloader,optimizer_G,optimizer_D,encoder,decoder,discriminator,adversarial_loss,pixelwise_loss):
- '''
- :param dataloader: 数据加载
- :param optimizer_G: 用于训练生成器的优化器
- :param optimizer_D: 用于训练判别器的优化器
- :param encoder: 编码器
- :param decoder: 解码器
- :param discriminator: 判别器
- :param adversarial_loss: 损失函数
- :param pixelwise_loss: 损失函数
- :return: 编码器,解码器,判别器
- '''
- for epoch in range(cfg.epoch_size):
- for batch, (imgs, _) in enumerate(dataloader):
-
- # Adversarial ground truths
- valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
- fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)
-
- # Configure input
- real_imgs = Variable(imgs.type(Tensor))
-
- # 训练生成器 Train Generator --------------------------------------------------------------------------------
-
- optimizer_G.zero_grad()
-
- encoded_imgs = encoder(real_imgs)
- decoded_imgs = decoder(encoded_imgs)
-
- g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * pixelwise_loss(decoded_imgs, real_imgs)
-
- g_loss.backward()
- optimizer_G.step()
-
- # 训练判别器 Train Discriminator -----------------------------------------------------------------------------
-
- optimizer_D.zero_grad()
-
- # Sample noise as discriminator ground truth
- z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], cfg.latent_dim))))
-
- # Measure discriminator's ability to classify real from generated samples
- real_loss = adversarial_loss(discriminator(z), valid)
- fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
- d_loss = 0.5 * (real_loss + fake_loss)
-
- d_loss.backward()
- optimizer_D.step()
-
- batches_done = epoch * len(dataloader) + batch
- if batches_done and batches_done % cfg.sample_interval == 0:
- print(
- "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
- % (epoch, cfg.epoch_size, batch, len(dataloader), d_loss.item(), g_loss.item())
- )
- i = 1
- for data in decoded_imgs.data:
- save_image(data, cfg.save_img_path+"batch%d_%d.png" % (batches_done, i), normalize=True)
- i += 1
- fid = calculate_fid_given_paths([cfg.data_path, cfg.save_img_path], cfg.batch_size, device=device,dims=64)
- print("训练过程 average fid_score: ", fid)
- pass
-
- return encoder,decoder,discriminator
-
- # 验证 函数
- def eval(encoder,decoder,ouput_path=cfg.ouput_path):
-
- dataloader = DataLoader(batch_size=1)
- encoder.eval()
- decoder.eval()
-
- for batch,(img,_) in enumerate(dataloader):
- if batch > 1000:
- pass
- else:
-
- img = Variable(img.type(Tensor))
- encoded_imgs = encoder(img)
- decoded_imgs = decoder(encoded_imgs)
- save_image(decoded_imgs, ouput_path+"%d.png" % batch, normalize=True)
- fid = calculate_fid_given_paths([cfg.data_path, cfg.ouput_path], cfg.batch_size, device=device, dims=64)
- print("最终生成的1k张图片 fid_socre: ", fid,"------------------------------")
-
-
- if __name__ == "__main__":
- # 二值交叉熵损失 binary cross-entropy loss
- adversarial_loss = torch.nn.BCELoss()
- pixelwise_loss = torch.nn.L1Loss()
-
- # 初始化生成器和辨别器 generator and discriminator
- encoder = Encoder()
- decoder = Decoder()
- discriminator = Discriminator()
-
- if cuda:
- device = 'cuda'
- encoder.cuda()
- decoder.cuda()
- discriminator.cuda()
- adversarial_loss.cuda()
- pixelwise_loss.cuda()
-
- # 加载数据集
- dataloader = DataLoader()
-
- # 设置优化器 Optimizers
- optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(),
- decoder.parameters()), lr=cfg.learning_rate, betas=(cfg.b1, cfg.b2))
- optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=cfg.learning_rate, betas=(cfg.b1, cfg.b2))
-
- Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
-
-
- # 开始训练 Training
- encoder,decoder,discriminator = train(dataloader,optimizer_G,optimizer_D,encoder,decoder,discriminator,adversarial_loss,pixelwise_loss)
- # 评估最终生成1千张图片的fid得分
- eval(encoder,decoder)
|