Source code for healpix_resample.bilinear
"""
bilinear.py
GPU-friendly sparse HEALPix regridding from unstructured lon/lat samples
to a subset of HEALPix pixels at a target resolution (nside = 2**level).
Core ideas:
- Use npt=4.
This module is designed for large N and batched values (B,N) on CUDA.
"""
from healpix_resample.knn import KNeighborsResampler
import math
import numpy as np
import torch
[docs]
class BilinearResampler(KNeighborsResampler):
[docs]
def __init__(self, *args, **kwargs):
super().__init__(Npt=4, *args, **kwargs)
def comp_matrix(self):
# --- weights per sample->cell link
# w = exp(-2*d^2/sigma^2)
w = 1/( 1e-6 + self.d_m/self.sigma_m)
# Build (N,K) operator M and (K,N) operator MT.
# We avoid numpy bincount; use torch.bincount on GPU.
# idx: (N,Npt) row indices 0..N-1
idx = torch.arange(self.N, device=self.device, dtype=torch.long)[:, None].expand(self.N, self.Npt)
# -------- M : (N,K) (normalized per column / per healpix cell)
# norm_col[k] = sum_{i links to k} w[i,k]
flat_hi = self.hi.reshape(-1)
flat_w = w.reshape(-1)
valid = flat_hi >= 0
flat_hi_v = flat_hi[valid]
flat_w_v = flat_w[valid]
norm_col = torch.bincount(flat_hi_v, weights=flat_w_v, minlength=self.K).to(self.dtype)
# weight divided by column sum
wM = flat_w_v / norm_col[flat_hi_v]
rowsM = idx.reshape(-1)[valid]
colsM = flat_hi_v
indicesM = torch.stack([rowsM, colsM], dim=0)
M_coo = torch.sparse_coo_tensor(
indicesM,
wM.to(self.dtype),
size=(self.N, self.K),
device=self.device,
dtype=self.dtype,
).coalesce()
# -------- MT : (K,N) (normalized per row / per input sample)
# norm_row[i] = sum_{k links from i} w[i,k]
flat_idx = idx.reshape(-1)
flat_idx_v = flat_idx[valid]
norm_row = torch.bincount(flat_idx_v, weights=flat_w_v, minlength=self.N).to(self.dtype)
wMT = flat_w_v / norm_row[flat_idx_v]
indicesMT = torch.stack([colsM, rowsM], dim=0) # (hi, idx)
MT_coo = torch.sparse_coo_tensor(
indicesMT,
wMT.to(self.dtype),
size=(self.K, self.N),
device=self.device,
dtype=self.dtype,
).coalesce()
# Convert to CSR for faster spMM (recommended on GPU)
self.M = M_coo.to_sparse_csr()
self.MT = MT_coo.to_sparse_csr()