|
- import math
-
- import modules.scripts as scripts
- import gradio as gr
- from PIL import Image
-
- from modules import processing, shared, sd_samplers, images, devices
- from modules.processing import Processed
- from modules.shared import opts, cmd_opts, state
-
-
- class Script(scripts.Script):
- def title(self):
- return "SD upscale"
-
- def show(self, is_img2img):
- return is_img2img
-
- def ui(self, is_img2img):
- info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image to twice the dimensions; use width and height sliders to set tile size</p>")
- overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
- upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False)
-
- return [info, overlap, upscaler_index]
-
- def run(self, p, _, overlap, upscaler_index):
- processing.fix_seed(p)
- upscaler = shared.sd_upscalers[upscaler_index]
-
- p.extra_generation_params["SD upscale overlap"] = overlap
- p.extra_generation_params["SD upscale upscaler"] = upscaler.name
-
- initial_info = None
- seed = p.seed
-
- init_img = p.init_images[0]
- img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
-
- devices.torch_gc()
-
- grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
-
- batch_size = p.batch_size
- upscale_count = p.n_iter
- p.n_iter = 1
- p.do_not_save_grid = True
- p.do_not_save_samples = True
-
- work = []
-
- for y, h, row in grid.tiles:
- for tiledata in row:
- work.append(tiledata[2])
-
- batch_count = math.ceil(len(work) / batch_size)
- state.job_count = batch_count * upscale_count
-
- print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
-
- result_images = []
- for n in range(upscale_count):
- start_seed = seed + n
- p.seed = start_seed
-
- work_results = []
- for i in range(batch_count):
- p.batch_size = batch_size
- p.init_images = work[i*batch_size:(i+1)*batch_size]
-
- state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
- processed = processing.process_images(p)
-
- if initial_info is None:
- initial_info = processed.info
-
- p.seed = processed.seed + 1
- work_results += processed.images
-
- image_index = 0
- for y, h, row in grid.tiles:
- for tiledata in row:
- tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
- image_index += 1
-
- combined_image = images.combine_grid(grid)
- result_images.append(combined_image)
-
- if opts.samples_save:
- images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
-
- processed = Processed(p, result_images, seed, initial_info)
-
- return processed
|