Source code for icenet.deep.iceboostgrad

# Autograd losses for iceboost == xgboost + torch
# with various Hessian diagonal approaches.
#
# m.mieskolainen@imperial.ac.uk, 2025

import numpy as np
import torch
from torch import Tensor
import xgboost
import time
from tqdm import tqdm

from typing import Callable, Sequence, List, Tuple

# ------------------------------------------
from icenet import print
# ------------------------------------------

[docs] class AutogradObjective(): """ Custom loss driver class with xgboost + torch (autograd) hessian_mode: 'hutchinson' (or 'iterative') may make the model converge significantly faster (or better) than 'constant' in some cases. N.B. Remember to call manually: obj.mode = 'train' or obj.mode = 'eval' while running the training boost iteration loop, otherwise .grad_prev, .preds_prev will get mixed with 'iterative' hessian mode. Args: loss_func: Loss function handle mode: 'train' or 'eval', see the comment above flatten_grad: For vector valued model output [experimental] hessian_mode: 'constant', 'iterative', 'hutchinson', 'exact' hessian_const: Hessian scalar parameter for the 'constant' mode hessian_beta: Hessian EMA smoothing parameter for the 'iterative' mode hessian_eps: Hessian estimate denominator regularization for the 'iterative' mode hessian_slices: Hutchinson MC estimator number of slices for the 'hutchinson' mode hessian_limit: Hessian absolute range clip parameters device: Torch device """ def __init__(self, loss_func: Callable[[Tensor, Tensor], Tensor], mode: str='train', flatten_grad: bool=False, hessian_mode: str='hutchinson', hessian_const: float=1.0, hessian_beta: float=0.9, hessian_eps: float=1e-8, hessian_slices: int=10, hessian_limit: list=[1e-2, 20], device: torch.device='cpu' ): self.mode = mode self.loss_func = loss_func self.device = device self.hessian_mode = hessian_mode # Constant mode self.hessian_const = hessian_const # Iterative mode self.hessian_beta = hessian_beta self.hessian_eps = hessian_eps # Hutchinson mode self.hessian_slices = int(hessian_slices) self.hessian_limit = hessian_limit self.flatten_grad = flatten_grad # For the algorithms self.hess_diag = None self.grad_prev = None self.preds_prev = None txt = f'Using device: {self.device} | hessian_mode = {self.hessian_mode} (use "constant" or "iterative" for speed)' match self.hessian_mode: case 'constant': print(f'{txt} | hessian_const = {self.hessian_const}') case 'iterative': print(f'{txt} | hessian_beta = {self.hessian_beta}') case 'hutchinson': print(f'{txt} | hessian_slices = {self.hessian_slices}') case _: print(f'{txt}') def __call__(self, preds: np.ndarray, targets: xgboost.DMatrix): preds_, targets_, weights_ = self.torch_conversion(preds=preds, targets=targets) match self.mode: case 'train': loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_) return self.derivatives(loss=loss, preds=preds_) case 'eval': loss = self.loss_func(preds=preds_, targets=targets_, weights=weights_) return 'custom', loss.detach().cpu().numpy() case _: raise Exception('Unknown mode (set either "train" or "eval")')
[docs] def torch_conversion(self, preds: np.ndarray, targets: xgboost.DMatrix): """ Conversion from xgboost.Dmatrix object """ try: weights = targets.get_weight() weights = None if weights == [] else torch.FloatTensor(weights).to(self.device) except: weights = None preds = torch.FloatTensor(preds).requires_grad_().to(self.device) targets = torch.FloatTensor(targets.get_label()).to(self.device) return preds, targets, weights
[docs] def regulate_hess(self, hess: Tensor): """ Regulate to be positive definite (H_ii > hessian_min) as required by second order gradient descent. Do not clip to zero, as that might result in zero denominators in the Hessian routines inside xgboost. """ hess = torch.abs(hess) # ~ negative weights hess = torch.clamp(hess, min=self.hessian_limit[0], max=self.hessian_limit[1]) return hess
[docs] @torch.no_grad def iterative_hessian_update(self, grad: Tensor, preds: Tensor): """ Iterative approximation of the Hessian diagonal using finite differences based on a previous boost iteration Hessian (full batch training ~ only one Hessian stored) Args: grad: Current gradient vector preds: Current prediction vector """ print(f'Computing Hessian diag with iterative finite difference (beta = {self.hessian_beta})') # Initialize to unit curvature as a neutral default # (if sigma_i^2 = 1/H_ii, then this is a Gaussian N(0,1) prior) if self.hess_diag is None: self.hess_diag = torch.ones_like(grad) hess_diag_new = torch.ones_like(grad) # H_ii ~ difference in gradients / difference in predictions else: dg = grad - self.grad_prev ds = preds - self.preds_prev hess_diag_new = dg / (ds + self.hessian_eps) hess_diag_new = self.regulate_hess(hess_diag_new) # regulate # Exponential Moving Average (EMA), approx filter size ~ 1 / (1 - beta) steps self.hess_diag = self.hessian_beta * self.hess_diag + \ (1 - self.hessian_beta) * hess_diag_new # Save the gradient vector and predictions self.grad_prev = grad.clone().detach() self.preds_prev = preds.clone().detach()
[docs] def hessian_hutchinson(self, grad: Tensor, preds: Tensor): """ Hutchinson MC estimator for the Hessian diagonal ~ O(slices) (time) """ tic = time.time() print(f'Computing Hessian diag with Hutchinson MC (slices = {self.hessian_slices}) ... ') hess = torch.zeros_like(preds) for _ in range(self.hessian_slices): # Generate a Rademacher vector (each element +-1 with probability 0.5) v = torch.empty_like(preds).uniform_(-1, 1) v = torch.sign(v) # Compute Hessian-vector product H * v Hv = torch.autograd.grad(grad, preds, grad_outputs=v, retain_graph=True)[0] # Accumulate element-wise product v * Hv to get the diagonal hess += v * Hv # Average over all samples hess = hess / self.hessian_slices hess = self.regulate_hess(hess) # regulate print(f'Took {time.time()-tic:.2f} sec') return hess
[docs] def hessian_exact(self, grad: Tensor, preds: Tensor): """ Hessian diagonal with exact autograd ~ O(data points) (time) """ tic = time.time() print('Computing Hessian diagonal with exact autograd ... ') hess = torch.zeros_like(preds) for i in tqdm(range(len(preds))): # A basis vector e_i = torch.zeros_like(preds) e_i[i] = 1.0 # Compute the Hessian-vector product H e_i hess[i] = torch.autograd.grad(grad, preds, grad_outputs=e_i, retain_graph=True)[0][i] hess = self.regulate_hess(hess) # regulate print(f'Took {time.time()-tic:.2f} sec') return hess
[docs] def derivatives(self, loss: Tensor, preds: Tensor) -> Tuple[np.ndarray, np.ndarray]: """ Gradient and Hessian diagonal Args: loss: loss function values preds: model predictions Returns: gradient vector, hessian diagonal vector as numpy arrays """ ## Gradient grad1 = torch.autograd.grad(loss, preds, create_graph=True)[0] ## Diagonal elements of the Hessian matrix match self.hessian_mode: # Constant curvature case 'constant': print(f'Setting Hessian diagonal using a constant (hessian_const = {self.hessian_const})') grad2 = self.hessian_const * torch.ones_like(grad1) # BFGS style iterative updates case 'iterative': self.iterative_hessian_update(grad=grad1, preds=preds) grad2 = self.hess_diag # Hutchinson based MC estimator case 'hutchinson': grad2 = self.hessian_hutchinson(grad=grad1, preds=preds) # Exact autograd (slow) case 'exact': grad2 = self.hessian_exact(grad=grad1, preds=preds) # Squared derivative based [uncontrolled] approximation case 'squared_approx': print(f'Setting Hessian diagonal using grad^2 [DEBUG ONLY]') grad2 = grad1 * grad1 case _: raise Exception(f'Unknown "hessian_mode" {self.hessian_mode}') # Return numpy arrays grad1, grad2 = grad1.detach().cpu().numpy(), grad2.detach().cpu().numpy() if self.flatten_grad: grad1, grad2 = grad1.flatten("F"), grad2.flatten("F") return grad1, grad2