Source code for healpix_resample.nearest

"""
nearest.py

Nearest-neighbour HEALPix resampler built on top of KNeighborsResampler.

Strategy
--------
1. Run the standard KNN (Npt neighbours per source sample) — this gives
   ``hi (N, Npt)`` (cell indices per source) and ``d_m (N, Npt)`` (distances).
2. Compute Gaussian weights  w[n,j] = exp(-2 d²/σ²).
3. For each HEALPix cell k, keep only the source n that maximises w  →  hi_k (K,).
   This is a single ``scatter_reduce('amax')`` pass — O(N·Npt), no sparse matrix.
4. ``resample`` : hval[:, k] = val[:, hi_k[k]]            (direct index)
   ``invert``   : val[:, n]  = mean of hval[:, k] ∀k s.t. hi_k[k]==n  (scatter-mean)

When ``out_cell_ids`` is provided
----------------------------------
``healpix_weighted_nearest`` already intersects the KNN result with ``out_cell_ids``,
but cells beyond ``ring_search_max`` rings may be absent.  After the KNN step,
missing cells are filled via a memory-bounded chunked dot-product fallback.
"""
from __future__ import annotations

from typing import Optional

import healpix_geo
import numpy as np
import torch

from healpix_resample.base import ResampleResults, T_Array
from healpix_resample.knn import KNeighborsResampler, _lonlat_to_xyz


# ─────────────────────────────────────────────────────────────────────────────
# Helper: scatter-mean
# ─────────────────────────────────────────────────────────────────────────────

def _scatter_mean(src: torch.Tensor, idx: torch.Tensor, T: int) -> torch.Tensor:
    """out[:, t] = mean of src[:, s] where idx[s] == t.  Shape (B, S) → (B, T)."""
    B   = src.shape[0]
    out = torch.zeros(B, T, device=src.device, dtype=src.dtype)
    out.scatter_add_(1, idx.unsqueeze(0).expand(B, -1), src)
    count = torch.bincount(idx, minlength=T).to(src.dtype).clamp(min=1)
    out /= count.unsqueeze(0)
    return out


# ─────────────────────────────────────────────────────────────────────────────
# NearestResampler
# ─────────────────────────────────────────────────────────────────────────────

