|
- import math
- import pathlib
- import warnings
- from typing import BinaryIO, List, Optional, Tuple, Union
-
- import numpy as np
- import msadapter.pytorch as torch
- from PIL import Image, ImageColor, ImageDraw, ImageFont
-
- __all__ = [
- "make_grid",
- "save_image",
- "draw_bounding_boxes",
- "draw_segmentation_masks",
- "draw_keypoints",
- "flow_to_image",
- ]
-
-
- def make_grid(
- tensor: Union[torch.Tensor, List[torch.Tensor]],
- nrow: int = 8,
- padding: int = 2,
- normalize: bool = False,
- value_range: Optional[Tuple[int, int]] = None,
- scale_each: bool = False,
- pad_value: float = 0.0,
- **kwargs,
- ) -> torch.Tensor:
- """
- Make a grid of images.
-
- Args:
- tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
- or a list of images all of the same size.
- nrow (int, optional): Number of images displayed in each row of the grid.
- The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
- padding (int, optional): amount of padding. Default: ``2``.
- normalize (bool, optional): If True, shift the image to the range (0, 1),
- by the min and max values specified by ``value_range``. Default: ``False``.
- value_range (tuple, optional): tuple (min, max) where min and max are numbers,
- then these numbers are used to normalize the image. By default, min and max
- are computed from the tensor.
- range (tuple. optional):
- .. warning::
- This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range``
- instead.
- scale_each (bool, optional): If ``True``, scale each image in the batch of
- images separately rather than the (min, max) over all images. Default: ``False``.
- pad_value (float, optional): Value for the padded pixels. Default: ``0``.
-
- Returns:
- grid (Tensor): the tensor containing grid of images.
- """
- if not torch.is_tensor(tensor):
- if isinstance(tensor, list):
- for t in tensor:
- if not torch.is_tensor(t):
- raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
- else:
- raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
-
- if "range" in kwargs.keys():
- warnings.warn(
- "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. "
- "Please use 'value_range' instead."
- )
- value_range = kwargs["range"]
-
- # if list of tensors, convert to a 4D mini-batch Tensor
- if isinstance(tensor, list):
- tensor = torch.stack(tensor, dim=0)
-
- if tensor.dim() == 2: # single image H x W
- tensor = tensor.unsqueeze(0)
- if tensor.dim() == 3: # single image
- if tensor.size(0) == 1: # if single-channel, convert to 3-channel
- tensor = torch.cat((tensor, tensor, tensor), 0)
- tensor = tensor.unsqueeze(0)
-
- if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
- tensor = torch.cat((tensor, tensor, tensor), 1)
-
- if normalize is True:
- tensor = tensor.clone() # avoid modifying tensor in-place
- if value_range is not None and not isinstance(value_range, tuple):
- raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
-
- def norm_ip(img, low, high):
- img = img.clamp(min=low, max=high)
- img = img.sub(low).div(max(high - low, 1e-5))
- return img
-
- def norm_range(t, value_range):
- if value_range is not None:
- t = norm_ip(t, value_range[0], value_range[1])
- else:
- t = norm_ip(t, float(t.min()), float(t.max()))
- return t
-
- if scale_each is True:
- temp = []
- for t in tensor: # loop over mini-batch dimension
- t = norm_range(t, value_range)
- temp.append(t)
- tensor = torch.stack(temp)
- else:
- tensor = norm_range(tensor, value_range)
-
- if not isinstance(tensor, torch.Tensor):
- raise TypeError("tensor should be of type torch.Tensor")
- if tensor.size(0) == 1:
- return tensor.squeeze(0)
-
- # make the mini-batch of images into a grid
- nmaps = tensor.size(0)
- xmaps = min(nrow, nmaps)
- ymaps = int(math.ceil(float(nmaps) / xmaps))
- height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
- num_channels = tensor.size(1)
- grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
- k = 0
- for y in range(ymaps):
- for x in range(xmaps):
- if k >= nmaps:
- break
- # Tensor.copy_() is a valid method but seems to be missing from the stubs
- # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
- grid[:, y * height+padding:(y + 1)* height, x * width+padding:(x+1)*width] = tensor[k]
- k = k + 1
- return grid
-
-
- def save_image(
- tensor: Union[torch.Tensor, List[torch.Tensor]],
- fp: Union[str, pathlib.Path, BinaryIO],
- format: Optional[str] = None,
- **kwargs,
- ) -> None:
- """
- Save a given Tensor into an image file.
-
- Args:
- tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
- saves the tensor as a grid of images by calling ``make_grid``.
- fp (string or file object): A filename or a file object
- format(Optional): If omitted, the format to use is determined from the filename extension.
- If a file object was used instead of a filename, this parameter should always be used.
- **kwargs: Other arguments are documented in ``make_grid``.
- """
- grid = make_grid(tensor, **kwargs)
- # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
- ndarr = grid.mul(255).add(0.5).clamp(0, 255).permute(1, 2, 0).to(torch.uint8).numpy()
- im = Image.fromarray(ndarr)
- im.save(fp, format=format)
-
- def draw_bounding_boxes(
- image: torch.Tensor,
- boxes: torch.Tensor,
- labels: Optional[List[str]] = None,
- colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
- fill: Optional[bool] = False,
- width: int = 1,
- font: Optional[str] = None,
- font_size: Optional[int] = None,
- ) -> torch.Tensor:
-
- """
- Draws bounding boxes on given image.
- The values of the input image should be uint8 between 0 and 255.
- If fill is True, Resulting Tensor should be saved as PNG image.
-
- Args:
- image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
- boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
- the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
- `0 <= ymin < ymax < H`.
- labels (List[str]): List containing the labels of bounding boxes.
- colors (color or list of colors, optional): List containing the colors
- of the boxes or single color for all boxes. The color can be represented as
- PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
- By default, random colors are generated for boxes.
- fill (bool): If `True` fills the bounding box with specified color.
- width (int): Width of bounding box.
- font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
- also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
- `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
- font_size (int): The requested font size in points.
-
- Returns:
- img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
- """
- if not isinstance(image, torch.Tensor):
- raise TypeError(f"Tensor expected, got {type(image)}")
- elif image.dtype != torch.uint8:
- raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
- elif image.dim() != 3:
- raise ValueError("Pass individual images, not batches")
- elif image.size(0) not in {1, 3}:
- raise ValueError("Only grayscale and RGB images are supported")
-
- num_boxes = boxes.shape[0]
-
- if num_boxes == 0:
- warnings.warn("boxes doesn't contain any box. No box was drawn")
- return image
-
- if labels is None:
- labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
- elif len(labels) != num_boxes:
- raise ValueError(
- f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
- )
-
- if colors is None:
- colors = _generate_color_palette(num_boxes)
- elif isinstance(colors, list):
- if len(colors) < num_boxes:
- raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
- else: # colors specifies a single color for all boxes
- colors = [colors] * num_boxes
-
- colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
-
- if font is None:
- if font_size is not None:
- warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
- txt_font = ImageFont.load_default()
- else:
- txt_font = ImageFont.truetype(font=font, size=font_size or 10)
-
- # Handle Grayscale images
- if image.size(0) == 1:
- image = torch.tile(image, (3, 1, 1))
-
- ndarr = image.permute(1, 2, 0).cpu().numpy()
- img_to_draw = Image.fromarray(ndarr)
- img_boxes = boxes.to(torch.int64).tolist()
-
- if fill:
- draw = ImageDraw.Draw(img_to_draw, "RGBA")
- else:
- draw = ImageDraw.Draw(img_to_draw)
-
- for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
- if fill:
- fill_color = color + (100,)
- draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
- else:
- draw.rectangle(bbox, width=width, outline=color)
-
- if label is not None:
- margin = width + 1
- draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
-
- return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
-
- def draw_segmentation_masks(
- image: torch.Tensor,
- masks: torch.Tensor,
- alpha: float = 0.8,
- colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
- ) -> torch.Tensor:
-
- """
- Draws segmentation masks on given RGB image.
- The values of the input image should be uint8 between 0 and 255.
-
- Args:
- image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
- masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
- alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
- 0 means full transparency, 1 means no transparency.
- colors (color or list of colors, optional): List containing the colors
- of the masks or single color for all masks. The color can be represented as
- PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
- By default, random colors are generated for each mask.
-
- Returns:
- img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
- """
-
- if not isinstance(image, torch.Tensor):
- raise TypeError(f"The image must be a tensor, got {type(image)}")
- elif image.dtype != torch.uint8:
- raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
- elif image.dim() != 3:
- raise ValueError("Pass individual images, not batches")
- elif image.size()[0] != 3:
- raise ValueError("Pass an RGB image. Other Image formats are not supported")
- if masks.ndim == 2:
- masks = masks[None, :, :]
- if masks.ndim != 3:
- raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
- if masks.dtype != torch.bool:
- raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
- if masks.shape[-2:] != image.shape[-2:]:
- raise ValueError("The image and the masks must have the same height and width")
-
- num_masks = masks.size()[0]
- if colors is not None and num_masks > len(colors):
- raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
-
- if num_masks == 0:
- warnings.warn("masks doesn't contain any mask. No mask was drawn")
- return image
-
- if colors is None:
- colors = _generate_color_palette(num_masks)
- if not isinstance(colors, list):
- colors = [colors]
- if not isinstance(colors[0], (torch.Tensor, str, int, tuple)):
- raise ValueError("colors must be a tuple or a string, or a list of int or tensor")
- if isinstance(colors[0], torch.Tensor) and colors[0].shape != (3,):
- raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
-
- out_dtype = torch.uint8
-
- colors_ = []
- for color in colors:
- if isinstance(color, str):
- color = ImageColor.getrgb(color)
- colors_.append(torch.tensor(color, dtype=out_dtype))
-
- img_to_draw = image.detach().clone()
- # TODO: There might be a way to vectorize this
- for mask, color in zip(masks, colors_):
- img_to_draw[:, mask] = color[:, None]
-
- out = image * (1 - alpha) + img_to_draw * alpha
- return out.to(out_dtype)
-
- def draw_keypoints(
- image: torch.Tensor,
- keypoints: torch.Tensor,
- connectivity: Optional[List[Tuple[int, int]]] = None,
- colors: Optional[Union[str, Tuple[int, int, int]]] = None,
- radius: int = 2,
- width: int = 3,
- ) -> torch.Tensor:
-
- """
- Draws Keypoints on given RGB image.
- The values of the input image should be uint8 between 0 and 255.
-
- Args:
- image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
- keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
- in the format [x, y].
- connectivity (List[Tuple[int, int]]]): A List of tuple where,
- each tuple contains pair of keypoints to be connected.
- colors (str, Tuple): The color can be represented as
- PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
- radius (int): Integer denoting radius of keypoint.
- width (int): Integer denoting width of line connecting keypoints.
-
- Returns:
- img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
- """
-
- if not isinstance(image, torch.Tensor):
- raise TypeError(f"The image must be a tensor, got {type(image)}")
- elif image.dtype != torch.uint8:
- raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
- elif image.dim() != 3:
- raise ValueError("Pass individual images, not batches")
- elif image.size()[0] != 3:
- raise ValueError("Pass an RGB image. Other Image formats are not supported")
-
- if keypoints.ndim != 3:
- raise ValueError("keypoints must be of shape (num_instances, K, 2)")
-
- ndarr = image.permute(1, 2, 0).cpu().numpy()
- img_to_draw = Image.fromarray(ndarr)
- draw = ImageDraw.Draw(img_to_draw)
- img_kpts = keypoints.to(torch.int64).tolist()
-
- for kpt_id, kpt_inst in enumerate(img_kpts):
- for inst_id, kpt in enumerate(kpt_inst):
- x1 = kpt[0] - radius
- x2 = kpt[0] + radius
- y1 = kpt[1] - radius
- y2 = kpt[1] + radius
- draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
-
- if connectivity:
- for connection in connectivity:
- start_pt_x = kpt_inst[connection[0]][0]
- start_pt_y = kpt_inst[connection[0]][1]
-
- end_pt_x = kpt_inst[connection[1]][0]
- end_pt_y = kpt_inst[connection[1]][1]
-
- draw.line(
- ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
- width=width,
- )
-
- return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
-
-
- # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
- def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
-
- """
- Converts a flow to an RGB image.
-
- Args:
- flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
-
- Returns:
- img (Tensor): Image Tensor of dtype uint8 where each color corresponds
- to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
- """
-
- if flow.dtype != torch.float:
- raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
-
- orig_shape = flow.shape
- if flow.ndim == 3:
- flow = flow[None] # Add batch dim
-
- if flow.ndim != 4 or flow.shape[1] != 2:
- raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
-
- max_norm = torch.sum(flow ** 2, dim=1).sqrt().max()
- epsilon = torch.finfo((flow).dtype).eps
- normalized_flow = flow / (max_norm + epsilon)
- img = _normalized_flow_to_image(normalized_flow)
-
- if len(orig_shape) == 3:
- img = img[0] # Remove batch dim
- return img
-
- def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
-
- """
- Converts a batch of normalized flow to an RGB image.
-
- Args:
- normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
- Returns:
- img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
- """
-
- N, _, H, W = normalized_flow.shape
- device = normalized_flow.device
- flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
- colorwheel = _make_colorwheel().to(device) # shape [55x3]
- num_cols = colorwheel.shape[0]
- norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
- a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
- fk = (a + 1) / 2 * (num_cols - 1)
- k0 = torch.floor(fk).to(torch.long)
- k1 = k0 + 1
- k1[k1 == num_cols] = 0
- f = fk - k0
-
- for c in range(colorwheel.shape[1]):
- tmp = colorwheel[:, c]
- col0 = tmp[k0] / 255.0
- col1 = tmp[k1] / 255.0
- col = (1 - f) * col0 + f * col1
- col = 1 - norm * (1 - col)
- flow_image[:, c, :, :] = torch.floor(255 * col)
- return flow_image
-
-
- def _make_colorwheel() -> torch.Tensor:
- """
- Generates a color wheel for optical flow visualization as presented in:
- Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
- URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
-
- Returns:
- colorwheel (Tensor[55, 3]): Colorwheel Tensor.
- """
-
- RY = 15
- YG = 6
- GC = 4
- CB = 11
- BM = 13
- MR = 6
-
- ncols = RY + YG + GC + CB + BM + MR
- colorwheel = torch.zeros((ncols, 3))
- col = 0
-
- # RY
- colorwheel[0:RY, 0] = 255
- colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
- col = col + RY
- # YG
- colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
- colorwheel[col : col + YG, 1] = 255
- col = col + YG
- # GC
- colorwheel[col : col + GC, 1] = 255
- colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
- col = col + GC
- # CB
- colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(0, CB) / CB)
- colorwheel[col : col + CB, 2] = 255
- col = col + CB
- # BM
- colorwheel[col : col + BM, 2] = 255
- colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
- col = col + BM
- # MR
- colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(0, MR) / MR)
- colorwheel[col : col + MR, 0] = 255
- return colorwheel
-
-
- # def _generate_color_palette(num_objects: int):
- # palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
- # return [(i * palette) % 255 for i in range(num_objects)]
-
- def _generate_color_palette(num_objects: int):
- a = 2 ** 25 - 1
- b = 2 ** 15 - 1
- c = 2 ** 21 - 1
- return [((i * a) % 255, (i * b) % 255, (i * c) % 255) for i in range(num_objects)]
|