# Event sample re-weighting tools
#
# m.mieskolainen@imperial.ac.uk, 2025
import os
import numpy as np
import awkward as ak
import torch
import matplotlib.pyplot as plt
import copy
from tqdm import tqdm
from icenet.tools import aux, io, prints
from icenet.deep import iceboost, predict
# ------------------------------------------
from icenet import print
# ------------------------------------------
[docs]
def histogram_helper(x, y, w, ids, pdf, args, EPS):
"""
Helper function for histogram based reweighting
"""
class_ids = np.unique(y.astype(int))
reference_class = args['reweight_param']['reference_class']
diff_args = args['reweight_param']['diff_param']
h_args = args['reweight_param']['diff_param']['hist_param']
variables = diff_args['var']
print(f"Reference (target) class [{reference_class}] | Found classes: {class_ids} from y")
### Collect re-weighting variables
RV = {}
for i, var in enumerate(variables):
if isinstance(x, ak.Array):
RV[i] = ak.to_numpy(x[var]).astype(np.float32)
else:
RV[i] = x[:, ids.index(var)].astype(np.float32)
### Pre-transform
for i, var in enumerate(variables):
mode = h_args[f'transform'][i]
if h_args[f'binmode'][i] == 'edges':
raise Exception(__name__ + '.histogram_helper: Cannot transform "edges" type')
d = h_args[f'bins'][i]
if mode == 'log10':
if np.any(RV[i] <= 0):
ind = (RV[i] <= 0)
print(f'Variable {var} < 0 (in {np.sum(ind)} elements) in log10 -- truncating to zero', 'red')
# Transform values
RV[i] = np.log10(np.maximum(RV[i], EPS))
# Transform bins
h_args[f'bins'][i][0] = np.log10(max(d[0], EPS))
h_args[f'bins'][i][1] = np.log10(d[1])
elif mode == 'sqrt':
# Values
RV[i] = np.sqrt(np.maximum(RV[i], EPS))
# Bins
h_args[f'bins'][i][0] = np.sqrt(max(d[0], EPS))
h_args[f'bins'][i][1] = np.sqrt(d[1])
elif mode == 'square':
# Values
RV[i] = RV[i]**2
# Bins
h_args[f'bins'][i][0] = d[0]**2
h_args[f'bins'][i][1] = d[1]**2
elif mode == None:
True
else:
raise Exception(__name__ + '.histogram_helper: Unknown pre-transform')
# Binning setup
binedges = {}
for i, var in enumerate(variables):
d = h_args[f'bins'][i]
if h_args[f'binmode'][i] == 'linear':
binedges[i] = np.linspace(d[0], d[1], d[2])
elif h_args[f'binmode'][i] == 'log10':
binedges[i] = np.logspace(np.log10(max(d[0], EPS)), np.log10(d[1]), d[2], base=10)
elif h_args[f'binmode'][i] == 'edges':
binedges[i] = np.array(d)
else:
raise Exception(__name__ + ': Unknown reweight binning mode')
rwparam = {
'y': y,
'reference_class': reference_class,
'max_reg': h_args['max_reg']
}
### Compute 2D-PDFs for each class
if diff_args['type'] == '2D':
print(f'2D re-weighting with: [{variables[0]}, {variables[1]}] ...')
if pdf is None: # Not given by user
pdf = {}
for c in class_ids:
ind = (y == c)
sample_weights = w[ind] if w is not None else None # Feed in the input weights !
pdf[c] = pdf_2D_hist(X_A=RV[0][ind], X_B=RV[1][ind], w=sample_weights, \
binedges_A=binedges[0], binedges_B=binedges[1])
pdf['binedges'] = binedges
pdf['class_ids'] = class_ids
weights_doublet = reweightcoeff2D(X_A = RV[0], X_B = RV[1], pdf=pdf, **rwparam)
### Compute factorized 1D x 1D product
elif diff_args['type'] == 'pseudo-ND':
print(f'pseudo-ND (2D for now) reweighting with: [{variables[0]}, {variables[1]}] ...')
if pdf is None: # Not given by user
pdf = {}
for c in class_ids:
ind = (y == c)
sample_weights = w[ind] if w is not None else None # Feed in the input weights !
pdf_0 = pdf_1D_hist(X=RV[0][ind], w=sample_weights, binedges=binedges[0])
pdf_1 = pdf_1D_hist(X=RV[1][ind], w=sample_weights, binedges=binedges[1])
if h_args['pseudo_type'] == 'geometric_mean':
pdf[c] = np.sqrt(np.outer(pdf_0, pdf_1)) # (A,B) order gives normal matrix indexing
elif h_args['pseudo_type'] == 'product':
pdf[c] = np.outer(pdf_0, pdf_1)
else:
raise Exception(__name__ + f'.histogram_helper: Unknown "pseudo_type"')
# Normalize to discrete density
pdf[c] /= np.sum(pdf[c].flatten())
pdf['binedges'] = binedges
pdf['class_ids'] = class_ids
weights_doublet = reweightcoeff2D(X_A = RV[0], X_B = RV[1], pdf=pdf, **rwparam)
### Compute 1D-PDFs for each class
elif diff_args['type'] == '1D':
print(f'1D reweighting with: {variables[0]} ...')
if pdf is None: # Not given by user
pdf = {}
for c in class_ids:
ind = (y == c)
sample_weights = w[ind] if w is not None else None # Feed in the input weights !
pdf[c] = pdf_1D_hist(X=RV[0][ind], w=sample_weights, binedges=binedges[0])
pdf['binedges'] = binedges[0]
pdf['class_ids'] = class_ids
weights_doublet = reweightcoeff1D(X = RV[0], pdf=pdf, **rwparam)
else:
raise Exception(__name__ + f'.histogram_helper: Unsupported dimensionality mode "{diff_args["type"]}"')
pdf['vars'] = variables
return weights_doublet, pdf
[docs]
def map_xyw(x, y, w, vars, c, reference_class):
"""
For AIRW helper
"""
# Source and Target collected
ind = (y == c) | (y == reference_class)
new_data = io.IceXYW(x=x[ind], y=y[ind], w=w[ind], ids=vars)
# ----------------------------
# Change labels to (0,1)
y_new = np.zeros(len(new_data.y)).astype(np.int32)
y_new[new_data.y == c] = 0
y_new[new_data.y == reference_class] = 1
new_data.y = y_new # !
# Reference scale
n_tilde = max(np.sum(y_new == 0), np.sum(y_new == 1))
# ----------------------------
# Equalize class balance for the training
for k in [0,1]:
ind = (y_new == k)
new_data.w[ind] = n_tilde * new_data.w[ind] / np.sum(new_data.w[ind])
return new_data
[docs]
def AIRW_helper(x, y, w, ids, pdf, args, x_val, y_val, w_val, EPS=1e-12):
"""
Helper function for ML based reweighting
"""
class_ids = np.unique(y.astype(int))
reference_class = args['reweight_param']['reference_class']
diff_args = args['reweight_param']['diff_param']
variables = diff_args['var']
ID = diff_args['AIRW_param']['active_model']
RW_mode = diff_args['AIRW_param']['mode']
MAX_REG = diff_args['AIRW_param']['max_reg']
param = args['models'][ID]
# ----------------------------------------
# Conversions and pick variables of interest
if isinstance(x, ak.Array):
x = aux.ak2numpy(x=x, fields=variables).astype(np.float32)
if x_val is not None:
x_val = aux.ak2numpy(x=x_val, fields=variables).astype(np.float32)
else:
x = x[:, io.index_list(ids, variables)].astype(np.float32)
if x_val is not None:
x_val = x_val[:, io.index_list(ids, variables)].astype(np.float32)
# ----------------------------------------
print(f'Training N-dim reweighting', 'magenta')
print(f'x.shape = {x.shape}', 'magenta')
print(f'variables = {variables}', 'magenta')
print('')
# Train model per class pair (will skip if we use this function to evaluate)
if pdf is None:
print('Training mode:')
aux.makedir(os.path.join(args["plotdir"], 'train/AIRW_S1'))
# Training
print('Training sample RAW weights:', 'magenta')
output_file = os.path.join(args["plotdir"], 'train/AIRW_S1', 'stats_train_input_weights_raw.log')
prints.print_weights(weights=w, y=y, output_file=output_file)
print('Training sample RAW data:', 'magenta')
output_file = os.path.join(f'{args["plotdir"]}', 'train/AIRW_S1', f'stats_train_input_data_raw.log')
prints.print_variables(x, variables, w, output_file=output_file)
# Validation
print('Validation sample RAW weights:', 'magenta')
output_file = os.path.join(args["plotdir"], 'train/AIRW_S1', 'stats_validate_input_weights_raw.log')
prints.print_weights(weights=w_val, y=y_val, output_file=output_file)
print('Validation sample RAW data:', 'magenta')
output_file = os.path.join(f'{args["plotdir"]}', 'train/AIRW_S1', f'stats_validate_input_data_raw.log')
prints.print_variables(x_val, variables, w_val, output_file=output_file)
# --------------------------------------------------
# Optimal Dequantization
if 'optimal_dequantize' in diff_args['AIRW_param'] and diff_args['AIRW_param']['optimal_dequantize'] > 0:
print('Optimal Dequantization for training data based on mantissa bit depth', 'green')
for j in tqdm(range(x.shape[1])):
out = io.infer_precision(arr=x[:,j])
p = out.get('mantissa_bits_eff', None)
if p is None:
name = variables[j]
print(f"infer_precision failed to return mantissa_bits_eff for variable {name} -- skip")
else:
scale = diff_args['AIRW_param']['optimal_dequantize'] # Additional scale boost
x[:,j] = io.optimal_dequantize(x=x[:,j], p=p, scale=scale)
print('Training data after dequantization:')
output_file = os.path.join(f'{args["plotdir"]}', 'train/AIRW_S1', f'stats_train_input_data_dequant.log')
prints.print_variables(x, variables, w, output_file=output_file)
# Here one could add Z-standardization (but not needed with xgboost)
pdf = {'ID': ID, 'param': param, 'model': {}, 'vars': variables}
for c in class_ids:
if c != reference_class:
data_trn = map_xyw(x=x, y=y, w=w, vars=variables, c=c, reference_class=reference_class)
data_val = map_xyw(x=x_val, y=y_val, w=w_val, vars=variables, c=c, reference_class=reference_class)
# Train
print('Training sample weights going into training:', 'magenta')
output_file = os.path.join(args["plotdir"], 'train/AIRW_S1', 'stats_train_input_weights.log')
prints.print_weights(weights=data_trn.w, y=data_trn.y, output_file=output_file)
# Validate
print('Validation sample weights going into training:', 'magenta')
output_file = os.path.join(args["plotdir"], 'train/AIRW_S1', 'stats_validate_input_weights.log')
prints.print_weights(weights=data_val.w, y=data_val.y, output_file=output_file)
inputs = {'data_trn': data_trn,
'data_val': data_val,
'args': args,
'data_trn_MI': None,
'data_val_MI': None,
'param': param}
# Take in the output
# (this is not used but the model is re-loaded from the disk,
# according the best model [epoch] criteria in the model card)
pdf['model'][c] = iceboost.train_xgb(**inputs)
# ------------------------------------------------
## Now apply the model
param = pdf['param']
if w is not None:
wnew = copy.deepcopy(w)
else:
wnew = np.ones(len(y))
for c in pdf['model'].keys():
print(f'Applying AIRW reweighter to class [{c}]')
if param['predict'] == 'xgb':
func_predict = predict.pred_xgb(args=args, param=param, feature_names=variables)
elif param['predict'] == 'xgb_logistic':
func_predict = predict.pred_xgb_logistic(args=args, param=param, feature_names=variables)
else:
raise Exception(__name__ + f'.AIRW_helper: Unsupported model -- "predict" field should be "xgb" or "xgb_logistic"')
# -----------------------------------------------
# Predict for events of this class
ind = (y == c)
pred = func_predict(x[ind])
# Handle logits vs probabilities
min_pred, max_pred = np.min(pred), np.max(pred)
THRESH = 1E-5
if min_pred < (-THRESH) or max_pred > (1.0 + THRESH):
print(f'Detected raw logit output [{min_pred:0.4f}, {max_pred:0.4f}] from the model')
logits = pred
p = aux.sigmoid(logits)
print(f'Corresponding probability output [{np.min(p):0.4f}, {np.max(p):0.4f}]')
else:
print(f'Detected probability output [{min_pred:0.4f}, {max_pred:0.4f}] from the model')
logits = aux.inverse_sigmoid(pred)
print(f'Corresponding logit output [{np.min(logits):0.4f}, {np.max(logits):0.4f}]')
# Get weights after the re-weighting transform
AI_w = rw_transform_with_logits(logits=logits, mode=RW_mode)
# Apply maximum weight regularization
AI_w = np.clip(AI_w, 0.0, MAX_REG)
# Apply multiplicatively to event weights
wnew[ind] = wnew[ind] * AI_w
# -----------------------------------------------
# Transform weights into weights doublet
weights_doublet = {}
for c in class_ids:
ind = (y == c)
weights_doublet[c] = np.zeros(x.shape[0])
weights_doublet[c][ind] = wnew[ind]
return weights_doublet, pdf
[docs]
def doublet_helper(x, y, w, class_ids):
weights_doublet = {}
for c in class_ids:
weights_doublet[c] = np.zeros(len(x)) # init with zeros
sample_weights = w[y==c] if w is not None else 1.0 # Feed in the input weights
weights_doublet[c][y == c] = sample_weights
return weights_doublet
[docs]
def compute_ND_reweights(x, y, w, ids, args, pdf=None, EPS=1e-12,
x_val=None, y_val=None, w_val=None, skip_reweights=False):
"""
Compute N-dim reweighting coefficients
Supports 'ML' (ND), 'pseudo-ND' (1D x 1D ... x 1D), '2D', '1D'
For 'args' dictionary structure, see steering cards.
Args:
x : training sample input
y : training sample (class) labels
w : training sample weights
ids : variable names of columns of x
pdf : pre-computed pdfs (default None)
args : reweighting parameters in a dictionary
Returns:
weights : 1D-array of re-weights
pdf : computed pdfs
"""
use_ak = True if isinstance(x, ak.Array) else False
if use_ak:
y = copy.deepcopy(ak.to_numpy(y).astype(int))
w = copy.deepcopy(ak.to_numpy(w).astype(float))
class_ids = np.unique(y.astype(int))
# Make sure we make a copy, because we modify args here
args = copy.deepcopy(args)
## Differential reweighting
if args['reweight_param']['differential']:
print(f'Differential reweighting')
# Histogram based
if args['reweight_param']['diff_param']['type'] != 'AIRW':
weights_doublet, pdf = histogram_helper(x=x, y=y, w=w, ids=ids,
pdf=pdf, args=args, EPS=EPS)
# AIRW based
else:
weights_doublet, pdf = AIRW_helper(x=x, y=y, w=w, ids=ids, pdf=pdf, args=args,
x_val=x_val, y_val=y_val, w_val=w_val, EPS=EPS)
# Renormalize integral (sum) to the event counts per class
if args['reweight_param']['diff_param']['renorm_weight_to_count']:
print(f'Renormalizing sum(weights) == sum(count) per class')
for c in class_ids:
ind = (y == c)
weights_doublet[c][ind] /= np.sum(weights_doublet[c][ind])
weights_doublet[c][ind] *= np.sum(ind)
# ---------------------------------------------------------------
# Special mode -- allows to only train pdf but then do not apply
if skip_reweights:
print(f'Special mode [skip_reweights] active: differential reweight model not applied', 'red')
weights_doublet = doublet_helper(x=x, y=y, w=w, class_ids=class_ids)
# ---------------------------------------------------------------
# No differential re-weighting
else:
print(f"No differential reweighting")
print(f"Reference [target] class: [{args['reweight_param']['reference_class']}] | Found classes {class_ids} from y")
weights_doublet = doublet_helper(x=x, y=y, w=w, class_ids=class_ids)
# --------------------------------------------------------
### Apply class balance equalizing weight
if (args['reweight_param']['equal_frac'] == True):
print(f"Equalizing class fractions", "green")
weights_doublet = balanceweights(weights_doublet=weights_doublet,
reference_class=args['reweight_param']['reference_class'], y=y)
### Finally map back to 1D-array
weights = np.zeros(len(w))
for c in class_ids:
weights = weights + weights_doublet[c]
### Print weights
prints.print_weights(weights=weights, y=y)
if use_ak:
return ak.Array(weights), pdf
else:
return weights, pdf
[docs]
def reweightcoeff1D(X, y, pdf, reference_class, max_reg = 1e3, EPS=1e-12):
""" Compute N-class density reweighting coefficients.
Args:
X: Observable of interest (N x 1)
y: Class labels (N x 1)
pdf: PDF for each class
reference_class: e.g. 0 (background) or 1 (signal)
Returns:
weights for each event
"""
class_ids = pdf['class_ids']
# Re-weighting weights
weights_doublet = {} # Init with zeros!!
for c in class_ids:
weights_doublet[c] = np.zeros(X.shape[0])
# Weight each class against the reference class
for c in class_ids:
inds = aux.x2ind(X[y == c], pdf['binedges'])
if c is not reference_class:
weights_doublet[c][y == c] = pdf[reference_class][inds] / np.clip(pdf[c][inds], EPS, None)
else:
weights_doublet[c][y == c] = 1.0 # Reference class stays intact
# Maximum weight cut-off regularization
weights_doublet[c][weights_doublet[c] > max_reg] = max_reg
return weights_doublet
[docs]
def reweightcoeff2D(X_A, X_B, y, pdf, reference_class, max_reg = 1e3, EPS=1E-12):
"""
Compute N-class density reweighting coefficients.
Operates in full 2D without any factorization.
Args:
X_A : First observable of interest (N x 1)
X_B : Second observable of interest (N x 1)
y : Class labels (N x 1)
pdf : Density histograms for each class
reference_class : e.g. Background (0) or signal (1)
max_reg : Regularize the maximum reweight coefficient
Returns:
weights for each event
"""
class_ids = pdf['class_ids']
# Re-weighting weights
weights_doublet = {} # Init with zeros!!
for c in class_ids:
weights_doublet[c] = np.zeros(X_A.shape[0])
# Weight each class against the reference class
for c in class_ids:
inds_0 = aux.x2ind(X_A[y == c], pdf['binedges'][0]) # variable 0
inds_1 = aux.x2ind(X_B[y == c], pdf['binedges'][1]) # variable 1
if c is not reference_class:
weights_doublet[c][y == c] = pdf[reference_class][inds_0, inds_1] / np.clip(pdf[c][inds_0, inds_1], EPS, None)
else:
weights_doublet[c][y == c] = 1.0 # Reference class stays intact
# Maximum weight cut-off regularization
weights_doublet[c][weights_doublet[c] > max_reg] = max_reg
return weights_doublet
[docs]
def pdf_1D_hist(X, w, binedges):
"""
Compute re-weighting 1D pdfs.
"""
# Take re-weighting variables
pdf,_,_ = plt.hist(x = X, weights=w, bins = binedges)
# Make them densities
pdf /= np.sum(pdf.flatten())
return pdf
[docs]
def pdf_2D_hist(X_A, X_B, w, binedges_A, binedges_B):
"""
Compute re-weighting 2D pdfs.
"""
# Take re-weighting variables
pdf,_,_,_ = plt.hist2d(x = X_A, y = X_B, weights=w, bins = [binedges_A, binedges_B])
# Make them densities
pdf /= np.sum(pdf.flatten())
return pdf
[docs]
def balanceweights(weights_doublet, reference_class, y, EPS=1e-12):
""" Balance N-class weights to sum to equal counts.
Args:
weights_doublet: N-class event weights (events x classes)
reference_class: which class gives the reference (integer)
y: class targets
Returns:
weights doublet with new weights per event
"""
class_ids = np.unique(y).astype(int)
ref_sum = np.sum(weights_doublet[reference_class][y == reference_class])
for c in class_ids:
if c is not reference_class:
EQ = ref_sum / np.clip(np.sum(weights_doublet[c][y == c]), EPS, None)
weights_doublet[c][y == c] *= EQ
return weights_doublet