[docs] class NearestResampler(KNeighborsResampler): """Nearest-neighbour HEALPix resampler — no sparse matrices. Uses ``Npt`` KNN neighbours (default 9) to robustly find the nearest source for every HEALPix cell, even when the output grid is finer than the input. Parameters ---------- Npt : int Number of HEALPix neighbours per source sample used by the KNN. Larger values cover more cells at the cost of more memory during construction. Default 9 is a good trade-off for most use cases. All other parameters are forwarded to ``KNeighborsResampler``. """
[docs] def __init__(self, *args, Npt: int = 9, **kwargs): # Ensure ring_search_max >= ring_search_init(Npt) so the KNN search # loop in healpix_weighted_nearest actually executes. # # healpix_weighted_nearest computes: # r_min = ceil((sqrt(Npt) - 1) / 2) # ring_search_init = max(1, r_min + 1) # # KNeighborsResampler default is ring_search_max=2, which is too small # for Npt >= 16 (needs ring_search_init=3). We auto-correct here only # when the caller has not supplied ring_search_max explicitly. if "ring_search_max" not in kwargs: import math as _math r_min = int(_math.ceil((_math.sqrt(Npt) - 1.0) / 2.0)) ring_search_init_needed = max(1, r_min + 1) # +2 margin so the loop has room to grow and find Npt candidates kwargs["ring_search_max"] = ring_search_init_needed + 2 # super().__init__ calls self.comp_matrix() at the end — our override # will be called there, building hi_k instead of sparse M/MT. super().__init__(*args, Npt=Npt, **kwargs) # If out_cell_ids was requested, fill any cells the KNN rings missed. if self.out_cell_ids is not None: self._fill_missing_out_cells() # Geometry buffers only needed during construction — free them. for attr in ("xyz_samples", "xyz_cells"): if hasattr(self, attr): delattr(self, attr) if self.device.type == "cuda": torch.cuda.empty_cache()
# ── comp_matrix: scatter_reduce instead of sparse M / MT ───────────────── def comp_matrix(self) -> None: """Build hi_k (K,) — nearest source index per HEALPix cell. Algorithm --------- 1. Compute Gaussian weights w[n,j] from the KNN distances d_m[n,j]. 2. Flatten (N, Npt) → (N*Npt,) pairs (source_n, cell_k, weight). 3. scatter_reduce('amax') over cell_k → max weight per cell O(N·Npt). 4. Keep the winner per cell (deduplicate ties with stable sort). Result: ``self.hi_k`` (K,) long — index into [0..N-1] for each cell. """ # w[n,j] = exp(-2 d²/σ²) ─ same formula as KNeighborsResampler w = torch.exp( (-2.0) * (self.d_m * self.d_m) / (self.sigma_m * self.sigma_m) ) # (N, Npt) flat_hi = self.hi.reshape(-1) # (N*Npt,) flat_w = w.reshape(-1) # (N*Npt,) flat_n = ( torch.arange(self.N, device=self.device, dtype=torch.long) .unsqueeze(1).expand(self.N, self.Npt).reshape(-1) ) # (N*Npt,) del w # Discard invalid links (hi == -1 means no cell was found) valid = flat_hi >= 0 flat_hi_v = flat_hi[valid] flat_w_v = flat_w[valid] flat_n_v = flat_n[valid] del flat_hi, flat_w, flat_n # ── Max weight per cell ──────────────────────────────────────────── max_w = torch.full( (self.K,), float("-inf"), device=self.device, dtype=flat_w_v.dtype ) max_w.scatter_reduce_( 0, flat_hi_v, flat_w_v, reduce="amax", include_self=True ) # ── Keep winner pairs (weight == max for their cell) ────────────── is_winner = flat_w_v >= max_w[flat_hi_v] del max_w, flat_w_v win_cell = flat_hi_v[is_winner] win_src = flat_n_v[is_winner] del flat_hi_v, flat_n_v, is_winner # ── Deduplicate ties: stable sort by cell, first occurrence wins ── ord_w = torch.argsort(win_cell, stable=True) win_cell = win_cell[ord_w] win_src = win_src[ord_w] del ord_w is_first = torch.ones(len(win_cell), dtype=torch.bool, device=self.device) is_first[1:] = win_cell[1:] != win_cell[:-1] # ── Write result: hi_k[k] = nearest source index for cell k ─────── # -1 means no source reached the cell (shouldn't happen for cells that # passed the threshold, but kept as a defensive sentinel). hi_k = torch.full((self.K,), -1, dtype=torch.long, device=self.device) hi_k[win_cell[is_first]] = win_src[is_first] del win_cell, win_src, is_first self.hi_k = hi_k # (K,) # ── Fill cells from out_cell_ids missed by the KNN rings ───────────────── def _fill_missing_out_cells(self) -> None: """Extend cell_ids / hi_k to cover every cell in out_cell_ids. Cells in ``out_cell_ids`` not returned by ``healpix_weighted_nearest`` (i.e., beyond ``ring_search_max`` rings of any source) are handled via a memory-bounded chunked dot-product: the nearest source is found directly from ``self.xyz_samples`` without pair explosion. """ out_t = self.out_cell_ids if not isinstance(out_t, torch.Tensor): out_t = torch.as_tensor(out_t) out_t = out_t.to(device=self.device, dtype=torch.long).reshape(-1) # ── Find missing cells ───────────────────────────────────────────── cell_sorted, _ = torch.sort(self.cell_ids) out_sorted, _ = torch.sort(out_t) pos = torch.searchsorted(cell_sorted, out_sorted).clamp(0, self.K - 1) present = cell_sorted[pos] == out_sorted missing = out_sorted[~present] # (M,) if missing.numel() == 0: return if self.verbose: print( f"[NearestResampler] {missing.numel():,} out_cell_ids not covered " f"by KNN rings → chunked dot-product fallback" ) # ── xyz of missing cell centres ──────────────────────────────────── miss_np = missing.cpu().numpy().astype(np.uint64) if self.nest: lon_c_deg, lat_c_deg = healpix_geo.nested.healpix_to_lonlat( miss_np, self.level, ellipsoid=self.ellipsoid ) else: lon_c_deg, lat_c_deg = healpix_geo.ring.healpix_to_lonlat( miss_np, self.level, ellipsoid=self.ellipsoid ) src_dtype = self.xyz_samples.dtype xyz_miss = _lonlat_to_xyz( torch.deg2rad(torch.as_tensor(lon_c_deg, device=self.device, dtype=src_dtype)), torch.deg2rad(torch.as_tensor(lat_c_deg, device=self.device, dtype=src_dtype)), ) # (M, 3) # ── Nearest source per missing cell — chunked, O(chunk × N) memory ─ hi_miss = self._chunked_nearest(xyz_miss, self.xyz_samples) # (M,) del xyz_miss # ── Extend cell_ids and hi_k, then re-sort to match out_cell_ids ── self.cell_ids = torch.cat([self.cell_ids, missing]) self.hi_k = torch.cat([self.hi_k, hi_miss]) self.K = int(self.cell_ids.numel()) order = torch.argsort(self.cell_ids) self.cell_ids = self.cell_ids[order] self.hi_k = self.hi_k[order] @staticmethod def _chunked_nearest( xyz_query: torch.Tensor, # (Q, 3) xyz_src: torch.Tensor, # (N, 3) mem_budget_bytes: int = 512 * 1024 * 1024, # 512 MB ) -> torch.Tensor: """argmax dot(query, src) per query row, memory-bounded by budget. Returns hi (Q,) — index of nearest source for each query cell. Memory peak = chunk_q × N × bytes_per_element. """ N = xyz_src.shape[0] Q = xyz_query.shape[0] bpe = xyz_src.element_size() chunk_q = max(1, mem_budget_bytes // (N * bpe)) hi_out = torch.empty(Q, dtype=torch.long, device=xyz_src.device) for start in range(0, Q, chunk_q): end = min(start + chunk_q, Q) dots = xyz_query[start:end] @ xyz_src.T # (chunk_q, N) hi_out[start:end] = dots.argmax(dim=1) del dots return hi_out # ── resample ───────────────────────────────────────────────────────────── @torch.no_grad() def resample(self, val: T_Array, **_kwargs) -> ResampleResults: """Source samples → HEALPix cells. hval[:, k] = val[:, hi_k[k]] — direct index, zero allocation overhead. """ y = val if isinstance(val, torch.Tensor) else torch.as_tensor(val) y = y.to(self.device, dtype=self.dtype) squeezed = y.ndim == 1 if squeezed: y = y.unsqueeze(0) # (1, N) # Clamp -1 sentinels (defensive; should not occur after _fill_missing) safe_hi = self.hi_k.clamp(min=0) hval = y[:, safe_hi] # (B, K) cell_ids = self.cell_ids if squeezed: hval = hval.squeeze(0) if not isinstance(val, torch.Tensor): hval = hval.cpu().numpy() cell_ids = cell_ids.cpu().numpy() return ResampleResults(cell_data=hval, cell_ids=cell_ids) # ── invert ─────────────────────────────────────────────────────────────── @torch.no_grad() def invert(self, hval: T_Array) -> T_Array: """HEALPix cells → source samples. val[:, n] = mean of hval[:, k] for all k s.t. hi_k[k] == n. (scatter-mean; multiple cells can share the same nearest source) """ y = hval if isinstance(hval, torch.Tensor) else torch.as_tensor(hval) y = y.to(self.device, dtype=self.dtype) squeezed = y.ndim == 1 if squeezed: y = y.unsqueeze(0) # (1, K) safe_hi = self.hi_k.clamp(min=0) res = _scatter_mean(y, safe_hi, self.N) # (B, N) if squeezed: res = res.squeeze(0) if not isinstance(hval, torch.Tensor): res = res.cpu().numpy() return res