GT-Zhang cd3cc48120 | 3 years ago | |
---|---|---|
.github/workflows | 3 years ago | |
.idea | 3 years ago | |
patta | 3 years ago | |
tests | 3 years ago | |
tools | 3 years ago | |
LICENSE | 3 years ago | |
README.md | 3 years ago | |
__init__.py | 3 years ago | |
requirement.txt | 3 years ago | |
setup.py | 3 years ago |
Image Test Time Augmentation with Paddle2.0!
Input
| # input batch of images
/ / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
\ \ \ / / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
We support that you can use the following to test after defining the network.
import patta as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
Note: the model must return keypoints in the format Tensor([x1, y1, ..., xn, yn])
We support that you can use the following to test when you have the static model: *.pdmodel
、*.pdiparams
、*.pdiparams.info
.
import patta as tta
model = tta.load_model(path='output/model')
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
We recommend modifying the file seg.py
according to your own model.
python seg.py --model_path='output/model' \
--batch_size=16 \
--test_dataset='test.txt'
Note: Related to paddleseg
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.Rotate90(angles=[0, 180]),
tta.Scale(scales=[1, 2, 4]),
tta.Multiply(factors=[0.9, 1, 1.1]),
]
)
tta_model = tta.SegmentationTTAWrapper(model, transforms)
# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer.augment_image(image)
# pass to model
model_output = model(augmented_image, another_input_data)
# reverse augmentation for mask and label
deaug_mask = transformer.deaugment_mask(model_output['mask'])
deaug_label = transformer.deaugment_label(model_output['label'])
# save results
labels.append(deaug_mask)
masks.append(deaug_label)
# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)
Transform | Parameters | Values |
---|---|---|
HorizontalFlip | - | - |
VerticalFlip | - | - |
Rotate90 | angles | List[0, 90, 180, 270] |
Scale | scales interpolation |
List[float] "nearest"/"linear" |
Resize | sizes original_size interpolation |
List[Tuple[int, int]] Tuple[int,int] "nearest"/"linear" |
Add | values | List[float] |
Multiply | factors | List[float] |
FiveCrops | crop_height crop_width |
int int |
PyPI:
# Use pip install PaTTA
$ pip install patta
or
# After downloading the whole dir
$ git clone https://github.com/AgentMaker/PaTTA.git
$ pip install PaTTA/
# run test_transforms.py and test_base.py for test
python test/test_transforms.py
python test/test_base.py
Image Test Time Augmentation with Paddle2.0!
Python other
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》