Browse Source

上传文件至 ''

master
1073270530 2 months ago
parent
commit
8f8e7d5dac
1 changed files with 412 additions and 0 deletions
  1. +412
    -0
      a(2).py

+ 412
- 0
a(2).py View File

@@ -0,0 +1,412 @@
import os
import copy
import PIL.Image as Image
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import datasets
from torch.utils.data import DataLoader
import argparse
import logging
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from tqdm import tqdm
from torch import optim

path_csv = '/code/data_clean(1).csv'
label_csv = pd.read_csv(path_csv, engine='python').values
path_image = '/dataset'

class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
self.step = 0

def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)

def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new

def step_ema(self, ema_model, model, step_start_ema=2000):
if self.step < step_start_ema:
self.reset_parameters(ema_model, model)
self.step += 1
return
self.update_model_average(ema_model, model)
self.step += 1

def reset_parameters(self, ema_model, model):
ema_model.load_state_dict(model.state_dict())


class SelfAttention(nn.Module):
def __init__(self, channels, size):
super(SelfAttention, self).__init__()
self.channels = channels
self.size = size
self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)

def forward(self, x):
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
super().__init__()
self.residual = residual
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, mid_channels),
nn.GELU(),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, out_channels),
)

def forward(self, x):
if self.residual:
return F.gelu(x + self.double_conv(x))
else:
return self.double_conv(x)


class Down(nn.Module):
def __init__(self, in_channels, out_channels, emb_dim=256):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels),
)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_dim,
out_channels
),
)

def forward(self, x, t):
x = self.maxpool_conv(x)
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb


class Up(nn.Module):
def __init__(self, in_channels, out_channels, emb_dim=256):
super().__init__()

self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = nn.Sequential(
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels, in_channels // 2),
)

self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_dim,
out_channels
),
)

def forward(self, x, skip_x, t):
x = self.up(x)
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb


class UNet_conditional(nn.Module):
def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda"):
super().__init__()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64)
self.down1 = Down(64, 128)
self.sa1 = SelfAttention(128, 32)
self.down2 = Down(128, 256)
self.sa2 = SelfAttention(256, 16)
self.down3 = Down(256, 256)
self.sa3 = SelfAttention(256, 8)

self.bot1 = DoubleConv(256, 512)
self.bot2 = DoubleConv(512, 512)
self.bot3 = DoubleConv(512, 256)

self.up1 = Up(512, 128)
self.sa4 = SelfAttention(128, 16)
self.up2 = Up(256, 64)
self.sa5 = SelfAttention(64, 32)
self.up3 = Up(128, 64)
self.sa6 = SelfAttention(64, 64)
self.outc = nn.Conv2d(64, c_out, kernel_size=1)

if num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_dim)

def pos_encoding(self, t, channels):
inv_freq = 1.0 / (
10000
** (torch.arange(0, channels, 2, device=self.device).float() / channels)
)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc

def forward(self, x, t, y):
t = t.unsqueeze(-1).type(torch.float)
t = self.pos_encoding(t, self.time_dim)

if y is not None:
t += self.label_emb(y)

x1 = self.inc(x)
x2 = self.down1(x1, t)
x2 = self.sa1(x2)
x3 = self.down2(x2, t)
x3 = self.sa2(x3)
x4 = self.down3(x3, t)
x4 = self.sa3(x4)

x4 = self.bot1(x4)
x4 = self.bot2(x4)
x4 = self.bot3(x4)

x = self.up1(x4, x3, t)
x = self.sa4(x)
x = self.up2(x, x2, t)
x = self.sa5(x)
x = self.up3(x, x1, t)
x = self.sa6(x)
output = self.outc(x)
return output

class FaceDataset(Dataset):

def __init__(self, size=64, path='/dataset'):
super(FaceDataset, self).__init__()
self.dataset = []
for i in label_csv:
file_index = i[1]
sample = {}
sample['image_path'] = os.path.join(path, str(file_index) + '.jpg')
i[3] = int(i[3].strip('[]'))
i[4] = int(i[4].strip('[]'))
sample['label'] = i[3]*12+i[4]
# print(sample['label'])
# sample['eye_label'] = i[3]
# sample['hair_label'] = i[4]
self.dataset.append(sample)
self.train_transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(size + int(0.25 * size)),
torchvision.transforms.RandomCrop(size),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)

def __len__(self):
return len(self.dataset)

def __getitem__(self, item):
image_path = self.dataset[item]['image_path']
# image_eye_label = self.dataset[item]['eye_label']
# image_hair_label = self.dataset[item]['hair_label']
image_label = self.dataset[item]['label']
image = Image.open(image_path)
image_tensor = self.train_transform(image)
return image_tensor, image_label

def plot_images(images):
plt.figure(figsize=(32, 32))
plt.imshow(torch.cat([
torch.cat([i for i in images.cpu()], dim=-1),
], dim=-2).permute(1, 2, 0).cpu())
plt.show()


