Awesome
Setfit-PyTorch-Lightning
π We are happy to be featured in the official SetFit repository.
<br>π€ About SetFit
The SetFit provides a strong method of few-shot learning for text classification. With SetFit, you can create an AI with an accuracy comparable to GPT3 with as little as a few dozen data points. You can see the official paper, blog, and code of SetFit.
If you want to run SetFit instantaneously, you can access here and find some example notebooks to run SetFit.
This repository provides code that allows SetFit to run in PyTorch Lightning to facilitate parameter, experiment management and so on.
This repository is created from lightning-hydra-template
<br>π How to use this repository
step 0: create miniconda GPU environment and operation check
Create miniconda GPU environment
# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv
# install requirements
pip install -r requirements.txt
Operation check
Enter the following code to execute the sample code (classification of sst2).
make operation-check
or
python src/train.py ++trainer.fast_dev_run=true
<br>
step 1. Custom LightningDataModule.
Data is managed in LightningDataModule. In the sample code, training data is obtained from the sst2 dataset.
If you are not familiar with PyTorch Lightning, I recommend you to change only self.train_dataset, self.valid_dataset and self.test_dataset in __init__
of DataModule.
Parameters of Datamodule are managed in config file.
Or if you want to custom more, README of lightning-hydra-template would offer useful information.
<br>step 2. Custom LightningModule.
Parameters that were entered into the original SetFit trainer and SetFitModel can be entered into LightnigngModule. You can manage such parameters in config file.
If you want to customize more, see here to find out how we implemented SetFit in PyTorch Lightning
<br>step 3. Custom other options such as callback or logger.
PyTorch Lightning offers useful callbacks and logger to save a model or metrics and so on. You can manage what and how callback or logger will be called in config files.
β Note : if you want to use callbacks of ModelCheckpoint, use SetFitModelCheckpoint to save the model if the model head is consist of sklearn, like sample code
step 4. Execute the train
Run
python src/train.py
Or you can override experimental configtion like below
python src/train.py trainer.max_epochs=1
step 5. Load the trained model
Since SetFit model may be configured with sklearn, so please load the model as in this notebook.
πΎ others
Experiment management
For managing your experimentm you can add experimental confition to config file like this and run it like below
python src/train.py experiment=example
For more information, this might useful for you
<br>Hyperparameter optimize
IF you want to excepuce hyperparameter optimization, just add config file like this and run like below
python src/train.py -m hparams_search=setfit_optuna
For more information, this might useful for you <br>
π Welcome contributions
if you find some error or feel something, feel free to tell me by PR or Issues!! Opinions on any content are welcome!
π Appendix
This Implementation is based on our experience in adapting SetFit to the JX Press training template code.
JX PRESS Corporation has created and used the training template code in order to enhance team development capability and development speed.
For more information on JX's training template code, see How we at JX PRESS Corporation devise for team development of R&D that tends to become a genus and PyTorch Lightning explained by a heavy user. (Now these blogs are written in Japanese. If you want to see, please translate it into your language. We would like to translate it in English and publish it someday)