Home

Awesome

Can Graph Learning Improve Planning in LLM-based Agents?

Paper

This is the official implementation for our NeurIPS 2024 paper "Can Graph Learning Improve Planning in LLM-based Agents?" [äø­ę–‡]

task

Task planning aims to break down complex user request into solvable sub-tasks, thereby fulfilling the original request. In this context, the sub-tasks can be naturally viewed as a graph where nodes represent the sub-tasks, and the edges denote dependencies among them. Consequently, task planning is a decision-making problem that involves selecting a connected path within the corresponding graph and invoking it. In this paper, we first provide theoretical analysis, showing that the biases of attention and auto-regressive loss impede LLM's ability to effectively solve decision-making on graphs. Based on the theoretical analysis, we introduce an additional GNN for sub-task retrieval, available in both training-free and training-based variants. The experiments on diverse LLMs and planning benchmarks demonstrate that the proposed method outperforms existing solutions with much less computation time.

Feel free to cite this work if you find it useful to you! šŸ˜„

@inproceedings{wu2024graph,
  title={Can Graph Learning Improve Planning in LLM-based Agents?},
  author={Xixi Wu and Yifei Shen and Caihua Shan and Kaitao Song and Siwei Wang and Bohang Zhang and Jiarui Feng and Hong Cheng and Wei Chen and Yun Xiong and Dongsheng Li},
  booktitle={Proceedings of Neural Information Processing Systems},
  year={2024}
}

šŸ”„ News


Table of Contents

Environment Setup

pip install -r requirements.txt

Run the above command to install required Python packages.

Deploy Open-sourced LLMs

For running LLM's direct inference or GraphSearch, our codes are implemented as deploying LLMs as API services using FastChat to the localhost:8008 endpoint.

Overview

.
ā”œā”€ā”€ GraphToken                     --> Ours implementation of the training-required baseline: GraphToken
ā”œā”€ā”€ README.assets    
ā”œā”€ā”€ README.md       
ā”œā”€ā”€ data                           --> Provide all experimental datasets [HuggingFace, Multimedia, DailyLife, TMDB, and UltraTool]
ā”‚   ā”œā”€ā”€ dailylife
ā”‚   ā”œā”€ā”€ huggingface
ā”‚   ā”œā”€ā”€ multimedia
ā”‚   ā”œā”€ā”€ raw                        --> Original files from RestBench and UltraTool
ā”‚   ā”‚   ā””ā”€ā”€ RestBench                    `https://github.com/Yifan-Song793/RestGPT`
ā”‚Ā Ā  ā”‚Ā Ā  ā””ā”€ā”€ ultratool                    `https://github.com/JoeYing1019/UltraTool`
ā”‚   ā”œā”€ā”€ raw_process_restgpt.py     --> Codes for processing RestBench
ā”‚   ā”œā”€ā”€ raw_process_ultratool.py   --> Codes for processing UltraTool
ā”‚   ā”œā”€ā”€ split_data.py              --> Codes for splitting testset
ā”‚Ā Ā  ā”œā”€ā”€ tmdb
ā”‚Ā Ā  ā””ā”€ā”€ ultratool
ā”œā”€ā”€ evaluate.py                    --> Codes for evaluation 
ā”œā”€ā”€ finetunellm                    --> Codes for fine-tuning LLMs and then make direct inference based on fine-tuned LLMs
ā”œā”€ā”€ finetunellm_script.sh          --> Scripts for fine-tuning LLMs
ā”œā”€ā”€ prediction                     --> Results of Task Planning
ā”œā”€ā”€ requirements.txt     
ā”œā”€ā”€ trainfree                      --> Codes for training-free methods (Direct, GraphSearch, and SGC)
ā”œā”€ā”€ trainfree_script.sh            --> Scripts for training-free methods
ā”œā”€ā”€ traingnn                       --> Codes for training GNNs
ā”œā”€ā”€ traingnn_reproduce.sh          --> Scripts for reproducing reported GNN / LM+GNN results
ā””ā”€ā”€ utils                 

This repo provides both training-free and training-based methods.

Besides, we provide source codes for fine-tuning LLMs using LoRA on splitted training data. Explanations of these contents will be detailed as follows.

Datasets

Five experimental datasets (HuggingFace, Multimedia, Daily Life from TaskBench, TMDB from RestBench), and UltraTool are under the data folder.

Each dataset contains the following files:

As dataset from RestBench only contains orignal request and ground-truth API sequences, we have reformatted this dataset to align with experiments, including assigning a unique name to each API, constructing a task graph, and finally reformatting original data samples. Processing details are covered in raw_process_restgpt.py.

To demonstrate scalability with large task graphs, we introduced a new dataset, UltraTool (ACL2024 Findings). The original data, processing details, and reformatted data are well-organized in the dataset folder. This dataset includes 260 distinct tasks. The processing details involve filtering data samples with invoked tasks >= 2, retaining valid tasks with appearance counts >= 5, constructing task graphs based on the filtered tasks and trajectories, and finally prompting GPT-4 to fill in the steps.

