Skip to content

An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).

Notifications You must be signed in to change notification settings

ANGLE404/efficient-kan

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

An Efficient Implementation of Kolmogorov-Arnold Network

This repository contains an efficient implementation of Kolmogorov-Arnold Network (KAN). The original implementation of KAN is available here.

The performance issue of the original implementation is mostly because it needs to expand all intermediate variables to perform the different activation functions. For a layer with in_features input and out_features output, the original implementation needs to expand the input to a tensor with shape (batch_size, out_features, in_features) to perform the activation functions. However, all activation functions are linear combination of a fixed set of basis functions which are B-splines; given that, we can reformulate the computation as activate the input with different basis functions and then combine them linearly. This reformulation can significantly reduce the memory cost and make the computation a straightforward matrix multiplication, and works with both forward and backward pass naturally.

The problem is in the sparsification which is claimed to be critical to KAN's interpretability. The authors proposed a L1 regularization defined on the input samples, which requires non-linear operations on the (batch_size, out_features, in_features) tensor, and is thus not compatible with the reformulation. I instead replace the L1 regularization with a L1 regularization on the weights, which is more common in neural networks and is compatible with the reformulation. The author's implementation indeed include this kind of regularization alongside the one described in the paper as well, so I think it might help. More experiments are needed to verify this; but at least the original approach is infeasible if efficiency is wanted.

Another difference is that, beside the learnable activation functions (B-splines), the original implementation also includes a learnable scale on each activation function. I provided an option enable_standalone_scale_spline that defaults to True to include this feature; disable it will make the model more efficient, but potentially hurts results. It needs more experiments.

2024-05-04 Update: @xiaol hinted that the constant initialization of base_weight parameters can be a problem on MNIST. For now I've changed both the base_weight and spline_scaler matrices to be initialized with kaiming_uniform_, following nn.Linear's initialization. It seems to work much much better on MNIST (~20% to ~97%), but I'm not sure if it's a good idea in general.

About

An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%