Home

Awesome

PyTorch Tabular

pypi Testing documentation status pre-commit.ci status Open In Colab

PyPI - Downloads DOI contributions welcome

PyTorch Tabular aims to make Deep Learning with Tabular data easy and accessible to real-world cases and research alike. The core principles behind the design of the library are:

It has been built on the shoulders of giants like PyTorch(obviously), and PyTorch Lightning.

Table of Contents

Installation

Although the installation includes PyTorch, the best and recommended way is to first install PyTorch from here, picking up the right CUDA version for your machine.

Once, you have got Pytorch installed, just use:

pip install -U “pytorch_tabular[extra]”

to install the complete library with extra dependencies (Weights&Biases & Plotly).

And :

pip install -U “pytorch_tabular”

for the bare essentials.

The sources for pytorch_tabular can be downloaded from the Github repo_.

You can either clone the public repository:

git clone git://github.com/manujosephv/pytorch_tabular

Once you have a copy of the source, you can install it with:

cd pytorch_tabular && pip install .[extra]

Documentation

For complete Documentation with tutorials visit ReadTheDocs

Available Models

Semi-Supervised Learning

Implement Custom Models

To implement new models, see the How to implement new models tutorial. It covers basic as well as advanced architectures.

Usage

from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
    ExperimentConfig,
)

data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list.
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True,  # Runs the LRFinder to automatically derive a learning rate
    batch_size=1024,
    max_epochs=100,
)
optimizer_config = OptimizerConfig()

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
tabular_model.fit(train=train, validation=val)
result = tabular_model.evaluate(test)
pred_df = tabular_model.predict(test)
tabular_model.save_model("examples/basic")
loaded_model = TabularModel.load_model("examples/basic")

Blogs

Future Roadmap(Contributions are Welcome)

  1. Integrate Optuna Hyperparameter Tuning
  2. Migrate Datamodule to Polars or NVTabular for faster data loading and to handle larger than RAM datasets.
  3. Add GaussRank as Feature Transformation
  4. Have a scikit-learn compatible API
  5. Enable support for multi-label classification
  6. Keep adding more architectures

Contributors

