Awesome
<div align = "center"> <img width=400 src="assets/kernexlogo.svg" align="center"> <h3 align="center">Differentiable Stencil computations in JAX </h2>Installation |Description |Quick example |More Examples |Benchmarking
</div>🛠️ Installation<a id="Installation"></a>
pip install kernex
📖 Description<a id="Description"></a>
Kernex extends jax.vmap
/jax.lax.map
/jax.pmap
with kmap
and jax.lax.scan
with kscan
for general stencil computations.
⏩ Quick Example <a id="QuickExample">
<div align="center"> <table> <tr> <td width="50%" align="center" > kmap </td> <td align="center" > kscan </td> </tr> <tr> <td>import kernex as kex
import jax.numpy as jnp
@kex.kmap(kernel_size=(3,))
def sum_all(x):
return jnp.sum(x)
x = jnp.array([1,2,3,4,5])
print(sum_all(x))
# [ 6 9 12]
</td>
<td>
import kernex as kex
import jax.numpy as jnp
@kex.kscan(kernel_size=(3,))
def sum_all(x):
return jnp.sum(x)
x = jnp.array([1,2,3,4,5])
print(sum_all(x))
# [ 6 13 22]
</td>
</tr>
</table>
<table>
<tr>
<td width="50%">
`jax.vmap` is used to sum each window content.
<img src="assets/kmap_sum.png" width=400px>
</td>
<td>
`lax.scan` is used to update the array and the window sum is calculated sequentially.
the first three rows represents the three sequential steps used to get the solution in the last row.
<img align="center" src="assets/kscan_sum.png" width=400px>
</td>
</tr>
</table>
</div>
🔢 More examples<a id="MoreExamples"></a>
<details> <summary>1️⃣ Convolution operation</summary>import jax
import jax.numpy as jnp
import kernex as kex
@jax.jit
@kex.kmap(
kernel_size= (3,3,3),
padding = ('valid','same','same'))
def kernex_conv2d(x,w):
# JAX channel first conv2d with 3x3x3 kernel_size
return jnp.sum(x*w)
</details>
<details>
<summary>2️⃣ Laplacian operation</summary>
# see also
# https://numba.pydata.org/numba-doc/latest/user/stencil.html#basic-usage
import jax
import jax.numpy as jnp
import kernex as kex
@kex.kmap(
kernel_size=(3,3),
padding= 'valid',
relative=True) # `relative`= True enables relative indexing
def laplacian(x):
return ( 0*x[1,-1] + 1*x[1,0] + 0*x[1,1] +
1*x[0,-1] +-4*x[0,0] + 1*x[0,1] +
0*x[-1,-1] + 1*x[-1,0] + 0*x[-1,1] )
print(laplacian(jnp.ones([10,10])))
# [[0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0., 0., 0., 0.]]
</details>
<details><summary>3️⃣ Get Patches of an array</summary>
import jax
import jax.numpy as jnp
import kernex as kex
@kex.kmap(kernel_size=(3,3),relative=True)
def identity(x):
# similar to numba.stencil
# this function returns the top left cell in the padded/unpadded kernel view
# or center cell if `relative`=True
return x[0,0]
# unlike numba.stencil , vector output is allowed in kernex
# this function is similar to
# `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')`
@jax.jit
@kex.kmap(kernel_size=(3,3),padding='same')
def get_3x3_patches(x):
# returns 5x5x3x3 array
return x
mat = jnp.arange(1,26).reshape(5,5)
print(mat)
# [[ 1 2 3 4 5]
# [ 6 7 8 9 10]
# [11 12 13 14 15]
# [16 17 18 19 20]
# [21 22 23 24 25]]
# get the view at array index = (0,0)
print(get_3x3_patches(mat)[0,0])
# [[0 0 0]
# [0 1 2]
# [0 6 7]]
</details>
<details>
<summary>4️⃣ Linear convection </summary>
<div align ="center">
<table>
<tr>
<td> Problem setup </td> <td> Stencil view </td>
</tr>
<tr>
<td>
<img src="assets/linear_convection_init.png" width="500px">
</td>
<td>
<img src="assets/linear_convection_view.png" width="500px">
</td>
</tr>
</table>
</div>
import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt
# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb
tmax,xmax = 0.5,2.0
nt,nx = 151,51
dt,dx = tmax/(nt-1) , xmax/(nx-1)
u = jnp.ones([nt,nx])
c = 0.5
# kscan moves sequentially in row-major order and updates in-place using lax.scan.
F = kernex.kscan(
kernel_size = (3,3),
padding = ((1,1),(1,1)),
# n for time axis , i for spatial axis (optional naming)
named_axis={0:'n',1:'i'},
relative=True
)
# boundary condtion as a function
def bc(u):
return 1
# initial condtion as a function
def ic1(u):
return 1
def ic2(u):
return 2
def linear_convection(u):
return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )
F[:,0] = F[:,-1] = bc # assign 1 for left and right boundary for all t
# square wave initial condition
F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1
F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2
# assign linear convection function for
# interior spatial location [1:-1]
# and start from t>0 [1:]
F[1:,1:-1] = linear_convection
kx_solution = F(jnp.array(u))
plt.figure(figsize=(20,7))
for line in kx_solution[::20]:
plt.plot(jnp.linspace(0,xmax,nx),line)
<img src="assets/linear_convection.svg">
</details>
<details><summary>5️⃣ Gaussian blur</summary>
import jax
import jax.numpy as jnp
import kernex as kex
def gaussian_blur(image, sigma, kernel_size):
x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
w = jnp.outer(w, w)
w = w / w.sum()
@kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
def conv(x):
return jnp.sum(x * w)
return conv(image)
</details>
<details > <summary>6️⃣ Depthwise convolution </summary>
import jax
import jax.numpy as jnp
import kernex as kex
@jax.jit
@jax.vmap
@kex.kmap(
kernel_size= (3,3),
padding = ('same','same'))
def kernex_depthwise_conv2d(x,w):
return jnp.sum(x*w)
h,w,c = 5,5,2
k=3
x = jnp.arange(1,h*w*c+1).reshape(c,h,w)
w = jnp.arange(1,k*k*c+1).reshape(c,k,k)
print(kernex_depthwise_conv2d(x,w))
</details>
<details> <summary>7️⃣ Average pooling 2D </summary>
@jax.vmap # vectorize over the channel dimension
@kex.kmap(kernel_size=(3,3), strides=(2,2))
def avgpool_2d(x):
# define the kernel for the Average pool operation over the spatial dimensions
return jnp.mean(x)
</details>
<details><summary>8️⃣ Runge-Kutta integration</summary>
# lets solve dydt = y, where y0 = 1 and y(t)=e^t
# using Runge-Kutta 4th order method
# f(t,y) = y
import jax.numpy as jnp
import matplotlib.pyplot as plt
import kernex as kex
t = jnp.linspace(0, 1, 5)
y = jnp.zeros(5)
x = jnp.stack([y, t], axis=0)
dt = t[1] - t[0] # 0.1
f = lambda tn, yn: yn
def ic(x):
""" initial condition y0 = 1 """
return 1.
def rk4(x):
""" runge kutta 4th order integration step """
# ┌────┬────┬────┐ ┌──────┬──────┬──────┐
# │ y0 │*y1*│ y2 │ │[0,-1]│[0, 0]│[0, 1]│
# ├────┼────┼────┤ ==> ├──────┼──────┼──────┤
# │ t0 │ t1 │ t2 │ │[1,-1]│[1, 0]│[1, 1]│
# └────┴────┴────┘ └──────┴──────┴──────┘
t0 = x[1, -1]
y0 = x[0, -1]
k1 = dt * f(t0, y0)
k2 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k1)
k3 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k2)
k4 = dt * f(t0 + dt, y0 + k3)
yn_1 = y0 + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return yn_1
F = kex.kscan(kernel_size=(2, 3), relative=True, padding=((0, 1))) # kernel size = 3
F[0:1, 1:] = rk4
F[0, 0] = ic
# compile the solver
solver = jax.jit(F.__call__)
y = solver(x)[0, :]
plt.plot(t, y, '-o', label='rk4')
plt.plot(t, jnp.exp(t), '-o', label='analytical')
plt.legend()
</details>