Awesome
Torch-MTS
A simple PyTorch framework for multivariate time series forecasting, for personal use.
All the codes are based on my personal coding habit and style.
Performance Table
Environment: PyTorch 1.13.1 & NVIDIA RTX 5000 Ada 32G
Tips
- 不要用 sklearn 的
StandardScaler
- 构建 train-val-test 的时候必须只使用 x_train 的分布构建 scaler
- 读取原始数据的时候
.astype(np.float32)
,保证全程的数据不管是 ndarray 还是 tensor 都是 float32 类型 - [重点重点重点] 千万不要 transform 任何 y,这一点可以解决各种各样的奇怪问题,例如 MAPE 爆炸
- 在 forward 过程中
inverse_transform(y_pred)
之后再和y_true
算 loss。如果算 loss 时y_pred, y_true
是 transform 后的,会严重影响 performance - metrics 中不要修改
y_pred, y_true
的任何值,例如把特别小的值置为 0。引起这个操作的情况是 MAPE 分母的y_true
中有一些特别小的值导致算出来的值爆炸,只要按照上面的 1234 的要求做,就不会出现这种情况 - Learning scheduler:
MultiStepLR
必须有而且要细调,而且要和课程学习的设置联动
Changelog
- v1.1: 第一个能用的版本
- v1.2: 参考 BasicTS,重构了整个 dataset 的创建和读取
- v1.3: 添加单独的 loss 和 runner 模块
- v1.4: 交通预测的稳定 (或许是完全 bug-free) 版本,使用半年以上