Home

Awesome

Introduction

This is the implementation of our paper FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning (accepted by AAAI 2024).

Key words: federated learning, data heterogeneity, model heterogeneity, communication overhead, intellectual property (IP) protection

Take away: We enhance the typical HtFL method FedProto with Trainable Global Prototypes (TGP) and Adaptive-margin-enhanced Contrastive Learning (ACL), making it more versatile and resilient to various model heterogeneities.

Citation

@inproceedings{zhang2024fedtgp,
  title={FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning},
  author={Zhang, Jianqing and Liu, Yang and Hua, Yang and Cao, Jian},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  year={2024}
}

The global and client prototypes in FedProto and our FedTGP. Different colors and numbers represent classes and clients, respectively. Circles represent the client prototypes and triangles represent the global prototypes. The black and yellow dotted arrows show the inter-class separation among the client and global prototypes, respectively. Triangles with dotted borders represent our Trainable Global Prototypes (TGP). The red arrows show the inter-class intervals between TGP and the client prototypes of other classes in our Adaptive-margin-enhanced Contrastive Learning (ACL).

Dataset

Due to the file size limitation, we only upload the statistics (config.json) of the Cifar10 dataset in the practical setting ($\beta=0.1$). Please refer to our popular repository PFLlib and HtFLlib to generate all the datasets and create the required python environment.

System

Learning reasonable global prototypes can be challenging in some cases, particularly due to the limited number of client prototypes and the introduced adaptive margin during ACL. To address this, consider setting a larger top_cnt and ensuring that the global communication iteration number is larger than 1000, which should result in a Server loss smaller than 0.001. The best accuracy is typically achieved when a minimal Server loss is obtained. In most of our experiments, we achieved a Server loss of 0.0.