Source code for healpix_resample.knn

"""
knn.py

GPU-friendly sparse HEALPix regridding from unstructured lon/lat samples
to a subset of HEALPix pixels at a target resolution (nside = 2**level).

Core ideas:
- Use HEALPix local neighbourhoods (healpix_geo.kth_neighbourhood) to avoid N×npix distance matrices.
- Build sparse operators M (samples -> grid) and MT (grid -> samples) with Gaussian weights.
- Solve a damped least-squares problem with Conjugate Gradient (CG) on normal equations.

This module is designed for large N and batched values (B,N) on CUDA.
"""

import math
from typing import Generic, Tuple, Optional, Union

import healpix_geo
import numpy as np
import torch

from healpix_resample.base import ResampleResults, T_Array


def _lonlat_to_xyz(lon_rad: torch.Tensor, lat_rad: torch.Tensor) -> torch.Tensor:
    clat = torch.cos(lat_rad)
    return torch.stack([clat * torch.cos(lon_rad),
                        clat * torch.sin(lon_rad),
                        torch.sin(lat_rad)], dim=-1)  # (...,3)

def _sigma_level_m(level: int, radius: float = 6371000.0) -> float:
    # sigma = sqrt(4*pi / (12*4**level)) * R
    return math.sqrt(4.0 * math.pi / (12.0 * (4.0 ** level))) * radius
    
[docs] @torch.no_grad() def healpix_weighted_nearest( longitude1: torch.Tensor, # (N,) degrés latitude1: torch.Tensor, # (N,) degrés level: int, Npt: int, *, nest: bool = True, threshold: float = 0.1, radius: float = 6371000.0, ellipsoid: str = "WGS84", sigma: float | None = None, # sous-ensemble de pixels de sortie autorisés (en ids healpix au même "level") out_cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, # voisinage utilisé pour estimer les poids (construction cell_ids) ring_weight: Optional[int] = None, # voisinage utilisé pour trouver Npt voisins parmi les pixels gardés (peut être augmenté automatiquement) ring_search_init: Optional[int] = None, ring_search_max: int = 20, num_threads: int = 0, device_for_dist: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Retourne: cell_ids: (K,) pixels HEALPix (au level) retenus par le seuil de poids (et éventuellement intersectés avec out_cell_ids) idx_k : (N, Npt) indices dans cell_ids (0..K-1), -1 si insuffisant dist_k : (N, Npt) distances (m) vers centres des pixels correspondants, inf si insuffisant Notes: - On utilise la distance géodésique via angle = acos(dot(xyz)) ; dist = R * angle. - healpix_geo est utilisé pour lonlat_to_healpix, kth_neighbourhood, healpix_to_lonlat. - La précision est très correcte pour du “pixel center matching”. """ assert longitude1.ndim == latitude1.ndim == 1 N = int(longitude1.numel()) assert Npt >= 1 # --- choix rings par défaut # ring minimal pour avoir >= Npt candidats dans un carré (2r+1)^2, + marge r_min = int(math.ceil((math.sqrt(Npt) - 1.0) / 2.0)) if ring_weight is None: ring_weight = max(2, r_min + 2) # plus large: mieux pour estimer les poids globaux if ring_search_init is None: ring_search_init = max(1, r_min + 1) # --- distances: on peut les faire sur GPU si dispo dev = device_for_dist if device_for_dist is not None else longitude1.device # --- CPU numpy pour healpix_geo lon_np = longitude1.detach().cpu().numpy().astype(np.float64) lat_np = latitude1.detach().cpu().numpy().astype(np.float64) if nest: ipix1 = healpix_geo.nested.lonlat_to_healpix(lon_np, lat_np, level, num_threads=num_threads,ellipsoid=ellipsoid) ipix_u, inv = np.unique(ipix1.astype(np.uint64), return_inverse=True) if Npt==1: return torch.from_numpy(ipix_u.astype(np.int64)).to(dev), torch.from_numpy(inv.astype(np.int64)).to(dev),0 neigh_u_w = healpix_geo.nested.kth_neighbourhood(ipix_u, level, ring_weight, num_threads=num_threads) else: ipix1 = healpix_geo.ring.lonlat_to_healpix(lon_np, lat_np, level, num_threads=num_threads,ellipsoid=ellipsoid) ipix_u, inv = np.unique(ipix1.astype(np.uint64), return_inverse=True) if Npt==1: return torch.from_numpy(ipix_u.astype(np.int64)).to(dev), torch.from_numpy(inv.astype(np.int64)).to(dev),0 neigh_u_w = healpix_geo.ring.kth_neighbourhood(ipix_u, level, ring_weight, num_threads=num_threads) # healpix_geo.kth_neighbourhood may return -1 for "invalid" neighbours. # We replace invalid entries by a valid neighbour (the last valid in the row; fallback to the center pixel). # This keeps arrays dense (no masking) and avoids uint64 overflow issues. # Duplicates are later handled naturally by weight accumulation + normalization. neigh_u_w = neigh_u_w.astype(np.int64, copy=False) valid = neigh_u_w >= 0 # position du dernier valide par ligne last_valid_pos = valid[:, ::-1].argmax(axis=1) last_valid_pos = (neigh_u_w.shape[1] - 1) - last_valid_pos last_valid = neigh_u_w[np.arange(neigh_u_w.shape[0]), last_valid_pos] # fallback si toute la ligne est invalide -> centre all_invalid = ~valid.any(axis=1) last_valid[all_invalid] = ipix_u[all_invalid].astype(np.int64, copy=False) # remplace les -1 (broadcast explicite) mask = neigh_u_w < 0 neigh_u_w[mask] = np.broadcast_to(last_valid[:, None], neigh_u_w.shape)[mask] # neigh_w : (N, Kw) neigh_w = neigh_u_w[inv] Kw = neigh_w.shape[1] # --- centres lon/lat des pixels du voisinage (Kw*N potentiellement grand) => on unique neigh_w_flat = neigh_w.reshape(-1).astype(np.uint64) neigh_w_uniq, back = np.unique(neigh_w_flat, return_inverse=True) if nest: lon_c_deg, lat_c_deg = healpix_geo.nested.healpix_to_lonlat(neigh_w_uniq, level,ellipsoid=ellipsoid) else: lon_c_deg, lat_c_deg = healpix_geo.ring.healpix_to_lonlat(neigh_w_uniq, level,ellipsoid=ellipsoid) lon1 = torch.deg2rad(longitude1.to(dev)) lat1 = torch.deg2rad(latitude1.to(dev)) xyz1 = _lonlat_to_xyz(lon1, lat1) # (N,3) lon_c = torch.deg2rad(torch.from_numpy(np.asarray(lon_c_deg)).to(dev)) lat_c = torch.deg2rad(torch.from_numpy(np.asarray(lat_c_deg)).to(dev)) xyz_c = _lonlat_to_xyz(lon_c, lat_c) # (Kuniq,3) # Remap neighbours vers indices uniques (Kuniq) back_t = torch.from_numpy(back.astype(np.int64)).to(dev) # (N*Kw,) back_t = back_t.view(N, Kw) # (N,Kw) xyz_c_n = xyz_c[back_t] # (N,Kw,3) dot = (xyz_c_n * xyz1[:, None, :]).sum(dim=-1) # (N,Kw) dot = torch.clamp(dot, -1.0, 1.0) ang = torch.acos(dot) dist = radius * ang # (N,Kw) # --- poids et somme par pixel # w = exp(-2*d^2/sigma^2) w = torch.exp((-2.0) * (dist * dist) / (sigma * sigma)) # (N,Kw) # scatter_add sur pixels uniques # On accumule w sur chaque pixel (Kuniq) sums = torch.zeros((xyz_c.shape[0],), device=dev, dtype=w.dtype) # Thresholding stage: # We compute Gaussian weights from each sample to its neighbourhood pixel centers, # then accumulate a global weight sum per HEALPix cell. Only cells whose total # influence exceeds 'threshold' are kept in cell_ids_keep (size K). sums.scatter_add_(0, back_t.reshape(-1), w.reshape(-1)) keep = sums >= threshold if not torch.any(keep): # aucun pixel ne passe le seuil -> on retourne vide + -1/inf cell_ids = torch.empty((0,), device=dev, dtype=torch.long) idx_k = torch.full((N, Npt), -1, device=dev, dtype=torch.long) dist_k = torch.full((N, Npt), float("inf"), device=dev, dtype=lon1.dtype) return cell_ids, idx_k, dist_k # cell_ids retenus (en id healpix) cell_ids_np = neigh_w_uniq # np.uint64, taille Kuniq keep_idx = torch.nonzero(keep, as_tuple=False).squeeze(1) # indices dans [0..Kuniq-1] cell_ids_keep = torch.from_numpy(cell_ids_np.astype(np.int64)).to(dev)[keep_idx] # (K,) # xyz des pixels retenus (K,3) xyz_keep = xyz_c[keep_idx] # --- optionnel: restreindre explicitement les pixels de sortie # out_cell_ids peut être une liste/array/torch tensor d'ids HEALPix (au même level). if out_cell_ids is not None: out_t = out_cell_ids if isinstance(out_cell_ids, torch.Tensor) else torch.as_tensor(out_cell_ids) out_t = out_t.to(device=dev, dtype=torch.long).reshape(-1) if out_t.numel() == 0: # sous-ensemble vide demandé -> rien à regriller cell_ids = torch.empty((0,), device=dev, dtype=torch.long) idx_k = torch.full((N, Npt), -1, device=dev, dtype=torch.long) dist_k = torch.full((N, Npt), float("inf"), device=dev, dtype=lon1.dtype) return cell_ids, idx_k, dist_k out_sorted = torch.unique(out_t) out_sorted, _ = torch.sort(out_sorted) # test d'appartenance robuste via searchsorted (évite dépendre de torch.isin) pos_out = torch.searchsorted(out_sorted, cell_ids_keep) pos_out = torch.clamp(pos_out, 0, out_sorted.numel() - 1) in_mask = (out_sorted[pos_out] == cell_ids_keep) if not torch.any(in_mask): cell_ids = torch.empty((0,), device=dev, dtype=torch.long) idx_k = torch.full((N, Npt), -1, device=dev, dtype=torch.long) dist_k = torch.full((N, Npt), float("inf"), device=dev, dtype=lon1.dtype) return cell_ids, idx_k, dist_k keep_idx = keep_idx[in_mask] cell_ids_keep = cell_ids_keep[in_mask] xyz_keep = xyz_c[keep_idx] # --- pour chaque point: trouver Npt pixels retenus les plus proches # Stratégie: voisinage healpix autour du pixel contenant le point, ring qui s’agrandit # On mappe les pixels du voisinage -> indices dans cell_ids_keep via tri+searchsorted. cell_sorted, order = torch.sort(cell_ids_keep) # (K,), (K,) K = int(cell_ids_keep.numel()) idx_out = torch.full((N, Npt), -1, device=dev, dtype=torch.long) dist_out = torch.full((N, Npt), float("inf"), device=dev, dtype=lon1.dtype) r = ring_search_init done = torch.zeros((N,), device=dev, dtype=torch.bool) # On travaille côté CPU pour kth_neighbourhood, mais on ne fait ça que pour les rings nécessaires. # On recalcule sur pixels uniques pour limiter les appels. while r <= ring_search_max and not bool(torch.all(done).item()): if nest: neigh_u = healpix_geo.nested.kth_neighbourhood(ipix_u, level, r, num_threads=num_threads) else: neigh_u = healpix_geo.ring.kth_neighbourhood(ipix_u, level, r, num_threads=num_threads) neigh = neigh_u[inv] # (N, Ks) Ks = neigh.shape[1] neigh_t = torch.from_numpy(neigh.astype(np.int64)).to(dev) # map neigh_t -> index dans cell_ids_keep pos = torch.searchsorted(cell_sorted, neigh_t) pos = torch.clamp(pos, 0, K - 1) hit = (cell_sorted[pos] == neigh_t) cand_idx_keep = torch.where(hit, order[pos], torch.full_like(pos, -1)) # indices dans cell_ids_keep # distances pour candidats valides safe = torch.clamp(cand_idx_keep, 0, K - 1) xyz2 = xyz_keep[safe] # (N,Ks,3) dot2 = (xyz2 * xyz1[:, None, :]).sum(dim=-1) dot2 = torch.clamp(dot2, -1.0, 1.0) dist2 = radius * torch.acos(dot2) dist2 = torch.where(cand_idx_keep >= 0, dist2, torch.full_like(dist2, float("inf"))) # topk parmi ces candidats k = min(Npt, Ks) d_k, p_k = torch.topk(dist2, k=k, dim=1, largest=False, sorted=True) i_k = torch.gather(cand_idx_keep, 1, p_k) # pour les points pas encore "done", on accepte si on a au moins Npt valides (i.e. d_k[:, Npt-1] < inf) valid_enough = (d_k[:, -1] < float("inf")) update = (~done) & valid_enough if torch.any(update): idx_out[update] = i_k[update] dist_out[update] = d_k[update] done[update] = True r += 1 # Si certains points n’ont jamais trouvé Npt pixels gardés, on laisse -1/inf (ou on pourrait rendre moins que Npt) return cell_ids_keep, idx_out, dist_out
[docs] class KNeighborsResampler(Generic[T_Array]): """GPU-friendly sparse HEALPix regridding via local Gaussian weights + CG deconvolution. This class builds two sparse operators from unstructured lon/lat samples to a subset of HEALPix pixels at a target resolution (nside = 2**level). Notation (matching your notebook): - N: number of samples (lon/lat) - K: number of kept HEALPix cells (cell_ids) - M: operator of shape (N, K) (named ``M`` here) - MT: operator of shape (K, N) (named ``MT`` here) The solver estimates ``hval`` (B,K) such that ``M @ hval.T`` matches ``val`` (B,N), by solving a damped normal equation around a reference field ``x_ref = val @ M``. """
[docs] def __init__( self, lon_deg: T_Array, lat_deg: T_Array, Npt: int, level: int, *, nest: bool = True, radius: float = 6371000.0, ellipsoid: str = "WGS84", dtype: torch.dtype = torch.float64, device: torch.device | str | None = None, ring_weight: Optional[int] = None, ring_search_init: Optional[int] = None, ring_search_max: int = 2, num_threads: int = 0, threshold: float = 0.1, sigma_m: float | None = None, verbose: bool = True, out_cell_ids: T_Array | None = None, group_by: bool = False, ) -> None: """Pre-compute sparse operators. Args: lon_deg, lat_deg: unstructured sample coordinates in degrees, shape (N,) Npt: number of nearest HEALPix cells used per sample level: HEALPix level, nside = 2**level sigma_m: Gaussian length scale (meters). If None, uses the HEALPix pixel scale sigma = sqrt(4*pi/(12*4**level))*R. threshold: keep only HEALPix cells whose global weight sum >= threshold nest: HEALPix indexing scheme dtype/device: torch dtype/device for all matrices and computations """ self.level = int(level) self.nside = 2 ** int(level) self.group_by = bool(group_by) if not self.group_by: self.Npt = int(Npt) self.nest = bool(nest) self.radius = float(radius) self.ellipsoid = str(ellipsoid) self.dtype = dtype if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" else: if device.startswith("cuda") and not torch.cuda.is_available(): raise RuntimeError("CUDA requested but not available.") self.device = torch.device(device) self.threshold = float(threshold) self.out_cell_ids = out_cell_ids # --- sigma in meters (controls the Gaussian weights used for thresholding) sigma = float(_sigma_level_m(level, radius=radius) if sigma_m is None else sigma_m) self.sigma_m = sigma self.verbose = verbose # --- move lon/lat to torch on target device (but healpix_geo needs CPU numpy internally) lon_t = lon_deg if isinstance(lon_deg, torch.Tensor) else torch.as_tensor(lon_deg) lat_t = lat_deg if isinstance(lat_deg, torch.Tensor) else torch.as_tensor(lat_deg) if self.group_by: if self.nest: cell_ids = healpix_geo.nested.lonlat_to_healpix(lon_t,lat_t, self.level, ellipsoid=self.ellipsoid) else: cell_ids = healpix_geo.ring.lonlat_to_healpix(lon_t,lat_t, self.level, ellipsoid=self.ellipsoid) cell_ids,hi = torch.unique(torch.tensor(cell_ids,dtype=torch.long,device=self.device), return_inverse=True) self.cell_ids = cell_ids self.hi = hi else: lon_t = lon_t.to(self.device) lat_t = lat_t.to(self.device) if self.out_cell_ids is not None: self.xyz_samples = _lonlat_to_xyz(torch.deg2rad(lon_t),torch.deg2rad(lat_t)) # (N,3) # --- get kept healpix cells + per-sample nearest indices + distances cell_ids, hi, d = healpix_weighted_nearest( lon_t, lat_t, level=self.level, Npt=self.Npt, nest=self.nest, threshold=self.threshold, radius=self.radius, ellipsoid=self.ellipsoid, sigma=self.sigma_m, out_cell_ids=self.out_cell_ids, ring_weight=ring_weight, ring_search_init=ring_search_init, ring_search_max=ring_search_max, num_threads=num_threads, device_for_dist=self.device, ) if cell_ids.numel() == 0: raise RuntimeError( "No HEALPix cell passed the threshold. " "Lower 'threshold' or increase neighbourhood rings." ) # Store geometry outputs self.cell_ids = cell_ids.to(torch.long).to(self.device) # (K,) self.hi = hi.to(torch.long).to(self.device) # (N,Npt) indices into cell_ids if Npt>1: self.d_m = d.to(self.dtype).to(self.device) # (N,Npt) meters self.K = int(self.cell_ids.numel()) self.N = int(lon_t.numel()) if self.out_cell_ids is not None: # --- geometry buffers for optional fallbacks (e.g. when out_cell_ids forces empty columns) # unit vectors for output HEALPix cell centers (K,3) cell_np = self.cell_ids.detach().cpu().numpy().astype(np.uint64) if self.nest: lon_c_deg, lat_c_deg = healpix_geo.nested.healpix_to_lonlat(cell_np, self.level, ellipsoid=self.ellipsoid) else: lon_c_deg, lat_c_deg = healpix_geo.ring.healpix_to_lonlat(cell_np, self.level, ellipsoid=self.ellipsoid) lon_c = torch.deg2rad(torch.as_tensor(lon_c_deg, device=self.device, dtype=self.xyz_samples.dtype)) lat_c = torch.deg2rad(torch.as_tensor(lat_c_deg, device=self.device, dtype=self.xyz_samples.dtype)) self.xyz_cells = _lonlat_to_xyz(lon_c, lat_c) # (K,3) if not self.group_by: self.comp_matrix()
def comp_matrix(self): # --- weights per sample->cell link # w = exp(-2*d^2/sigma^2) w = torch.exp((-2.0) * (self.d_m * self.d_m) / (self.sigma_m * self.sigma_m)) # Build (N,K) operator M and (K,N) operator MT. # We avoid numpy bincount; use torch.bincount on GPU. # idx: (N,Npt) row indices 0..N-1 idx = torch.arange(self.N, device=self.device, dtype=torch.long)[:, None].expand(self.N, self.Npt) # -------- M : (N,K) (normalized per column / per healpix cell) # norm_col[k] = sum_{i links to k} w[i,k] flat_hi = self.hi.reshape(-1) flat_w = w.reshape(-1) valid = flat_hi >= 0 flat_hi_v = flat_hi[valid] flat_w_v = flat_w[valid] norm_col = torch.bincount(flat_hi_v, weights=flat_w_v, minlength=self.K).to(self.dtype) # weight divided by column sum wM = flat_w_v / norm_col[flat_hi_v] rowsM = idx.reshape(-1)[valid] colsM = flat_hi_v indicesM = torch.stack([rowsM, colsM], dim=0) M_coo = torch.sparse_coo_tensor( indicesM, wM.to(self.dtype), size=(self.N, self.K), device=self.device, dtype=self.dtype, ).coalesce() # -------- MT : (K,N) (normalized per row / per input sample) # norm_row[i] = sum_{k links from i} w[i,k] flat_idx = idx.reshape(-1) flat_idx_v = flat_idx[valid] norm_row = torch.bincount(flat_idx_v, weights=flat_w_v, minlength=self.N).to(self.dtype) wMT = flat_w_v / norm_row[flat_idx_v] indicesMT = torch.stack([colsM, rowsM], dim=0) # (hi, idx) MT_coo = torch.sparse_coo_tensor( indicesMT, wMT.to(self.dtype), size=(self.K, self.N), device=self.device, dtype=self.dtype, ).coalesce() # Convert to CSR for faster spMM (recommended on GPU) self.M = M_coo.to_sparse_csr() self.MT = MT_coo.to_sparse_csr() @torch.no_grad() def invert(self, hval: T_Array) -> T_Array: """Project HEALPix field back to the sample locations. Args: hval: (B,K) or (K,) on same device Returns: val_hat: (B,N) or (N,) """ y = hval if isinstance(hval, torch.Tensor) else torch.as_tensor(hval) y = y.to(self.device, dtype=self.dtype) if hval.ndim == 1: res = (y[None, :] @ self.MT)[0] else: res = y @ self.MT if not isinstance(hval, torch.Tensor): res=res.cpu().numpy() return res @torch.no_grad() def resample( self, val: T_Array, *, lam: float = 0.0, max_iter: int = 100, tol: float = 1e-8, x0: Optional[torch.Tensor] = None, return_info: bool = False, ) -> ResampleResults[T_Array]: """Estimate the HEALPix field from unstructured samples. Args: val: (B,N) or (N,) values at lon/lat sample points lam: Tikhonov regularization strength (damping) used in CG max_iter, tol: CG parameters x0: optional initial guess for the *delta* around x_ref, shape (B,K) return_info: whether to return CG diagnostics Returns: hval: (B,K) or (K,) (optional) info: CG information dict """ y = val if isinstance(val, torch.Tensor) else torch.as_tensor(val) y = y.to(self.device, dtype=self.dtype) clean_shape=False if y.ndim == 1: clean_shape=True y = y[None, :] # reference field (B,K) hval = y @ self.M cell_ids = self.cell_ids if not isinstance(val, torch.Tensor): hval= hval.cpu().numpy() cell_ids = cell_ids.cpu().numpy() if clean_shape: hval = hval[0] return ResampleResults(cell_data=hval, cell_ids=cell_ids) def get_cell_ids(self): return self.cell_ids.cpu().numpy()