Home

Awesome

Overview | Tutorials | Examples | Installation | FAQ | API Docs | How to Cite

PyPI Status ktrain python compatibility license Downloads

<!--[![Twitter URL](https://img.shields.io/twitter/url/https/twitter.com/ktrain_ai.svg?style=social&label=Follow%20%40ktrain_ai)](https://twitter.com/ktrain_ai)--> <p align="center"> <img src="https://github.com/amaiya/ktrain/raw/master/ktrain_logo_200x100.png" width="200"/> </p>

Welcome to ktrain

a "Swiss Army knife" for machine learning

News and Announcements


Overview

ktrain is a lightweight wrapper for the deep learning library TensorFlow Keras (and other libraries) to help build, train, and deploy neural networks and other machine learning models. Inspired by ML framework extensions like fastai and ludwig, ktrain is designed to make deep learning and AI more accessible and easier to apply for both newcomers and experienced practitioners. With only a few lines of code, ktrain allows you to easily and quickly:

Tutorials

Please see the following tutorial notebooks for a guide on how to use ktrain on your projects:

Some blog tutorials and other guides about ktrain are shown below:

ktrain: A Lightweight Wrapper for Keras to Help Train Neural Networks

BERT Text Classification in 3 Lines of Code

Text Classification with Hugging Face Transformers in TensorFlow 2 (Without Tears)

Build an Open-Domain Question-Answering System With BERT in 3 Lines of Code

Finetuning BERT using ktrain for Disaster Tweets Classification by Hamiz Ahmed

Indonesian NLP Examples with ktrain by Sandy Khosasi

Examples

Using ktrain on Google Colab? See these Colab examples:

Tasks such as text classification and image classification can be accomplished easily with only a few lines of code.

Example: Text Classification of IMDb Movie Reviews Using BERT <sub><sup>[see notebook]</sup></sub>

import ktrain
from ktrain import text as txt

# load data
(x_train, y_train), (x_test, y_test), preproc = txt.texts_from_folder('data/aclImdb', maxlen=500,
                                                                     preprocess_mode='bert',
                                                                     train_test_names=['train', 'test'],
                                                                     classes=['pos', 'neg'])

# load model
model = txt.text_classifier('bert', (x_train, y_train), preproc=preproc)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model,
                             train_data=(x_train, y_train),
                             val_data=(x_test, y_test),
                             batch_size=6)

# find good learning rate
learner.lr_find()             # briefly simulate training to find good learning rate
learner.lr_plot()             # visually identify best learning rate

# train using 1cycle learning rate schedule for 3 epochs
learner.fit_onecycle(2e-5, 3)

Example: Classifying Images of Dogs and Cats Using a Pretrained ResNet50 model <sub><sup>[see notebook]</sup></sub>

import ktrain
from ktrain import vision as vis

# load data
(train_data, val_data, preproc) = vis.images_from_folder(
                                              datadir='data/dogscats',
                                              data_aug = vis.get_data_aug(horizontal_flip=True),
                                              train_test_names=['train', 'valid'],
                                              target_size=(224,224), color_mode='rgb')

# load model
model = vis.image_classifier('pretrained_resnet50', train_data, val_data, freeze_layers=80)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model=model, train_data=train_data, val_data=val_data,
                             workers=8, use_multiprocessing=False, batch_size=64)

# find good learning rate
learner.lr_find()             # briefly simulate training to find good learning rate
learner.lr_plot()             # visually identify best learning rate

# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(1e-4, checkpoint_folder='/tmp/saved_weights')

Example: Sequence Labeling for Named Entity Recognition using a randomly initialized Bidirectional LSTM CRF model <sub><sup>[see notebook]</sup></sub>

import ktrain
from ktrain import text as txt

# load data
(trn, val, preproc) = txt.entities_from_txt('data/ner_dataset.csv',
                                            sentence_column='Sentence #',
                                            word_column='Word',
                                            tag_column='Tag',
                                            data_format='gmb',
                                            use_char=True) # enable character embeddings

# load model
model = txt.sequence_tagger('bilstm-crf', preproc)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, train_data=trn, val_data=val)


# conventional training for 1 epoch using a learning rate of 0.001 (Keras default for Adam optmizer)
learner.fit(1e-3, 1)

Example: Node Classification on Cora Citation Graph using a GraphSAGE model <sub><sup>[see notbook]</sup></sub>

import ktrain
from ktrain import graph as gr

# load data with supervision ratio of 10%
(trn, val, preproc)  = gr.graph_nodes_from_csv(
                                               'cora.content', # node attributes/labels
                                               'cora.cites',   # edge list
                                               sample_size=20,
                                               holdout_pct=None,
                                               holdout_for_inductive=False,
                                              train_pct=0.1, sep='\t')

# load model
model=gr.graph_node_classifier('graphsage', trn)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=64)


# find good learning rate
learner.lr_find(max_epochs=100) # briefly simulate training to find good learning rate
learner.lr_plot()               # visually identify best learning rate

# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(0.01, checkpoint_folder='/tmp/saved_weights')

Example: Text Classification with Hugging Face Transformers on 20 Newsgroups Dataset Using DistilBERT <sub><sup>[see notebook]</sup></sub>

# load text data
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True)
test_b = fetch_20newsgroups(subset='test',categories=categories, shuffle=True)
(x_train, y_train) = (train_b.data, train_b.target)
(x_test, y_test) = (test_b.data, test_b.target)

# build, train, and validate model (Transformer is wrapper around transformers library)
import ktrain
from ktrain import text
MODEL_NAME = 'distilbert-base-uncased'
t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)
trn = t.preprocess_train(x_train, y_train)
val = t.preprocess_test(x_test, y_test)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)
learner.fit_onecycle(5e-5, 4)
learner.validate(class_names=t.get_classes()) # class_names must be string values

