Home

Awesome

<div align="center">

[NeurIPS 2024] <br> Defensive Unlearning with Adversarial Training <br> for Robust Concept Erasure in Diffusion Models

<div align="left">

Arxiv Preprint | Fine-tuned Weights | HF Model | Unlearned DM Benchmark | Demo <br>

Our proposed robust unlearning framework, AdvUnlearn, enhances diffusion models' safety by robustly erasing unwanted concepts through adversarial training, achieving an optimal balance between concept erasure and image generation quality.

This is the code implementation of our Robust DM Unlearning Framework: AdvUnlearn, and we developed our code based on the code base of SD and ESD.

<div align='center'> <img src = 'assets/nudity_main.png'> </div>

Simple Usage of AdvUnlearn Text Encoders (HuggingFace Model)

from transformers import CLIPTextModel
cache_path = ".cache"

Base model of our unlearned text encoders

model_name_or_path = "CompVis/stable-diffusion-v1-4"

text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="text_encoder", cache_dir=cache_path)

AdvUnlearn (Ours): Unlearned text encoder

model_name_or_path = "OPTML-Group/AdvUnlearn"

# Nudity-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="nudity_unlearned", cache_dir=cache_path)

# Style-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="vangogh_unlearned", cache_dir=cache_path)

# Object-Unlearned
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="church_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="garbage_truck_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="parachute_unlearned", cache_dir=cache_path)
text_encoder = CLIPTextModel.from_pretrained(model_name_or_path, subfolder="tench_unlearned", cache_dir=cache_path)

Download ckpts from Google Drive

pip install gdown

# AdvUnlearn ckpts
gdown --folder https://drive.google.com/drive/folders/1toiNWxJEX0X8pm8b88Og5_nMAqq6NZRd

# [Baseline ckpts] Nudity
gdown https://drive.google.com/uc?id=1wOqACzdpWpjRjl-a3hWgtjGAhemnelQr

# [Baseline ckpts] Style
gdown https://drive.google.com/uc?id=1-Gyl7Ls-Pa4vJzReG58bXjfI_YWu8teb

# [Baseline ckpts] Objects
gdown https://drive.google.com/uc?id=1OjYiFxYwd1B9R7vfKG6mooY27O2txYNg

Prepare

Environment Setup

A suitable conda environment named AdvUnlearn can be created and activated with:

conda env create -f environment.yaml
conda activate AdvUnlearn

Files Download

<br>

Code Implementation

Step 1: AdvUnlearn [Train]

Hyperparameters:

a) Command Example: Multi-step Attack

python train-scripts/AdvUnlearn.py --attack_init random --attack_step 30 --retain_train 'reg' --dataset_retain 'coco_object' --prompt 'nudity' --train_method 'text_encoder_full' --retain_loss_w 0.3

b) Command Example: Fast AT variant

python train-scripts/AdvUnlearn.py --attack_method fast_at --attack_init random --attack_step 30 --retain_train 'reg' --dataset_retain 'coco_object' --prompt 'nudity' --train_method 'text_encoder_full'   --retain_loss_w 0.3

Step 2: Attack Evaluation [Robustness Evaluation]

Follow the instruction in UnlearnDiffAtk to implement attacks on DMs with AdvUnlearn text encoder for robustness evaluation.

Step 3: Image Generation Quality Evaluation [Model Utility Evaluation]

Generate 10k images for FID & CLIP evaluation

bash jobs/fid_10k_generate.sh

Calculate FID & CLIP scores using T2IBenchmark

bash jobs/tri_quality_eval.sh
<br>

Checkpoints

ALL CKPTs for different DM unleanring tasks can be found here.

DM Unlearning MethodsNudityVan GoghObjects
ESD (Erased Stable Diffusion)
FMN (Forget-Me-Not)
AC (Ablating Concepts)
UCE (Unified Concept Editing)
SalUn (Saliency Unlearning)
SH (ScissorHands)
ED (EraseDiff)
SPM (concept-SemiPermeable Membrane)
AdvUnlearn (Ours)
<br>

Cite Our Work

The preprint can be cited as follows:

@article{zhang2024defensive,
  title={Defensive Unlearning with Adversarial Training for Robust Concept Erasure in Diffusion Models},
  author={Zhang, Yimeng and Chen, Xin and Jia, Jinghan and Zhang, Yihua and Fan, Chongyu and Liu, Jiancheng and Hong, Mingyi and Ding, Ke and Liu, Sijia},
  journal={arXiv preprint arXiv:2405.15234},
  year={2024}
}