|
- import os
- import copy
- import PIL.Image as Image
- import pandas as pd
- import numpy as np
- from torch.utils.data import Dataset
- import torch
- import torchvision
- import torch.nn as nn
- import torch.nn.functional as F
- from PIL import Image
- from matplotlib import pyplot as plt
- from torchvision import datasets
- from torch.utils.data import DataLoader
- import argparse
- import logging
- from torch.utils.tensorboard import SummaryWriter
- import torch.nn as nn
- from tqdm import tqdm
- from torch import optim
-
- path_csv = '/dataset/data_clean/data_clean(1).csv'
- label_csv = pd.read_csv(path_csv, engine='python').values
- path_image = '/dataset/faces/faces'
-
- class EMA:
- def __init__(self, beta):
- super().__init__()
- self.beta = beta
- self.step = 0
-
- def update_model_average(self, ma_model, current_model):
- for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
- old_weight, up_weight = ma_params.data, current_params.data
- ma_params.data = self.update_average(old_weight, up_weight)
-
- def update_average(self, old, new):
- if old is None:
- return new
- return old * self.beta + (1 - self.beta) * new
-
- def step_ema(self, ema_model, model, step_start_ema=2000):
- if self.step < step_start_ema:
- self.reset_parameters(ema_model, model)
- self.step += 1
- return
- self.update_model_average(ema_model, model)
- self.step += 1
-
- def reset_parameters(self, ema_model, model):
- ema_model.load_state_dict(model.state_dict())
-
-
- class SelfAttention(nn.Module):
- def __init__(self, channels, size):
- super(SelfAttention, self).__init__()
- self.channels = channels
- self.size = size
- self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
- self.ln = nn.LayerNorm([channels])
- self.ff_self = nn.Sequential(
- nn.LayerNorm([channels]),
- nn.Linear(channels, channels),
- nn.GELU(),
- nn.Linear(channels, channels),
- )
-
- def forward(self, x):
- x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
- x_ln = self.ln(x)
- attention_value, _ = self.mha(x_ln, x_ln, x_ln)
- attention_value = attention_value + x
- attention_value = self.ff_self(attention_value) + attention_value
- return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
-
-
- class DoubleConv(nn.Module):
- def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
- super().__init__()
- self.residual = residual
- if not mid_channels:
- mid_channels = out_channels
- self.double_conv = nn.Sequential(
- nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
- nn.GroupNorm(1, mid_channels),
- nn.GELU(),
- nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
- nn.GroupNorm(1, out_channels),
- )
-
- def forward(self, x):
- if self.residual:
- return F.gelu(x + self.double_conv(x))
- else:
- return self.double_conv(x)
-
-
- class Down(nn.Module):
- def __init__(self, in_channels, out_channels, emb_dim=256):
- super().__init__()
- self.maxpool_conv = nn.Sequential(
- nn.MaxPool2d(2),
- DoubleConv(in_channels, in_channels, residual=True),
- DoubleConv(in_channels, out_channels),
- )
-
- self.emb_layer = nn.Sequential(
- nn.SiLU(),
- nn.Linear(
- emb_dim,
- out_channels
- ),
- )
-
- def forward(self, x, t):
- x = self.maxpool_conv(x)
- emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
- return x + emb
-
-
- class Up(nn.Module):
- def __init__(self, in_channels, out_channels, emb_dim=256):
- super().__init__()
-
- self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- self.conv = nn.Sequential(
- DoubleConv(in_channels, in_channels, residual=True),
- DoubleConv(in_channels, out_channels, in_channels // 2),
- )
-
- self.emb_layer = nn.Sequential(
- nn.SiLU(),
- nn.Linear(
- emb_dim,
- out_channels
- ),
- )
-
- def forward(self, x, skip_x, t):
- x = self.up(x)
- x = torch.cat([skip_x, x], dim=1)
- x = self.conv(x)
- emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
- return x + emb
-
-
- class UNet_conditional(nn.Module):
- def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda"):
- super().__init__()
- self.device = device
- self.time_dim = time_dim
- self.inc = DoubleConv(c_in, 64)
- self.down1 = Down(64, 128)
- self.sa1 = SelfAttention(128, 32)
- self.down2 = Down(128, 256)
- self.sa2 = SelfAttention(256, 16)
- self.down3 = Down(256, 256)
- self.sa3 = SelfAttention(256, 8)
-
- self.bot1 = DoubleConv(256, 512)
- self.bot2 = DoubleConv(512, 512)
- self.bot3 = DoubleConv(512, 256)
-
- self.up1 = Up(512, 128)
- self.sa4 = SelfAttention(128, 16)
- self.up2 = Up(256, 64)
- self.sa5 = SelfAttention(64, 32)
- self.up3 = Up(128, 64)
- self.sa6 = SelfAttention(64, 64)
- self.outc = nn.Conv2d(64, c_out, kernel_size=1)
-
- if num_classes is not None:
- self.label_emb = nn.Embedding(num_classes, time_dim)
-
- def pos_encoding(self, t, channels):
- inv_freq = 1.0 / (
- 10000
- ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
- )
- pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
- pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
- pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
- return pos_enc
-
- def forward(self, x, t, y):
- t = t.unsqueeze(-1).type(torch.float)
- t = self.pos_encoding(t, self.time_dim)
-
- if y is not None:
- t += self.label_emb(y)
-
- x1 = self.inc(x)
- x2 = self.down1(x1, t)
- x2 = self.sa1(x2)
- x3 = self.down2(x2, t)
- x3 = self.sa2(x3)
- x4 = self.down3(x3, t)
- x4 = self.sa3(x4)
-
- x4 = self.bot1(x4)
- x4 = self.bot2(x4)
- x4 = self.bot3(x4)
-
- x = self.up1(x4, x3, t)
- x = self.sa4(x)
- x = self.up2(x, x2, t)
- x = self.sa5(x)
- x = self.up3(x, x1, t)
- x = self.sa6(x)
- output = self.outc(x)
- return output
-
- class FaceDataset(Dataset):
-
- def __init__(self, size=64, path='/dataset'):
- super(FaceDataset, self).__init__()
- self.dataset = []
- for i in label_csv:
- file_index = i[1]
- sample = {}
- sample['image_path'] = os.path.join(path, str(file_index) + '.jpg')
- i[3] = int(i[3].strip('[]'))
- i[4] = int(i[4].strip('[]'))
- sample['label'] = i[3]*12+i[4]
- # print(sample['label'])
- # sample['eye_label'] = i[3]
- # sample['hair_label'] = i[4]
- self.dataset.append(sample)
- self.train_transform = torchvision.transforms.Compose(
- [
- torchvision.transforms.Resize(size + int(0.25 * size)),
- torchvision.transforms.RandomCrop(size),
- torchvision.transforms.RandomHorizontalFlip(),
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ]
- )
-
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, item):
- image_path = self.dataset[item]['image_path']
- # image_eye_label = self.dataset[item]['eye_label']
- # image_hair_label = self.dataset[item]['hair_label']
- image_label = self.dataset[item]['label']
- image = Image.open(image_path)
- image_tensor = self.train_transform(image)
- return image_tensor, image_label
-
- def plot_images(images):
- plt.figure(figsize=(32, 32))
- plt.imshow(torch.cat([
- torch.cat([i for i in images.cpu()], dim=-1),
- ], dim=-2).permute(1, 2, 0).cpu())
- plt.show()
-
-
- def save_images(images, path, **kwargs):
- grid = torchvision.utils.make_grid(images, **kwargs)
- ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
- im = Image.fromarray(ndarr)
- im.save(path)
-
-
- def get_data(args):
- dataset = FaceDataset(size=args.image_size, path=path_image)
- dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
- return dataloader
-
-
- def setup_logging(run_name):
- os.makedirs("/model/models", exist_ok=True)
- os.makedirs("/model/results", exist_ok=True)
- os.makedirs("/model/models_ema", exist_ok=True)
- os.makedirs("/model/results_ema", exist_ok=True)
- os.makedirs("/model/optims", exist_ok=True)
- os.makedirs(os.path.join("/model/models", run_name), exist_ok=True)
- os.makedirs(os.path.join("/model/results", run_name), exist_ok=True)
- os.makedirs(os.path.join("/model/models_ema", run_name), exist_ok=True)
- os.makedirs(os.path.join("/model/results_ema", run_name), exist_ok=True)
- os.makedirs(os.path.join("/model/optims", run_name), exist_ok=True)
-
- logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
-
-
- class Diffusion:
- def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
- self.noise_steps = noise_steps
- self.beta_start = beta_start
- self.beta_end = beta_end
-
- self.beta = self.prepare_noise_schedule().to(device)
- self.alpha = 1. - self.beta
- self.alpha_hat = torch.cumprod(self.alpha, dim=0)
-
- self.img_size = img_size
- self.device = device
-
- def prepare_noise_schedule(self):
- return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
-
- def noise_images(self, x, t):
- sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
- sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
- Ɛ = torch.randn_like(x)
- return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ
-
- def sample_timesteps(self, n):
- return torch.randint(low=1, high=self.noise_steps, size=(n,))
-
- def sample(self, model, n, labels, cfg_scale=3):
- logging.info(f"Sampling {n} new images....")
- model.eval()
- with torch.no_grad():
- x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
- for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
- t = (torch.ones(n) * i).long().to(self.device)
- predicted_noise = model(x, t, labels)
- if cfg_scale > 0:
- uncond_predicted_noise = model(x, t, None)
- predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
- alpha = self.alpha[t][:, None, None, None]
- alpha_hat = self.alpha_hat[t][:, None, None, None]
- beta = self.beta[t][:, None, None, None]
- if i > 1:
- noise = torch.randn_like(x)
- else:
- noise = torch.zeros_like(x)
- x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
- model.train()
- x = (x.clamp(-1, 1) + 1) / 2
- x = (x * 255).type(torch.uint8)
- return x
-
-
- def train(args):
- setup_logging(args.run_name)
- device = args.device
- print(device)
- dataloader = get_data(args)
- torch.cuda.set_device(0)
- model = UNet_conditional(num_classes=args.num_classes, device=device).to(device)
- optimizer = optim.AdamW(model.parameters(), lr=args.lr)
- mse = nn.MSELoss()
- diffusion = Diffusion(img_size=args.image_size, device=device)
- logger = SummaryWriter(os.path.join("runs", args.run_name))
- l = len(dataloader)
- ema = EMA(0.995)
- ema_model = copy.deepcopy(model).eval().requires_grad_(False)
-
- lossList = []
- for epoch in range(args.epochs):
- Loss=0
- logging.info(f"Starting epoch {epoch}:")
- pbar = tqdm(dataloader)
- for i, (images, labels) in enumerate(pbar):
- images = images.to(device)
- labels = labels.to(device)
- t = diffusion.sample_timesteps(images.shape[0]).to(device)
- x_t, noise = diffusion.noise_images(images, t)
- # if np.random.random() < args.cond_threshold:
- # labels = None
- predicted_noise = model(x_t, t, labels)
- loss = mse(noise, predicted_noise)
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- ema.step_ema(ema_model, model)
-
- Loss+=loss.item()
-
- pbar.set_postfix(MSE=loss.item())
- logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
- print("Epoch: ", epoch, " Loss:", Loss/len(dataloader))
- lossList.append(Loss/len(dataloader))
-
- if epoch % 10 == 0:
- labels = torch.arange(16).long().to(device)
- sampled_images = diffusion.sample(model, n=len(labels), labels=labels)
- ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
- plot_images(sampled_images)
- save_images(sampled_images, os.path.join("/model/results", args.run_name, f"{epoch}.jpg"))
- save_images(ema_sampled_images, os.path.join("/model/results_ema", args.run_name, f"{epoch}_ema.jpg"))
- torch.save(model.state_dict(), os.path.join("/model/models", args.run_name, f"{epoch}_ckpt.pt"))
- torch.save(ema_model.state_dict(), os.path.join("/model/models_ema", args.run_name, f"{epoch}_ema_ckpt.pt"))
- torch.save(optimizer.state_dict(), os.path.join("/model/optims", args.run_name, f"{epoch}_optim.pt"))
- lst = [k for k in range(args.epochs)]
- plt.title('Loss-Epoch')
- plt.plot(lst, lossList)
- plt.savefig('/model/results/loss.jpg')
-
-
- def launch():
- import argparse
- parser = argparse.ArgumentParser()
- # args = parser.parse_args()
- args =parser.parse_known_args()[0]
- args.run_name = "DDPM_conditional"
- args.epochs = 800
- args.batch_size = 12
- args.image_size = 64
- args.num_classes = 12*11
- args.dataset_name = "faces"
- args.device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
- args.lr = 1e-5
- train(args)
-
-
- if __name__ == '__main__':
- launch()
|