def save_images(images, path, **kwargs):
grid = torchvision.utils.make_grid(images, **kwargs)
ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
im = Image.fromarray(ndarr)
im.save(path)


def get_data(args):
dataset = FaceDataset(size=args.image_size, path=path_image)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
return dataloader


def setup_logging(run_name):
os.makedirs("/model/models", exist_ok=True)
os.makedirs("/model/results", exist_ok=True)
os.makedirs("/model/models_ema", exist_ok=True)
os.makedirs("/model/results_ema", exist_ok=True)
os.makedirs("/model/optims", exist_ok=True)
os.makedirs(os.path.join("/model/models", run_name), exist_ok=True)
os.makedirs(os.path.join("/model/results", run_name), exist_ok=True)
os.makedirs(os.path.join("/model/models_ema", run_name), exist_ok=True)
os.makedirs(os.path.join("/model/results_ema", run_name), exist_ok=True)
os.makedirs(os.path.join("/model/optims", run_name), exist_ok=True)

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")


class Diffusion:
def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
self.noise_steps = noise_steps
self.beta_start = beta_start
self.beta_end = beta_end

self.beta = self.prepare_noise_schedule().to(device)
self.alpha = 1. - self.beta
self.alpha_hat = torch.cumprod(self.alpha, dim=0)

self.img_size = img_size
self.device = device

def prepare_noise_schedule(self):
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

def noise_images(self, x, t):
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
Ɛ = torch.randn_like(x)
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

def sample_timesteps(self, n):
return torch.randint(low=1, high=self.noise_steps, size=(n,))

def sample(self, model, n, labels, cfg_scale=3):
logging.info(f"Sampling {n} new images....")
model.eval()
with torch.no_grad():
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
t = (torch.ones(n) * i).long().to(self.device)
predicted_noise = model(x, t, labels)
if cfg_scale > 0:
uncond_predicted_noise = model(x, t, None)
predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
model.train()
x = (x.clamp(-1, 1) + 1) / 2
x = (x * 255).type(torch.uint8)
return x


def train(args):
setup_logging(args.run_name)
device = args.device
print(device)
dataloader = get_data(args)
torch.cuda.set_device(0)
model = UNet_conditional(num_classes=args.num_classes, device=device).to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=args.image_size, device=device)
logger = SummaryWriter(os.path.join("runs", args.run_name))
l = len(dataloader)
ema = EMA(0.995)
ema_model = copy.deepcopy(model).eval().requires_grad_(False)

lossList = []
for epoch in range(args.epochs):
Loss=0
logging.info(f"Starting epoch {epoch}:")
pbar = tqdm(dataloader)
for i, (images, labels) in enumerate(pbar):
images = images.to(device)
labels = labels.to(device)
t = diffusion.sample_timesteps(images.shape[0]).to(device)
x_t, noise = diffusion.noise_images(images, t)
# if np.random.random() < args.cond_threshold:
# labels = None
predicted_noise = model(x_t, t, labels)
loss = mse(noise, predicted_noise)

optimizer.zero_grad()
loss.backward()
optimizer.step()
ema.step_ema(ema_model, model)
Loss+=loss.item()

pbar.set_postfix(MSE=loss.item())
logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
print("Epoch: ", epoch, " Loss:", Loss/len(dataloader))
lossList.append(Loss/len(dataloader))

if epoch % 10 == 0:
labels = torch.arange(16).long().to(device)
sampled_images = diffusion.sample(model, n=len(labels), labels=labels)
ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
plot_images(sampled_images)
save_images(sampled_images, os.path.join("/model/results", args.run_name, f"{epoch}.jpg"))
save_images(ema_sampled_images, os.path.join("/model/results_ema", args.run_name, f"{epoch}_ema.jpg"))
torch.save(model.state_dict(), os.path.join("/model/models", args.run_name, f"{epoch}_ckpt.pt"))
torch.save(ema_model.state_dict(), os.path.join("/model/models_ema", args.run_name, f"{epoch}_ema_ckpt.pt"))
torch.save(optimizer.state_dict(), os.path.join("/model/optims", args.run_name, f"{epoch}_optim.pt"))
lst = [k for k in range(args.epochs)]
plt.title('Loss-Epoch')
plt.plot(lst, lossList)
plt.savefig('/model/results/loss.jpg')


def launch():
import argparse
parser = argparse.ArgumentParser()
# args = parser.parse_args()
args =parser.parse_known_args()[0]
args.run_name = "DDPM_conditional"
args.epochs = 800
args.batch_size = 12
args.image_size = 64
args.num_classes = 12*11
args.dataset_name = "faces"
args.device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
args.lr = 1e-4
train(args)

if __name__ == '__main__':
launch()

Loading…
Cancel
Save