Home

Awesome

TensorNets Build Status

High level network definitions with pre-trained weights in TensorFlow (tested with 2.1.0 >= TF >= 1.4.0).

Guiding principles

Installation

You can install TensorNets from PyPI (pip install tensornets) or directly from GitHub (pip install git+https://github.com/taehoonlee/tensornets.git).

A quick example

Each network (see full list) is not a custom class but a function that takes and returns tf.Tensor as its input and output. Here is an example of ResNet50:

import tensorflow as tf
# import tensorflow.compat.v1 as tf  # for TF 2
import tensornets as nets
# tf.disable_v2_behavior()  # for TF 2

inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
model = nets.ResNet50(inputs)

assert isinstance(model, tf.Tensor)

You can load an example image by using utils.load_img returning a np.ndarray as the NHWC format:

img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)
assert img.shape == (1, 224, 224, 3)

Once your network is created, you can run with regular TensorFlow APIs 😊 because all the networks in TensorNets always return tf.Tensor. Using pre-trained weights and pre-processing are as easy as pretrained() and preprocess() to reproduce the original results:

with tf.Session() as sess:
    img = model.preprocess(img)  # equivalent to img = nets.preprocess(model, img)
    sess.run(model.pretrained())  # equivalent to nets.pretrained(model)
    preds = sess.run(model, {inputs: img})

You can see the most probable classes:

print(nets.utils.decode_predictions(preds, top=2)[0])
[(u'n02124075', u'Egyptian_cat', 0.28067636), (u'n02127052', u'lynx', 0.16826575)]

You can also easily obtain values of intermediate layers with middles() and outputs():

with tf.Session() as sess:
    img = model.preprocess(img)
    sess.run(model.pretrained())
    middles = sess.run(model.middles(), {inputs: img})
    outputs = sess.run(model.outputs(), {inputs: img})

model.print_middles()
assert middles[0].shape == (1, 56, 56, 256)
assert middles[-1].shape == (1, 7, 7, 2048)

model.print_outputs()
assert sum(sum((outputs[-1] - preds) ** 2)) < 1e-8

With load() and save(), your weight values can be restorable:

with tf.Session() as sess:
    model.init()
    # ... your training ...
    model.save('test.npz')

with tf.Session() as sess:
    model.load('test.npz')
    # ... your deployment ...

TensorNets enables us to deploy well-known architectures and benchmark those results faster ⚡️. For more information, you can check out the lists of utilities, examples, and architectures.

Object detection example

Each object detection model can be coupled with any network in TensorNets (see performance) and takes two arguments: a placeholder and a function acting as a stem layer. Here is an example of YOLOv2 for PASCAL VOC:

import tensorflow as tf
import tensornets as nets

inputs = tf.placeholder(tf.float32, [None, 416, 416, 3])
model = nets.YOLOv2(inputs, nets.Darknet19)

img = nets.utils.load_img('cat.png')

with tf.Session() as sess:
    sess.run(model.pretrained())
    preds = sess.run(model, {inputs: model.preprocess(img)})
    boxes = model.get_boxes(preds, img.shape[1:3])

Like other models, a detection model also returns tf.Tensor as its output. You can see the bounding box predictions (x1, y1, x2, y2, score) by using model.get_boxes(model_output, original_img_shape) and visualize the results:

from tensornets.datasets import voc
print("%s: %s" % (voc.classnames[7], boxes[7][0]))  # 7 is cat

import numpy as np
import matplotlib.pyplot as plt
box = boxes[7][0]
plt.imshow(img[0].astype(np.uint8))
plt.gca().add_patch(plt.Rectangle(
    (box[0], box[1]), box[2] - box[0], box[3] - box[1],
    fill=False, edgecolor='r', linewidth=2))
plt.show()

More detection examples such as FasterRCNN on VOC2007 are here 😎. Note that:

Utilities

Besides pretrained() and preprocess(), the output tf.Tensor provides the following useful methods:

<details> <summary>Example outputs of print methods are:</summary>
>>> model.print_middles()
Scope: resnet50
conv2/block1/out:0 (?, 56, 56, 256)
conv2/block2/out:0 (?, 56, 56, 256)
conv2/block3/out:0 (?, 56, 56, 256)
conv3/block1/out:0 (?, 28, 28, 512)
conv3/block2/out:0 (?, 28, 28, 512)
conv3/block3/out:0 (?, 28, 28, 512)
conv3/block4/out:0 (?, 28, 28, 512)
conv4/block1/out:0 (?, 14, 14, 1024)
...

