Home

Awesome

This is the code used in the ICML 2020 paper. Here is a newer version used in the AAAI 2022 paper

GOSDT Documentation

Implementation of Generalized Optimal Sparse Decision Tree.

Table of Content


Usage

Guide for end-users who want to use the library without modification.

Describes how to install and use the library as a stand-alone command-line program or as an embedded extension in a larger project. Currently supported as a Python extension.

Installing Dependencies

Refer to Dependency Installation

As a Stand-Alone Command Line Program

Installation

./autobuild --install

Executing the Program

gosdt dataset.csv config.json
# or 
cat dataset.csv | gosdt config.json >> output.json

For examples of dataset files, refer to experiments/datasets/compas/binned.csv. For an example configuration file, refer to experiments/configurations/compas.json. For documentation on the configuration file, refer to Dependency Installation

As a Python Library with C++ Extensions

Build and Installation

./autobuild --install-python

If you have multiple Python installations, please make sure to build and install using the same Python installation as the one intended for interacting with this library.

Importing the C++ Extension

import gosdt

with open ("data.csv", "r") as data_file:
    data = data_file.read()

with open ("config.json", "r") as config_file:
    config = config_file.read()


print("Config:", config)
print("Data:", data)

gosdt.configure(config)
result = gosdt.fit(data)

print("Result: ", result)
print("Time (seconds): ", gosdt.time())
print("Iterations: ", gosdt.iterations())
print("Graph Size: ", gosdt.size())

Importing Extension with local Python Wrapper

import pandas as pd
import numpy as np
from model.gosdt import GOSDT

dataframe = pd.DataFrame(pd.read_csv("experiments/datasets/monk_2/data.csv"))

X = dataframe[dataframe.columns[:-1]]
y = dataframe[dataframe.columns[-1:]]

hyperparameters = {
    "regularization": 0.1,
    "time_limit": 3600,
    "verbose": True,
}

model = GOSDT(hyperparameters)
model.fit(X, y)
print("Execution Time: {}".format(model.time))

prediction = model.predict(X)
training_accuracy = model.score(X, y)
print("Training Accuracy: {}".format(training_accuracy))
print(model.tree)

Development

Guide for developers who want to use, modify and test the library.

Describes how to install and use the library with details on project structure.

Repository Structure

Installing Dependencies

Refer to Dependency Installation

Build Process

For a full list of build options, run ./autobuild --help


Configuration

Details on the configuration options.

gosdt dataset.csv config.json
# or
cat dataset.csv | gosdt config.json

Here the file config.json is optional. There is a default configuration which will be used if no such file is specified.

Configuration Description

The configuration file is a JSON object and has the following structure and default values:

{
  "balance": false,
  "cancellation": true,
  "look_ahead": true,
  "similar_support": true,
  "feature_exchange": true,
  "continuous_feature_exchange": true,
  "rule_list": false,

  "diagnostics": false,
  "verbose": false,

  "regularization": 0.05,
  "uncertainty_tolerance": 0.0,
  "upperbound": 0.0,

  "model_limit": 1,
  "precision_limit": 0,
  "stack_limit": 0,
  "tile_limit": 0,
  "time_limit": 0,
  "worker_limit": 1,

  "costs": "",
  "model": "",
  "profile": "",
  "timing": "",
  "trace": "",
  "tree": ""
}

Key parameters

regularization

time_limit

Flags

balance

cancellation

look_ahead

similar_support

feature_exchange

continuous_feature_exchange

diagnostics

verbose

Tuners

uncertainty_tolerance

Limits

model_limit

precision_limit

stack_limit

tile_limit

worker_limit

Files

costs

model

profile

timing

trace

tree

Optimizing Different Loss Functions

When using the Python interface python/model/gosdt.py additional loss functions are available. Here is the list of loss functions implemented along with descriptions of their hyperparameters.

Accuracy

{ "objective": "acc" }

This optimizes the loss defined as the uniformly weighted number of misclassifications.

Balanced Accuracy

{ "objective": "bacc" }

This optimizes the loss defined as the number of misclassifications, adjusted for imbalanced representation of positive or negative samples.

Weighted Accuracy

{ "objective": "wacc", "w": 0.5 }

This optimizes the loss defined as the number of misclassifications, adjusted so that negative samples have a weight of w while positive samples have a weight of 1.0

F - 1 Score

{ "objective": "f1", "w": 0.9 }

This optimizes the loss defined as the F-1 score of the model's predictions.

Area under the Receiver Operanting Characteristics Curve

{ "objective": "auc" }

This maximizes the area under the ROC curve formed by varying the prediction of the leaves.

Partial Area under the Receiver Operanting Characteristics Curve

{ "objective": "pauc", "theta": 0.1 }

This maximizes the partial area under the ROC curve formed by varying the prediction of the leaves. The area is constrained so that false-positive-rate is in the closed interval [0,theta]


Dependencies

List of external dependencies

The following dependencies need to be installed to build the program.

Bundled Dependencies

The following dependencies are included as part of the repository, thus requiring no additional installation.

Installation

Install these using your system package manager. There are also installation scripts provided for your convenience: trainer/auto

These currently support interface with brew and apt


FAQs

If you run into any issues, consult the FAQs first.


License

Licensing information


Inquiries

For general inquiries, send an email to jimmy.projects.lin@gmail.com