|
- #!/usr/bin/python3
- # -*- coding: utf-8 -*-
-
- import argparse
- import os
- import posixpath
- import sys
- import time
- import warnings
- from pathlib import Path
-
- import torch
- import yaml
- from torch_ecg.utils.misc import dict_to_str
- from utils.import_utils import is_timm_model_init, is_torchvision_model_init
-
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
-
-
- if __name__ == "__main__" and __package__ is None:
- level = 1
- # https://gist.github.com/vaultah/d63cb4c86be2774377aa674b009f759a
- import importlib
-
- file = Path(__file__).resolve()
- parent, top = file.parent, file.parents[level]
-
- sys.path.append(str(top))
- try:
- sys.path.remove(str(parent))
- except ValueError: # already removed
- pass
- __package__ = ".".join(parent.parts[len(top.parts) :])
- importlib.import_module(__package__) # won't be needed after that
-
-
- from image.test.testimport2 import main as cli_main
- from image.utils.misc import AITESTING_DOMAIN, download_if_needed
-
- _MODULE_DIR = Path(__file__).resolve().parent
-
- _CONFIGS_DIR = _MODULE_DIR / "configs"
-
- _BUILTIN_DATASETS = {
- f"cifar10_{num}": {
- "Data_path": {
- "image_path": f"Datasets/CIFAR_cln_data/cifar10_{num}_origin_inputs.npy",
- "label_path": f"Datasets/CIFAR_cln_data/cifar10_{num}_origin_labels.npy",
- },
- "Dict_path": "Datasets/CIFAR_cln_data/cifar10_dict.txt",
- "Scale_ImageSize": [32, 32],
- "Crop_ImageSize": [32, 32],
- "num_examples": num,
- }
- for num in [30, 100, 300, 1000]
- }
- _BUILTIN_DATASETS["ImageNetBuiltin"] = {
- "Data_path": {
- "image_path": "Datasets/ImageNet/images/",
- "label_path": "Datasets/ImageNet/val_5000.txt",
- },
- "Dict_path": "Datasets/ImageNet/ImageNet_12_dict.txt",
- "Scale_ImageSize": [224, 224],
- "Crop_ImageSize": [224, 224],
- "num_examples": 20,
- }
- _BUILTIN_DATASETS["ImageNetTiny"] = {
- "Data_path": {
- "image_path": posixpath.join(AITESTING_DOMAIN, "data", "ImageNetTiny.tar.gz"),
- "label_path": posixpath.join(AITESTING_DOMAIN, "data", "ImageNetValLabels.txt"),
- },
- "Dict_path": "Datasets/ImageNet/ImageNet_12_dict.txt",
- "Scale_ImageSize": [224, 224],
- "Crop_ImageSize": [224, 224],
- "num_examples": 500,
- }
- _BUILTIN_DATASETS["ImageNetSmall"] = {
- "Data_path": {
- "image_path": posixpath.join(AITESTING_DOMAIN, "data", "ImageNetSmall.tar.gz"),
- "label_path": posixpath.join(AITESTING_DOMAIN, "data", "ImageNetValLabels.txt"),
- },
- "Dict_path": "Datasets/ImageNet/ImageNet_12_dict.txt",
- "Scale_ImageSize": [224, 224],
- "Crop_ImageSize": [224, 224],
- "num_examples": 5000,
- }
-
- _BUILTIN_MODELS = {
- "ResNet2": {
- "model": "Models.UserModel.ResNet2",
- "model_path": "Models/weights/resnet20_cifar.pt",
- },
- "CIFAR10_RAND_enhanced": {
- "model": "Models.UserModel.ResNet2",
- "model_path": "Models/weights/CIFAR10_RAND_enhanced.pt",
- },
- "CIFAR10_PAT_enhanced": {
- "model": "Models.UserModel.ResNet2",
- "model_path": "Models/weights/CIFAR10_PAT_enhanced.pt",
- },
- "TRADES": {
- "model": "Models.UserModel.wideresnet_trades.WideResNet",
- "model_path": posixpath.join(AITESTING_DOMAIN, "models", "TRADES.pt"),
- },
- }
-
- _DEFAULT_RECIPES = ["PGD", "CW2", "AutoPGD"]
-
- _DEFAULT_ROBUST_THR = 0.6
-
-
- def parse_args() -> Path:
- """ """
- parser = argparse.ArgumentParser(
- description="AI-Testing Image Module",
- )
- parser.add_argument(
- "config_file_path",
- # nargs=1,
- nargs=argparse.OPTIONAL,
- type=str,
- help="Config file (.yml or .yaml file) path",
- )
-
- args = vars(parser.parse_args())
- if args["config_file_path"] is None:
- args["config_file_path"] = _CONFIGS_DIR / "ExampleConfigFile.yml"
- warnings.warn(
- "No input config file path, use default config file "
- f"""\042{Path(args["config_file_path"]).relative_to(_MODULE_DIR)}\042""",
- RuntimeWarning,
- )
-
- config_file_path = Path(args["config_file_path"]).resolve()
-
- return config_file_path
-
-
- def main(config_file_path):
- """ """
- config_file_path = Path(config_file_path)
- if not config_file_path.exists():
- raise FileNotFoundError(f"Config file {config_file_path} not found")
- if config_file_path.suffix not in [".yml", ".yaml"]:
- raise ValueError(f"Config file {config_file_path} must be a .yml or .yaml file")
- config = yaml.safe_load(config_file_path.read_text())
-
- config_bak = {k: v if v is not None else "Default" for k, v in config.items()}
-
- recipes = (config.pop("recipes") or "").split(",")
- if recipes == [""]:
- recipes = _DEFAULT_RECIPES
-
- if config_bak["recipes"] == "Default":
- config_bak["recipes"] = f"""Default ({",".join(_DEFAULT_RECIPES)})"""
-
- if config_bak["data_type"] in _BUILTIN_DATASETS:
- config_bak["num_examples"] = _BUILTIN_DATASETS[config_bak["data_type"]]["num_examples"] * len(recipes)
- else:
- config_bak["num_examples"] = "Custom"
-
- if config_bak.get("verbose", False):
- # log summary
- summary_rows = [[k, dict_to_str(v) if isinstance(v, (dict, list, tuple)) else v] for k, v in config_bak.items()]
- log_summary_rows(summary_rows, "Testing Args")
- print("\n")
- else:
- log_summary_rows([], "AI-Testing Image Module")
-
- device_id = config.pop("device_id", None)
- if device_id is not None:
- if isinstance(device_id, int) and device_id >= 0:
- config["GPU_Config"] = [str(device_id + 1), str(device_id)]
- elif isinstance(device_id, str):
- device_id = [int(i) for i in device_id.split(",")]
- device_count = torch.cuda.device_count()
- assert all([0 <= i < device_count for i in device_id]) or (
- len(device_id) == 1 and device_id[0] < 0
- ), f"Invalid device id {device_id}, must be in range [0, {device_count - 1}] or negative"
- if all([0 <= i < device_count for i in device_id]):
- config["GPU_Config"] = [
- str(max(device_id) + 1),
- ",".join([str(i) for i in device_id]),
- ]
-
- if config["data_type"] in _BUILTIN_DATASETS:
- if config.get("Data_path", None) is not None:
- warnings.warn(
- "No need to specify `Data_path` for built-in datasets, " "the input `Data_path` will be ignored",
- RuntimeWarning,
- )
- config["Data_path"] = _BUILTIN_DATASETS[config["data_type"]]["Data_path"]
- if str(config["Data_path"]["image_path"]).startswith(AITESTING_DOMAIN):
- config["Data_path"]["image_path"] = download_if_needed(config["Data_path"]["image_path"], extract=True)
- if str(config["Data_path"]["label_path"]).startswith(AITESTING_DOMAIN):
- config["Data_path"]["label_path"] = download_if_needed(config["Data_path"]["label_path"], extract=False)
- if config.get("Dict_path", None) is not None:
- warnings.warn(
- "No need to specify `Dict_path` for built-in datasets, " "the input `Dict_path` will be ignored",
- RuntimeWarning,
- )
- config["Dict_path"] = _BUILTIN_DATASETS[config["data_type"]]["Dict_path"]
- config["Scale_ImageSize"] = _BUILTIN_DATASETS[config["data_type"]]["Scale_ImageSize"]
- config["Crop_ImageSize"] = _BUILTIN_DATASETS[config["data_type"]]["Crop_ImageSize"]
- config["data_type"] = config["data_type"].split("_")[0]
- if "ImageNet" in config["data_type"]:
- config["data_type"] = "ImageNet"
- else:
- for key in ["Data_path", "Dict_path", "Scale_ImageSize", "Crop_ImageSize"]:
- assert key in config and config[key] is not None, f"Missing key `{key}` in config file `{str(config_file_path)}`"
-
- if config["model"] in _BUILTIN_MODELS:
- if config.get("model_path", None) is not None:
- warnings.warn(
- f"`model_path` {str(config['model_path'])} in the "
- f"config file `{str(config_file_path)}` is discarded, "
- f"and reset to `{str(_BUILTIN_MODELS[config['model']]['model_path'])}` "
- f"for built-in model `{config['model']}`",
- RuntimeWarning,
- )
- config["model_path"] = _BUILTIN_MODELS[config["model"]]["model_path"]
- if str(config["model_path"]).startswith(AITESTING_DOMAIN):
- config["model_path"] = download_if_needed(config["model_path"], extract=False)
- config["model"] = _BUILTIN_MODELS[config["model"]]["model"]
- elif (not is_torchvision_model_init(config["model"])) and (not is_timm_model_init(config["model"])):
- for key in ["model_path"]:
- assert key in config and config[key] is not None, f"Missing key `{key}` in config file `{str(config_file_path)}`"
- else: # is torchvision model or timm model
- if config.get("model_path", None) is None and config["data_type"] not in [
- "ImageNet",
- "ImageCustom",
- ]:
- warnings.warn(
- "The default pretrained torchvision and timm models are trained on ImageNet, "
- "so you may need to specify `model_path` in the config file "
- f"if you are using other datasets `{config['data_type']}`",
- RuntimeWarning,
- )
-
- start_time = time.time()
-
- testing_results = []
- for idx, recipe in enumerate(recipes):
- _config = config.copy()
- _config["attack_method"] = recipe
- # config to Namespace
- args = argparse.Namespace(**_config)
- try:
- rst = cli_main(args)
- except KeyboardInterrupt:
- if len(testing_results) > 0:
- print("\n\nTesting terminated by user before completion\n\n")
- else:
- print("\n\nTesting cancelled by user\n\n")
- return
- break
- testing_results.append(rst)
- num_success = sum([item["Number of successful attacks"] for item in testing_results])
- num_fail = sum([rst["Number of failed attacks"] for rst in testing_results])
- num_skips = sum([rst["Number of skipped attacks"] for rst in testing_results])
- num_total = num_success + num_fail + num_skips
- print(
- f"Recipes run: {idx + 1}/{len(recipes)}. "
- "Accumulated [Succeeded / Failed / Skipped / Total: "
- f"{num_success} / {num_fail} / {num_skips} / {num_total}]"
- )
-
- accumulated_items = [
- "Number of successful attacks",
- "Number of failed attacks",
- "Number of skipped attacks",
- ]
- average_rst = {
- k: sum([rst[k] for rst in testing_results]) / len(testing_results)
- for k in testing_results[0].keys()
- if k not in accumulated_items + ["Adversarial Attack Success Rate"]
- }
- for k in accumulated_items:
- average_rst[k] = sum([rst[k] for rst in testing_results])
- average_rst["Adversarial Attack Success Rate"] = average_rst["Number of successful attacks"] / (
- average_rst["Number of successful attacks"] + average_rst["Number of failed attacks"]
- )
-
- print(f"\n\nTestingImage elapsed time: {time.time() - start_time:.2f} seconds\n\n")
-
- # log summary
- summary_rows = [[k, dict_to_str(v) if isinstance(v, (dict, list, tuple)) else v] for k, v in config_bak.items()]
-
- log_summary_rows(summary_rows, "Testing Args")
-
- # log average result
- aasr = average_rst["Adversarial Attack Success Rate"]
- average_rst["Adversarial Attack Success Rate"] = f"{100 * aasr:.2f}%"
- if config_bak["model_path"] is None or config_bak["model_path"] == "Default":
- model_name = config_bak["model"]
- else:
- model_name = Path(config_bak["model_path"]).name
- if aasr >= config_bak.get("robust_threshold", _DEFAULT_ROBUST_THR):
- conclusion = f"Assessed by TestingImage, " f"the model \042{model_name}\042 is NOT robust"
- else:
- conclusion = f"Assessed by TestingImage, the model \042{model_name}\042 is robust"
- summary_rows = [[k, f"{v:.4f}" if isinstance(v, float) else v] for k, v in average_rst.items()]
- summary_rows.extend([["", ""], [conclusion, ""]])
-
- log_summary_rows(summary_rows, "Testing Results")
-
-
- def log_summary_rows(rows, title, align_center=False):
- width, fillchar = 80, "#"
- title = title.center(len(title) + 10, " ")
- title = title.center(width, fillchar)
- msg = "\n" + fillchar * width + "\n" + title + "\n" + fillchar * width + "\n\n"
- if len(rows) == 0:
- print(msg)
- return
- max_len = max([len(row[0]) for row in rows if row[1] != ""])
- # rows = [[row[0].ljust(max_len), row[1]] for row in rows]
- for idx in range(len(rows)):
- if align_center:
- rows[idx][0] = rows[idx][0].rjust(max_len)
- else:
- rows[idx][0] = rows[idx][0].ljust(max_len)
- tmp = rows[idx][1]
- if not isinstance(tmp, str):
- continue
- tmp = tmp.splitlines()
- for i in range(1, len(tmp)):
- tmp[i] = " " * max_len + tmp[i]
- rows[idx][1] = "\n".join(tmp)
- msg += "\n".join([f"{row[0]} {row[1]}" for row in rows]) + "\n\n" + fillchar * width
- print(msg)
-
-
- if __name__ == "__main__":
- try:
- sys.exit(main(parse_args()))
- except KeyboardInterrupt:
- print("\n\nTesting cancelled by user.\n\n")
- sys.exit(1)
|