Home

Awesome

Dynamic-memory-networks-plus-Pytorch

DMN+ implementation in Pytorch for question answering on the bAbI 10k dataset.

Contents

filedescription
babi_loader.pydeclaration of bAbI Pytorch Dataset class
babi_main.pycontains DMN+ model and training code
fetch_data.shshell script to fetch bAbI tasks (from DMNs in Theano)

Usage

Install Pytorch v0.1.12 and Python 3.6.x (for Literal String Interpolation)

Run the included shell script to fetch the data

chmod +x fetch_data.sh
./fetch_data.sh

Run the main python code

python babi_main.py

Benchmarks

Low accuracies compared to Xiong et al's are may due to different weight decay setting or the model's instability.

On some tasks, the accuracy was not stable across multiple runs. This was particularly problematic on QA3, QA17, and QA18. To solve this, we repeated training 10 times using random initializations and evaluated the model that achieved the lowest validation set loss.

You can find pretrained models here

Task IDThis RepoXiong et al
1100%100%
296.8%99.7%
389.2%98.9%
4100%100%
599.5%99.5%
6100%100%
797.8%97.6%
8100%100%
9100%100%
10100%100%
11100%100%
12100%100%
13100%100%
1499%99.8%
15100%100%
1651.6%54.7%
1786.4%95.8%
1897.9%97.9%
1999.7%100%
20100%100%