Home

Awesome

Requirements:

All models require a GPU to train and were run on a single GPU with 16 GB memory. The models were run on a server with a slurm task scheduled but run scripts are now presented purely as shell scripts deviod of the slurm related commands.

RUN SCRIPTS

runscripts/

helpers.py

         import helpers
         helpers.calculate_mse_wmse(/full/path/to/DynamicAttentionNetworks/results/DyAtH/DyAtH_mean/,90)

source/

DyAtH/
	> train.py  -- Contains model training logic.
	> test.py	  -- Contains model testing logic.
	> model_train_eval.py -- Contains helper functions and sub-routines related to model training and testing.
	> train.sh  -- Shell script to invoke train.py
	> test.sh	  -- Shell script to invoke test.py
	> rnn.py  -- Contains the Sequence to Sequence Architecture Definition.

DyAtMaxPoolH/	
	> Similar structure as in DyAtH.

datasets/

ghl_small/ - GHL dataset used in the paper.

notebooks/

GHL\ Experiments.ipynb  
    -This notebook depicts how once trainining and testing has been conducted for a specific model (DyAt-H, DyAt-Maxpool-H) and a specific sequence length, it can be evaluated to obtain the average MSE, WMSE values.

results/

- Once a training script for a model is successfully run, the `results/` directory will automatically be created in the root folder of this repository,(with corresponding subdirectories) which will house the trained model, raw inputs and forecasts (.csv files), and plots for training and testing and validation sets. 
- The file structure will be results/$DATASETNAME/$MODELNAME/$MODELTYPE/SEQUENCE_LENGTH_$SEQUENCELENGTH_HIDDEN_SIZE_$HIDDENSIZE_..._ITERNUM_$ITERNUM/
- A separate such directory will be created per sequence length, per hidden size and experiment number.
- Within each results folder there you will find the training and validation raw inputs and forecasts along with plots, after training along with the trained model stored in the model/ folder. After testing, a subdirectory named `results_test/` will be created which will house the raw inputs, forecasts (.csv files) and forecast visualziations.

Note, error values might not be recovered exactly due to differences in cuda versions, system and various other configurations.