# Input data containers and memory management
#
# m.mieskolainen@imperial.ac.uk, 2025
import numpy as np
import awkward as ak
from collections import Counter
from typing import Literal, Optional
import numba
import copy
import torch
import os
import psutil
import subprocess
import re
import pathlib
from natsort import natsorted
# MVA imputation
from sklearn.impute import KNNImputer
from sklearn.impute import SimpleImputer
from sklearn.experimental import enable_iterative_imputer # Needs this
from sklearn.impute import IterativeImputer
# Command line arguments
from glob import glob
from braceexpand import braceexpand
import copy
import time
import hashlib
import base64
from icenet.tools import aux
from icenet.tools import stx
# ------------------------------------------
from icenet import print
# ------------------------------------------
[docs]
def get_file_timestamp(file_path: str):
"""
Return file timestamp as a string
"""
if os.path.exists(file_path):
# Get the file last modification time
timestamp = os.path.getmtime(file_path)
# Convert it to a readable format
readable_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))
return readable_time
else:
return f"File '{file_path}' does not exist."
[docs]
def rootsafe(txt):
"""
Change character due to ROOT
"""
return txt.replace('-', '_').replace('+','_').replace('/','_').replace('*','_')
[docs]
def safetxt(txt):
"""
Protection for '/'
"""
if type(txt) is str:
return txt.replace('/', '|')
else:
return txt
[docs]
def count_files_in_dir(path):
"""
Count the number of files in a path
"""
return len([name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))])
[docs]
def make_hash_sha256_file(filename):
"""
Create SHA256 hash from a file
"""
h = hashlib.sha256()
b = bytearray(128*1024)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
[docs]
def make_hash_sha256_object(o):
"""
Create SHA256 hash from an object
Args:
o: python object (e.g. dictionary)
Returns:
hash
"""
hasher = hashlib.sha256()
hasher.update(repr(make_hashable(o)).encode())
hash_str = base64.b64encode(hasher.digest()).decode()
# May cause problems with directories
hash_str = hash_str.replace('/', '_')
hash_str = hash_str.replace('\\', '__')
hash_str = hash_str.replace('.', '___')
return hash_str
[docs]
def make_hashable(o):
"""
Turn a python object into hashable type (recursively)
"""
if isinstance(o, (tuple, list)):
return tuple((make_hashable(e) for e in o))
if isinstance(o, dict):
return tuple(sorted((k,make_hashable(v)) for k,v in o.items()))
if isinstance(o, (set, frozenset)):
return tuple(sorted(make_hashable(e) for e in o))
return o
[docs]
def glob_expand_files(datasets, datapath, recursive_glob=False):
"""
Do global / brace expansion of files
Args:
datasets: dataset filename with glob syntax (can be a list of files)
datapath: root path to files
Returns:
files: full filenames including the path
"""
print("")
print(f"Supported syntax: <filename_*>, <filename_0>, <filename_[0-99]>, <filename_{{0,3,4}}>")
print("See https://docs.python.org/3/library/glob.html and brace expansion (be careful, do not use [,] brackets in your filenames)")
print("")
# Remove unnecessary []
if type(datasets) is list and len(datasets) == 1:
datasets = datasets[0]
# Try first to brace expand
try:
datasets = list(braceexpand(datasets))
except:
True
#print(__name__ + f'.glob_expand_files: After braceexpand: {datasets}')
if (len(datasets) == 1) and ('[' in datasets[0]) and (']' in datasets[0]):
print(f'Parsing of range [first-last] ...')
res = re.findall(r'\[.*?\]', datasets[0])[0]
temp = res[1:-1]
numbers = temp.split('-')
first = int(numbers[0])
last = int(numbers[1])
print(f'Obtained range of files: [{first}, {last}]')
# Split and add
parts = datasets[0].split(res)
datasets[0] = parts[0] + '{'
for i in range(first, last+1):
datasets[0] += f'{i}'
if i != last:
datasets[0] += ','
datasets[0] += '}' + parts[1]
datasets = list(braceexpand(datasets[0]))
#print(__name__ + f'.glob_expand_files: After expanding the range: {datasets}')
# Parse input files into a list
files = []
for data in datasets:
# This does e.g. _*.root expansion (finds the files)
x = str(pathlib.Path(datapath) / data)
expanded_files = glob(x, recursive=recursive_glob)
files.extend(expanded_files)
if files == []:
files = [datapath]
# Normalize e.g. for accidental multiple slashes
files = [os.path.normpath(f) for f in files]
# Make them unique and natural sorted
files = natsorted(set(files))
#print(__name__ + f'.glob_expand_files: Final files: {files}')
return files
[docs]
def showmem(color='red'):
print(f"""Process RAM: {process_memory_use():0.2f} GB [total RAM in use {psutil.virtual_memory()[2]} %]""", color)
[docs]
def showmem_cuda(device='cuda:0', color='red'):
print(f"Process RAM: {process_memory_use():0.2f} GB [total RAM in use {psutil.virtual_memory()[2]} %] | VRAM usage: {get_gpu_memory_map()} GB [total VRAM {torch_cuda_total_memory(device):0.2f} GB]", color)
[docs]
def get_gpu_memory_map():
"""Get the GPU VRAM use in GB.
Returns:
dictionary with keys as device ids [integers]
and values the memory used by the GPU.
"""
try:
result = subprocess.check_output(
[
'nvidia-smi', '--query-gpu=memory.used',
'--format=csv,nounits,noheader'
], encoding='utf-8')
# into dictionary
gpu_memory = [int(x)/1024.0 for x in result.strip().split('\n')]
gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
return gpu_memory_map
except Exception as e:
print(f'Error: Could not run nvidia-smi: {e}')
return None
[docs]
def torch_cuda_total_memory(device='cuda:0'):
"""
Return CUDA device VRAM available in GB.
"""
return torch.cuda.get_device_properties(device).total_memory / 1024.0**3
[docs]
def process_memory_use():
"""
Return system memory (RAM) used by the process in GB.
"""
pid = os.getpid()
py = psutil.Process(pid)
return py.memory_info()[0]/2.**30
[docs]
def checkinfnan(x, value = 0):
""" Check inf and Nan values and replace with a default value.
"""
inf_ind = np.isinf(x)
nan_ind = np.isnan(x)
x[inf_ind] = value
x[nan_ind] = value
if np.sum(inf_ind) > 0:
print(f'Inf found, replacing with {value}', 'red')
if np.sum(nan_ind) > 0:
print(f'NaN found, replacing with {value}', 'red')
return x
[docs]
class fastarray1:
""" 1D pre-memory occupied buffer arrays for histogramming etc.
"""
def __init__(self, capacity = 32):
self.capacity = capacity
self.data = np.zeros((self.capacity))
self.size = 0
# use with x.update([1,3,5,1])
#
[docs]
def update(self, row):
for x in row:
self.add(x)
# use with x.add(32)
#
[docs]
def add(self, x):
if self.size == self.capacity:
print(f'increasing current capacity = {self.capacity} to 2x')
self.capacity *= 2
newdata = np.zeros((self.capacity))
newdata[:self.size] = self.data
self.data = newdata
self.data[self.size] = x
self.size += 1
# Get values
#
[docs]
def values(self):
return self.data[0:self.size]
# Reset index, keep the buffersize
[docs]
def reset(self):
self.size = 0
[docs]
def index_list(target_list, keys):
"""
Use e.g. x_subset = x[:, io.index_list(ids, variables)]
"""
index = []
for key in keys:
index.append(target_list.index(key))
return index
[docs]
def pick_vars(data, set_of_vars):
""" Choose the active set of input variables.
Args:
data: IceXYW type object
set_of_vars: Variables to pick
Returns:
newind: Chosen indices
newvars: Chosen variables
"""
newind = np.where(np.isin(data.ids, set_of_vars))
newind = np.array(newind).flatten()
newvars = []
for i in newind :
newvars.append(data.ids[i])
return newind, newvars
[docs]
class IceXYW:
"""
Args:
x : data object
y : target output data
w : weights
"""
# constructor
def __init__(self, x = np.array([]), y = np.array([]), w = None, ids = None):
self.x = x
self.y = y
self.w = w
self.ids = ids
if isinstance(x, np.ndarray):
self.concat = np.concatenate
elif isinstance(x, ak.Array):
self.concat = ak.concatenate
else:
raise Exception(__name__ + f'.IceXYW.__init__: Unknown input array type')
[docs]
def find_ind(self, key):
"""
Return column index corresponding to key
"""
return int(np.where(np.array(self.ids, dtype=np.object_) == key)[0])
def __getitem__(self, key):
""" Advanced indexing with a variable or a list of variables """
if type(key) is not list: # access with a single variable name
try:
select = key
new = IceXYW(x=self.x[select, ...], y=self.y[select], w=self.w[select], ids=self.ids)
return new
except:
True
if key in self.ids: # direct access
col = self.ids.index(key)
ids = [key]
elif isinstance(self.x, np.ndarray): # might be a cut string, try that
select = stx.eval_boolean_syntax(expr=key, X=self.x, ids=self.ids, verbose=True)
return IceXYW(x=self.x[select, ...], y=self.y[select], w=self.w[select], ids=self.ids)
else:
raise Exception(__name__ + f'[operator]: Cannot execute')
else: # list of variables
col,ids = pick_vars(data=self, set_of_vars=key)
if isinstance(self.x, np.ndarray):
return IceXYW(x=self.x[..., col], y=self.y, w=self.w, ids=ids)
else:
return IceXYW(x=self.x[col], y=self.y, w=self.w, ids=ids)
# length operator
def __len__(self):
return len(self.x)
# + operator
def __add__(self, other):
if (len(self.x) == 0): # still empty
return other
x = self.concat((self.x, other.x), axis=0)
y = self.concat((self.y, other.y), axis=0)
if self.w is not None:
w = self.concat((self.w, other.w), axis=0)
return IceXYW(x, y, w)
# += operator
def __iadd__(self, other):
if (len(self.x) == 0): # still empty
return other
self.x = self.concat((self.x, other.x), axis=0)
self.y = self.concat((self.y, other.y), axis=0)
if self.w is not None:
self.w = self.concat((self.w, other.w), axis=0)
return self
# filter operator
[docs]
def classfilter(self, classid):
ind = (self.y == classid)
x = self.x[ind]
y = self.y[ind]
if self.w is not None:
w = self.w[ind]
else:
w = self.w
return IceXYW(x=x, y=y, w=w, ids=self.ids)
# Permute events
[docs]
def permute(self, permutation):
self.x = self.x[permutation]
self.y = self.y[permutation]
if self.w is not None:
self.w = self.w[permutation]
return self
[docs]
def split_data_simple(X, frac, permute=True):
""" Split machine learning data into train, validation, test sets
Args:
X: data as a list of event objects (such as torch geometric Data)
frac: split fraction
"""
### Permute events to have random mixing between events
if permute:
randi = np.random.permutation(len(X)).tolist()
X = [X[i] for i in randi]
N = len(X)
N_A = round(N * frac)
N_B = N - N_A
N_trn = N_A
N_val = round(N_B / 2)
N_tst = N - N_trn - N_val
# Split
X_trn = X[0:N_trn]
X_val = X[N_trn:N_trn + N_val]
X_tst = X[N - N_tst:]
print(f"fractions [train: {len(X_trn)/N:0.3f}, validate: {len(X_val)/N:0.3f}, test: {len(X_tst)/N:0.3f}]")
return X_trn, X_val, X_tst
[docs]
def split_data(X, Y, W, ids, frac=[0.5, 0.1, 0.4], permute=True):
""" Split machine learning data into train, validation, test sets
Args:
X: data matrix
Y: target matrix
W: weight matrix
ids: variable names of columns
frac: fraction [train, validate, evaluate] (sum to 1)
rngseed: random seed
"""
### Permute events to have random mixing between classes
if permute:
randi = np.random.permutation(len(X))
X = X[randi]
Y = Y[randi]
if W is not None:
W = W[randi]
# --------------------------------------------------------------------
frac = np.array(frac)
frac = frac / np.sum(frac)
# Get event counts
N = len(X)
N_trn = int(round(N * frac[0]))
N_tst = int(round(N * frac[2]))
N_val = N - N_trn - N_tst
# 1. Train
X_trn = X[0:N_trn]
Y_trn = Y[0:N_trn]
if W is not None:
W_trn = W[0:N_trn]
else:
W_trn = None
# 2. Validation
X_val = X[N_trn:N_trn + N_val]
Y_val = Y[N_trn:N_trn + N_val]
if W is not None:
W_val = W[N_trn:N_trn + N_val]
else:
W_val = None
# 3. Test
X_tst = X[N - N_tst:]
Y_tst = Y[N - N_tst:]
if W is not None:
W_tst = W[N - N_tst:]
else:
W_tst = None
# --------------------------------------------------------
# Construct
trn = IceXYW(x = X_trn, y = Y_trn, w=W_trn, ids=ids)
val = IceXYW(x = X_val, y = Y_val, w=W_val, ids=ids)
tst = IceXYW(x = X_tst, y = Y_tst, w=W_tst, ids=ids)
# --------------------------------------------------------
print(f"fractions [train: {len(X_trn)/N:0.3f}, validate: {len(X_val)/N:0.3f}, test: {len(X_tst)/N:0.3f}]")
return trn, val, tst
[docs]
def impute_data(X, imputer=None, dim=None, values=[-999], labels=None, algorithm='iterative', fill_value=0, knn_k=6):
""" Data imputation (treatment of missing values, Nan and Inf).
NOTE: This function can impute only fixed dimensional input currently (not Jagged numpy arrays)
Args:
X : Input data matrix [N vectors x D dimensions]
imputer : Pre-trained imputator, default None
dim : Array of active dimensions to impute
values : List of special integer values indicating the need for imputation
labels : List containing textual description of input variables
algorithm : 'constant', mean', 'median', 'iterative', knn_k'
knn_k : knn k-nearest neighbour parameter
Returns:
X : Imputed output data
"""
if dim is None:
dim = np.arange(X.shape[1])
if labels is None:
labels = np.zeros(X.shape[1])
N = X.shape[0]
# Count NaN
for j in dim:
nan_ind = np.isnan(np.array(X[:,j], dtype=np.float32))
found = np.sum(nan_ind)
if found > 0:
print(f'Column {j} Number of {nan_ind} NaN found ({found/len(X):0.3E}) [{labels[j]}]', 'red')
# Loop over dimensions
for j in dim:
# Set NaN for special values
M_tot = 0
for z in values:
ind = np.isclose(np.array(X[:,j], dtype=np.float32), z)
X[ind, j] = np.nan
M = np.sum(ind)
M_tot += M
if (M/N > 0):
print(f'Column {j} fraction [{M/N:0.3E}] with value {z} [{labels[j]}]')
if (M_tot == N): # Protection, if all are now NaN
# Set to zero so Imputer Function below does not remove the full column!!
X[:,j] = 0.0
# Treat infinities (inf)
for j in dim:
inf_ind = np.isinf(np.array(X[:,j], dtype=np.float32))
X[inf_ind, j] = np.nan
found = np.sum(inf_ind)
if found > 0:
print(f'Column {j} Number of {found} Inf found ({found/len(X):0.3E}) [{labels[j]}]', 'red')
if imputer == None:
# Fill missing values
if algorithm == 'constant':
imputer = SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=fill_value)
elif algorithm == 'mean':
imputer = SimpleImputer(missing_values=np.nan, strategy='mean')
elif algorithm == 'median':
imputer = SimpleImputer(missing_values=np.nan, strategy='median')
elif algorithm == 'iterative':
imputer = IterativeImputer(missing_values=np.nan)
elif algorithm == 'knn':
imputer = KNNImputer(missing_values=np.nan, n_neighbors=knn_k)
else:
raise Exception(__name__ + '.impute_data: Unknown algorithm chosen')
# Fit and transform
imputer.fit(X[:,dim])
X[:,dim] = imputer.transform(X[:,dim])
print('[done] \n')
return X, imputer
[docs]
def calc_madscore(X : np.array):
""" Calculate robust normalization.
Args:
X : Input with [# vectors x # dimensions]
Returns:
X_m : Median vector
X_mad : Median deviation vector
"""
X_m = np.zeros((X.shape[1]))
X_mad = np.zeros((X.shape[1]))
# Calculate mean and std based on the training data
for i in range(X.shape[1]) :
X_m[i] = np.median(X[:,i])
X_mad[i] = np.median(np.abs(X[:,i] - X_m[i]))
if (np.isnan(X_mad[i])):
raise Exception(__name__ + f': Fatal error with MAD[index = {i}] is NaN')
if (np.isinf(X_mad[i])):
raise Exception(__name__ + f': Fatal error with MAD[index = {i}] is Inf')
return X_m, X_mad
[docs]
def calc_zscore_tensor(T):
"""
Compute z-score normalization for tensors.
Args:
T : input tensor data (events x channels x rows x cols, ...)
Returns:
mu, std tensors
"""
Y = copy.deepcopy(T)
Y[~np.isfinite(Y)] = 0
mu = np.mean(Y, axis=0)
std = np.std(Y, axis=0)
return mu, std
[docs]
def apply_zscore_tensor(T, mu, std, EPS=1E-12):
"""
Apply z-score normalization for tensors.
"""
Y = copy.deepcopy(T)
Y[~np.isfinite(Y)] = 0
# Over all events
for i in range(T.shape[0]):
Y[i,...] = (Y[i,...] - mu) / max(std, EPS)
return Y
[docs]
def calc_zscore(X: np.array, weights: np.array = None):
""" Calculate 0-mean & unit-variance normalization.
Args:
X : Input with [N x dim]
weights : Event weights
Returns:
X_mu : Mean vector
X_std : Standard deviation vector
"""
X_mu = np.zeros((X.shape[1]))
X_std = np.zeros((X.shape[1]))
# Calculate mean and std based on the training data
for i in range(X.shape[1]):
xval = X[:,i]
if weights is not None:
X_mu[i], X_std[i] = aux.weighted_avg_and_std(xval, weights)
else:
X_mu[i], X_std[i] = np.mean(xval), np.std(xval)
if (np.isnan(X_std[i])):
raise Exception(__name__ + f': Fatal error with std[index = {i}] is NaN')
if (np.isinf(X_std[i])):
raise Exception(__name__ + f': Fatal error with std[index = {i}] is Inf')
return X_mu, X_std
[docs]
@numba.njit(parallel=True)
def apply_zscore(X : np.array, X_mu, X_std, EPS=1E-12):
""" Z-score normalization
"""
Y = np.zeros(X.shape)
for i in range(len(X_mu)):
Y[:,i] = (X[:,i] - X_mu[i]) / max(X_std[i], EPS)
return Y
[docs]
@numba.njit(parallel=True)
def reverse_zscore(Y: np.array, X_mu, X_std, EPS=1E-12):
""" Reverse Z-score normalization
"""
X = np.zeros(Y.shape)
for i in range(len(X_mu)):
X[:, i] = Y[:, i] * max(X_std[i], EPS) + X_mu[i]
return X
[docs]
@numba.njit(parallel=True)
def apply_madscore(X : np.array, X_m, X_mad, EPS=1E-12):
""" MAD-score normalization
"""
Y = np.zeros(X.shape)
scale = 0.6745 # 0.75th of z-normal
for i in range(len(X_m)):
Y[:,i] = scale * (X[:,i] - X_m[i]) / max(X_mad[i], EPS)
return Y
[docs]
def infer_precision(
arr: np.ndarray,
*,
small_spacing_quantile: float = 0.20,
max_pairs: int = 100_000,
min_pairs: int = 500,
grid_coarse: int = 64,
grid_fine: int = 128,
assume_ieee: bool = True,
return_debug: bool = False,
):
"""
GPT5 driven (briefly tested)
Estimate effective mantissa (fraction) bits p from a float array whose values are
on a binary quantization grid (possibly reduced precision). Invariant to affine
transforms y = a*x + b with a > 0.
Method:
- Take unique sorted values u (in float64 for stability).
- Center once with a robust location m = median(u).
- Spacings: d = u[i+1] - u[i] (shift-invariant)
- Centered midpoints: c0 = (u[i]+u[i+1])/2 - m (shift-invariant)
- For r in [0,1): z_r = log2(d) - ( floor(log2|c0| + r) - 1 )
For a binary quantizer with p fraction bits, z_r clusters near -p.
- Pick r minimizing MAD(z_r), then p = round(-median(z_r)).
Notes:
- Positive scaling y = a*x shifts both log2(d) and log2|c0| by log2(a); the r-search re-aligns
exponent bins, leaving z_r (and thus p) unchanged. Centering cancels b.
- For true full-precision float64 random arrays, you typically won't recover 52 without
huge N. This is intended for *reduced* precision data.
Returns:
dict with keys:
mantissa_bits_eff : int | None
mad_bits : float | None
samples_used : int
pairs_used : int
notes : str
debug : dict (if return_debug)
"""
x = np.asarray(arr)
if not np.issubdtype(x.dtype, np.floating):
raise TypeError("Input must be a floating dtype array.")
orig_dtype = x.dtype
# 1) Finite filter
x = x[np.isfinite(x)]
if x.size < 3:
return {'mantissa_bits_eff': None, 'mad_bits': None,
'samples_used': int(x.size), 'pairs_used': 0,
'notes': "Too few finite samples."}
# 2) Unique sorted (use float64 arithmetic)
u = np.unique(x.astype(np.float64, copy=False))
if u.size < 3:
return {'mantissa_bits_eff': None, 'mad_bits': None,
'samples_used': int(x.size), 'pairs_used': 0,
'notes': "Too few unique samples."}
# 3) Center once (shift invariance), build spacings and centered midpoints
m = np.median(u)
d = np.diff(u) # spacings (positive if unique-sorted)
c0 = 0.5 * (u[:-1] + u[1:]) - m # centered midpoints
# 4) Clean
mask = (d > 0) & np.isfinite(c0) & (c0 != 0.0)
if assume_ieee:
# use original dtype tiny to drop (near-)subnormal midpoints after centering
mask &= (np.abs(c0) >= np.finfo(orig_dtype).tiny)
d = d[mask]
c0 = c0[mask]
if d.size < min_pairs:
return {'mantissa_bits_eff': None, 'mad_bits': None,
'samples_used': int(x.size), 'pairs_used': int(d.size),
'notes': "Not enough spacing pairs after filtering."}
pairs_total = int(d.size)
log2d_all = np.log2(d) # d > 0 by construction
lc_all = np.log2(np.abs(c0)) # |c0| > 0 by mask
# Helpers
def stats_over_r(log2d, lc, r_vec):
r = r_vec[:, None] # [R,1]
# Key formula (affine-invariant): no '+r' on log2d; '- 1' inside to align bins
z = log2d[None, :] - (np.floor(lc[None, :] + r) - 1.0)
med = np.median(z, axis=1)
mad = np.median(np.abs(z - med[:, None]), axis=1)
return med, mad
def eval_subset(log2d, lc):
# coarse
R1 = int(max(16, grid_coarse))
r1 = np.linspace(0.0, 1.0, R1, endpoint=False)
med1, mad1 = stats_over_r(log2d, lc, r1)
i1 = int(np.argmin(mad1))
r_best = float(r1[i1])
# fine around best (wrap)
R2 = int(max(32, grid_fine))
halfw = 1.0 / R1
r2 = (r_best + np.linspace(-halfw, halfw, R2, endpoint=True)) % 1.0
med2, mad2 = stats_over_r(log2d, lc, r2)
i2 = int(np.argmin(mad2))
return float(r2[i2]), float(med2[i2]), float(mad2[i2])
def select_smallest(d_all, log2d_all, lc_all, q):
# pick k smallest spacings by d (fast, stable)
kq = int(np.ceil(q * d_all.size))
kq = max(kq, min_pairs)
if max_pairs is not None:
kq = min(kq, max_pairs)
if kq < d_all.size:
idx = np.argpartition(d_all, kq - 1)[:kq]
return log2d_all[idx], lc_all[idx], int(kq)
return log2d_all, lc_all, int(d_all.size)
# 5) Adaptive quantile sweep to find the tightest cluster
q0 = small_spacing_quantile if 0.0 < small_spacing_quantile < 1.0 else 1.0
tried = []
q = q0
for _ in range(6): # q, q/2, q/4, ...
log2d, lc, k = select_smallest(d, log2d_all, lc_all, q)
r_star, med_star, mad_star = eval_subset(log2d, lc)
tried.append((q, k, r_star, med_star, mad_star))
# Early exit if perfectly clustered
if mad_star == 0.0:
break
q *= 0.5
if q < 1e-3:
break
# choose the attempt with minimal MAD
q_best, k_best, r_star, med_star, mad_star = min(tried, key=lambda t: t[4])
p = int(round(-med_star)) # z ≈ -p
out = {
'mantissa_bits_eff': p,
'mad_bits': mad_star,
'samples_used': int(x.size),
'pairs_used': int(k_best),
'notes': (
f"Affine-invariant (y=a*x+b, a>0). Used {k_best}/{pairs_total} pairs "
f"(q≈{q_best:.4f}). Best r={r_star:.6f}, MAD={mad_star:.3f} bits."
),
}
if return_debug:
out['debug'] = {
'attempts': [
{'q': float(a[0]), 'pairs': int(a[1]), 'r': float(a[2]),
'med': float(a[3]), 'mad': float(a[4])}
for a in tried
]
}
return out
[docs]
def optimal_dequantize(
x: np.ndarray,
p: float,
scale: float = 1.0,
zero_mode: Literal["median", "min_nonzero"] = "min_nonzero",
*,
rng: Optional[np.random.Generator] = None,
) -> np.ndarray:
"""
Optimal dequantization based on effective mantissa bits and a uniform
quantization model. Adds heteroscedastic Gaussian noise with relative STD.
Args:
x: float array
p: effective mantissa bits
scale: extra multiplier for the dequantization strength
zero_mode: how to set the *reference magnitude* used for x == 0 elements:
- "median": use median(|x_nonzero|)
- "min_nonzero":use min(|x_nonzero|) (more conservative near zero)
rng: optional numpy Generator (for reproducibility)
Returns:
Dequantized array (same shape as x).
"""
x = np.asarray(x)
if not np.issubdtype(x.dtype, np.floating):
raise TypeError("x must be a floating dtype")
# Relative STD implied by p-bit rounding noise (uniform -> Gaussian match)
rel_sigma = scale * (2.0 ** (-p)) / np.sqrt(12.0)
absx = np.abs(x)
nonzero_idx = np.flatnonzero(absx)
if nonzero_idx.size:
# Absolute
nz_vals = absx[nonzero_idx]
if zero_mode == "median":
zero_ref = np.median(nz_vals)
elif zero_mode == "min_nonzero":
zero_ref = float(np.min(nz_vals))
# guard against pathological tiny values
zero_ref = max(zero_ref, np.finfo(x.dtype).tiny)
else:
raise ValueError("zero_mode must be 'median' or 'min_nonzero'")
else:
# All zeros -> fall back to 1.0 (dimensionless default)
zero_ref = 1.0
# Heteroscedastic sigma: proportional to |x|, with zero treated by zero_mode
sigma = rel_sigma * np.where(absx != 0.0, absx, zero_ref)
rng = np.random.default_rng() if rng is None else rng
noise = rng.normal(loc=0.0, scale=1.0, size=x.shape) * sigma
return x + noise