Getting Started
We provide some demo tasks under the folder examples
. Here, we introduce how to conduct AD experiment based on READ.
In order to ensure the unity of input, we define the input of READ while training as a pytorch dataset (torch.utils.data.Dataset), however, we still retain some degree of freedom, for example, apply some data augmentations as you wish.
Also, with the consideration of futher deployment of READ, it takes torch.Tensor as input while inference with purpose to keep maximum freedom. The input misalignment between training and inference may be confusing a little bit, but we believe such implementation should be suitable for the experiemnts we have designed.
-
Define a torch.utils.data.Dataset that returns defect-free samples for training, you only need to ensure that the first return is the training images:
# Define a torch.utils.data.Dataset at first
train_data = MVTecDataset(data_path=args.data_dir, class_name=class_name)
# Optional
val_data = MVTecDataset(data_path=args.data_dir, class_name=class_name)
-
Choose a algorithm supported by READ. Taking RIAD as an example:
from READ_pytorch.ad_algorithm import RIAD
model = RIAD()
Train from scratch
- Training the defined model:
# Define out_dir to save trained weights
model.train(train_data, out_dir)
Note: threshold will be estimated automatically after training, however threshold will not be appiled automatically.
Load trained weights
-
Load pre-trained weights and estimate threshold manually:
# Define location of pretrained weights
model.load_weights(weights_dir)
model.est_thres(val_data)
-
Inputing the test torch.Tensor into the model, then model will return image-level anomaly score and pixel-wise anomaly score, the test input should have dims like (N X C X H X W):
img_score, pixel_score = model.predict(data)
Note: img_score has dims like (N, ), pixel_score has dims like (N X H X W)
Get results from unsupervised threshold
- Apply the estimated threshold to img_score and pixel_score to get the final classification and segmentation results.
img_thres = model.cls_thres
seg_thres = model.seg_thres
Get results from supervised threshold
- Coming soon.