Home

Awesome

ICDM 2023 解决方案

基于掩模图自编码器的预训练-微调方法 队伍:Echoch

运行环境依赖

代码复现流程

可直接sh run.sh进行一步复现,也可按照下述步骤一步步进行

注:即使是同样的随机种子,不同的机器Kmeans最终运行结果也差异较大,因此如果结果差异较大,可以在最后聚类的步骤多运行几次

python cluster.py --embedding_path embedding_arxiv.pt --output submit.txt --k 15 --seed -1 --runs 5

其中--seed -1为不指定随机种子

(Step 1 & 2) 预训练 (arXiv数据集)

运行代码后会自动从ogb下载arxiv数据集

python pretrain-sup.py --dataset arxiv --save_path encoder_arxiv_sup.pt --root data

参数说明:

运行结束后目录下会出现encoder_arxiv_sup.pt模型参数文件

python pretrain.py --dataset arxiv --save_path encoder_arxiv.pt --pretrain_path encoder_arxiv_sup.pt --lr 0.005 --root data

参数说明:

运行结束后目录下会出现encoder_arxiv.pt模型参数文件

(Step 3) 加载预训练参数+微调 (ICDM数据集)

需下载数据集到data/icdm2023_session1_test文件夹内

python finetune.py --epochs 5 --pretrain_path encoder_arxiv.pt --embedding_save_path embedding_arxiv.pt --root data/icdm2023_session1_test

参数说明:

运行结束后目录下会出现embedding_arxiv.pt的节点embedding文件

(Step 4) 聚类+集成

python cluster.py --embedding_path embedding_arxiv.pt --output submit.txt --k 15 --seed 666 --runs 5

参数说明:

运行结束后目录下会出现submit.txt结果文件

文件夹目录


├── fast_pytorch_kmeans
│   ├── init_methods.py
│   ├── __init__.py
│   ├── kmeans.py
│   ├── multi_kmeans.py
│   └── util.py
├── maskgae
│   ├── loss.py
│   ├── mask.py
│   ├── model.py
├── pretrain.py
├── pretrain-sup.py
├── finetune.py
├── cluster.py
├── README.md
├── run.sh
├── data
    └── icdm2023_session1_test
        ├── icdm2023_session1_test_edge.txt
        └── icdm2023_session1_test_node_feat.txt
    └── ogbn_arxiv
        ├── mapping
        │   ├── labelidx2arxivcategeory.csv.gz
        │   ├── nodeidx2paperid.csv.gz
        │   └── README.md
        ├── processed
        │   ├── data_processed
        │   ├── geometric_data_processed.pt
        │   ├── pre_filter.pt
        │   └── pre_transform.pt
        ├── raw
        │   ├── edge.csv.gz
        │   ├── node-feat.csv.gz
        │   ├── node-label.csv.gz
        │   ├── node_year.csv.gz
        │   ├── num-edge-list.csv.gz
        │   └── num-node-list.csv.gz
        ├── RELEASE_v1.txt
        └── split
            └── time
                ├── test.csv.gz
                ├── train.csv.gz
                └── valid.csv.gz