Home

Awesome

MXNet Gluon Dynamic-batching

This repository includes simplified implementation of Fold, a helper for dynamic batching.

This animation from Tensorflow Fold shows a recursive neural network run with dynamic batching. Operations of the signature at the same depth in the computation graph (indicated by color in the animiation) are batched together regardless of whether or not they appear in the same sample.

Usage

This library performs dynamic batching by pooling the computation for multiple samples and merging the shared operators through concatenation. For example:

import fold
import mxnet as mx
from mxnet.gluon import nn

embed_layer = nn.Embedding(10, 1)
embed_layer.initialize()
fold_pool = fold.Fold()
lazy_values = []
lazy_results = []
for i in range(15):
    lazy_values.append(fold_pool.record(0, embed_layer, mx.nd.array([i % 10])))
shared_value = fold_pool.record(0, mx.nd.concat, *lazy_values).no_batch()
for i in range(100):
    lazy_results.append(fold_pool.record(0, mx.nd.dot, lazy_values[i % 10], shared_value))

# collect actual results
actual_result = fold_pool([lazy_results])[0]

Also, for use with Gluon RNN cells, there is the fold_unroll function:

import random
cell = mx.gluon.rnn.LSTMCell(20)
fold_pool = fold.Fold()
batch_size = 5
for _ in range(batch_size):
    length = random.randint(1, 5)
    fold.fold_unroll(cell, fold_pool, length, mx.nd.random.uniform(shape=(length, 1, 10)),
                     layout='TNC')

# show the complete computation graph info
print(fold_pool)

Performance

The following results are obtained from MXNet Gluon implementation of treelstm.pytorch, which can be found in example/tree_lstm folder. It implements Child-sum Tree-LSTM by Tai et al. on the semantic-relatedness task on SICK dataset. The performance is evaluated on:

The following speed benchmark is performed with the following settings:

The following results are obtained on EC2 c4.8xlarge host.

ImplementationTrainingInference
MXNet Gluon w/o Fold33.77 samples/s50.46 samples/s
MXNet Gluon w/o Fold Hybridized66.79 samples/s131.86 samples/s
MXNet Gluon w/ Fold201.11 samples/s315.54 samples/s