Source code for icefit.transport

# Optimal Transport (Wasserstein) distance measures (tests under construction)
# 
# run tests with: pytest icefit/transport.py -rP -vv
#
# m.mieskolainen@imperial.ac.uk, 2025

import torch
import pytest
import math

[docs] def quantile_function(qs: torch.Tensor, cumweights: torch.Tensor, values: torch.Tensor): """ Computes the quantile function of an empirical distribution (handles also multiple independent columns, i.e. for vectorized SWD) Args: qs: quantile positions where the quantile function is evaluated, (n,) or (n,2) cumweights: cumulative weights of the 1D empirical distribution, (n,) or (n,2) values: locations of the 1D empirical distribution Returns: quantiles of the distribution """ n = values.shape[0] if qs.ndim == 2: qs = qs.transpose(0,1).contiguous() if cumweights.ndim == 2: cumweights = cumweights.transpose(0,1).contiguous() idx = torch.searchsorted(cumweights, qs, right=False) if idx.ndim == 2: idx = idx.transpose(0,1) return torch.take_along_dim(values, torch.clip(idx, 0, n - 1), dim=0)
[docs] def wasserstein_distance_1D(u_values: torch.Tensor, v_values: torch.Tensor, u_weights: torch.Tensor=None, v_weights: torch.Tensor=None, p: int=1, inverse_power: bool=False, norm_weights: bool=True, require_sort: bool=True): """ Wasserstein 1D distance over two empirical samples. This function computes with quantile functions (not just CDFs as with special case p=1), thus compatible with arbitrary p. Args: u_values: sample U vectors (n, [possibly multiple independent columns, for vectorized]) v_values: sample V vectors (m, [as above]) u_weights: sample U weights (n,) (if None, assumed to be unit) v_weights: sample V weights (m,) p: p-norm parameter (p = 1 is 'Earth Movers', 2 = is W-2, ...) num_slices: number of random MC projections (slices) (higher the better) inverse_power: apply final inverse power 1/p norm_weights: normalize per sample (U,V) weights to sum to one require_sort: always by default, unless presorted Returns: distance between the empirical sample distributions """ def zero_pad(a, pad_width, value=0): """ Helper zero-padding function """ how_pad = tuple(element for t in pad_width[::-1] for element in t) return torch.nn.functional.pad(a, how_pad, value=value) n = u_values.shape[0] m = v_values.shape[0] if u_weights is None: u_weights = torch.ones((n,), dtype=u_values.dtype, device=u_values.device) if v_weights is None: v_weights = torch.ones((m,), dtype=v_values.dtype, device=v_values.device) if norm_weights: u_weights = u_weights / torch.sum(u_weights) v_weights = v_weights / torch.sum(v_weights) if u_values.ndim == 2: # need (n,d) view, expand is memory friendly u_weights = u_weights.unsqueeze(1).expand(-1, u_values.size(-1)) v_weights = v_weights.unsqueeze(1).expand(-1, v_values.size(-1)) if require_sort: u_sorter = torch.argsort(u_values, dim=0) u_values = torch.take_along_dim(u_values, u_sorter, dim=0) u_weights = torch.take_along_dim(u_weights, u_sorter, dim=0) v_sorter = torch.argsort(v_values, dim=0) v_values = torch.take_along_dim(v_values, v_sorter, dim=0) v_weights = torch.take_along_dim(v_weights, v_sorter, dim=0) u_cumweights = torch.cumsum(u_weights, dim=0) v_cumweights = torch.cumsum(v_weights, dim=0) # Compute quantile functions qs = torch.sort(torch.cat((u_cumweights, v_cumweights), dim=0), dim=0).values u_quantiles = quantile_function(qs=qs, cumweights=u_cumweights, values=u_values) v_quantiles = quantile_function(qs=qs, cumweights=v_cumweights, values=v_values) del u_cumweights, v_cumweights # Boundary conditions qs = zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) # Measure and integrand delta = qs[1:, ...] - qs[:-1, ...] del qs dq = torch.abs(u_quantiles - v_quantiles) del u_quantiles, v_quantiles if p == 1: return (delta * dq).sum(dim=0) if inverse_power: dq.pow_(p) return (delta * dq).sum(dim=0).pow(1.0 / p) else: dq.pow_(p) return (delta * dq).sum(dim=0)
[docs] @torch.no_grad() def rand_projections(dim: int, N: int=1000, device: str='cpu', dtype=torch.float32): """ Define N random projection directions on the unit sphere S^{dim-1} Normally distributed components and final normalization guarantee (Harm) uniformity. """ projections = torch.randn((N, dim), dtype=dtype, device=device) return projections / torch.norm(projections, p=2, dim=1, keepdim=True)
[docs] @torch.no_grad() def qr_projections( dim: int, N: int=1000, *, device: str = 'cpu', dtype = torch.float32, block_size: int | None = None, random_sign: bool = True, ): """ Orthogonal blocks of random directions, with optional +-1 sign flip. """ if block_size is None: block_size = dim n_blocks = math.ceil(N / block_size) rays = [] for _ in range(n_blocks): k = min(block_size, N - len(rays)) # last block may be smaller g = torch.randn((dim, k), device=device, dtype=dtype) q, _ = torch.linalg.qr(g, mode="reduced") # (dim, k) if random_sign: # independent sign per vector signs = (torch.randint(0, 2, (1, k), device=q.device) * 2 - 1).to(q.dtype) q = q * signs rays.append(q.T) # store k vectors as rows return torch.cat(rays, dim=0)[:N] # (N, dim)
[docs] def sliced_wasserstein_distance(u_values: torch.Tensor, v_values: torch.Tensor, u_weights: torch.Tensor=None, v_weights: torch.Tensor=None, p: int=1, num_slices: int=1000, mode: str='SWD', norm_weights: bool=True, vectorized: bool=True, inverse_power: bool=True): """ Sliced Wasserstein Distance over arbitrary dimensional samples References: https://arxiv.org/abs/1902.00434 https://arxiv.org/abs/2211.08775 https://arxiv.org/abs/2304.13586 Notes: When using this as a loss function e.g. with neural nets, large minibatch sizes may be beneficial or needed. Args: u_values: sample U vectors (n x dim) v_values: sample V vectors (m x dim) u_weights: sample U weights (n,) (if None, assumed to be unit) v_weights: sample V weights (m,) p: p-norm parameter (p = 1 is 'Earth Movers', 2 = is W-2, ...) num_slices: number of random MC projections (slices) (higher the better) mode: 'SWD' (basic uniform MC random) 'SWD-QR' (orthogonal random blocks, lower estimator variance) norm_weights: normalize per sample (U,V) weights to sum to one vectorized: fully vectorized (may take more GPU/CPU memory, but 10x faster) inverse_power: apply final inverse power Returns: distance between the empirical sample distributions """ # Generate a random projection direction dim = int(u_values.shape[-1]) if mode == 'SWD': directions = rand_projections(dim=dim, N=num_slices, device=u_values.device) elif mode == 'SWD-QR': directions = qr_projections(dim=dim, N=num_slices, device=u_values.device) else: raise Exception(f'Unknown mode parameter: {mode}') if vectorized: u_proj = torch.matmul(u_values, directions.T) v_proj = torch.matmul(v_values, directions.T) dist = wasserstein_distance_1D(u_values=u_proj, v_values=v_proj, u_weights=u_weights, v_weights=v_weights, p=p, norm_weights=norm_weights, inverse_power=False) else: dist = torch.zeros(num_slices, device=u_values.device, dtype=u_values.dtype) for i in range(num_slices): # Project the distributions on the random direction u_proj = torch.matmul(u_values, directions[i,:]) v_proj = torch.matmul(v_values, directions[i,:]) # Calculate the 1-dim Wasserstein on the direction dist[i] = wasserstein_distance_1D(u_values=u_proj, v_values=v_proj, u_weights=u_weights, v_weights=v_weights, p=p, norm_weights=norm_weights, inverse_power=False) # Average over dist = torch.sum(dist) / num_slices if inverse_power: return dist ** (1.0 / p) else: return dist
[docs] def test_1D(EPS=1e-3): """ Test function (fixed reference checked against scikit-learn) """ from scipy.stats import wasserstein_distance import numpy as np import ot # ----------------------------------------------- ## p = 1 case checked against scikit-learn p = 1 u = torch.tensor([0.0, 1.0, 3.0], requires_grad=True) v = torch.tensor([5.0, 6.0, 8.0], requires_grad=True) res = wasserstein_distance_1D(u, v, p=p) res_scikit = wasserstein_distance(u.detach().numpy(), v.detach().numpy()) print(f'1D case 1: p = 1 | {res.item()} {res_scikit}') assert res.item() == pytest.approx(res_scikit, abs=EPS) res = wasserstein_distance_1D(torch.tensor([0., 1.]), torch.tensor([0., 1.]), torch.tensor([3., 1.]), torch.tensor([2., 2.]), p=p) res_scikit = wasserstein_distance(np.array([0., 1.]), np.array([0, 1.]), np.array([3., 1.]), np.array([2, 2.])) print(f'1D case 2: p = 1 | {res.item()} {res_scikit}') assert res.item() == pytest.approx(res_scikit, abs=EPS) res = wasserstein_distance_1D(torch.tensor([3.4, 3.9, 7.5, 7.8]), torch.tensor([4.5, 1.4]), torch.tensor([1.4, 0.9, 3.1, 7.2]), torch.tensor([3.2, 3.5])) res_scikit = wasserstein_distance(np.array([3.4, 3.9, 7.5, 7.8]), np.array([4.5, 1.4]), np.array([1.4, 0.9, 3.1, 7.2]), np.array([3.2, 3.5])) print(f'1D case 3: p = 1 | res = {res.item()} res_scikit = {res_scikit}') assert res.item() == pytest.approx(res_scikit, abs=EPS) # ----------------------------------------------- ## p = 1,2 against POT u_values = torch.tensor([1.0, 2.0], requires_grad=True) v_values = torch.tensor([3.0, 4.0], requires_grad=True) u_weights = torch.tensor([0.3, 1.0]) v_weights = torch.tensor([1.0, 0.5]) # pot library does not normalize norm = False for p in [1,2]: res = wasserstein_distance_1D(u_values, v_values, u_weights, v_weights, p=p, norm_weights=False) pot = ot.wasserstein_1d(u_values, v_values, u_weights, v_weights, p=p) print(f'1D case 3: p = {p} | res = {res.item()} pot = {pot}') res_grad = torch.autograd.grad(res, u_values) pot_grad = torch.autograd.grad(pot, u_values) print(f'gradient dW/du: {res_grad}') print(f'gradient dW/du (POT): {pot_grad}') assert res.detach() == pytest.approx(pot.detach(), abs=EPS) assert torch.allclose(res.detach(), pot.detach(), atol=EPS) # ----------------------------------------------- ## p = 2 case checked against pre-computed p = 2 res = wasserstein_distance_1D(torch.tensor([0, 1, 3]), torch.tensor([5, 6, 8]), p=p).item() assert res == pytest.approx(25.0, abs=EPS) res = wasserstein_distance_1D(torch.tensor([0, 1]), torch.tensor([0, 1]), torch.tensor([3, 1]), torch.tensor([2, 2]), p=p).item() assert res == pytest.approx(0.25, abs=EPS) res = wasserstein_distance_1D(torch.tensor([3.4, 3.9, 7.5, 7.8]), torch.tensor([4.5, 1.4]), torch.tensor([1.4, 0.9, 3.1, 7.2]), torch.tensor([3.2, 3.5]), p=p).item() assert res == pytest.approx(19.09, abs=EPS)
[docs] def test_swd(): """ Test sliced Wasserstein distance """ import ot # POT for reference def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. #np.random.seed(seed) # Numpy module. #random.seed(seed) # Python random module. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True from time import time seed = 42 # --------------------------------------------------------- # 2D test n = 20 # number of samples # Mean vectors mu_u = torch.tensor([0, 0], dtype=torch.float32) mu_v = torch.tensor([4, 0], dtype=torch.float32) # Covariance matrices cov_u = torch.tensor([[1.0, 0], [0, 1.0]], dtype=torch.float32) cov_v = torch.tensor([[1.0, -0.3], [-0.3, 1.0]], dtype=torch.float32) # Function to generate random vectors def generate_random_vectors(mu, cov, num_samples): distribution = torch.distributions.MultivariateNormal(mu, cov) samples = distribution.sample((num_samples,)) return samples # Generate random vectors for both distributions set_seed(seed) u_values = generate_random_vectors(mu_u, cov_u, n).requires_grad_(True) v_values = generate_random_vectors(mu_v, cov_v, n).requires_grad_(True) print(u_values.shape) print(v_values.shape) # ------------------------------------------------------------ # Versus POT num_slices = 200 #exact = wasserstein2_distance_gaussian(mu1=mu_u, cov1=cov_u, mu2=mu_v, cov2=cov_v) for p in [1,2]: pot = ot.sliced_wasserstein_distance(X_s=u_values, X_t=v_values, a=torch.ones(n)/n, b=torch.ones(n)/n, n_projections=num_slices, p=p) print(f'gradient dW/du (POT): {torch.autograd.grad(pot, u_values)}') for mode in ['SWD', 'SWD-QR']: for vectorized in [False, True]: set_seed(seed) res = sliced_wasserstein_distance(u_values=u_values, v_values=v_values, p=p, num_slices=num_slices, mode=mode, vectorized=vectorized) print(f'p = {p}: case 2 [{mode}] | res = {res.item()} pot = {pot} (vectorized = {vectorized})') print(f'gradient dW/du: {torch.autograd.grad(res, u_values)}') assert res.detach() == pytest.approx(pot.detach(), abs=0.3) # --------------------------------------------------------- # Fixed values 1D test u_values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) u_weights = torch.tensor([0.2, 0.5, 0.3]) # event weights v_values = torch.tensor([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]]) v_weights = torch.tensor([0.3, 0.4, 0.3]) # event weights p = 2 num_slices = 200 for mode in ['SWD', 'SWD-QR']: set_seed(seed) res = sliced_wasserstein_distance(u_values=u_values, v_values=v_values, u_weights=u_weights, v_weights=v_weights, p=p, num_slices=num_slices, mode=mode).item() print(res) assert res == pytest.approx(0.683, abs=0.05) # ------------------------------------------------------- # Test vectorized versus non-vectorized slicing implementation seed = 42 set_seed(seed) u_values = torch.randn(100, 4) v_values = torch.randn(150, 4) u_weights = None v_weights = None for p in [1,2]: for mode in ['SWD', 'SWD-QR']: # 'vectorized' set_seed(seed) tic = time() d = sliced_wasserstein_distance(u_values=u_values, v_values=v_values, u_weights=u_weights, v_weights=v_weights, p=p, num_slices=num_slices, mode=mode, vectorized=True).item() toc = time() - tic # 'non-vectorized' set_seed(seed) tic_alt = time() d_alt = sliced_wasserstein_distance(u_values=u_values, v_values=v_values, u_weights=u_weights, v_weights=v_weights, p=p, num_slices=num_slices, mode=mode, vectorized=False).item() toc_alt = time() - tic_alt print(f'p = {p} ({mode}) || D = {d} (vectorized, {toc:0.2e} sec) | D = {d_alt} (non-vectorized, {toc_alt:0.2e} sec)') assert d == pytest.approx(d_alt, abs=1e-4)
# --------------------------------------------------------------------- # Advanced SWD test: value, variance, runtime, memory and gradients # ---------------------------------------------------------------------
[docs] def test_swd_advanced(): """ For each (dim, test-case, #rays) the routine compares two samplers - SWD: i.i.d. Gaussian directions - SWD-QR: QR-orthogonal blocks Metrics reported: - Monte-Carlo mean of SWD - Variance across 100 replicates - Relative mean-absolute error (RMAE) against a high-ray POT ref - Mean relative gradient error vs. POT gradient - Peak memory during a forward/backward pass - Average wall-time per replicate """ import ot # POT for reference import psutil import os import time import gc def set_seed(seed: int): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def sample_mog(n: int, dim: int, n_components: int = 3): """Mixture-of-Gaussians with random means/covariances.""" # draw mixing weights that sum to 1 weights = torch.distributions.Dirichlet(torch.ones(n_components)).sample() comps = torch.multinomial(weights, n, replacement=True) # (n,) means = torch.randn(n_components, dim) * 2.0 covs = [] for _ in range(n_components): A = torch.randn(dim, dim) covs.append((A @ A.T) / dim + 0.2 * torch.eye(dim)) covs = torch.stack(covs) # (k, d, d) out = torch.empty(n, dim) for k in range(n_components): idx = (comps == k).nonzero(as_tuple=True)[0] if idx.numel() == 0: continue mvn = torch.distributions.MultivariateNormal(means[k], covs[k]) out[idx] = mvn.sample((idx.numel(),)) return out def sample_ring(n: int, dim: int, radius=4.0, noise=0.2): """d-dimensional ring/donut distribution.""" theta = torch.rand(n) * 2 * math.pi base = torch.stack((radius * torch.cos(theta), radius * torch.sin(theta)), dim=1) if dim > 2: extra = torch.randn(n, dim - 2) * radius * 0.1 base = torch.cat((base, extra), dim=1) return base + noise * torch.randn_like(base) seed = 123 ray_grid = (1, 10, 100, 1000) reps = 100 n_samples = 2048 nrays_ref = int(1e5) # POT reference rays rtol_grad = 0.30 atol_grad = 1e-4 device = 'cuda' if torch.cuda.is_available() else 'cpu' # ----------------------------------------------------------------- # helper to draw two independent clouds with a deterministic seed # ----------------------------------------------------------------- def _draw_pair(f1, f2, *, base_seed): set_seed(base_seed) X = f1(n_samples, dim).to(device) set_seed(base_seed + 999) Y = f2(n_samples, dim).to(device) return X, Y test_cases = [ ('mog_vs_ring', lambda bs: _draw_pair(sample_mog, sample_ring, base_seed=bs)), ('mog_vs_mog', lambda bs: _draw_pair(sample_mog, sample_mog, base_seed=bs)), ] for dim in [2, 4, 16, 64]: print('\n' + '=' * 70) print(f'** Dimension: {dim} **') for case_name, make_xy in test_cases: print('\n' + '-' * 70) print(f'Test case: {case_name}') # ----------------------- data -------------------------------- X_raw, Y_raw = make_xy(seed) # POT reference value (no grad needed) with torch.no_grad(): set_seed(seed) exact_val = ot.sliced_wasserstein_distance( X_s=X_raw, X_t=Y_raw, a=None, b=None, n_projections=nrays_ref, p=2 ) print(f' POT reference value ({nrays_ref:0.0E} rays): {exact_val:.4f}') # POT reference gradient (needs requires_grad) set_seed(seed) X_ref = X_raw.clone().detach().requires_grad_(True) Y_ref = Y_raw.clone().detach() pot_val = ot.sliced_wasserstein_distance( X_s=X_ref, X_t=Y_ref, n_projections=nrays_ref, p=2 ) pot_val.backward() grad_ref = X_ref.grad.detach() ref_norm = grad_ref.norm().clamp_min(1e-12) # ---------------------- table header ------------------------- header = (f"{'rays':>6} | {'sampler':>12} | {'mean':>10} | " f"{'var':>10} | {'RMAE':>10} | {'grad err':>10} | " f"{'peak MB':>8} | {'ms/rep':>8}") print('\n' + header) print('-' * len(header)) # Loop over #slices and samplers for num_slices in ray_grid: results = {} for sampler in ('SWD', 'SWD-QR'): # --------------------------------------------------- # (1) PEAK-MEM measurement, run a few fwd/bwd passes # --------------------------------------------------- if device.startswith('cuda'): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) start_mem = torch.cuda.memory_allocated(device) else: process = psutil.Process(os.getpid()) gc.collect() start_mem = process.memory_info().rss for _ in range(5): set_seed(seed) X_ = X_raw.clone().detach().requires_grad_(True) Y_ = Y_raw.clone().detach() loss = sliced_wasserstein_distance( X_, Y_, p=2, num_slices=num_slices, mode=sampler, vectorized=True ) loss.backward() if device.startswith('cuda'): peak_mb = (torch.cuda.max_memory_allocated(device) - start_mem) / 1024**2 else: gc.collect() peak_mb = (psutil.Process(os.getpid()).memory_info().rss - start_mem) / 1024**2 # --------------------------------------------------- # (2) Gradient error (single pass, same seed) # --------------------------------------------------- set_seed(seed) X_grad = X_raw.clone().detach().requires_grad_(True) Y_grad = Y_raw.clone().detach() g_loss = sliced_wasserstein_distance( X_grad, Y_grad, p=2, num_slices=num_slices, mode=sampler, vectorized=True ) g_loss.backward() grad_err = (X_grad.grad - grad_ref).norm() / ref_norm # --------------------------------------------------- # (3) Monte-Carlo replicates for mean / var / RMAE # --------------------------------------------------- vals = torch.empty(reps, device=device) tic = time.perf_counter() with torch.no_grad(): for r in range(reps): set_seed(seed + r) vals[r] = sliced_wasserstein_distance( X_raw, Y_raw, p=2, num_slices=num_slices, mode=sampler, vectorized=True, ) elapsed = (time.perf_counter() - tic) * 1e3 / reps # ms/rep results[sampler] = dict( mean = vals.mean().cpu(), var = vals.var(unbiased=True).cpu(), rmae = ((vals - exact_val).abs().mean() / exact_val).cpu(), grad_err = grad_err.cpu(), peak_mb = peak_mb, elapsed = elapsed, ) print(f"{num_slices:6d} | {sampler:12s} | " f"{results[sampler]['mean']:10.2e} | " f"{results[sampler]['var']:10.2e} | " f"{results[sampler]['rmae']:10.2e} | " f"{results[sampler]['grad_err']:10.2e} | " f"{results[sampler]['peak_mb']:8.2f} | " f"{results[sampler]['elapsed']:8.2f}")