Home

Awesome

CellTypeGraph Benchmark

CellTypeGraph is a new graph benchmark for node classification.

Benchmark Overview

The dataset

The benchmark is distilled from of 84 Arabidopsis ovules segmentations, and the task is to classify each cell with its specific cell type. We represent each specimen as a graph, where each cell is a node and any two adjacent cells are connected with an edge. This python-package comes with a Pytorch DataLoader, and pre-computed node and edge features. But the latter can be fully customized and modified. The source data for CellTypeGraph Benchmark can be also manually download from zenodo.org.

Evaluation

In the package we also include evaluation code and examples.

Results

To see our most recent results check out the leadboard page in the repository wiki.

Requirements

Dependencies

Optional Dependencies (for running the examples):

Install CellTypeGraph Benchmark using conda

conda create -n ctg -c rusty1s -c pytorch -c conda-forge -c lcerrone ctg-benchmark cudatoolkit=11.3
conda create -n ctg -c rusty1s -c pytorch -c conda-forge -c lcerrone ctg-benchmark cudatoolkit=10.2
conda create -n ctg -c rusty1s -c pytorch -c conda-forge -c lcerrone ctg-benchmark cpuonly 

Simple training example

Basic usage

from ctg_benchmark.loaders import get_cross_validation_loaders
loaders_dict = get_cross_validation_loaders(root='./ctg_data/')

where loaders_dict is a dictionary that contains 5 tuple of training and validation data-loaders.

for split, loader_dict in loaders_dict.items():
    train_loader = loader_dict['train'] 
    val_loader = loader_dict['val']
from ctg_benchmark.loaders import get_split_loaders
loader = get_split_loaders(root='./ctg_data/',)
print(loader['train'], loader['val'], loader['test'])
from ctg_benchmark.evaluation import NodeClassificationMetrics, aggregate_class
eval_metrics = NodeClassificationMetrics(num_classes=9)

predictions = torch.randint(9, (1000,))
target = torch.randint(9, (1000,))
results = eval_metrics.compute_metrics(predictions, target)
class_average_accuracy, _ = aggregate_class(results['accuracy_class'], index=7)

print(f"global accuracy: {results['accuracy_micro']: .3f}")
print(f"class average accuracy: {class_average_accuracy: .3f}")

Advanced usage examples

Reproducibility

Cite

@inproceedings{cerrone2022celltypegraph, title={CellTypeGraph: A New Geometric Computer Vision Benchmark}, author={Cerrone, Lorenzo and Vijayan, Athul and Mody, Tejasvinee and Schneitz, Kay and Hamprecht, Fred A}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={20897--20907}, year={2022} }