Source code for healpix_plot.sampling_grid

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, TypedDict

import numpy as np
from affine import Affine

if TYPE_CHECKING:
    from typing import Self

    from healpix_plot.healpix import HealpixGrid


class SamplingGridParameters(TypedDict):
    """Sampling parameters as a dict

    Parameters
    ----------
    shape : int or tuple of int, default: 1024
        The shape of the array. If a int, the shape is a square of equal size.
    resolution : float or tuple of float, optional
        The resolution or step size of the sampling grid. If a float, expands to
        a 2-tuple of equal values. If missing, derived from the spatial extent
        of the data and the given shape.
    center : tuple of float, optional
        The center of the sampling grid. If missing, this is inferred from the data.
    """

    shape: int | tuple[int, int] = 1024
    resolution: float | tuple[float, float] | None = None
    center: tuple[float, float] | None = None


class SamplingGrid:
    def resolve(
        self, cell_ids: np.ndarray, parameters: HealpixGrid
    ) -> ConcreteSamplingGrid:  # pragma: no cover
        raise NotImplementedError


def crosses_prime_meridian(cell_ids, params):
    # very coarse detection of prime meridian crossing
    base_cells = cell_ids // 4**params.level
    crossing_base_cells = np.array([0, 3, 4, 8, 11], dtype="uint64")
    return np.any(np.isin(crossing_base_cells, base_cells))


def _infer_parameters(
    grid: ParametrizedSamplingGrid, cell_ids: np.ndarray, params: HealpixGrid
) -> (tuple[int, int], tuple[float, float], tuple[float, float]):
    # TODO: figure out how to deal with the difference between source and target grid
    center = grid.center
    shape = grid.shape
    resolution = grid.resolution
    if resolution is None or center is None:
        lon, lat = params.operations.healpix_to_lonlat(
            cell_ids, **params.as_keyword_params()
        )
        if crosses_prime_meridian(cell_ids, params):
            lon = (lon + 180) % 360 - 180

        if center is None:
            center = (np.mean(lon).item(), np.mean(lat).item())

        if resolution is None:
            size_x, size_y = shape
            min_x, max_x = np.min(lon).item(), np.max(lon).item()
            min_y, max_y = np.min(lat).item(), np.max(lat).item()

            dx = (max_x - min_x) / (size_x - 1)
            dy = (max_y - min_y) / (size_y - 1)

            resolution = (dx, dy)

    return shape, resolution, center


[docs] @dataclass class ParametrizedSamplingGrid: shape: tuple[int, int] resolution: tuple[float, float] | None center: tuple[float, float] | None @classmethod def from_parameters( cls, shape: int | tuple[int, int], resolution: float | tuple[float, float] | None = None, center: tuple[float, float] | None = None, ) -> Self: if isinstance(shape, int): shape = (shape, shape) if isinstance(resolution, float): resolution = (resolution, resolution) return cls(shape=shape, resolution=resolution, center=center) @classmethod def from_dict(cls, mapping: SamplingGridParameters) -> Self: return cls.from_parameters(**mapping) @classmethod def from_bbox( cls, bbox: tuple[float, float, float, float], shape: int | tuple[int, int], ) -> Self: if isinstance(shape, int): shape = (shape, shape) xmin, ymin, xmax, ymax = bbox center = (float(np.mean([xmin, xmax])), float(np.mean([ymin, ymax]))) resolution = ( (xmax - xmin) / (shape[0] - 1), (ymax - ymin) / (shape[1] - 1), ) return cls(shape=shape, center=center, resolution=resolution) def resolve( self, cell_ids: np.ndarray, parameters: HealpixGrid ) -> ConcreteSamplingGrid: shape, resolution, center = _infer_parameters(self, cell_ids, parameters) size_x, size_y = shape resolution_x, resolution_y = resolution half_x = size_x // 2 half_y = size_y // 2 center_x, center_y = center xmin = center_x - half_x * resolution_x xmax = center_x + half_x * resolution_x ymin = np.clip(center_y - half_y * resolution_y, -90, 90).item() ymax = np.clip(center_y + half_y * resolution_y, -90, 90).item() if xmin > xmax: # prime meridian crossing xmin = (xmin + 180) % 360 - 180 xmax = (xmax + 180) % 360 - 180 xs = np.linspace(xmin, xmax, size_x, endpoint=True) ys = np.linspace(ymin, ymax, size_y, endpoint=True) x, y = np.meshgrid(xs, ys) extent_x = (xmin, xmax) extent_y = (ymin, ymax) return ConcreteSamplingGrid(x, y, extent_x, extent_y)
[docs] @dataclass class AffineSamplingGrid(SamplingGrid): transform: Affine shape: tuple[int, int] @classmethod def from_transform( cls, transform: Affine, shape: int | tuple[int, int], ) -> Self: if isinstance(shape, int): shape = (shape, shape) return cls(transform, shape) @property def center_transform(self): return self.transform @property def corner_transform(self): return self.transform * Affine.translation(-0.5, -0.5) def resolve( self, cell_ids: np.ndarray, parameters: HealpixGrid ) -> ConcreteSamplingGrid: pixel_x = np.arange(self.shape[0]) pixel_y = np.arange(self.shape[1]) py, px = np.meshgrid(pixel_y, pixel_x) x, y = self.center_transform * (px, py) _, scale_x, _, _, _, scale_y = self.center_transform.to_gdal() (xmin, xmax), (ymin, ymax) = self.center_transform * ( np.array([0, self.shape[0]]), np.array([0, self.shape[1]]), ) if xmin <= 0 and xmax >= 0: x = (x + 180) % 360 - 180 extent_x = (float(xmin), float(xmax)) extent_y = (float(ymin), float(ymax)) return ConcreteSamplingGrid(x, y, extent_x, extent_y)
[docs] @dataclass class ConcreteSamplingGrid: x: np.ndarray y: np.ndarray extent_x: tuple[float, float] extent_y: tuple[float, float] @property def shape(self): return self.x.shape @property def extent(self): extent_x = tuple((x + 180) % 360 - 180 for x in self.extent_x) return extent_x + self.extent_y