Source code for healpix_resample.bilinear

"""
bilinear.py

GPU-friendly sparse HEALPix regridding from unstructured lon/lat samples
to a subset of HEALPix pixels at a target resolution (nside = 2**level).

Core ideas:
- Use npt=4.

This module is designed for large N and batched values (B,N) on CUDA.
"""
from healpix_resample.knn import KNeighborsResampler
import math
import numpy as np
import torch


[docs] class BilinearResampler(KNeighborsResampler):
[docs] def __init__(self, *args, **kwargs): super().__init__(Npt=4, *args, **kwargs)
def comp_matrix(self): # --- weights per sample->cell link # w = exp(-2*d^2/sigma^2) w = 1/( 1e-6 + self.d_m/self.sigma_m) # Build (N,K) operator M and (K,N) operator MT. # We avoid numpy bincount; use torch.bincount on GPU. # idx: (N,Npt) row indices 0..N-1 idx = torch.arange(self.N, device=self.device, dtype=torch.long)[:, None].expand(self.N, self.Npt) # -------- M : (N,K) (normalized per column / per healpix cell) # norm_col[k] = sum_{i links to k} w[i,k] flat_hi = self.hi.reshape(-1) flat_w = w.reshape(-1) valid = flat_hi >= 0 flat_hi_v = flat_hi[valid] flat_w_v = flat_w[valid] norm_col = torch.bincount(flat_hi_v, weights=flat_w_v, minlength=self.K).to(self.dtype) # weight divided by column sum wM = flat_w_v / norm_col[flat_hi_v] rowsM = idx.reshape(-1)[valid] colsM = flat_hi_v indicesM = torch.stack([rowsM, colsM], dim=0) M_coo = torch.sparse_coo_tensor( indicesM, wM.to(self.dtype), size=(self.N, self.K), device=self.device, dtype=self.dtype, ).coalesce() # -------- MT : (K,N) (normalized per row / per input sample) # norm_row[i] = sum_{k links from i} w[i,k] flat_idx = idx.reshape(-1) flat_idx_v = flat_idx[valid] norm_row = torch.bincount(flat_idx_v, weights=flat_w_v, minlength=self.N).to(self.dtype) wMT = flat_w_v / norm_row[flat_idx_v] indicesMT = torch.stack([colsM, rowsM], dim=0) # (hi, idx) MT_coo = torch.sparse_coo_tensor( indicesMT, wMT.to(self.dtype), size=(self.K, self.N), device=self.device, dtype=self.dtype, ).coalesce() # Convert to CSR for faster spMM (recommended on GPU) self.M = M_coo.to_sparse_csr() self.MT = MT_coo.to_sparse_csr()