

FixMatch experiments in PyTorch and Ignite

Experiments with "FixMatch" on Cifar10 dataset.

Based on "FixMatch: Simplifying Semi-Supervised Learning withConsistency and Confidence" and its official code.

Data-augmentations policy is CTA

Online logging on W&B: https://app.wandb.ai/vfdev-5/fixmatch-pytorch


pip install --upgrade --pre hydra-core tensorboardX
pip install --upgrade git+https://github.com/pytorch/ignite
# pip install --upgrade --pre pytorch-ignite

Optionally, we can install wandb for online experiments tracking.

pip install wandb

We can also opt to replace Pillow by Pillow-SIMD to accelerate image processing part:

pip uninstall -y pillow && CC="cc -mavx2" pip install --no-cache-dir --force-reinstall pillow-simd


python -u main_fixmatch.py model=WRN-28-2

This script automatically trains on multiple GPUs (torch.nn.DistributedParallel).

If it is needed to specify input/output folder :

python -u main_fixmatch.py dataflow.data_path=/data/cifar10/ hydra.run.dir=/output-fixmatch model=WRN-28-2

To use wandb logger, we need login and run with online_exp_tracking.wandb=true:

wandb login <token>
python -u main_fixmatch.py model=WRN-28-2 online_exp_tracking.wandb=true

To see other options:

python -u main_fixmatch.py --help

Training curves visualization

By default, we use Tensorboard to log training curves

tensorboard --logdir=/tmp/output-fixmatch-cifar10/

Distributed Data Parallel (DDP) on multiple GPUs (Experimental)

For example, training on 2 GPUs

python -u -m torch.distributed.launch --nproc_per_node=2 main_fixmatch.py model=WRN-28-2 distributed.backend=nccl

TPU(s) on Colab (Experimental)

Open In Colab For example, training on 8 TPUs in distributed mode:

python -u main_fixmatch.py model=resnet18 distributed.backend=xla-tpu distributed.nproc_per_node=8
# or python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8


Faster Resnet-18 training

python main_fixmatch.py distributed.backend=nccl online_exp_tracking.wandb=true solver.num_epochs=500 \
    ssl.confidence_threshold=0.8 ema_decay=0.9 ssl.cta_update_every=15