Home

Awesome

ChemBERTa

ChemBERTa: A collection of BERT-like models applied to chemical SMILES data for drug design, chemical modelling, and property prediction. To be presented at Baylearn and the Royal Society of Chemistry's Chemical Science Symposium.

Tutorial <br /> ArXiv ChemBERTa-2 Paper <br /> Arxiv ChemBERTa Paper <br /> Poster <br /> Abstract <br /> BibTex

License: MIT License

Right now the notebooks are all for the RoBERTa model (a variant of BERT) trained on the task of masked-language modelling (MLM). Training was done over 10 epochs until loss converged to around 0.26 on the ZINC 250k dataset. The model weights for ChemBERTA pre-trained on various datasets (ZINC 100k, ZINC 250k, PubChem 100k, PubChem 250k, PubChem 1M, PubChem 10M) are available using HuggingFace. We expect to continue to release larger models pre-trained on even larger subsets of ZINC, CHEMBL, and PubChem in the near future.

This library is currently primarily a set of notebooks with our pre-training and fine-tuning setup, and will be updated soon with model implementation + attention visualization code, likely after the Arxiv publication. Stay tuned!

I hope this is of use to developers, students and researchers exploring the use of transformers and the attention mechanism for chemistry!

Citing Our Work

Please cite ChemBERTa-2's ArXiv paper if you have used these models, notebooks, or examples in any way. The BibTex is available below:

@article{ahmad2022chemberta,
  title={Chemberta-2: Towards chemical foundation models},
  author={Ahmad, Walid and Simon, Elana and Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath},
  journal={arXiv preprint arXiv:2209.01712},
  year={2022}
}

Example

You can load the tokenizer + model for MLM prediction tasks using the following code:

from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline

#any model weights from the link above will work here
model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)

Todo: