"""
groupby.py
GPU-friendly sparse HEALPix regridding from unstructured lon/lat samples
to a subset of HEALPix pixels at a target resolution (nside = 2**level).
Core ideas:
- Compute group by into cell_ids
- can apply different methods to merge information within one cell_ids 'mean','nearest','median',etc.
This module is designed for large N and batched values (B,N) on CUDA.
"""
from typing import Generic
import math
import numpy as np
import torch
from healpix_resample.base import T_Array, ResampleResults
from healpix_resample.knn import KNeighborsResampler
[docs]
class GroupByResampler(KNeighborsResampler, Generic[T_Array]):
[docs]
def __init__(self,
reduce="mean", # Could be "sum", "prod", "mean", "amax", "amin" as for torch.scatter_reduce
*args, **kwargs):
assert reduce in ["sum", "prod", "mean", "amax", "amin"], \
'reduce should be in "sum", "prod", "mean", "amax", "amin"'
super().__init__(group_by=True,Npt=1, *args, **kwargs)
self.reduce = reduce
@torch.no_grad()
def invert(self, hval: T_Array) -> T_Array:
"""Project HEALPix field back to the sample locations.
Args:
hval: (B,K) or (K,) on same device
Returns:
val_hat: (B,N) or (N,)
"""
y: torch.Tensor = hval if isinstance(hval, torch.Tensor) else torch.as_tensor(hval)
y = y.to(self.device, dtype=self.dtype)
if y.ndim == 1:
res = y.index_select(0, self.hi) # (N,)
else:
res = y.index_select(1, self.hi) # (B,N)
if not isinstance(hval, torch.Tensor):
res = res.cpu().numpy()
return res
@torch.no_grad()
def resample(self, val: T_Array) -> ResampleResults[T_Array]:
"""Estimate the HEALPix field from unstructured samples.
Args:
val: (B,N) or (N,) values at lon/lat sample points
lam: Tikhonov regularization strength (damping) used in CG
max_iter, tol: CG parameters
x0: optional initial guess for the *delta* around x_ref, shape (B,K)
return_info: whether to return CG diagnostics
Returns:
hval: (B,K) or (K,)
(optional) info: CG information dict
"""
y = val if isinstance(val, torch.Tensor) else torch.as_tensor(val)
y = y.to(self.device, dtype=self.dtype)
clean_shape=False
if y.ndim == 1:
clean_shape=True
y = y[None, :]
# reference field (B,K)
B=y.shape[0]
hval = torch.zeros(B, self.K, device=y.device,dtype= self.dtype)
hval.scatter_reduce_(
1,
self.hi.unsqueeze(0).expand(B, -1),
y,
reduce=self.reduce,
include_self=False
)
cell_ids = self.cell_ids
if not isinstance(val, torch.Tensor):
hval= hval.cpu().numpy()
cell_ids = cell_ids.cpu().numpy()
if clean_shape:
hval = hval[0]
return ResampleResults(cell_data=hval, cell_ids=cell_ids)
[docs]
class CellPointResampler(GroupByResampler):
"""Resample input lat-lon points as HEALPix "cell-points".
The values of the input lat-lon points are resampled on HEALPix cells
with a fixed, maximum refinement level (=29) in which the points are located.
Level-29 cells have a tiny resolution of 0.4 milliarcseconds. In many
applications those cells can be approximated as "points".
If multiple input points are located within the same cell, their values
are merged according to the `reduce` option.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(level=29, *args, **kwargs)