Training-free Methods

Code Intro

Codes of training-free modes are under the trainfree folder:

Besides, we also provide two improved prompt templates, 2-shot and PlaG, to investigate the orthogonal effectiveness of our method.

ā”œā”€ā”€ trainfree
ā”‚Ā Ā  ā”œā”€ā”€ direct.py               --> LLM's direct inference
ā”‚Ā Ā  ā”œā”€ā”€ direct_diffprompt.py    --> LLM's direct inference under improved prompts, including 1) more in-context learning examples and 2) plan like a graph (PlaG)
ā”‚Ā Ā  ā”œā”€ā”€ graphsearch.py          --> GraphSearch method
ā”‚Ā Ā  ā””ā”€ā”€ sgc.py                  --> SGC method

Reproducibility

Running scripts can be found in trainfree_script.sh.

Hint You have to first run the Direct Inference to obtain any LLM's direct inference results to facilitate SGC or GraphSearch.

Training GNNs

Codes of training-based GNNs are under the traingnn folder:

ā”œā”€ā”€ traingnn
ā”‚   ā”œā”€ā”€ gnn.py              --> GNN encoder implementation, including SGC, GCN, GAT, SAGE, GIN, and TransformerConv
ā”‚   ā”œā”€ā”€ main.py             --> Training GNN and then testing the performance
ā”‚   ā”œā”€ā”€ model.py            --> LM+GNN model
ā”‚   ā””ā”€ā”€ sampler.py          --> Sampling object to prepare training triplets `<step, positive task, negative task>`
ā”œā”€ā”€ traingnn_reproduce.sh   --> Scripts for reproducing all experimental results

Specifically, we explain the core arguments of main.py:

# HuggingFace - GNN only
python main.py --lm_frozen=1 --epoch=10 --text_negative=1 --gnn_name=SAGE --lr=0.001

# HuggingFace - LM+GNN co-train
python main.py --lm_frozen=0 --epoch=20 --text_negative=1 --gnn_name=SAGE
# HuggingFace - LM+GNN co-train (limited GPU requires smaller batch_size)
python main.py --lm_frozen=0 --epoch=10 --text_negative=1 --gnn_name=SAGE --batch_size=256

More running scripts can be found in traingnn_reproduce.sh.

Fine-tuning LLMs

ā”œā”€ā”€ finetunellm
ā”‚Ā Ā  ā”œā”€ā”€ inference.py         --> Direct inference of fine-tuned LLMs
ā”‚Ā Ā  ā”œā”€ā”€ main.py              --> Fine-tuning LLM
ā”‚Ā Ā  ā””ā”€ā”€ user_prompt.py       --> Instruction Template

Codes of fine-tuning LLMs are under the finetunellm folder:

Running scripts can be found in finetunellm_script.sh and we use 2 NVIDIA A100-80G GPUs for fine-tuning LLMs.

Evaluation

evaluate.py provides a evaluation of task planning result, and metrics including Node-F1, Link-F1, Node-Hallucination (both Macro and Micro), and Link-Hallucination (both Macro and Micro).

To facilitate reproducibility, we have provided the direct inference results of CodeLLaMA-13B and Mistral-7B on HuggingFace dataset under the prediction folder.

For evaluation, you have to specify the LLM's name, dataset, and the method (for example, direct denotes LLM's direct inference):

python evaluate.py --llm=CodeLlama-13b --dataset=huggingface --method=direct

And the result is as follows (NF - Node-F1, LF - Link-F1, Acc - Accuracy, NH-1 - Micro Node Hallucination Rate, NH-2 - Macro Node Hallucination Rate, LH-1 - Micro Link Hallucination Rate, LH-2 - Macro Link Hallucination Rate):

+-------------+---------------+-------+--------+--------+--------+--------+--------+--------+--------+
|   Dataset   |      LLM      |  Mode |   NF   |   LF   |  Acc   |  NH-1  |  NH-2  |  LH-1  |  LH-2  |
+-------------+---------------+-------+--------+--------+--------+--------+--------+--------+--------+
| huggingface | CodeLlama-13b | chain | 0.5755 | 0.2888 | 0.1429 | 0.1656 | 0.4306 | 0.4228 | 0.6338 |
+-------------+---------------+-------+--------+--------+--------+--------+--------+--------+--------+

Implementation of Baselines

We provide our reproduction of the baseline method, GraphToken, in the GraphToken folder. Since this method does not have an official implementation, we reproduced it based on the original paper while tailoring it to the planning scenario. Running scripts are available in GraphToken/run.sh. Feel free to adjust training configurations, e.g., batch_size, eval_batch_size, according to your experimental devices.

Perozzi, Bryan, et al. "Let Your Graph Do the Talking: Encoding Structured Data for LLMs." arXiv preprint, 2024.

TODO

Acknowledgement

We sincerely thank the following repositories for their valuable insights and contributions to our paper and implementation:


šŸ“® If your still have other questions, you can open an issue or contact via e-mail: xxwu@se.cuhk.edu.hk