>>> model.print_outputs()
Scope: resnet50
conv1/pad:0 (?, 230, 230, 3)
conv1/conv/BiasAdd:0 (?, 112, 112, 64)
conv1/bn/batchnorm/add_1:0 (?, 112, 112, 64)
conv1/relu:0 (?, 112, 112, 64)
pool1/pad:0 (?, 114, 114, 64)
pool1/MaxPool:0 (?, 56, 56, 64)
conv2/block1/0/conv/BiasAdd:0 (?, 56, 56, 256)
conv2/block1/0/bn/batchnorm/add_1:0 (?, 56, 56, 256)
conv2/block1/1/conv/BiasAdd:0 (?, 56, 56, 64)
conv2/block1/1/bn/batchnorm/add_1:0 (?, 56, 56, 64)
conv2/block1/1/relu:0 (?, 56, 56, 64)
...

>>> model.print_weights()
Scope: resnet50
conv1/conv/weights:0 (7, 7, 3, 64)
conv1/conv/biases:0 (64,)
conv1/bn/beta:0 (64,)
conv1/bn/gamma:0 (64,)
conv1/bn/moving_mean:0 (64,)
conv1/bn/moving_variance:0 (64,)
conv2/block1/0/conv/weights:0 (1, 1, 64, 256)
conv2/block1/0/conv/biases:0 (256,)
conv2/block1/0/bn/beta:0 (256,)
conv2/block1/0/bn/gamma:0 (256,)
...

>>> model.summary()
Scope: resnet50
Total layers: 54
Total weights: 320
Total parameters: 25,636,712
</details>

Examples

inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = [
    nets.MobileNet75(inputs),
    nets.MobileNet100(inputs),
    nets.SqueezeNet(inputs),
]

img = utils.load_img('cat.png', target_size=256, crop_size=224)
imgs = nets.preprocess(models, img)

with tf.Session() as sess:
    nets.pretrained(models)
    for (model, img) in zip(models, imgs):
        preds = sess.run(model, {inputs: img})
        print(utils.decode_predictions(preds, top=2)[0])
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
outputs = tf.placeholder(tf.float32, [None, 50])
model = nets.DenseNet169(inputs, is_training=True, classes=50)

loss = tf.losses.softmax_cross_entropy(outputs, model.logits)
train = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(loss)

with tf.Session() as sess:
    nets.pretrained(model)
    for (x, y) in your_NumPy_data:  # the NHWC and one-hot format
        sess.run(train, {inputs: x, outputs: y})
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = []

with tf.device('gpu:0'):
    models.append(nets.ResNeXt50(inputs))

with tf.device('gpu:1'):
    models.append(nets.DenseNet201(inputs))

from tensornets.preprocess import fb_preprocess
img = utils.load_img('cat.png', target_size=256, crop_size=224)
img = fb_preprocess(img)

with tf.Session() as sess:
    nets.pretrained(models)
    preds = sess.run(models, {inputs: img})
    for pred in preds:
        print(utils.decode_predictions(pred, top=2)[0])

Performance

Image classification

