Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
zhounuoyan 245b2beab1 | 3 months ago | |
---|---|---|
.idea | 3 months ago | |
models | 3 months ago | |
.gitignore | 3 months ago | |
README.md | 3 months ago | |
dataloader.py | 3 months ago | |
env.yaml | 3 months ago | |
eval_ood.py | 3 months ago | |
main.py | 3 months ago | |
model.py | 3 months ago | |
optimizer.py | 3 months ago | |
utils.py | 3 months ago |
This is the official implementation of ICCV2023 Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning.
Abstract: Deep neural networks are susceptible to adversarial examples, posing a significant security risk in critical applications. Adversarial Training (AT) is a well-established technique to enhance adversarial robustness, but it often comes at the cost of decreased generalization ability. This paper proposes Robustness Critical Fine-Tuning (RiFT), a novel approach to enhance generalization without compromising adversarial robustness. The core idea of RiFT is to exploit the redundant capacity for robustness by fine-tuning the adversarially trained model on its non-robust-critical module. To do so, we introduce module robust criticality (MRC), a measure that evaluates the significance of a given module to model robustness under worst-case weight perturbations. Using this measure, we identify the module with the lowest MRC value as the non-robust-critical module and fine-tune its weights to obtain fine-tuned weights. Subsequently, we linearly interpolate between the adversarially trained weights and fine-tuned weights to derive the optimal fine-tuned model weights. We demonstrate the efficacy of RiFT on ResNet18, ResNet34, and WideResNet34-10 models trained on CIFAR10, CIFAR100, and Tiny-ImageNet datasets. Our experiments show that RiFT can significantly improve both generalization and out-of-distribution robust- ness by around 1.5% while maintaining or even slightly enhancing adversarial robustness. Code is available at https://github.com/microsoft/robustlearn.
To install requirements:
conda env create -f env.yaml
conda activate rift
CIFAR10 and CIFAR100 can be downloaded via PyTorch.
For other datasets:
After downloading these datasets, move them to ./data.
The images in Tiny-ImageNet datasets are 64x64 with 200 classes.
Here we present a example for RiFT ResNet18 on CIFAR10.
Download the adversarially trained model weights here.
python main.py --layer=layer2.1.conv2 --resume="./ResNet18_CIFAR10.pth"
Here, layer2.1.conv2 is a non-robust-critical module.
The non-robust-critical module of each model on each dataset are summarized as follows:
CIFAR10 | CIFAR100 | Tiny-ImageNet | |
---|---|---|---|
ResNet18 | layer2.1.conv2 | layer2.1.conv2 | layer3.1.conv2 |
ResNet34 | layer2.3.conv2 | layer2.3.conv2 | layer3.5.conv2 |
WRN34-10 | block1.layer.3.conv2 | block1.layer.2.conv2 | block1.layer.2.conv2 |
python main.py --cal_mrc --resume=/path/to/your/model
python main.py --layer=xxx --lr=yyy --resume=zzz
python eval_ood.py --resume=xxx
@inproceedings{zhu2023improving,
title={Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning},
author={Zhu, Kaijie and Hu, Xixu and Wang, Jindong and Xie, Xing and Yang, Ge },
year={2023},
booktitle={International Conference on Computer Vision},
}
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》