Home

Awesome

pytorch-ctc

PyTorch-CTC is an implementation of CTC (Connectionist Temporal Classification) beam search decoding for PyTorch. C++ code borrowed liberally from TensorFlow with some improvements to increase flexibility.

Installation

The library is largely self-contained and requires only PyTorch and CFFI. Building the C++ library requires at least GCC-5. If gcc-5 or later is not your default compiler, you may specify the path via environment variables. KenLM language modeling support is also optionally included, and enabled by default.

# get the code
git clone --recursive https://github.com/ryanleary/pytorch-ctc.git
cd pytorch-ctc

# install dependencies (PyTorch and CFFI)
pip install -r requirements.txt

# build the extension and install python package (requires gcc-5 or later)
# python setup.py install
CC=/path/to/gcc-5 CXX=/path/to/g++-5 python setup.py install

# If you do NOT require kenlm, the `--recursive` flag is not required on git clone
# and `--exclude-kenlm` should be appended to the `python setup.py install` command

API

pytorch-ctc includes a CTC beam search decoder with multiple scorer implementations. A scorer is a function that the decoder calls to condition the probability of a given beam based on its state.

Scorers

Two Scorer implementations are currently implemented for pytorch-ctc.

Scorer: is a NO-OP and enables the decoder to do a vanilla beam decode

scorer = Scorer()

KenLMScorer: conditions beams based on the provided KenLM binary language model.

scorer = KenLMScorer(labels, lm_path, trie_path, blank_index=0, space_index=28)

where:

The KenLMScorer may be further configured with weights for the language model contribution to the score (lm_weight), as well as word and valid word bonuses (to offset decreasing probability as a function of sequence length).

scorer.set_lm_weight(2.1)
scorer.set_word_weight(1.1)
scorer.set_valid_word_weight(1.5)

Decoder

decoder = CTCBeamDecoder(scorer, labels, top_paths=3, beam_width=20,
                         blank_index=0, space_index=28, merge_repeated=False)

where:

output, score, out_seq_len = decoder.decode(probs, sizes=None)

where:

and returns:

Utilities

generate_lm_trie(dictionary_path, kenlm_path, output_path, labels, blank_index, space_index)

A vocabulary trie is required for the KenLM Scorer. The trie is created from a lexicon specified as a newline separated text file of words in the vocabulary.

Acknowledgements

Thanks to ebrevdo for the original TensorFlow CTC decoder implementation, timediv for his KenLM extension, and SeanNaren for his assistance.