Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
DAMI211 266c9dd3b4 | 1 year ago | |
---|---|---|
.. | ||
images | 1 year ago | |
src | 1 year ago | |
README.md | 1 year ago | |
requirements.txt | 1 year ago | |
srgan.ipynb | 1 year ago |
Despite the breakthroughs in accuracy and speed of single image super-resolution using faster and deeper convolutional neural networks, one central problem remains largely unsolved: how do we recover the finer texture details when we super-resolve at large upscaling factors? The behavior of optimization-based super-resolution methods is principally driven by the choice of the objective function.Recent work has largely focused on minimizing the mean squared reconstruction error. The resulting estimates have high peak signal-to-noise ratios, but they are often lacking high-frequency details and are perceptually unsatisfying in the sense that they fail to match the fidelity expected at the higher resolution. In this paper, we present SRGAN,a generative adversarial network (GAN) for image superresolution (SR). To our knowledge, it is the first framework capable of inferring photo-realistic natural images for 4× upscaling factors. To achieve this, we propose a perceptualloss function which consists of an adversarial loss and a content loss. The adversarial loss pushes our solution to the natural image manifold using a discriminator network that is trained to differentiate between the super-resolved images and original photo-realistic images. In addition, we use a content loss motivated by perceptual similarity instead of similarity in pixel space. Our deep residual network is able to recover photo-realistic textures from heavily downsampled images on public benchmarks.
Paper: Christian Ledig, Lucas thesis, Ferenc Huszar, Jose Caballero, Andrew Cunningham, Alejandro Acosta, Andrew Aitken, Alykhan Tejani, Johannes Totz, Zehan Wang, Wenzhe Shi Twitter.
Model trained by MindSpore, the process of training SRGAN needs a pretrained VGG19 based on Imagenet.
Type | ckpt |
---|---|
srgan | ckpt |
There are 3 datasets for training, validating and evaluating.
Attention, for DIV2K dataset, you only need 'Train Data (HR images)' and 'train Data Track 1 bicubic downscaling x4 (LR images)' these two datasets.
Type | Resource |
---|---|
DIV2K | link |
Set5 | link |
Set14 | link |
Make sure your file organization is as following:
.datasets/
└── DIV2K
| ├── HR
| └── LR
└── Set5
| ├── HR
| └── LR
└── Set14
├── HR
└── LR
Parameters | Default | Description |
---|---|---|
train_LR_path | None | Path of low resolution image of training set |
train_GT_path | None | Path of high resolution image of training set |
val_LR_path | None | Path of low resolution image of validation set |
val_GT_path | None | Path of high resolution image of validation set |
vgg_ckpt | None | Path of pretrained vgg19 model |
image_size | 96 | Image size of high resolution image |
train_batch_size | 16 | Batch size for training stage |
val_batch_size | 1 | Batch size for validation stage |
psnr_epochs | 2000 | Number of epochs in psnr training stage |
gan_epochs | 1000 | Number of epochs in gan training stage |
init_type | normal | The method used in network initialization |
platform | Ascend | Platform used in training stage |
run_distribute | 0 | Distributed training |
device_num | 1 | Number of devices |
Dataset | Resource | Speed |
---|---|---|
DIV2K | Ascend 910 | 1pc: 540 ms/step; 8pcs: 1500 ms/step |
DIV2K | NVIDIA GeForce RTX 3090 | 1pc: 350 ms/step |
Dataset | Resource | PSNR | PSNR(Paper) |
---|---|---|---|
Set5 | Ascend 910 | 31.00 | 29.40 |
Set14 | Ascend 910 | 27.93 | 26.02 |
Dataset | Resource | Speed |
---|---|---|
Set5 | Ascend 910 | 1pc: 7 ms/step |
Set14 | Ascend 910 | 1pc: 10 ms/step |
Use DIV2K dataset for training, Set5 dataset for validation. You can change the path as needed.
Training result will be stored in './ckpt'.
python -m src.train.train --train_LR_path './datasets/DIV2K/LR' --train_GT_path './datasets/DIV2K/HR' --val_LR_path './datasets/Set5/LR' --val_GT_path './datasets/Set5/HR' --vgg_ckpt './src/vgg19/vgg19.ckpt'
Output:
...
training 999 epoch
per step needs time:356ms
D_loss:
0.6385005
G_loss:
0.009284813
999/1000 epoch finished
training 1000 epoch
per step needs time:351ms
D_loss:
0.6633582
G_loss:
0.016534843
saving ckpt
1000/1000 epoch finished
You can run './src/train/run_distribute_train.sh' to train your model in distributed environment.
sh ./src/train/run_distribute_train.sh 8 1 /home/user/work/srgan_p/datasets/DIV2K/LR /home/user/work/srgan_p/datasets/DIV2K/HR /home/user/work/srgan_p/vgg19.ckpt /home/user/work/srgan_p/datasets/Set5/LR /home/user/work/srgan_p/datasets/Set5/HR
You can run './src/eval.py' to evaluate the performance of your model.
python -m src.eval --test_LR_path './datasets/Set5/LR' --test_GT_path './datasets/Set5/HR' --generator_path './ckpt/G_model_1000.ckpt'
Output:
======load checkpoint
[WARNING] ME(3178:281473299786304,MainProcess):2022-08-28-14:22:01.705.760 [mindspore/train/serialization.py:674] For 'load_param_into_net', remove parameter generator.conv1.0.weight's prefix name: generator., continue to load it to net parameter conv1.0.weight.
=======starting test=====
avg PSNR: 30.997722005197154
Inference result will be stored in './output'.
python -m src.infer --test_LR_path './datasets/Set14/LR' --generator_path './ckpt/G_model_1000.ckpt'
Output:
======load checkpoint
[WARNING] ME(3178:281473299786304,MainProcess):2022-08-28-14:22:01.705.760 [mindspore/train/serialization.py:674] For 'load_param_into_net', remove parameter generator.conv1.0.weight's prefix name: generator., continue to load it to net parameter conv1.0.weight.
=======starting test=====
Total 14 images need 134ms, per image needs 10ms.
Inference End.
Result:
MindSpore实验,仅用于教学或培训目的。配合MindSpore官网使用。 MindSpore experiments, for teaching or training purposes only. Use it together with the MindSpore official website.
CSV Jupyter Notebook Text Python Markdown other
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》