Awesome
CoSE: Compositional Stroke Embeddings
We haven't yet refactored our code repository, hence it contains unused functionality. Please do not rely on the default values of the command-line arguments.
The master
branch is used to train the model we used in the paper for the tables and figures. The development
branch implements some ideas improving our model.
Environment
Our codebase is in Python3 and using Tensorflow 2.1. We suggest creating a new virtual environment.
- The required packages can be installed by running
pip install -r requirements.txt
- Update
PYTHONPATH
by runningexport PYTHONPATH="${PYTHONPATH}:<CODE_PATH>
" - You can optionally set environment variables or alternatively use FLAGS.
COSE_DATA_DIR
or--data_dir
: Path to data files.COSE_LOG_DIR
or--experiment_dir
: Path to experiment/model files.COSE_EVAL_DIR
or--eval_dir
: Path to save model evaluation results. This is required only when the evaluation scripts are called.
Dataset
We use DiDi dataset
diagram drawings without text. Please skip their preprocessing steps as we provide it.
Move the .NDJSON file to COSE_DATA_DIR/didi_wo_text/
.
-
Run
data_scripts/didi_json_to_tfrecords.py
to create TFRecord files required for model training and evaluation. SetDATA_DIR
variable in this script toCOSE_DATA_DIR/didi_wo_text/
. -
Run
python data_scripts/calculate_data_statistics.py
to create data statistics file required for model to apply data normalization. SetDATA_DIR
variable in this script toCOSE_DATA_DIR/didi_wo_text/
.
Similarly, QuickDraw dataset can also be used. Note that our model require raw files.
Training
In training_commands.json
file, we provide commands for training our main and some of the ablation models.
For example, our model can be trained by running
python ink_training_eager_predictive.py --experiment_id <UNIQUE_ID> --gt_targets --use_start_pos --num_pred_inputs 32 --stop_predictive_grad --pred_input_type hybrid --stroke_loss nll_gmm --n_t_samples 4 --batch_size 128 --affine_prob 0.3 --resampling_factor 2 --scale_factor 0 --grad_clip_norm 1 --encoder_model transformer --transformer_scale --transformer_pos_encoding --transformer_layers 6 --transformer_heads 4 --transformer_dmodel 64 --transformer_hidden_units 256 --transformer_dropout 0.0 --latent_units 8 --decoder_model t_emb --decoder_dropout 0.0 --decoder_layers 4 --decoder_hidden_units 512,512,512,512 --predictive_model transformer --learning_rate_type transformer --p_transformer_layers 6 --p_transformer_heads 4 --p_transformer_dmodel 64 --p_transformer_hidden_units 256 --p_transformer_dropout 0.0 --p_transformer_scale --embedding_loss nll_gmm --embedding_gmm_components 10 --loss_predicted_embedding --loss_reconstructed_ink --position_model transformer --data_name didi_wo_text --metadata_type position --disable_pen_loss --mask_encoder_pen
where you are expected to pass a unique identifier (--experiment_id
). We recommend using timestamp (i.e., output of date +%s
).
Evaluation
Qualitative and quantitative evaluation can be done easily by running
python eval.py --model_ids <UNIQUE_ID> --qualitative --quantitative --embedding_analysis
where UNIQUE_ID
is the same as above.
Pre-trained Models
You can download our main model and run evaluation script as explained above.
Demo
An interactive demo is provided in the smarting-js folder, using the pre-trained models shared above.
Citation
@article{aksan2020cose,
title={CoSE: Compositional Stroke Embeddings},
author={Aksan, Emre and Deselaers, Thomas and Tagliasacchi, Andrea and Hilliges, Otmar},
journal={Advances in Neural Information Processing Systems},
volume={33},
year={2020}
}
@article{gervais2020didi,
title={The DIDI dataset: Digital Ink Diagram data},
author={Gervais, Philippe and Deselaers, Thomas and Aksan, Emre and Hilliges, Otmar},
journal={arXiv preprint arXiv:2002.09303},
year={2020}
}