|
- # -*- coding: utf-8 -*-
- """
- Created on Wed May 11 21:56:26 2022
-
- @author: Administrator
- """
-
-
- from __future__ import print_function, division
- import os
- import torch
- import pandas as pd
- from skimage import io, transform
- import numpy as np
- import matplotlib.pyplot as plt
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms, utils, datasets, models
-
-
- # Ignore warnings
- import warnings
- warnings.filterwarnings("ignore")
- import os
- os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
- os.environ["CUDA_VISIBLE_DEVICES"]="0"
-
-
-
- #from lib import RadioUNet_modules3, RadioUNet_loaders2
- from lib import cognitive_loaders1_Tx as loaders
- from lib import modulesGAN1_Tx as modules
-
-
-
-
-
- Radio_train = loaders.RadioUNet_s(phase="train", simulation="rand", cityMap="missing", missing=3)
- Radio_val = loaders.RadioUNet_s(phase="val", simulation="rand", cityMap="missing", missing=3)
- Radio_test = loaders.RadioUNet_s(phase="test", simulation="rand", cityMap="missing", missing=3)
-
- image_datasets = {
- 'train': Radio_train, 'val': Radio_val
- }
-
- batch_size = 15
-
- dataloaders = {
- 'train': DataLoader(Radio_train, batch_size=batch_size, shuffle=False, num_workers=4),
- 'val': DataLoader(Radio_val, batch_size=batch_size, shuffle=False, num_workers=4)
- }
-
-
-
-
- from torchsummary import summary
-
- torch.set_default_dtype(torch.float32)
- torch.set_default_tensor_type('torch.cuda.FloatTensor')
- torch.backends.cudnn.enabled
- generator = modules.GeneratorUNet(inputs=3)
- generator.cuda()
- #summary(generator, input_size=(2, 256, 256))
- discriminator = modules.Discriminator()
- discriminator.cuda()
- #summary(discriminator, [(2, 256, 256), (1, 256, 256)])
-
-
-
- import time
- def weights_init_normal(m):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
- elif classname.find( "BatchNorm2d") != -1:
- torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
- torch.nn.init.constant_(m.bias.data, 0.0)
-
-
-
- def discriminator_train_step(real_src, real_trg, fake_trg):
- d_optimizer.zero_grad()
-
- prediction_real = discriminator(real_src, real_trg)
- error_real = criterion_GAN(prediction_real, \
- torch.ones(len(real_src), 1, 16, 16)\
- .to(device))
- error_real.backward()
-
- prediction_fake = discriminator(real_src, fake_trg.detach())
- error_fake = criterion_GAN(prediction_fake, torch.zeros(len(real_src), 1, 16, 16).to(device))
- error_fake.backward()
-
- d_optimizer.step()
- return error_real + error_fake
-
- def generator_train_step(real_src, fake_trg):
- g_optimizer.zero_grad()
- prediction = discriminator(real_src, fake_trg)
-
- loss_GAN = criterion_GAN(prediction, torch.ones(len(real_src), 1, 16, 16).to(device))
- loss_pixel = criterion_pixelwise(fake_trg, targets)
- loss_G = loss_GAN + lambda_pixel * loss_pixel
-
- loss_G.backward()
- g_optimizer.step()
- return loss_G
-
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- print(device)
-
- generator.apply(weights_init_normal)
- discriminator.apply(weights_init_normal)
-
-
- criterion_GAN = torch.nn.MSELoss()
- criterion_pixelwise = torch.nn.L1Loss()
-
- lambda_pixel = 100
- g_optimizer = torch.optim.Adam(generator.parameters(), \
- lr=0.0002, betas=(0.5, 0.999))
- d_optimizer = torch.optim.Adam(discriminator.parameters(), \
- lr=0.0002, betas=(0.5, 0.999))
-
-
- epochs = 50
- #log = Report(epochs)
- for epoch in range(epochs):
- print('This is epoch', epoch)
- since = time.time()
- N = len(Radio_train)
- for inputs, targets in dataloaders['train']:
- inputs = inputs.to(device)
- targets = targets.to(device)
-
- fake_trg = generator(inputs)
- errD = discriminator_train_step(inputs, targets, fake_trg)
- errG = generator_train_step(inputs, fake_trg)
- time_elapsed = time.time() - since
- print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
- #log.record(pos=epoch+(1+bx)/N, errD=errD.item(), errG=errG.item(), end='\r ')
-
-
-
- import os
- try:
- os.mkdir('RadioWNet_s_randSim_miss4Build_Thr2')
- except OSError as error:
- print(error)
-
- torch.save(generator.state_dict(), '/tmp/output/Trained_Model_FirstU.pt')
-
|