|
- """
- Generate a folder containing all the main variables' value.
-
- Example:
- python test.py --mode=0 --device_target="Ascend" --model_name="RN50" --pretrained="openai" --quickgelu=True
-
- P.S. This generated folder can be used by difference.py to calculate the difference statistics.
-
- """
-
- import argparse
- import os
- import sys
-
- from PIL import Image
- from src.open_clip import create_model_and_transforms, get_tokenizer
-
- import mindspore as ms
- from mindspore import Tensor, ops
-
-
- def parse_args(args):
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--mode",
- type=int,
- default=0,
- help="Mode of set_context, GRAPH_MODE(0) or PYNATIVE_MODE(1)",
- )
- parser.add_argument(
- "--device_target",
- type=str,
- default="Ascend",
- help="Ascend, CPU or GPU",
- )
- parser.add_argument(
- "--model_name",
- type=str,
- default=None,
- help="",
- )
- parser.add_argument(
- "--pretrained",
- type=str,
- default=None,
- help="A keyword (refer to ./src/open_clip/pretrained.py) or path of ckpt file",
- )
- parser.add_argument(
- "--quickgelu",
- type=bool,
- default=False,
- help="",
- )
- args = parser.parse_args(args)
- return args
-
-
- def main(args):
- args = parse_args(args)
- ms.set_context(device_target=args.device_target, mode=args.mode)
- model, preprocess_train, preprocess_val = create_model_and_transforms(
- args.model_name,
- args.pretrained,
- force_quick_gelu=args.quickgelu,
- force_custom_text=False,
- force_patch_dropout=None,
- force_image_size=None,
- image_mean=None,
- image_std=None,
- aug_cfg={},
- )
- tokenizer = get_tokenizer(args.model_name)
-
- image = Tensor(preprocess_val(Image.open("CLIP.png")))
- text = tokenizer(["a diagram", "a dog", "a cat"])
-
- image_features = model.encode_image(image)
- text_features = model.encode_text(text)
-
- root = "./" + args.model_name + args.pretrained
- if not os.path.exists(root):
- root = "./" + args.model_name + args.pretrained
- os.mkdir(root)
-
- # file = open(root + "/image.txt", "w+")
- # file.write(str(image.asnumpy().tolist()))
- # file.close()
- #
- # file = open(root + "/text.txt", "w+")
- # file.write(str(text.asnumpy().tolist()))
- # file.close()
-
- file = open(root + "/image_features.txt", "w+")
- file.write(str(image_features.asnumpy().tolist()))
- file.close()
-
- file = open(root + "/text_features.txt", "w+")
- file.write(str(text_features.asnumpy().tolist()))
- file.close()
-
- image_features /= image_features.norm(dim=-1, keepdim=True)
- text_features /= text_features.norm(dim=-1, keepdim=True)
-
- text_probs = ops.softmax(100.0 * image_features @ text_features.T, axis=-1)
-
- file = open(root + "/text_probs.txt", "w+")
- file.write(str(text_probs.asnumpy().tolist()))
- file.close()
-
-
- if __name__ == "__main__":
- main(sys.argv[1:])
- print("Done!")
|