<!-- readme: contributors -start --> <table> <tbody> <tr> <td align="center"> <a href="https://github.com/manujosephv"> <img src="https://avatars.githubusercontent.com/u/10508493?v=4" width="100;" alt="manujosephv"/> <br /> <sub><b>Manu Joseph</b></sub> </a> </td> <td align="center"> <a href="https://github.com/Borda"> <img src="https://avatars.githubusercontent.com/u/6035284?v=4" width="100;" alt="Borda"/> <br /> <sub><b>Jirka Borovec</b></sub> </a> </td> <td align="center"> <a href="https://github.com/wsad1"> <img src="https://avatars.githubusercontent.com/u/13963626?v=4" width="100;" alt="wsad1"/> <br /> <sub><b>Jinu Sunil</b></sub> </a> </td> <td align="center"> <a href="https://github.com/ProgramadorArtificial"> <img src="https://avatars.githubusercontent.com/u/130674366?v=4" width="100;" alt="ProgramadorArtificial"/> <br /> <sub><b>Programador Artificial</b></sub> </a> </td> <td align="center"> <a href="https://github.com/sorenmacbeth"> <img src="https://avatars.githubusercontent.com/u/130043?v=4" width="100;" alt="sorenmacbeth"/> <br /> <sub><b>Soren Macbeth</b></sub> </a> </td> <td align="center"> <a href="https://github.com/fonnesbeck"> <img src="https://avatars.githubusercontent.com/u/81476?v=4" width="100;" alt="fonnesbeck"/> <br /> <sub><b>Chris Fonnesbeck</b></sub> </a> </td> </tr> <tr> <td align="center"> <a href="https://github.com/snehilchatterjee"> <img src="https://avatars.githubusercontent.com/u/127598707?v=4" width="100;" alt="snehilchatterjee"/> <br /> <sub><b>Snehil Chatterjee</b></sub> </a> </td> <td align="center"> <a href="https://github.com/jxtrbtk"> <img src="https://avatars.githubusercontent.com/u/40494970?v=4" width="100;" alt="jxtrbtk"/> <br /> <sub><b>Null</b></sub> </a> </td> <td align="center"> <a href="https://github.com/abhisharsinha"> <img src="https://avatars.githubusercontent.com/u/24841841?v=4" width="100;" alt="abhisharsinha"/> <br /> <sub><b>Abhishar Sinha</b></sub> </a> </td> <td align="center"> <a href="https://github.com/ndrsfel"> <img src="https://avatars.githubusercontent.com/u/21068727?v=4" width="100;" alt="ndrsfel"/> <br /> <sub><b>Andreas</b></sub> </a> </td> <td align="center"> <a href="https://github.com/charitarthchugh"> <img src="https://avatars.githubusercontent.com/u/37895518?v=4" width="100;" alt="charitarthchugh"/> <br /> <sub><b>Charitarth Chugh</b></sub> </a> </td> <td align="center"> <a href="https://github.com/EeyoreLee"> <img src="https://avatars.githubusercontent.com/u/49790022?v=4" width="100;" alt="EeyoreLee"/> <br /> <sub><b>Earlee</b></sub> </a> </td> </tr> <tr> <td align="center"> <a href="https://github.com/JulianRein"> <img src="https://avatars.githubusercontent.com/u/35046938?v=4" width="100;" alt="JulianRein"/> <br /> <sub><b>Null</b></sub> </a> </td> <td align="center"> <a href="https://github.com/krshrimali"> <img src="https://avatars.githubusercontent.com/u/19997320?v=4" width="100;" alt="krshrimali"/> <br /> <sub><b>Kushashwa Ravi Shrimali</b></sub> </a> </td> <td align="center"> <a href="https://github.com/Actis92"> <img src="https://avatars.githubusercontent.com/u/46601193?v=4" width="100;" alt="Actis92"/> <br /> <sub><b>Luca Actis Grosso</b></sub> </a> </td> <td align="center"> <a href="https://github.com/sgbaird"> <img src="https://avatars.githubusercontent.com/u/45469701?v=4" width="100;" alt="sgbaird"/> <br /> <sub><b>Sterling G. Baird</b></sub> </a> </td> <td align="center"> <a href="https://github.com/furyhawk"> <img src="https://avatars.githubusercontent.com/u/831682?v=4" width="100;" alt="furyhawk"/> <br /> <sub><b>Teck Meng</b></sub> </a> </td> <td align="center"> <a href="https://github.com/yinyunie"> <img src="https://avatars.githubusercontent.com/u/25686434?v=4" width="100;" alt="yinyunie"/> <br /> <sub><b>Yinyu Nie</b></sub> </a> </td> </tr> <tr> <td align="center"> <a href="https://github.com/YonyBresler"> <img src="https://avatars.githubusercontent.com/u/24940683?v=4" width="100;" alt="YonyBresler"/> <br /> <sub><b>YonyBresler</b></sub> </a> </td> <td align="center"> <a href="https://github.com/HernandoR"> <img src="https://avatars.githubusercontent.com/u/45709656?v=4" width="100;" alt="HernandoR"/> <br /> <sub><b>Liu Zhen</b></sub> </a> </td> </tr> <tbody> </table> <!-- readme: contributors -end -->

Citation

If you use PyTorch Tabular for a scientific publication, we would appreciate citations to the published software and the following paper:

@misc{joseph2021pytorch,
      title={PyTorch Tabular: A Framework for Deep Learning with Tabular Data},
      author={Manu Joseph},
      year={2021},
      eprint={2104.13638},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@software{manu_joseph_2023_7554473,
  author       = {Manu Joseph and
                  Jinu Sunil and
                  Jiri Borovec and
                  Chris Fonnesbeck and
                  jxtrbtk and
                  Andreas and
                  JulianRein and
                  Kushashwa Ravi Shrimali and
                  Luca Actis Grosso and
                  Sterling G. Baird and
                  Yinyu Nie},
  title        = {manujosephv/pytorch\_tabular: v1.0.1},
  month        = jan,
  year         = 2023,
  publisher    = {Zenodo},
  version      = {v1.0.1},
  doi          = {10.5281/zenodo.7554473},
  url          = {https://doi.org/10.5281/zenodo.7554473}
}