Awesome
<div align="center">[NeurIPS 2024] <br> Defensive Unlearning with Adversarial Training <br> for Robust Concept Erasure in Diffusion Models
<div align="left">Arxiv Preprint | Fine-tuned Weights | HF Model | Unlearned DM Benchmark | Demo <br>
Our proposed robust unlearning framework, AdvUnlearn, enhances diffusion models' safety by robustly erasing unwanted concepts through adversarial training, achieving an optimal balance between concept erasure and image generation quality.
This is the code implementation of our Robust DM Unlearning Framework: AdvUnlearn
, and we developed our code based on the code base of SD and ESD.
Simple Usage of AdvUnlearn Text Encoders (HuggingFace Model)
from transformers import CLIPTextModel
cache_path = ".cache"
Base model of our unlearned text encoders
model_name_or_path = "CompVis/stable-diffusion-v1-4"
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="text_encoder", cache_dir=cache_path)
AdvUnlearn (Ours): Unlearned text encoder
model_name_or_path = "OPTML-Group/AdvUnlearn"
# Nudity-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="nudity_unlearned", cache_dir=cache_path)
# Style-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="vangogh_unlearned", cache_dir=cache_path)
# Object-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="church_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="garbage_truck_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="parachute_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="tench_unlearned", cache_dir=cache_path)
Download ckpts from Google Drive
pip install gdown
# AdvUnlearn ckpts
gdown --folder https://drive.google.com/drive/folders/1toiNWxJEX0X8pm8b88Og5_nMAqq6NZRd
# [Baseline ckpts] Nudity
gdown https://drive.google.com/uc?id=1wOqACzdpWpjRjl-a3hWgtjGAhemnelQr
# [Baseline ckpts] Style
gdown https://drive.google.com/uc?id=1-Gyl7Ls-Pa4vJzReG58bXjfI_YWu8teb
# [Baseline ckpts] Objects
gdown https://drive.google.com/uc?id=1OjYiFxYwd1B9R7vfKG6mooY27O2txYNg
Prepare
Environment Setup
A suitable conda environment named AdvUnlearn
can be created and activated with:
conda env create -f environment.yaml
conda activate AdvUnlearn
Files Download
- Base model - SD v1.4: download it from here, and move it to
models/sd-v1-4-full-ema.ckpt
- COCO-10k (for CLIP score and FID): you can extract the image subset from COCO dataset, or you can download it from here. Then, move it to
data/imgs/coco_10k
Code Implementation
Step 1: AdvUnlearn [Train]
Hyperparameters:
- Concept to be unlearned:
--prompt
(e.g., 'nudity') - Trainable module within DM:
--train_method
- Attack generation strategy :
--attack_method
- Number of attack steps for the adversarial prompt generation:
--attack_step
- Adversarial prompting strategy:
--attack_type
('prefix_k', 'replace_k' ,'add') - Retaining prompt dataset:
--dataset_retain
- Utility regularization parameter:
--retain_loss_w
a) Command Example: Multi-step Attack
python train-scripts/AdvUnlearn.py --attack_init random --attack_step 30 --retain_train 'reg' --dataset_retain 'coco_object' --prompt 'nudity' --train_method 'text_encoder_full' --retain_loss_w 0.3
b) Command Example: Fast AT variant
python train-scripts/AdvUnlearn.py --attack_method fast_at --attack_init random --attack_step 30 --retain_train 'reg' --dataset_retain 'coco_object' --prompt 'nudity' --train_method 'text_encoder_full' --retain_loss_w 0.3
Step 2: Attack Evaluation [Robustness Evaluation]
Follow the instruction in UnlearnDiffAtk to implement attacks on DMs with AdvUnlearn
text encoder for robustness evaluation.
Step 3: Image Generation Quality Evaluation [Model Utility Evaluation]
Generate 10k images for FID & CLIP evaluation
bash jobs/fid_10k_generate.sh
Calculate FID & CLIP scores using T2IBenchmark
bash jobs/tri_quality_eval.sh
<br>
Checkpoints
ALL CKPTs for different DM unleanring tasks can be found here.
DM Unlearning Methods | Nudity | Van Gogh | Objects |
---|---|---|---|
ESD (Erased Stable Diffusion) | ✅ | ✅ | ✅ |
FMN (Forget-Me-Not) | ✅ | ✅ | ✅ |
AC (Ablating Concepts) | ❌ | ✅ | ❌ |
UCE (Unified Concept Editing) | ✅ | ✅ | ❌ |
SalUn (Saliency Unlearning) | ✅ | ❌ | ✅ |
SH (ScissorHands) | ✅ | ❌ | ✅ |
ED (EraseDiff) | ✅ | ❌ | ✅ |
SPM (concept-SemiPermeable Membrane) | ✅ | ✅ | ✅ |
AdvUnlearn (Ours) | ✅ | ✅ | ✅ |
Cite Our Work
The preprint can be cited as follows:
@article{zhang2024defensive,
title={Defensive Unlearning with Adversarial Training for Robust Concept Erasure in Diffusion Models},
author={Zhang, Yimeng and Chen, Xin and Jia, Jinghan and Zhang, Yihua and Fan, Chongyu and Liu, Jiancheng and Hong, Mingyi and Ding, Ke and Liu, Sijia},
journal={arXiv preprint arXiv:2405.15234},
year={2024}
}