Home

Awesome

QGFN: Controllable Greediness with Action Values

Table of Contents

Description

This repository contains the code to run the experiments and visualize the results highlighted in the paper QGFN: Controllable Greediness with Action Values.

Overview

Our codebase builds on top of a fork of the public recursion gflownet repo which provides the environment setup to run the gflownet framework on graph domains. Our main edits to the forked repo are found in the following files

Setup

To setup the project, run the following commands to install the required base packages:

Then, install the gflownet package from local source:

pip install --no-index torch-scatter -f https://data.pyg.org/whl/torch-1.13.1+cu116.html 
pip install --no-index torch-sparse -f https://data.pyg.org/whl/torch-1.13.1+cu116.html
pip install --no-index torch-cluster -f https://data.pyg.org/whl/torch-1.13.1+cu116.html 

pip install -e . --find-links https://data.pyg.org/whl/torch-1.13.1+cu116.html 
# It may take a while to build wheels for `torch-cluster`, `torch-scatter` & `torch-sparse

or follow the guideline from the public recursion gflownet repo

How to run

After setting up the virtual environment, you can test out QGFN by running python interactive_script.py. interactive_script.py has three flags, indicating the three variants of QGFN. Setting one of them to true will run one of the variants.

Usage

The current project supports training a mixture policy from scratch and visualizing the results of one or multiple runs in the same plots.

Training from scratch

Our setup is made to support jobs running with the Slurm workload manager. To train a mixture policy from scratch, you must first edit the utils/template.sh file to customize the Slurm executable script to your GPU environment. Next, to generate the executables for a given job, simply edit the hyperparameters in the main function of gen.py and run the file from the project root directory:

Example

First, select the [task] on which you'd like to train the network. Next, set the hyperparameters in the gen.py file. The main hyperparameters of interest are as follows:

BASE_HPS: Config = {
    "log_dir": 'path.to.log.directory',
    "num_training_steps": 10000,
    ...
    "cond": {
        "temperature": {
            "sample_dist": "constant",
            "dist_params": 32.0,        # Set reward exponent beta
        }
    },
    "algo": {
        "p_greedy_sample": True,        # Set to true to run p-greedy sampling
        "p_of_max_sample": False,       # Set to true to run p-of-max sampling
        "p_quantile_sample": False,     # Set to true to run p-quantile sampling
        "p": 0.6,                       # Set value of `p` to control greediness
        "dqn_n_step": 30,               # Set n-step returns for Q learning
        ...
    },
    "task": {
        "qm9": {
            "h5_path": "path.to.dataset/qm9.h5",                 # Set path to qm9 dataset
            "model_path": "path.to.model/mxmnet_gap_model.pt"    # Set path to mxmnet proxy model for qm9 task reward
        },
        "bitseq": {
            "variant": "prepend-append",                         # Set variant of bit sequence generation mode as `prepend-append` or `autoregressive`
            "modes_path": "path.to.ref.sequences/modes.pkl",     # Set path to the pre-determined set of modes for the task 
            ...
        }
    }
}

Finally, you may set a list of hyperparameters in the main() method of gen.py for grid search. To generate the runs, simply run the file:

python gen.py

For every hyperparameter combination you specified, this will generate a corresponding run folder at jobs/<current_date>/<run_id-current_time>. This folder will contain the following files:

To submit the job to slurm, simply run the command specified in howto.txt from the run config directory. For example, if the command is sbatch --array=0-4 run.sh config, then run the following:

cd jobs/<current_date>/<run_id-current_time>
sbatch --array=0-4 run.sh config

Plotting results

This repository also supports a variety of plotting functions to visualize and compare the results across multiple runs. The main plotting script is vis.py and we currently support the following plot types:

To produce a given plot, you need to provide the path to the config folders (read above) of the runs you want to compare in the runs list in the main function of vis.py, along with a name and color for each run. For example, if you want to compare the results of two runs with paths jobs/2023-11-19/001-01-54-02 and jobs/2023-12-09/002-23-23-29, you could write:

runs = [
    {
        'path': 'jobs/2023-11-19/001-01-54-02',
        'name': 'run1',
        'color': 'blue'
    },
    {
        'path': 'jobs/2023-12-09/002-23-23-29',
        'name': 'run2',
        'color': 'red'
    },
]

Then, you need to specify the plot type you want to produce in the plot_type variable. Available plot types are specified in the PlotType enum in utils/plotting.py. For example, if you want to produce an average reward plot, you would write:

plot_type = PlotType.NUMBER_OF_MODES

Finally, you can override the default plot parameters in the plot_params dictionary. For example, if you want to change the title of the plot, you could write:

plot_params = {
    **DEFAULT_SHARED_PARAMS,
    **DEFAULT_UNIQUE_PARAMS[plot_type],

    # Add your custom parameters here.
    'title': 'My custom title'
}

The full specification of modifiable parameters for each plot type can be found in the DEFAULT_UNIQUE_PARAMS and DEFAULT_SHARED_PARAMS dictionaries in vis.py.

Once you have specified the runs, plot type, and plot parameters, you must run the script from the project root directory, and your plot will be saved in the results folder under a unique identifier based on the current date and plot type.

python vis.py

Plotting for other task environments

The above script generation and plotting code has been written to support the standard molecules environment in the gflownet package. However, the plotting code can be easily adapted to support other environments by inhering from the PlottableRunObject class in utils/plotting.py. This class defines the interface for a run object that can be plotted, and characterizes a few simple methods to extract the relevant data from the runs into a universal format, which is then provided to the downstream PlotConfig class for standard plotting. The methods that need to be implemented are:

def load_raw_data(self, sqlite_cols: List[str]):
    """
    Loads the raw run data from the runs. Please override this method 
    in your child class depending on your data format. This implementation supports
    the molecules environment which reads data from a sqlite database using the 
    specified columns.
    """
    pass

def get_average_reward(self, save_df: bool=False) -> pd.DataFrame:
    """Returns a pandas dataframe with the average reward for each worker at each step"""
    pass

@staticmethod
def is_new_mode(obj, modes, sim_threshold=0.7) -> bool:
    """Returns True if obj is a new mode, False otherwise"""
    pass

def get_modes(self, min_reward: float=0.9, sim_threshold: float=0.7, save_df: bool=False) -> pd.DataFrame:
    """Returns a pandas dataframe with the number of modes with reward > min_reward and 
    similarity > sim_threshold for each worker at each step"""
    pass

def get_top_k_reward(self):
    """Returns a pandas dataframe with the top k reward for each worker at each step"""
    pass

def get_top_k_similarity(self):
    pass

def get_reward_distribution(self):
    pass

Updates

2024-05-15