Home

Awesome

CausalTransformer

Conference arXiv Python application

Causal Transformer for estimating counterfactual outcomes over time.

<img width="1518" alt="Screenshot 2022-06-03 at 16 41 44" src="https://user-images.githubusercontent.com/23198776/171877145-c7cba15e-9787-4594-8f1f-cbb8b337b74a.png">

The project is built with following Python libraries:

  1. Pytorch-Lightning - deep learning models
  2. Hydra - simplified command line arguments management
  3. MlFlow - experiments tracking

Installations

First one needs to make the virtual environment and install all the requirements:

pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt

MlFlow Setup / Connection

To start an experiments server, run:

mlflow server --port=5000

To access MlFLow web UI with all the experiments, connect via ssh:

ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>

Then, one can go to local browser http://localhost:5000.

Experiments

Main training script is universal for different models and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in configs/ folder.

Generic script with logging and fixed random seed is following (with training-type enc_dec, gnet, rmsn and multi):

PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices> 
python3 runnables/train_<training-type>.py +dataset=<dataset> +backbone=<backbone> exp.seed=10 exp.logging=True

Backbones (baselines)

One needs to choose a backbone and then fill the specific hyperparameters (they are left blank in the configs):

Models already have best hyperparameters saved (for each model and dataset), one can access them via: +backbone/<backbone>_hparams/cancer_sim_<balancing_objective>=<coeff_value> or +backbone/<backbone>_hparams/mimic3_real=diastolic_blood_pressure.

For CT, EDCT, and CT, several adversarial balancing objectives are available:

To train a decoder (for CRN and RMSNs), use the flag model.train_decoder=True.

To perform a manual hyperparameter tuning use the flags model.<sub_model>.tune_hparams=True, and then see model.<sub_model>.hparams_grid. Use model.<sub_model>.tune_range to specify the number of trials for random search.

Datasets

One needs to specify a dataset / dataset generator (and some additional parameters, e.g. set gamma for cancer_sim with dataset.coeff=1.0):

Before running MIMIC III experiments, place MIMIC-III-extract dataset (all_hourly_data.h5) to data/processed/

Example of running Causal Transformer on Synthetic Tumor Growth Generator with gamma = [1.0, 2.0, 3.0] and different random seeds (total of 30 subruns), using hyperparameters:

PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices> 
python3 runnables/train_multi.py -m +dataset=cancer_sim +backbone=ct +backbone/ct_hparams/cancer_sim_domain_conf=\'0\',\'1\',\'2\' exp.seed=10,101,1010,10101,101010

Updated results

Self- and cross-attention bug

