# Graph data readers and parsers for HGCAL
#
# Mikael Mieskolainen, 2023
# m.mieskolainen@imperial.ac.uk
import numpy as np
from tqdm import tqdm
import torch
from torch_geometric.data import Data
import torch_geometric.transforms as T
import icenet.algo.analytic as analytic
from icenet.tools import aux
from icenet.tools.icevec import vec4
# ------------------------------------------
from icenet import print
# ------------------------------------------
[docs]
def parse_graph_data_trackster(data, graph_param, weights=None, maxevents=int(1e9), null_value=-999.0):
"""
TRACKSTER LEVEL
Parse graph data to torch geometric format
Args:
data: awkward array
"""
#global_on = graph_param['global_on']
#coord = graph_param['coord']
directed = graph_param['directed']
self_loops = graph_param['self_loops']
# --------------------------------------------------------------------
event = 0
print(data['x'][0])
#print(data['edge_index'])
#print(data['edge_labels'])
nevents = np.min([len(data['x']), maxevents])
graph_dataset = []
## For all events
for ev in tqdm(range(nevents)):
# --------------------------------------------
## ** Construct node features **
nodes = data['x'][ev]
x = np.zeros((len(nodes), 8))
x[:,0] = nodes.barycenter_x.to_numpy()
x[:,1] = nodes.barycenter_y.to_numpy()
x[:,2] = nodes.barycenter_z.to_numpy()
x[:,3] = nodes.raw_energy.to_numpy()
x[:,4] = nodes.raw_em_energy.to_numpy()
x[:,5] = nodes.EV1.to_numpy()
x[:,6] = nodes.EV2.to_numpy()
x[:,7] = nodes.EV3.to_numpy()
x[~np.isfinite(x)] = null_value # Input protection
x = torch.tensor(x, dtype=torch.float)
# --------------------------------------------
# --------------------------------------------
## ** Construct edge indices **
edge_index = data['edge_index'][ev]
# size: 2 x num_edges
edge_index = torch.tensor(np.array(edge_index, dtype=int), dtype=torch.long)
# --------------------------------------------
# --------------------------------------------
## ** Construct edge labels ** (training truth)
y = data['edge_labels'][ev]
# size: num_edges
y = torch.tensor(np.array(y, dtype=int), dtype=torch.long)
# --------------------------------------------
# Global features (not active)
u = torch.tensor([], dtype=torch.float)
# Edge weights
w = torch.ones_like(y, dtype=torch.float)
# Create graph
graph = Data(num_nodes=x.shape[0], x=x, edge_index=edge_index, edge_attr=None, y=y, w=w, u=u)
# Add also edge attributes
graph.edge_attr = compute_edge_attr(graph)
graph_dataset.append(graph)
return graph_dataset
[docs]
def compute_edge_attr(data):
num_edges = data.edge_index.shape[1]
edge_attr = torch.zeros((num_edges, 1), dtype=torch.float)
for n in range(num_edges):
i,j = data.edge_index[0,n], data.edge_index[1,n]
# L2-distance
edge_attr[n,0] = torch.sqrt(torch.sum((data.x[i,0:3] - data.x[j,0:3]) ** 2))
return edge_attr
[docs]
def parse_graph_data_candidate(X, ids, features, graph_param, Y=None, weights=None, entry_start=None, entry_stop=None, EPS=1e-12, null_value=-999.0):
"""
EVENT LEVEL (PROCESSING CANDIDATES)
Jagged array data into pytorch-geometric style Data format array.
Args:
X : Jagged array of variables
ids : Variable names as an array of strings
features : List of active global feature strings
graph_param: Graph construction parameters dict
Y : Target class array (if any, typically MC only)
weights : (Re-)weighting array (if any, typically MC only)
Returns:
List of pytorch-geometric Data objects
"""
global_on = graph_param['global_on']
coord = graph_param['coord']
directed = graph_param['directed']
self_loops = graph_param['self_loops']
# --------------------------------------------------------------------
num_node_features = 4
num_edge_features = 4
num_global_features = 0
entry_start, entry_stop, num_events = aux.slice_range(start=entry_start, stop=entry_stop, N=len(X))
dataset = []
print(__name__ + f'.parse_graph_data_candidate: Converting {num_events} events into graphs ...')
zerovec = vec4()
# Collect feature indices
feature_ind = np.zeros(len(features), dtype=np.int32)
for i in range(len(features)):
feature_ind[i] = ids.index(features[i])
# Collect track indices
#ind__trk_pt = ids.index('trk_pt')
#ind__trk_eta = ids.index('trk_eta')
#ind__trk_phi = ids.index('trk_phi')
ind__candidate_energy = ids.index('candidate_energy')
ind__candidate_px = ids.index('candidate_px')
ind__candidate_py = ids.index('candidate_py')
ind__candidate_pz = ids.index('candidate_pz')
num_empty_HGCAL = 0
# Loop over events
for ev in tqdm(range(entry_start, entry_stop)):
num_nodes = len(X[ev, ind__candidate_energy])
num_edges = analytic.count_simple_edges(num_nodes=num_nodes, directed=directed, self_loops=self_loops)
# Construct 4-vector for each HGCAL candidate
p4vec = []
N_c = len(X[ev, ind__candidate_energy])
if N_c > 0:
for k in range(N_c):
energy = X[ev, ind__candidate_energy][k]
px = X[ev, ind__candidate_px][k]
py = X[ev, ind__candidate_py][k]
pz = X[ev, ind__candidate_pz][k]
v = vec4()
v.setPxPyPzE(px, py, pz, energy)
p4vec.append(v)
# Empty HGCAL cluster information
else:
num_empty_HGCAL += 1
# However, never skip empty HGCAL cluster events here!!, do pre-filtering before this function if needed
# ====================================================================
# CONSTRUCT TENSORS
# Construct output class, note [] is important to have for right dimensions
if Y is None:
y = torch.tensor([0], dtype=torch.long)
else:
y = torch.tensor([Y[ev]], dtype=torch.long)
# Training weights, note [] is important to have for right dimensions
if weights is None:
w = torch.tensor([1.0], dtype=torch.float)
else:
w = torch.tensor([weights[ev]], dtype=torch.float)
## Construct global feature vector
#u = torch.tensor(X[ev, feature_ind].tolist(), dtype=torch.float)
## Construct node features
x = get_node_features(p4vec=p4vec, num_nodes=num_nodes, num_node_features=num_node_features, coord=coord)
x[~np.isfinite(x)] = null_value # Input protection
x = torch.tensor(x, dtype=torch.float)
## Construct edge features
edge_attr = analytic.get_Lorentz_edge_features(p4vec=p4vec, num_nodes=num_nodes, \
num_edges=num_edges, num_edge_features=num_edge_features, self_loops=self_loops, directed=directed)
edge_attr[~np.isfinite(edge_attr)] = null_value # Input protection
edge_attr = torch.tensor(edge_attr, dtype=torch.float)
## Construct edge connectivity
edge_index = analytic.get_simple_edge_index(num_nodes=num_nodes, num_edges=num_edges, self_loops=self_loops, directed=directed)
edge_index = torch.tensor(edge_index, dtype=torch.long)
# Add this event
if global_on == False: # Null the global features
u = torch.tensor(np.zeros(num_global_features), dtype=torch.float)
else:
u = np.zeros(num_global_features)
u[~np.isfinite(u)] = null_value # input protection
u = torch.tensor(u, dtype=torch.float)
dataset.append(Data(num_nodes=x.shape[0], x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, w=w, u=u))
print(__name__ + f'.parse_graph_data_candidate: Empty HGCAL events: {num_empty_HGCAL} / {num_events} = {num_empty_HGCAL/num_events:0.5f} (using only global data u)')
return dataset
[docs]
def get_node_features(p4vec, num_nodes, num_node_features, coord):
# Node feature matrix
x = np.zeros((num_nodes, num_node_features), dtype=float)
for i in range(num_nodes-1): # Last one is the empty event case
if coord == 'ptetaphim':
x[i,0] = p4vec[i].pt
x[i,1] = p4vec[i].eta
x[i,2] = p4vec[i].phi
x[i,3] = p4vec[i].m
elif coord == 'pxpypze':
x[i,0] = p4vec[i].px
x[i,1] = p4vec[i].py
x[i,2] = p4vec[i].pz
x[i,3] = p4vec[i].e
else:
raise Exception(__name__ + f'.get_node_features: Unknown coordinate representation')
return x