Home

Awesome

TorchNTK

An Arbitrary** PyTorch Architecture Neural Tangent Kernel Library

This code was developed to bridge a gap in NTK computation before the release Pytorch1.11; but now with Pytorch 1.11 release I advise you take a look at functorch's NTK page, which generally will have better development + improvements than this repo. In other words, we do not expect to support this repo moving forward.

Installation

git clone https://github.com/pnnl/torchntk
export PYTHONPATH="${PYTHONPATH}:/my/path/TorchNTK/"

Basic Usage


import torchntk
import torch

DEVICE = 'cpu' #or cuda, lets say

model = Pytorch_Model() #Any architecture-- BUT must terminate in single neuron
model.to(DEVICE)

Y = model(X) 

NTK_components = torchntk.autograd.autograd_components_ntk(model,Y)

or, a generally faster implementation exists if torch.vmap exists (currently available in pytorch nightly builds only)


import torchntk
import torch
from torch.utils.data import DataLoader, TensorDataset

DEVICE = 'cuda' #

model = Pytorch_Model() #Any architecture-- BUT must terminate in single neuron
model.to(DEVICE)

xloader = DataLoader(TensorDataset(My_data,My_targets),batch_size=64, shuffle=False)

NTK_components = torchntk.autograd.vmap_ntk_loader(model,xloader)

Finally, if you are using a fully connected network (a network composed only of torch.nn.Linear layers) you can use this last method which is typically much faster:

import torchntk
import torch

DEVICE = 'cuda'

def activation(X):
    return torch.tanh(X)
	
def d_activation(X):
    return torch.cosh(X)**-2

class MLP(torch.nn.Module):
    def __init__(self,):
        super(MLP, self).__init__()
        self.d1 = torch.nn.Linear(784,100,bias=True) 
        self.d2 = torch.nn.Linear(100,100,bias=True)
        self.d3 = torch.nn.Linear(100,1,bias=True) 
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0)) / torch.sqrt(100)
        x_2 = activation(self.d2(x_1)) / torch.sqrt(100)
        x_3 = activation(self.d3(x_2)) / torch.sqrt(1)
        return x_3, x_2, x_1, x_0 


model = MLP()
model.to(DEVICE)

x_3, x_2, x_1, x_0 = model(X) #for some data, X

Xs = [x_0.T.detach(),
      x_1.T.detach(),
	  x_2.T.detach()]
	  
layers = [model.d1,
          model.d2,
		  model.d3]
		  
#this must match the layer's width
ds_int = [100, 100, 1]

#this must match what you divided the layer by, squared.
#i.e., if you didn't divide each layer by anything, this should be all ones.
ds_float = [100.0, 100.0, 1.0]


config = {'Xs':Xs,
          'layers':layers,
		  'ds_int':ds_int,
		  'ds_float':ds_float,
		  'dactivation_t':d_activation}
 
components = torchntk.explicit.explicit_ntk(**config)
#components is a list of torch.Tensor objects representing each component of
#the NTK from each parameterized operation in reverse order. Meaning, 
#components[0] is the outermost layer weight matrix NTK component, 
#components[1] is the outermost layer bias vector NTK component,
# ...
#components[-1] is the first layer's bias vector NTK components 

#to get the full NTK, simply sum the components across the list's dimension.

Logging with Tensorboard

check the tensorboard.ipynb notebook.

Once installed, Tensorboard can be started on the command line with:

tensorboard --logdir=LOGDIR

Possible Metrics of Interest

The condition number is the (minimum eigenvalue of the NTK / maximum eigenvalue of the NTK). It is negatively correlated with model performance

Credit

"torchntk.autograd.old_autograd_ntk" was directly adatapted from the TENAS group's code, available here , and you can view their paper on neural architecture seach here; authored by Chen, Wuyang and Gong, Xinyu and Wang, Zhangyang and titled: "Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective"

Some backward propogation functions were originally copied then heavily modified from this article by Pierre Jaumier, available here

I've also included some utility functions that I directly copied from the PyTorch source; therefore, their license clause is included in ours.

Experimental autograd operations were adapted from web pages in the pre-release of Pytorch1.11; but now with Pytorch 1.11 release I advise you take a look at functorch's NTK page.

Software TODO (or how you can contribute)