# Based on <github.com/ZiyaoLi/fast-kan/blob/master/fastkan>
#
# m.mieskolainen@imperial.ac.uk, 2024
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import *
from icenet.deep.deeptools import Multiply
[docs]
class SplineLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
self.init_scale = init_scale
super().__init__(in_features, out_features, bias=False, **kw)
[docs]
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
[docs]
class RadialBasisFunction(nn.Module):
def __init__(
self,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
denominator: float = None, # larger denominators lead to smoother basis
):
super().__init__()
grid = torch.linspace(grid_min, grid_max, num_grids)
self.grid = torch.nn.Parameter(grid, requires_grad=False)
self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
[docs]
def forward(self, x):
return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
[docs]
class FastKANLayer(nn.Module):
def __init__(self, input_dim: int, output_dim: int, grid_min: float = -2., grid_max: float = 2.,
num_grids: int = 8, use_base_update: bool = False, base_activation = F.silu, spline_weight_init_scale: float = 0.1):
super().__init__()
self.layernorm = nn.LayerNorm(input_dim)
self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
self.use_base_update = use_base_update
if use_base_update:
self.base_activation = base_activation
self.base_linear = nn.Linear(input_dim, output_dim)
[docs]
def forward(self, x, time_benchmark=False):
if not time_benchmark:
spline_basis = self.rbf(self.layernorm(x))
else:
spline_basis = self.rbf(x)
ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
if self.use_base_update:
base = self.base_linear(self.base_activation(x))
ret = ret + base
return ret
[docs]
class FastKAN(nn.Module):
def __init__(self, D, C, mlp_dim: List[int], grid_min: float = -2.0, grid_max: float = 2.0, num_grids: int = 8,
use_base_update: bool = False, base_activation = F.silu, spline_weight_init_scale: float = 0.1,
out_dim=None, last_tanh=False, last_tanh_scale=10.0, **kwargs):
super(FastKAN, self).__init__()
self.D = D
self.C = C
if out_dim is None:
self.out_dim = C
else:
self.out_dim = out_dim
layers_hidden = [D] + mlp_dim + [self.out_dim]
self.layers = nn.ModuleList([
FastKANLayer(
input_dim = in_dim,
output_dim = out_dim,
grid_min = grid_min,
grid_max = grid_max,
num_grids = num_grids,
use_base_update = use_base_update,
base_activation = base_activation,
spline_weight_init_scale = spline_weight_init_scale,
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
# Add extra final squeezing activation and post-scale aka "soft clipping"
if last_tanh:
self.layers.add_module("tanh", nn.Tanh())
self.layers.add_module("scale", Multiply(last_tanh_scale))
[docs]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
[docs]
def softpredict(self,x) :
""" Softmax probability
"""
if self.out_dim > 1:
return F.softmax(self.forward(x), dim=-1)
else:
return torch.sigmoid(self.forward(x))