Awesome
MXNet Gluon Dynamic-batching
This repository includes simplified implementation of Fold, a helper for dynamic batching.
This animation from Tensorflow Fold shows a recursive neural network run with dynamic batching. Operations of the signature at the same depth in the computation graph (indicated by color in the animiation) are batched together regardless of whether or not they appear in the same sample.
Usage
This library performs dynamic batching by pooling the computation for multiple samples and merging the shared operators through concatenation. For example:
import fold
import mxnet as mx
from mxnet.gluon import nn
embed_layer = nn.Embedding(10, 1)
embed_layer.initialize()
fold_pool = fold.Fold()
lazy_values = []
lazy_results = []
for i in range(15):
lazy_values.append(fold_pool.record(0, embed_layer, mx.nd.array([i % 10])))
shared_value = fold_pool.record(0, mx.nd.concat, *lazy_values).no_batch()
for i in range(100):
lazy_results.append(fold_pool.record(0, mx.nd.dot, lazy_values[i % 10], shared_value))
# collect actual results
actual_result = fold_pool([lazy_results])[0]
Also, for use with Gluon RNN cells, there is the fold_unroll
function:
import random
cell = mx.gluon.rnn.LSTMCell(20)
fold_pool = fold.Fold()
batch_size = 5
for _ in range(batch_size):
length = random.randint(1, 5)
fold.fold_unroll(cell, fold_pool, length, mx.nd.random.uniform(shape=(length, 1, 10)),
layout='TNC')
# show the complete computation graph info
print(fold_pool)
Performance
The following results are obtained from MXNet Gluon implementation of treelstm.pytorch, which can be found in example/tree_lstm folder. It implements Child-sum Tree-LSTM by Tai et al. on the semantic-relatedness task on SICK dataset. The performance is evaluated on:
- EC2 c4.8xlarge with
- Intel Xeon E5-2666 v3 (Haswell) processors
- Ubuntu 16.04
The following speed benchmark is performed with the following settings:
- Batch size: 256
- MXNet:
- mxnet-mkl==0.12.0b20171030 for CPU test.
The following results are obtained on EC2 c4.8xlarge host.
Implementation | Training | Inference |
---|---|---|
MXNet Gluon w/o Fold | 33.77 samples/s | 50.46 samples/s |
MXNet Gluon w/o Fold Hybridized | 66.79 samples/s | 131.86 samples/s |
MXNet Gluon w/ Fold | 201.11 samples/s | 315.54 samples/s |