Reading notes: Kolmogorov-Arnold Networks
This paper (Liu et al. 2024) proposes an alternative to multi-layer perceptrons (MLPs) in machine learning.
The basic idea is that MLPs have parameters on the nodes of the computation graph (the weights and biases on each cell), and that KANs have the parameters on the edges. Each edge has a learnable activation function parameterized as a spline.
The network is learned at two levels, which allows for “adjusting locally”:
- the overall shape of the computation graph and its connexions (external degrees of freedom, to learn the compositional structure),
- the parameters of each activation function (internal degrees of freedom).
It is based on the Kolmogorov-Arnold representation theorem, which says that any continuous multivariate function can be represented as a sum of continuous univariate functions. We recover the distinction between the compositional structure of the sum and the structure of each internal univariate function.
The theorem can be interpreted as two layers, and the paper then generalizes it to multiple layer of arbitrary width. In the theorem, the univariate functions are arbitrary and can be complex (even fractal), so the hope is that allowing for arbitrary depth and width will allow to only use splines. They derive an approximation theorem: when replacing the arbitrary continuous functions of the Kolmogorov-Arnold representation with splines, we can bound the error independently of the dimension. (However there is a constant which depends on the function and its representation, and therefore on the dimension…) Theoretical scaling laws in the number of parameters are much better than for MLPs, and moreover, experiments show that KANs are much closer to their theoretical bounds than MLPs.
KANs have interesting properties:
- The splines are interpolated on grid points which can be iteratively refined. The fact that there is a notion of “fine-grainedness” is very interesting, it allows to add parameters without having to retrain everything.
- Larger is not always better: the quality of the reconstruction depends on finding the optimal shape of the network, which should match the structure of the function we want to approximate. Finding this optimal shape is found via sparsification, pruning, and regularization (non-trivial).
- We can have a “human in the loop” during training, guiding pruning, and “symbolifying” some activations (i.e. by recognizing that an activation function is actually a cos function, replace it directly). This symbolic discovery can be guided by a symbolic system recognizing some functions. It’s therefore a mix of symbolic regression and numerical regression.
They test mostly with scientific applications in mind: reconstructing equations from physics and pure maths. Conceptually, it has a lot of overlap with Neural Differential Equations (Chen et al. 2018; Ruthotto 2024) and “scientific ML” in general.
There is an interesting discussion at the end about KANs as the model of choice for the “language of science”. The idea is that LLMs are important because they are useful for natural language, and KANs could fill the same role for the language of functions. The interpretability and adaptability (being able to be manipulated and guided during training by a domain expert) is thus a core feature that traditional deep learning models lack.
There are still challenges, mostly it’s unclear how it performs on other types of data and other modalities, but it is very encouraging. There is also a computational challenges, they are obviously much slower to train, but there has been almost no engineering work on them to optimize this, so it’s expected. The fact that the operations are not easily batchable (compared to matrix multiplication) is however worrying for scalability to large networks.