New results for semi-synthetic and real-world experiments after fixing a bug with self- and cross-attentions (https://github.com/Valentyn1997/CausalTransformer/issues/7). Therein, the bug affected only Tables 1 and 2, and Figure 5 (https://arxiv.org/pdf/2204.07258.pdf). Nevertheless, the performance of the CT with the bug fixed did not change drastically.

Table 1 (updated). Results for semi-synthetic data for $\tau$-step-ahead prediction based on real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.

$\tau = 1$$\tau = 2$$\tau = 3$$\tau = 4$$\tau = 5$$\tau = 6$$\tau = 7$$\tau = 8$$\tau = 9$$\tau = 10$
MSMs0.37 ± 0.010.57 ± 0.030.74 ± 0.060.88 ± 0.031.14 ± 0.101.95 ± 1.483.44 ± 4.57> 10.0> 10.0> 10.0
RMSNs0.24 ± 0.010.47 ± 0.010.60 ± 0.010.70 ± 0.020.78 ± 0.040.84 ± 0.050.89 ± 0.060.94 ± 0.080.97 ± 0.091.00 ± 0.11
CRN0.30 ± 0.010.48 ± 0.020.59 ± 0.020.65 ± 0.020.68 ± 0.020.71 ± 0.010.72 ± 0.010.74 ± 0.010.76 ± 0.010.78 ± 0.02
G-Net0.34 ± 0.010.67 ± 0.030.83 ± 0.040.94 ± 0.041.03 ± 0.051.10 ± 0.051.16 ± 0.051.21 ± 0.061.25 ± 0.061.29 ± 0.06
EDCT (GR; $\lambda = 1$)0.29 ± 0.010.46 ± 0.010.56 ± 0.010.62 ± 0.010.67 ± 0.010.70 ± 0.010.72 ± 0.010.74 ± 0.010.76 ± 0.010.78 ± 0.01
CT ($\alpha = 0$) (ours, fixed)0.20 ± 0.010.38 ± 0.010.46 ± 0.010.50 ± 0.010.52 ± 0.010.54 ± 0.010.56 ± 0.010.57 ± 0.010.59 ± 0.010.60 ± 0.01
CT (ours, fixed)0.21 ± 0.010.38 ± 0.010.46 ± 0.010.50 ± 0.010.53 ± 0.010.54 ± 0.010.55 ± 0.010.57 ± 0.010.58 ± 0.010.59 ± 0.01

Table 2 (updated). Results for experiments with real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.

$\tau = 1$$\tau = 2$$\tau = 3$$\tau = 4$$\tau = 5$
MSMs6.37 ± 0.269.06 ± 0.4111.89 ± 1.2813.12 ± 1.2514.44 ± 1.12
RMSNs5.20 ± 0.159.79 ± 0.3110.52 ± 0.3911.09 ± 0.4911.64 ± 0.62
CRN4.84 ± 0.089.15 ± 0.169.81 ± 0.1710.15 ± 0.1910.40 ± 0.21
G-Net5.13 ± 0.0511.88 ± 0.2012.91 ± 0.2613.57 ± 0.3014.08 ± 0.31
CT (ours, fixed)4.60 ± 0.089.01 ± 0.219.58 ± 0.199.89 ± 0.2110.12 ± 0.22

Figure 6 (updated). Subnetworks importance scores based on semi-synthetic benchmark (higher values correspond to higher importance of subnetwork connectivity via cross-attentions). Shown: RMSE differences between model with isolated subnetwork and full CT, means ± standard errors.

subnet-isolation

Last active entry zeroing bug

New results after fixing a bug with the synthetic tumor-growth simulator: outcome corresponding to the last entry for every time series was zeroed.

Table 9 (updated). Normalized RMSE for one-step-ahead prediction. Shown: mean and standard deviation over five runs (lower is better). Parameter $\gamma$ is the the amount of time-varying confounding: higher values mean larger treatment assignment bias.

$\gamma = 0$$\gamma = 1$$\gamma = 2$$\gamma = 3$$\gamma = 4$
MSMs1.091 ± 0.1151.202 ± 0.1081.383 ± 0.0901.647 ± 0.1211.981 ± 0.232
RMSNs0.834 ± 0.0720.860 ± 0.0251.000 ± 0.1341.131 ± 0.0571.434 ± 0.148
CRN0.755 ± 0.0590.788 ± 0.0570.881 ± 0.0661.062 ± 0.0881.358 ± 0.167
G-Net0.795 ± 0.0660.841 ± 0.0380.946 ± 0.0831.057 ± 0.1461.319 ± 0.248
CT ($\alpha = 0$) (ours)0.772 ± 0.0510.783 ± 0.0710.862 ± 0.0521.062 ± 0.1191.331 ± 0.217
CT (ours)0.770 ± 0.0490.783 ± 0.0710.864 ± 0.0591.098 ± 0.0971.413 ± 0.259

Table 10 (updated). Normalized RMSE for $\tau$-step-ahead prediction (here: random trajectories setting). Shown: mean and standard deviation over five runs (lower is better). Parameter $\gamma$ is the amount of time-varying confounding: higher values mean larger treatment assignment bias.

$\gamma = 0$$\gamma = 1$$\gamma = 2$$\gamma = 3$$\gamma = 4$
('2', 'MSMs')0.975 ± 0.0631.183 ± 0.1461.428 ± 0.2741.673 ± 0.4311.884 ± 0.637
('2', 'RMSNs')0.825 ± 0.0570.851 ± 0.0430.861 ± 0.0780.993 ± 0.1261.269 ± 0.294
('2', 'CRN')0.761 ± 0.0580.760 ± 0.0370.805 ± 0.0502.045 ± 1.4911.209 ± 0.192
('2', 'G-Net')1.006 ± 0.0820.994 ± 0.0861.185 ± 0.0771.083 ± 0.1451.243 ± 0.202
('2', 'CT ($\alpha = 0$) (ours)')0.766 ± 0.0290.781 ± 0.0660.814 ± 0.0780.944 ± 0.1441.191 ± 0.316
('2', 'CT (ours)')0.762 ± 0.0280.781 ± 0.0580.818 ± 0.0911.001 ± 0.1501.163 ± 0.233
('3', 'MSMs')0.937 ± 0.0601.133 ± 0.1581.344 ± 0.2621.525 ± 0.4001.564 ± 0.545
('3', 'RMSNs')0.824 ± 0.0430.871 ± 0.0360.857 ± 0.1091.020 ± 0.1401.267 ± 0.298
('3', 'CRN')0.769 ± 0.0570.777 ± 0.0370.826 ± 0.0771.789 ± 1.1081.356 ± 0.330
('3', 'G-Net')1.103 ± 0.0921.097 ± 0.0951.355 ± 0.1071.225 ± 0.1841.382 ± 0.242
('3', 'CT ($\alpha = 0$) (ours)')0.766 ± 0.0370.806 ± 0.0600.828 ± 0.1060.996 ± 0.1851.335 ± 0.465
('3', 'CT (ours)')0.762 ± 0.0360.807 ± 0.0560.838 ± 0.1201.072 ± 0.1961.283 ± 0.312
('4', 'MSMs')0.845 ± 0.0601.022 ± 0.1491.196 ± 0.2331.325 ± 0.3631.308 ± 0.482
('4', 'RMSNs')0.780 ± 0.0460.834 ± 0.0400.814 ± 0.1230.988 ± 0.1461.169 ± 0.269
('4', 'CRN')0.734 ± 0.0610.743 ± 0.0370.805 ± 0.0961.567 ± 0.8251.327 ± 0.293
('4', 'G-Net')1.092 ± 0.0901.074 ± 0.0981.385 ± 0.1171.212 ± 0.2021.358 ± 0.253
('4', 'CT ($\alpha = 0$) (ours)')0.730 ± 0.0420.776 ± 0.0560.802 ± 0.1190.983 ± 0.2081.394 ± 0.563
('4', 'CT (ours)')0.726 ± 0.0410.777 ± 0.0540.810 ± 0.1281.075 ± 0.2201.302 ± 0.356
('5', 'MSMs')0.747 ± 0.0560.896 ± 0.1361.038 ± 0.2101.128 ± 0.3201.155 ± 0.448
('5', 'RMSNs')0.717 ± 0.0530.775 ± 0.0410.747 ± 0.1240.922 ± 0.1411.057 ± 0.246
('5', 'CRN')0.678 ± 0.0620.692 ± 0.0370.761 ± 0.1041.410 ± 0.6041.242 ± 0.239
('5', 'G-Net')1.033 ± 0.0861.014 ± 0.0971.358 ± 0.1181.160 ± 0.1991.285 ± 0.242
('5', 'CT ($\alpha = 0$) (ours)')0.673 ± 0.0440.722 ± 0.0520.748 ± 0.1240.931 ± 0.2131.405 ± 0.648
('5', 'CT (ours)')0.669 ± 0.0430.723 ± 0.0530.751 ± 0.1251.036 ± 0.2381.264 ± 0.389
('6', 'MSMs')0.647 ± 0.0550.778 ± 0.1230.894 ± 0.1880.952 ± 0.2841.060 ± 0.432
('6', 'RMSNs')0.646 ± 0.0580.702 ± 0.0430.675 ± 0.1210.847 ± 0.1320.947 ± 0.225
('6', 'CRN')0.614 ± 0.0570.631 ± 0.0350.706 ± 0.1041.308 ± 0.4381.132 ± 0.194
('6', 'G-Net')0.963 ± 0.0830.942 ± 0.0901.321 ± 0.1181.092 ± 0.1831.195 ± 0.223
('6', 'CT ($\alpha = 0$) (ours)')0.609 ± 0.0420.657 ± 0.0460.684 ± 0.1220.864 ± 0.2011.383 ± 0.699
('6', 'CT (ours)')0.605 ± 0.0400.657 ± 0.0470.685 ± 0.1190.979 ± 0.2491.201 ± 0.419

Table 11 (updated). Normalized RMSE for $\tau$-step-ahead prediction (here: single sliding treatment setting). Shown: mean and standard deviation over five runs (lower is better). Parameter $\gamma$ is the amount of time-varying confounding: higher values mean larger treatment assignment bias.

$\gamma = 0$$\gamma = 1$$\gamma = 2$$\gamma = 3$$\gamma = 4$
('2', 'MSMs')1.362 ± 0.1091.612 ± 0.1721.939 ± 0.3652.290 ± 0.5452.468 ± 1.058
('2', 'RMSNs')0.742 ± 0.0430.760 ± 0.0470.827 ± 0.0560.957 ± 0.1061.276 ± 0.240
('2', 'CRN')0.671 ± 0.0660.666 ± 0.0520.741 ± 0.0421.668 ± 1.1841.151 ± 0.166
('2', 'G-Net')1.021 ± 0.0671.009 ± 0.0921.271 ± 0.0751.113 ± 0.1491.257 ± 0.227
('2', 'CT ($\alpha = 0$) (ours)')0.685 ± 0.0500.679 ± 0.0440.714 ± 0.0530.875 ± 0.1051.072 ± 0.315
('2', 'CT (ours)')0.681 ± 0.0520.677 ± 0.0440.713 ± 0.0420.908 ± 0.1221.274 ± 0.366
('3', 'MSMs')1.679 ± 0.1321.953 ± 0.2082.302 ± 0.4372.640 ± 0.6392.622 ± 1.132
('3', 'RMSNs')0.783 ± 0.0530.792 ± 0.0470.889 ± 0.0501.086 ± 0.1751.382 ± 0.286
('3', 'CRN')0.700 ± 0.0780.692 ± 0.0460.818 ± 0.0511.959 ± 1.0321.360 ± 0.225
('3', 'G-Net')1.253 ± 0.0791.226 ± 0.1041.611 ± 0.1021.383 ± 0.2001.574 ± 0.328
('3', 'CT ($\alpha = 0$) (ours)')0.707 ± 0.0530.711 ± 0.0380.770 ± 0.0430.969 ± 0.1191.261 ± 0.462
('3', 'CT (ours)')0.703 ± 0.0550.712 ± 0.0400.770 ± 0.0321.010 ± 0.1191.536 ± 0.450
('4', 'MSMs')1.871 ± 0.1452.145 ± 0.2272.489 ± 0.4712.791 ± 0.6812.615 ± 1.142
('4', 'RMSNs')0.821 ± 0.0790.837 ± 0.0580.963 ± 0.1061.216 ± 0.2401.416 ± 0.304
('4', 'CRN')0.734 ± 0.0870.722 ± 0.0410.898 ± 0.0682.201 ± 0.9671.573 ± 0.255
('4', 'G-Net')1.390 ± 0.0871.347 ± 0.1121.819 ± 0.1331.544 ± 0.2431.769 ± 0.413
('4', 'CT ($\alpha = 0$) (ours)')0.729 ± 0.0560.749 ± 0.0330.826 ± 0.0461.053 ± 0.1471.426 ± 0.574
('4', 'CT (ours)')0.726 ± 0.0570.748 ± 0.0360.822 ± 0.0361.089 ± 0.1221.762 ± 0.523
('5', 'MSMs')1.963 ± 0.1552.221 ± 0.2312.547 ± 0.4792.810 ± 0.6842.542 ± 1.122
('5', 'RMSNs')0.855 ± 0.0990.889 ± 0.0741.030 ± 0.1651.349 ± 0.3261.434 ± 0.299
('5', 'CRN')0.769 ± 0.0940.755 ± 0.0390.976 ± 0.0822.361 ± 1.0001.730 ± 0.292
('5', 'G-Net')1.477 ± 0.0921.430 ± 0.1191.963 ± 0.1571.667 ± 0.2751.907 ± 0.471
('5', 'CT ($\alpha = 0$) (ours)')0.758 ± 0.0550.788 ± 0.0360.875 ± 0.0561.118 ± 0.1721.560 ± 0.663
('5', 'CT (ours)')0.756 ± 0.0570.786 ± 0.0390.870 ± 0.0481.154 ± 0.1111.922 ± 0.569
('6', 'MSMs')1.970 ± 0.1552.205 ± 0.2282.509 ± 0.4692.732 ± 0.6622.422 ± 1.084
('6', 'RMSNs')0.889 ± 0.1120.936 ± 0.0911.081 ± 0.2111.473 ± 0.4331.436 ± 0.290
('6', 'CRN')0.807 ± 0.0970.790 ± 0.0351.047 ± 0.0922.480 ± 1.0781.827 ± 0.326
('6', 'G-Net')1.538 ± 0.0911.493 ± 0.1212.062 ± 0.1721.758 ± 0.2861.994 ± 0.500
('6', 'CT ($\alpha = 0$) (ours)')0.790 ± 0.0580.827 ± 0.0360.915 ± 0.0631.177 ± 0.1931.654 ± 0.704
('6', 'CT (ours)')0.789 ± 0.0590.821 ± 0.0340.909 ± 0.0541.205 ± 0.1002.052 ± 0.608
<p><small>Project based on the <a target="_blank" href="https://drivendata.github.io/cookiecutter-data-science/">cookiecutter data science project template</a>. #cookiecutterdatascience</small></p>