diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a6476c49..b14d38d96 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,6 +11,8 @@ jobs: # github.repository == 'UXARRAY/uxarray' name: Python (${{ matrix.python-version }}, ${{ matrix.os }}) runs-on: ${{ matrix.os }} + env: + MPLBACKEND: Agg defaults: run: shell: bash -l {0} diff --git a/.github/workflows/yac-optional.yml b/.github/workflows/yac-optional.yml new file mode 100644 index 000000000..ab0f7003b --- /dev/null +++ b/.github/workflows/yac-optional.yml @@ -0,0 +1,142 @@ +name: YAC Optional CI + +on: + pull_request: + paths: + - ".github/workflows/yac-optional.yml" + - "uxarray/remap/**" + - "test/test_remap_yac.py" + workflow_dispatch: + +jobs: + yac-optional: + name: YAC core v3.14.0_p1 (Ubuntu) + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + env: + YAC_VERSION: v3.14.0_p1 + YAXT_VERSION: v0.11.5.1 + MPIEXEC: /usr/bin/mpirun + MPIRUN: /usr/bin/mpirun + MPICC: /usr/bin/mpicc + MPIFC: /usr/bin/mpif90 + MPIF90: /usr/bin/mpif90 + OMPI_ALLOW_RUN_AS_ROOT: 1 + OMPI_ALLOW_RUN_AS_ROOT_CONFIRM: 1 + steps: + - name: checkout + uses: actions/checkout@v4 + with: + token: ${{ github.token }} + + - name: conda_setup + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: uxarray_build + channel-priority: strict + python-version: "3.11" + channels: conda-forge + environment-file: ci/environment.yml + miniforge-variant: Miniforge3 + miniforge-version: latest + + - name: Install build dependencies (apt) + run: | + sudo apt-get update + sudo apt-get install -y \ + autoconf \ + automake \ + gawk \ + gfortran \ + libopenmpi-dev \ + libtool \ + make \ + openmpi-bin \ + pkg-config + - name: Verify MPI tools + run: | + which mpirun + which mpicc + which mpif90 + mpirun --version + mpicc --version + mpif90 --version + - name: Install Python build dependencies + run: | + python -m pip install --upgrade pip + python -m pip install cython wheel + - name: Build and install YAXT + run: | + set -euxo pipefail + YAC_PREFIX="${GITHUB_WORKSPACE}/yac_prefix" + echo "YAC_PREFIX=${YAC_PREFIX}" >> "${GITHUB_ENV}" + git clone --depth 1 --branch "${YAXT_VERSION}" https://gitlab.dkrz.de/dkrz-sw/yaxt.git + if [ ! -x yaxt/configure ]; then + if [ -x yaxt/autogen.sh ]; then + (cd yaxt && ./autogen.sh) + else + (cd yaxt && autoreconf -i) + fi + fi + mkdir -p yaxt/build + cd yaxt/build + ../configure \ + --prefix="${YAC_PREFIX}" \ + --without-regard-for-quality \ + CC="${MPICC}" \ + FC="${MPIF90}" + make -j2 + make install + - name: Build and install YAC + run: | + set -euxo pipefail + git clone --depth 1 --branch "${YAC_VERSION}" https://gitlab.dkrz.de/dkrz-sw/yac.git + if [ ! -x yac/configure ]; then + if [ -x yac/autogen.sh ]; then + (cd yac && ./autogen.sh) + else + (cd yac && autoreconf -i) + fi + fi + mkdir -p yac/build + cd yac/build + ../configure \ + --prefix="${YAC_PREFIX}" \ + --with-yaxt-root="${YAC_PREFIX}" \ + --disable-mci \ + --disable-utils \ + --disable-examples \ + --disable-tools \ + --disable-netcdf \ + --enable-python-bindings \ + CC="${MPICC}" \ + FC="${MPIF90}" + make -j2 + make install + - name: Configure YAC runtime paths + run: | + set -euxo pipefail + PY_VER="$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')" + echo "LD_LIBRARY_PATH=${YAC_PREFIX}/lib:${LD_LIBRARY_PATH:-}" >> "${GITHUB_ENV}" + echo "PYTHONPATH=${YAC_PREFIX}/lib/python${PY_VER}/site-packages:${YAC_PREFIX}/lib/python${PY_VER}/dist-packages:${PYTHONPATH:-}" >> "${GITHUB_ENV}" + - name: Verify YAC core Python bindings + run: | + python - <<'PY' + from pathlib import Path + import sys + candidates = [] + for entry in sys.path: + pkg = Path(entry) / "yac" + candidates.extend(pkg.glob("core*.so")) + candidates.extend(pkg.glob("core*.pyd")) + assert candidates, "yac.core extension not found on sys.path" + print("Found yac.core extension:", candidates[0]) + PY + - name: Install uxarray + run: | + python -m pip install . --no-deps + - name: Run tests (uxarray with YAC) + run: | + python -m pytest test/test_remap_yac.py diff --git a/test/test_remap_yac.py b/test/test_remap_yac.py new file mode 100644 index 000000000..59fcce41f --- /dev/null +++ b/test/test_remap_yac.py @@ -0,0 +1,221 @@ +import numpy as np +import pytest + +import uxarray as ux +from uxarray.remap.yac import YacNotAvailableError, _import_yac + + +try: + _import_yac() +except YacNotAvailableError: + pytest.skip("yac.core is not available", allow_module_level=True) + + +def test_yac_nnn_node_remap(gridpath, datasetpath): + grid_path = gridpath("ugrid", "geoflow-small", "grid.nc") + uxds = ux.open_dataset(grid_path, datasetpath("ugrid", "geoflow-small", "v1.nc")) + dest = ux.open_grid(grid_path) + + out = uxds["v1"].remap.nearest_neighbor( + destination_grid=dest, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + assert out.size > 0 + assert "n_node" in out.dims + + +def test_yac_conservative_face_remap(gridpath): + mesh_path = gridpath("mpas", "QU", "mesh.QU.1920km.151026.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + dest = ux.open_grid(mesh_path) + + out = uxds["latCell"].remap( + destination_grid=dest, + remap_to="faces", + backend="yac", + yac_method="conservative", + yac_options={"order": 1}, + ) + assert out.size == dest.n_face + + +def test_yac_matches_uxarray_nearest_neighbor(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + ux_out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="uxarray", + ) + yac_out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + assert ux_out.shape == yac_out.shape + assert (ux_out.values == yac_out.values).all() + + +def test_yac_call_defaults_to_nnn(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + out = da.remap( + destination_grid=grid, + remap_to="nodes", + backend="yac", + ) + + assert out.shape == da.shape + np.testing.assert_array_equal(out.values, da.values) + + +def test_yac_invalid_backend_raises(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + with pytest.raises(ValueError, match="Invalid backend"): + da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="bogus", + ) + + +def test_yac_idw_not_implemented(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + with pytest.raises(NotImplementedError, match="inverse_distance_weighted"): + da.remap.inverse_distance_weighted( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + + +def test_yac_bilinear_not_implemented(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + with pytest.raises(NotImplementedError, match="bilinear"): + da.remap.bilinear( + destination_grid=grid, + remap_to="nodes", + backend="yac", + ) + + +def test_yac_conservative_rejects_non_face_data(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={"n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + with pytest.raises(ValueError, match="face-centered"): + da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="conservative", + yac_options={"order": 1}, + ) + + +def test_yac_preserves_spatial_coordinate_remap(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([1.0, 2.0, 3.0]), + dims=["n_node"], + coords={ + "n_node": [0, 1, 2], + "node_lon": ( + "n_node", + np.array([0.0, -180.0, 0.0]), + {"standard_name": "longitude", "units": "degrees_east"}, + ), + "node_lat": ( + "n_node", + np.array([90.0, 0.0, -90.0]), + {"standard_name": "latitude", "units": "degrees_north"}, + ), + }, + uxgrid=grid, + ) + + out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + + np.testing.assert_array_equal(out.values, da.values) + assert "node_lon" in out.coords + assert "node_lat" in out.coords + + +def test_yac_batched_remap_with_extra_dimension(): + verts = np.array([(0.0, 90.0), (-180.0, 0.0), (0.0, -90.0)]) + grid = ux.open_grid(verts) + da = ux.UxDataArray( + np.asarray([[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]), + dims=["time", "n_node"], + coords={"time": [0, 1], "n_node": [0, 1, 2]}, + uxgrid=grid, + ) + + out = da.remap.nearest_neighbor( + destination_grid=grid, + remap_to="nodes", + backend="yac", + yac_method="nnn", + yac_options={"n": 1}, + ) + + assert out.shape == da.shape + np.testing.assert_array_equal(out.values, da.values) diff --git a/uxarray/remap/accessor.py b/uxarray/remap/accessor.py index ebf74ffa4..28d6f1e84 100644 --- a/uxarray/remap/accessor.py +++ b/uxarray/remap/accessor.py @@ -11,6 +11,15 @@ from uxarray.remap.inverse_distance_weighted import _inverse_distance_weighted_remap from uxarray.remap.nearest_neighbor import _nearest_neighbor_remap +_VALID_BACKENDS = ("uxarray", "yac") + + +def _validate_backend(backend: str) -> None: + if backend not in _VALID_BACKENDS: + raise ValueError( + f"Invalid backend '{backend}'. Expected one of {_VALID_BACKENDS}." + ) + class RemapAccessor: """Expose remapping methods on UxDataArray and UxDataset objects.""" @@ -27,17 +36,33 @@ def __repr__(self) -> str: + " • inverse_distance_weighted(destination_grid, remap_to='faces', power=2, k=8)\n" ) - def __call__(self, *args, **kwargs) -> UxDataArray | UxDataset: + def __call__( + self, + *args, + backend: str = "uxarray", + yac_method: str | None = None, + yac_options: dict | None = None, + **kwargs, + ) -> UxDataArray | UxDataset: """ Shortcut for nearest-neighbor remapping. Calling `.remap(...)` with no explicit method will invoke `nearest_neighbor(...)`. """ - return self.nearest_neighbor(*args, **kwargs) + nn_kwargs: dict = {"backend": backend, "yac_options": yac_options} + if yac_method is not None: + nn_kwargs["yac_method"] = yac_method + return self.nearest_neighbor(*args, **nn_kwargs, **kwargs) def nearest_neighbor( - self, destination_grid: Grid, remap_to: str = "faces", **kwargs + self, + destination_grid: Grid, + remap_to: str = "faces", + backend: str = "uxarray", + yac_method: str | None = "nnn", + yac_options: dict | None = None, + **kwargs, ) -> UxDataArray | UxDataset: """ Perform nearest-neighbor remapping. @@ -51,16 +76,40 @@ def nearest_neighbor( remap_to : {'nodes', 'edges', 'faces'}, default='faces' Which grid element receives the remapped values. + backend : {'uxarray', 'yac'}, default='uxarray' + Remapping backend to use. When set to 'yac', requires YAC to be + available on PYTHONPATH. + yac_method : {'nnn', 'conservative'}, optional + YAC interpolation method. Defaults to 'nnn' when backend='yac'. + yac_options : dict, optional + YAC interpolation configuration options. + Returns ------- UxDataArray or UxDataset A new object with data mapped onto `destination_grid`. """ + _validate_backend(backend) + if backend == "yac": + from uxarray.remap.yac import _yac_remap + + yac_kwargs = yac_options or {} + return _yac_remap( + self.ux_obj, destination_grid, remap_to, yac_method, yac_kwargs + ) return _nearest_neighbor_remap(self.ux_obj, destination_grid, remap_to) def inverse_distance_weighted( - self, destination_grid: Grid, remap_to: str = "faces", power=2, k=8, **kwargs + self, + destination_grid: Grid, + remap_to: str = "faces", + power=2, + k=8, + backend: str = "uxarray", + yac_method: str | None = None, + yac_options: dict | None = None, + **kwargs, ) -> UxDataArray | UxDataset: """ Perform inverse-distance-weighted (IDW) remapping. @@ -80,18 +129,39 @@ def inverse_distance_weighted( k : int, default=8 Number of nearest source points to include in the weighted average. + backend : {'uxarray', 'yac'}, default='uxarray' + Remapping backend to use. When set to 'yac', requires YAC to be + available on PYTHONPATH. + yac_method : {'nnn', 'conservative'}, optional + YAC interpolation method. Required when backend='yac'. + yac_options : dict, optional + YAC interpolation configuration options. + Returns ------- UxDataArray or UxDataset A new object with data mapped onto `destination_grid`. """ + _validate_backend(backend) + if backend == "yac": + raise NotImplementedError( + "inverse_distance_weighted with backend='yac' is not implemented. " + "The YAC backend currently supports only 'nnn' and 'conservative' " + "methods and will not perform inverse-distance-weighted remapping. " + "Use backend='uxarray' for IDW, or choose a different remapping " + "method that is supported by YAC." + ) return _inverse_distance_weighted_remap( self.ux_obj, destination_grid, remap_to, power, k ) def bilinear( - self, destination_grid: Grid, remap_to: str = "faces", **kwargs + self, + destination_grid: Grid, + remap_to: str = "faces", + backend: str = "uxarray", + **kwargs, ) -> UxDataArray | UxDataset: """ Perform bilinear remapping. @@ -103,10 +173,23 @@ def bilinear( remap_to : {'nodes', 'edges', 'faces'}, default='faces' Which grid element receives the remapped values. + backend : {'uxarray'}, default='uxarray' + Remapping backend to use. The YAC backend does not support bilinear + remapping; use ``backend='uxarray'`` (the default). + Returns ------- UxDataArray or UxDataset A new object with data mapped onto `destination_grid`. """ + _validate_backend(backend) + if backend == "yac": + raise NotImplementedError( + "bilinear with backend='yac' is not implemented. " + "The YAC backend currently supports only 'nnn' and 'conservative' " + "methods and will not perform bilinear remapping. " + "Use backend='uxarray' for bilinear, or choose a different remapping " + "method that is supported by YAC." + ) return _bilinear(self.ux_obj, destination_grid, remap_to) diff --git a/uxarray/remap/yac.py b/uxarray/remap/yac.py new file mode 100644 index 000000000..1a6833bb7 --- /dev/null +++ b/uxarray/remap/yac.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import importlib +import importlib.util +import sys +from dataclasses import dataclass +from pathlib import Path +from types import ModuleType +from typing import Any +from uuid import uuid4 + +import numpy as np + +import uxarray.core.dataarray +from uxarray.remap.utils import ( + LABEL_TO_COORD, + _assert_dimension, + _construct_remapped_ds, + _get_remap_dims, + _to_dataset, +) + + +class YacNotAvailableError(RuntimeError): + """Raised when the YAC backend is requested but unavailable.""" + + +@dataclass +class _YacOptions: + method: str + kwargs: dict[str, Any] + + +def _load_yac_core_from_file() -> ModuleType | None: + if "yac.core" in sys.modules: + return sys.modules["yac.core"] + + for path_entry in sys.path: + pkg_dir = Path(path_entry) / "yac" + if not pkg_dir.is_dir(): + continue + + matches = sorted(pkg_dir.glob("core*.so")) + if not matches: + matches = sorted(pkg_dir.glob("core*.pyd")) + if not matches: + continue + + pkg = sys.modules.get("yac") + if pkg is None: + pkg = ModuleType("yac") + sys.modules["yac"] = pkg + pkg.__path__ = [str(pkg_dir)] + + spec = importlib.util.spec_from_file_location("yac.core", matches[0]) + if spec is None or spec.loader is None: + continue + + module = importlib.util.module_from_spec(spec) + sys.modules["yac.core"] = module + spec.loader.exec_module(module) + setattr(pkg, "core", module) + return module + + return None + + +def _import_yac(): + module = _load_yac_core_from_file() + if module is not None: + return module + + try: + return importlib.import_module("yac.core") + except Exception as exc: # pragma: no cover - fallback depends on local install + raise YacNotAvailableError( + "YAC backend requested but 'yac.core' is not available. " + "Build YAC with Python bindings and ensure its site-packages and " + "shared libraries are discoverable." + ) from exc + + +def _normalize_yac_method(yac_method: str | None) -> _YacOptions: + if not yac_method: + raise ValueError( + "backend='yac' requires yac_method to be set to 'nnn' or 'conservative'." + ) + method = yac_method.lower() + if method not in {"nnn", "conservative"}: + raise ValueError(f"Unsupported YAC method: {yac_method!r}") + return _YacOptions(method=method, kwargs={}) + + +def _get_location(yac_core, dim: str): + mapping = { + "n_face": yac_core.yac_location.YAC_LOC_CELL, + "n_node": yac_core.yac_location.YAC_LOC_CORNER, + "n_edge": yac_core.yac_location.YAC_LOC_EDGE, + } + try: + return mapping[dim] + except KeyError as exc: + raise ValueError(f"Unsupported remap dimension for YAC: {dim!r}") from exc + + +def _coerce_enum(enum_type, value: Any): + if not isinstance(value, str): + return value + + normalized = value.upper() + for member in enum_type: + if member.name == normalized or member.name.endswith(f"_{normalized}"): + return member + + raise ValueError(f"Unsupported value {value!r} for enum {enum_type.__name__}.") + + +class _YacRemapper: + def __init__( + self, + src_grid, + tgt_grid, + src_dim: str, + tgt_dim: str, + yac_method: str, + yac_kwargs: dict[str, Any], + ): + yac_core = _import_yac() + self._src_location = _get_location(yac_core, src_dim) + self._tgt_location = _get_location(yac_core, tgt_dim) + + define_edges = "n_edge" in (src_dim, tgt_dim) + unique = uuid4().hex + self._src_grid = yac_core.BasicGrid.from_uxgrid( + f"uxarray_src_{unique}", + src_grid, + def_edges=define_edges, + ) + self._tgt_grid = yac_core.BasicGrid.from_uxgrid( + f"uxarray_tgt_{unique}", + tgt_grid, + def_edges=define_edges, + ) + + self._src_field = yac_core.InterpField( + self._src_grid.add_coordinates(self._src_location) + ) + self._tgt_field = yac_core.InterpField( + self._tgt_grid.add_coordinates(self._tgt_location) + ) + + stack = yac_core.InterpolationStack() + if yac_method == "nnn": + weight_type = _coerce_enum( + yac_core.yac_interp_nnn_weight_type, + yac_kwargs.get("reduction_type", yac_kwargs.get("nnn_type")), + ) + if weight_type is None: + weight_type = yac_core.yac_interp_nnn_weight_type.YAC_INTERP_NNN_AVG + stack.add_nnn( + nnn_type=weight_type, + n=yac_kwargs.get("n", 1), + max_search_distance=yac_kwargs.get("max_search_distance", 0.0), + scale=yac_kwargs.get("scale", 1.0), + ) + elif yac_method == "conservative": + normalisation = _coerce_enum( + yac_core.yac_interp_method_conserv_normalisation, + yac_kwargs.get("normalisation"), + ) + if normalisation is None: + normalisation = yac_core.yac_interp_method_conserv_normalisation.YAC_INTERP_CONSERV_DESTAREA + stack.add_conservative( + order=yac_kwargs.get("order", 1), + enforced_conserv=yac_kwargs.get("enforced_conserv", False), + partial_coverage=yac_kwargs.get("partial_coverage", False), + normalisation=normalisation, + ) + + self._weights = yac_core.compute_weights( + stack, + self._src_field, + self._tgt_field, + ) + self._interpolations: dict[int, Any] = {} + self._src_size = self._src_grid.get_data_size(self._src_location) + self._tgt_size = self._tgt_grid.get_data_size(self._tgt_location) + + def apply(self, values: np.ndarray) -> np.ndarray: + """Apply the pre-computed interpolation weights to *values*. + + The interpolation method (NNN or conservative) is determined by + *yac_method* passed to the constructor and is fixed for the lifetime of + this remapper instance. This method simply executes the weight + application; it does not select or alter the interpolation algorithm. + + Parameters + ---------- + values : np.ndarray + 1-D or 2-D array of source-grid values. The trailing dimension must + equal the number of source points registered with YAC + (``self._src_size``). When 2-D, the leading dimension is treated as + the YAC collection size and is remapped in one batched call. + + Returns + ------- + np.ndarray + Array of remapped values on the destination grid with the same + number of leading collections as the input. + """ + values = np.ascontiguousarray(values, dtype=np.float64) + if values.ndim == 1: + values = values.reshape(1, -1) + elif values.ndim != 2: + raise ValueError( + f"YAC remap expects a 1-D or 2-D array, got {values.ndim}-D input." + ) + if values.shape[1] != self._src_size: + raise ValueError( + f"YAC remap expects {self._src_size} values, got {values.shape[1]}." + ) + + collection_size = values.shape[0] + interpolation = self._interpolations.get(collection_size) + if interpolation is None: + interpolation = self._weights.get_interpolation( + collection_size=collection_size + ) + self._interpolations[collection_size] = interpolation + + out = interpolation(values) + return np.asarray(out, dtype=np.float64) + + +def _yac_remap(source, destination_grid, remap_to: str, yac_method: str, yac_kwargs): + _assert_dimension(remap_to) + destination_dim = LABEL_TO_COORD[remap_to] + options = _normalize_yac_method(yac_method) + options.kwargs.update(yac_kwargs or {}) + ds, is_da, name = _to_dataset(source) + dims_to_remap = _get_remap_dims(ds) + + if options.method == "conservative": + if destination_dim != "n_face": + raise ValueError( + "YAC conservative remapping requires the destination to be " + "face-centered (remap_to='faces'). " + f"Got remap_to={remap_to!r} which maps to dimension {destination_dim!r}." + ) + non_face_src = dims_to_remap - {"n_face"} + if non_face_src: + raise ValueError( + "YAC conservative remapping requires all source data to be " + f"face-centered (dimension 'n_face'). " + f"Found non-face source dimension(s): {non_face_src}. " + "Use yac_method='nnn' for node- or edge-centered data." + ) + remappers: dict[str, _YacRemapper] = {} + remapped_vars = {} + + for src_dim in dims_to_remap: + remappers[src_dim] = _YacRemapper( + ds.uxgrid, + destination_grid, + src_dim, + destination_dim, + options.method, + options.kwargs, + ) + + for var_name, da in ds.data_vars.items(): + src_dim = next((d for d in da.dims if d in dims_to_remap), None) + if src_dim is None: + remapped_vars[var_name] = da + continue + + other_dims = [d for d in da.dims if d != src_dim] + da_t = da.transpose(*other_dims, src_dim) + src_values = np.asarray(da_t.values) + flat_src = src_values.reshape(-1, src_values.shape[-1]) + remapper = remappers[src_dim] + out_flat = remapper.apply(flat_src) + + out_shape = src_values.shape[:-1] + (remapper._tgt_size,) + out_values = out_flat.reshape(out_shape) + coords = {dim: da.coords[dim] for dim in other_dims if dim in da.coords} + da_out = uxarray.core.dataarray.UxDataArray( + out_values, + dims=other_dims + [destination_dim], + coords=coords, + name=da.name, + attrs=da.attrs, + uxgrid=destination_grid, + ) + remapped_vars[var_name] = da_out + + ds_remapped = _construct_remapped_ds( + source, remapped_vars, destination_grid, remap_to + ) + return ds_remapped[name] if is_da else ds_remapped