Home

Awesome

SAGA-based GPU solver for elastic net problems

A package for fitting sparse linear models at deep learning scales. This work was initially created and described in our paper, "Leveraging Sparse Linear Layers for Debuggable Deep Networks" with Eric Wong, Shibani Santurkar and Aleksander Madry. The main repository for this paper can be found here.

This package implements a SAGA-based solver in PyTorch for fitting sparse linear models with elastic net regularization. It combines the path algorithm used by glmnet with a minibatch variant of the SAGA algorithm, which allows solving the elastic net at ImageNet scales, which coordinate descent-based elastic net solvers struggle with.

Citation

If you find this solver to be useful in your work, consider citing our paper:

@article{wong2021leveraging,
  title={Leveraging Sparse Linear Layers for Debuggable Deep Networks},
  author={Wong, Eric and Santurkar, Shibani and M{\k{a}}dry, Aleksander},
  journal={arXiv preprint arXiv:2105.04857},
  year={2021}
}

Installation

This package is on PyPI. Install it with pip install glm_saga. The only requirement is PyTorch. Alternatively, the entire solver is implemented in glm_saga/elasticnet.py and can be copied locally if desired.

Usage and documentation

The following function is the main interface which can be used to fit a sequence of sparse linear models. A barebones example which fits a sparse linear model on top of a ResNet18 can be found in resnet18_example.py.

def glm_saga(linear, loader, max_lr, nepochs, alpha, 
             table_device=None, precompute=None, group=False, 
             verbose=None, state=None, n_ex=None, n_classes=None, 
             tol=1e-4, epsilon=0.001, k=100, checkpoint=None, 
             solver='saga', do_zero=True, lr_decay_factor=1, metadata=None, 
             val_loader=None, test_loader=None, lookbehind=None, 
             family='multinomial'):

Required arguments

Optional arguments

For extremely large datasets like ImageNet, it may be desired to avoid multiple initial passes over the dataloader which only calculates dataset statistics. This can be done by supplying the metadata variable, which is of the following form:

metadata = {
    'X' : {
        'mean': 0, 
        'std': 1
    },
    'y' : {
        'mean': 0, 
        'std': 1
    }, 
    'max_reg': {
        'grouped': 0.5, 
        'ungrouped': 0.5
    }
}

Any metadata supplied through this variable will not be recomputed. Not all variables need to be specified (i.e. it is possible to supply only the mean and standard deviation, and still perform one pass to calculate the maximum regularization).

Additional helper functions

The package also implements several additional functions which are helpful in order to adapt datasets to the format required by the solver, such as adding example indices and normalizers for the data.

Adding indices to datasets and dataloaders

IndexedTensorDataset(TensorDataset): 
    def __init__(self, *tensors): 
class IndexedDataset(Dataset): 
    def __init__(self, ds, sample_weight=None): 
add_index_to_dataloader(loader, sample_weight=None): 

Normalizing datasets

Sometimes a PyTorch dataset or dataloader is not easy to normalize directly. In this case, we can construct a normalizing PyTorch module and pass this into the solver via the preprocess argument.

class NormalizedRepresentation(nn.Module): 
    def __init__(self, loader, model=None, do_tqdm=True, mean=None, std=None, metadata=None, device='cuda'): 

Tests

A number of tests are in tests.py, which match the output of the implemented solve to solvers that exist in scikit-learn.