# Output from learner.validate()
#                        precision    recall  f1-score   support
#
#           alt.atheism       0.92      0.93      0.93       319
#         comp.graphics       0.97      0.97      0.97       389
#               sci.med       0.97      0.95      0.96       396
#soc.religion.christian       0.96      0.96      0.96       398
#
#              accuracy                           0.96      1502
#             macro avg       0.95      0.96      0.95      1502
#          weighted avg       0.96      0.96      0.96      1502
<!-- #### Example: NER With [BioBERT](https://arxiv.org/abs/1901.08746) Embeddings ```python # NER with BioBERT embeddings import ktrain from ktrain import text as txt x_train= [['IL-2', 'responsiveness', 'requires', 'three', 'distinct', 'elements', 'within', 'the', 'enhancer', '.'], ...] y_train=[['B-protein', 'O', 'O', 'O', 'O', 'B-DNA', 'O', 'O', 'B-DNA', 'O'], ...] (trn, val, preproc) = txt.entities_from_array(x_train, y_train) model = txt.sequence_tagger('bilstm-bert', preproc, bert_model='monologg/biobert_v1.1_pubmed') learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128) learner.fit(0.01, 1, cycle_len=5) ``` -->

Example: Tabular Classification for Titanic Survival Prediction Using an MLP <sub><sup>[see notebook]</sup></sub>

import ktrain
from ktrain import tabular
import pandas as pd
train_df = pd.read_csv('train.csv', index_col=0)
train_df = train_df.drop(['Name', 'Ticket', 'Cabin'], 1)
trn, val, preproc = tabular.tabular_from_df(train_df, label_columns=['Survived'], random_state=42)
learner = ktrain.get_learner(tabular.tabular_classifier('mlp', trn), train_data=trn, val_data=val)
learner.lr_find(show_plot=True, max_epochs=5) # estimate learning rate
learner.fit_onecycle(5e-3, 10)

# evaluate held-out labeled test set
tst = preproc.preprocess_test(pd.read_csv('heldout.csv', index_col=0))
learner.evaluate(tst, class_names=preproc.get_classes())

Additional examples can be found here.

Installation

  1. Make sure pip is up-to-date with: pip install -U pip

  2. Install TensorFlow 2 if it is not already installed (e.g., pip install tensorflow).

  3. Install ktrain: pip install ktrain

  4. If using tensorflow>=2.16:

    • Install tf_keras: pip install tf_keras
    • Set the environment variable TF_USE_LEGACY_KERAS to true before importing ktrain

The above should be all you need on Linux systems and cloud computing environments like Google Colab and AWS EC2. If you are using ktrain on a Windows computer, you can follow these more detailed instructions that include some extra steps.

Notes about TensorFlow Versions

Additional Notes About Installation

# for graph module:
pip install https://github.com/amaiya/stellargraph/archive/refs/heads/no_tf_dep_082.zip
# for text.TextPredictor.explain and vision.ImagePredictor.explain:
pip install https://github.com/amaiya/eli5-tf/archive/refs/heads/master.zip
# for tabular.TabularPredictor.explain:
pip install shap
# for text.zsl (ZeroShotClassifier), text.summarization, text.translation, text.speech:
pip install torch
# for text.speech:
pip install librosa
# for tabular.causal_inference_model:
pip install causalnlp
# for text.summarization.core.LexRankSummarizer:
pip install sumy
# for text.kw.KeywordExtractor
pip install textblob
# for text.generative_ai
pip install onprem
FeatureTensorFlowPyTorchSklearn
training any neural network (e.g., text or image classification)
End-to-End Question-Answering (pretrained)
QA-Based Information Extraction (pretrained)
Zero-Shot Classification (pretrained)
Language Translation (pretrained)
Summarization (pretrained)
Speech Transcription (pretrained)
Image Captioning (pretrained)
Object Detection (pretrained)
Sentiment Analysis (pretrained)
GenerativeAI (sentence-transformers)
Topic Modeling (sklearn)
Keyphrase Extraction (textblob/nltk/sklearn)

As noted above, end-to-end question-answering and information extraction in ktrain can be used with either TensorFlow (using framework='tf') or PyTorch (using framework='pt').

<!-- pip install pdoc3==0.9.2 pdoc3 --html -o docs ktrain diff -qr docs/ktrain/ /path/to/repo/ktrain/docs -->

How to Cite

Please cite the following paper when using ktrain:

@article{maiya2020ktrain,
    title={ktrain: A Low-Code Library for Augmented Machine Learning},
    author={Arun S. Maiya},
    year={2020},
    eprint={2004.10703},
    archivePrefix={arXiv},
    primaryClass={cs.LG},
    journal={arXiv preprint arXiv:2004.10703},
}

<!-- ### Requirements The following software/libraries should be installed: - [Python 3.6+](https://www.python.org/) (tested on 3.6.7) - [Keras](https://keras.io/) (tested on 2.2.4) - [TensorFlow](https://www.tensorflow.org/) (tested on 1.10.1) - [scikit-learn](https://scikit-learn.org/stable/) (tested on 0.20.0) - [matplotlib](https://matplotlib.org/) (tested on 3.0.0) - [pandas](https://pandas.pydata.org/) (tested on 0.24.2) - [keras_bert](https://github.com/CyberZHG/keras-bert/tree/master/keras_bert) - [fastprogress](https://github.com/fastai/fastprogress) -->

Creator: Arun S. Maiya

Email: arun [at] maiya [dot] net