Source code for xarray_units.accessor

__all__ = ["DataArray", "Units", "units"]


# standard library
from collections.abc import Hashable, Sequence
from dataclasses import dataclass, replace
from functools import wraps
from typing import TYPE_CHECKING, Callable, Generic, Union


# dependencies
from typing_extensions import Concatenate, Self
from xarray import DataArray, register_dataarray_accessor  # type: ignore
from . import operator, quantity
from .utils import UNITS, P, TDataArray


# type hints
if TYPE_CHECKING:

    class DataArray(DataArray):
        units: "Units[Self]"


def to_method(
    func: Callable[Concatenate[TDataArray, P], TDataArray]
) -> Callable[Concatenate["Units[TDataArray]", P], TDataArray]:
    """Convert an operator or quantity function to a method."""

    @wraps(func)
    def wrapper(
        units: "Units[TDataArray]",
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> TDataArray:
        if (of := units.of) is None:
            da = func(units.accessed, *args, **kwargs)
        else:
            da = units.accessed
            coords: dict[Hashable, DataArray]

            if isinstance(of, Hashable):
                coords = {of: da.coords[of]}  # type: ignore
            else:
                coords = {name: da.coords[name] for name in of}  # type: ignore

            for name, coord in coords.items():
                coord = func(coord, *args, **kwargs)  # type: ignore
                da = da.assign_coords({name: coord})  # type: ignore

        if (chain := units.chain) > 1:
            return Units(da, chain=chain - 1, of=of)  # type: ignore

        return da

    return wrapper


[docs] def units( accessed: TDataArray, /, *, chain: int = 1, of: Union[Sequence[Hashable], Hashable] = None, ) -> "Units[TDataArray]": """Return a units accessor of a DataArray. Args: accessed: DataArray to be accessed. chain: Length of method chain. If it is greater than 1, each accessor method (or operation) will return not the resulting DataArray but an accessor of it with the length of ``chain - 1``. Note that, while this simplifies the method chain notation, static type checking may not work correctly in the middle of it. of: Coordinate(s) of the DataArray to be accessed. If specified, each accessor method will be applied not to the DataArray itself but to the selected coordinates, and return the DataArray with the updated coordinates. Returns: Units accessor of the DataArray. Raises: TypeError: Raised if options of the accessor are not valid. """ return Units(accessed, chain=chain, of=of)
[docs] @register_dataarray_accessor(UNITS) @dataclass class Units(Generic[TDataArray]): """DataArray accessor for handling units. Args: accessed: DataArray to be accessed. chain: Length of method chain. If it is greater than 1, each accessor method (or operation) will return not the resulting DataArray but an accessor of it with the length of ``chain - 1``. Note that, while this simplifies the method chain notation, static type checking may not work correctly in the middle of it. of: Coordinate(s) of the DataArray to be accessed. If specified, each accessor method will be applied not to the DataArray itself but to the selected coordinates, and return the DataArray with the updated coordinates. Raises: TypeError: Raised if options of the accessor are not valid. """ accessed: TDataArray """DataArray to be accessed.""" chain: int = 1 """Length of method chain.""" of: Union[Sequence[Hashable], Hashable] = None """Coordinate(s) of the DataArray to be accessed.""" # quantity apply = to_method(quantity.apply) decompose = to_method(quantity.decompose) format = to_method(quantity.format) like = to_method(quantity.like) set = to_method(quantity.set) to = to_method(quantity.to) unset = to_method(quantity.unset) # operator take = to_method(operator.take) mul = to_method(operator.mul) pow = to_method(operator.pow) matmul = to_method(operator.matmul) truediv = to_method(operator.truediv) add = to_method(operator.add) sub = to_method(operator.sub) floordiv = to_method(operator.floordiv) mod = to_method(operator.mod) lt = to_method(operator.lt) le = to_method(operator.le) eq = to_method(operator.eq) ne = to_method(operator.ne) ge = to_method(operator.ge) gt = to_method(operator.gt) # aliases __mul__ = mul __pow__ = pow __matmul__ = matmul __truediv__ = truediv __add__ = add __sub__ = sub __floordiv__ = floordiv __mod__ = mod __lt__ = lt __le__ = le __eq__ = eq # type: ignore __ne__ = ne # type: ignore __ge__ = ge __gt__ = gt def __call__( self, /, *, chain: int = 1, of: Union[Sequence[Hashable], Hashable] = None, ) -> Self: """Update the options of the accessor. Args: accessed: DataArray to be accessed. chain: Length of method chain. If it is greater than 1, each accessor method (or operation) will return not the resulting DataArray but an accessor of it with the length of ``chain - 1``. Note that, while this simplifies the method chain notation, static type checking may not work correctly in the middle of it. of: Coordinate(s) of the DataArray to be accessed. If specified, each accessor method will be applied not to the DataArray itself but to the selected coordinates, and return the DataArray with the updated coordinates. Returns: Unit accessor with the updated options. Raises: TypeError: Raised if options of the accessor are not valid. """ return replace(self, chain=chain, of=of) def __post_init__(self) -> None: """Validate the options of the accessor.""" if not (isinstance(self.chain, int) and self.chain > 0): raise TypeError("Chain must be a positive integer.") if not isinstance(self.of, (Hashable, Sequence)): raise TypeError("Of must be a (sequence of) hashable.")