InputTop-1Top-5MACSizeStemSpeedReferences
ResNet5022474.87492.01851.0M25.6M23.6M195.4[paper] [tf-slim] [torch-fb] <br /> [caffe] [keras]
ResNet10122476.42092.78688.9M44.7M42.7M311.7[paper] [tf-slim] [torch-fb] <br /> [caffe]
ResNet15222476.60493.118120.1M60.4M58.4M439.1[paper] [tf-slim] [torch-fb] <br /> [caffe]
ResNet50v229975.96093.03451.0M25.6M23.6M209.7[paper] [tf-slim] [torch-fb]
ResNet101v229977.23493.81688.9M44.7M42.6M326.2[paper] [tf-slim] [torch-fb]
ResNet152v229978.03294.162120.1M60.4M58.3M455.2[paper] [tf-slim] [torch-fb]
ResNet200v222478.28694.152129.0M64.9M62.9M618.3[paper] [tf-slim] [torch-fb]
ResNeXt50c3222477.74093.81049.9M25.1M23.0M267.4[paper] [torch-fb]
ResNeXt101c3222478.73094.29488.1M44.3M42.3M427.9[paper] [torch-fb]
ResNeXt101c6422479.49494.5920.0M83.7M81.6M877.8[paper] [torch-fb]
WideResNet5022478.01893.934137.6M69.0M66.9M358.1[paper] [torch]
Inception122466.84087.67614.0M7.0M6.0M165.1[paper] [tf-slim] [caffe-zoo]
Inception222474.68092.15622.3M11.2M10.2M134.3[paper] [tf-slim]
Inception329977.94693.75847.6M23.9M21.8M314.6[paper] [tf-slim] [keras]
Inception429980.12094.97885.2M42.7M41.2M582.1[paper] [tf-slim]
InceptionResNet229980.25695.252111.5M55.9M54.3M656.8[paper] [tf-slim]
NASNetAlarge33182.49896.004186.2M93.5M89.5M2081[paper] [tf-slim]
NASNetAmobile22474.36691.85415.3M7.7M6.7M165.8[paper] [tf-slim]
PNASNetlarge33182.63496.050171.8M86.2M81.9M1978[paper] [tf-slim]
VGG1622471.26890.050276.7M138.4M14.7M348.4[paper] [keras]
VGG1922471.25689.988287.3M143.7M20.0M399.8[paper] [keras]
DenseNet12122474.97292.25815.8M8.1M7.0M202.9[paper] [torch]
DenseNet16922476.17693.17628.0M14.3M12.6M219.1[paper] [torch]
DenseNet20122477.32093.62039.6M20.2M18.3M272.0[paper] [torch]
MobileNet2522451.58275.7920.9M0.5M0.2M34.46[paper] [tf-slim]
MobileNet5022464.29285.6242.6M1.3M0.8M52.46[paper] [tf-slim]
MobileNet7522468.41288.2425.1M2.6M1.8M70.11[paper] [tf-slim]
MobileNet10022470.42489.5048.4M4.3M3.2M83.41[paper] [tf-slim]
MobileNet35v222460.08682.4323.3M1.7M0.4M57.04[paper] [tf-slim]
MobileNet50v222465.19486.0623.9M2.0M0.7M64.35[paper] [tf-slim]
MobileNet75v222469.53289.1765.2M2.7M1.4M88.68[paper] [tf-slim]
MobileNet100v222471.33690.1426.9M3.5M2.3M93.82[paper] [tf-slim]
MobileNet130v222474.68092.12210.7M5.4M3.8M130.4[paper] [tf-slim]
MobileNet140v222475.23092.42212.1M6.2M4.4M132.9[paper] [tf-slim]
75v3large22473.75491.6187.9M4.0M2.7M79.73[paper] [tf-slim]
100v3large22475.79092.84027.3M5.5M4.2M94.71[paper] [tf-slim]
100v3largemini22472.70690.9307.8M3.9M2.7M70.57[paper] [tf-slim]
75v3small22466.13886.5344.1M2.1M1.0M37.78[paper] [tf-slim]
100v3small22468.31887.9425.1M2.6M1.5M42.00[paper] [tf-slim]
100v3smallmini22463.44084.6464.1M2.1M1.0M29.65[paper] [tf-slim]
EfficientNetB022477.01293.33826.2M5.3M4.0M147.1[paper] [tf-tpu]
EfficientNetB124079.04094.28415.4M7.9M6.6M217.3[paper] [tf-tpu]
EfficientNetB226080.06494.86218.1M9.2M7.8M296.4[paper] [tf-tpu]
EfficientNetB330081.38495.58624.2M12.3M10.8M482.7[paper] [tf-tpu]
EfficientNetB438082.58896.09438.4M19.5M17.7M959.5[paper] [tf-tpu]
EfficientNetB545683.49696.59060.4M30.6M28.5M1872[paper] [tf-tpu]
EfficientNetB652883.77296.76285.5M43.3M41.0M3503[paper] [tf-tpu]
EfficientNetB760084.08896.740131.9M66.7M64.1M6149[paper] [tf-tpu]
SqueezeNet22454.43478.0402.5M1.2M0.7M71.43[paper] [caffe]

summary

Object detection

PASCAL VOC2007 testmAPSizeSpeedFPSReferences
YOLOv3VOC (416)0.742362M24.0941.51[paper] [darknet] [darkflow]
YOLOv2VOC (416)0.732051M14.7567.80[paper] [darknet] [darkflow]
TinyYOLOv2VOC (416)0.530316M6.534153.0[paper] [darknet] [darkflow]
FasterRCNN_ZF_VOC0.446659M241.43.325[paper] [caffe] [roi-pooling]
FasterRCNN_VGG16_VOC0.6872137M300.74.143[paper] [caffe] [roi-pooling]
MS COCO val2014mAPSizeSpeedFPSReferences
YOLOv3COCO (608)0.601662M60.6616.49[paper] [darknet] [darkflow]
YOLOv3COCO (416)0.602862M40.2324.85[paper] [darknet] [darkflow]
YOLOv2COCO (608)0.518951M45.8821.80[paper] [darknet] [darkflow]
YOLOv2COCO (416)0.492251M21.6646.17[paper] [darknet] [darkflow]

News 📰

Future work 🔥