Source code for xarray_units.operator

__all__ = [
    "take",
    # any-units operators
    "mul",  # *
    "pow",  # **
    "matmul",  # @
    "truediv",  # /
    # same-units operators
    "add",  # +
    "sub",  # -
    "floordiv",  # //
    "mod",  # %
    "lt",  # <
    "le",  # <=
    "eq",  # ==
    "ne",  # !=
    "ge",  # >=
    "gt",  # >
]


# standard library
import operator as opr
from typing import Any, Literal, Union, get_args


# dependencies
from astropy.units import Quantity
from xarray import DataArray
from .quantity import apply_any, set, to, unset
from .utils import TESTER, TDataArray, UnitsConversionError, unitsof


# type hints
AnyUnitsOperator = Literal[
    "mul",  # *
    "pow",  # **
    "matmul",  # @
    "truediv",  # /
]
SameUnitsOperator = Literal[
    "add",  # +
    "sub",  # -
    "floordiv",  # //
    "mod",  # %
    "lt",  # <
    "le",  # <=
    "eq",  # ==
    "ne",  # !=
    "ge",  # >=
    "gt",  # >
]
Operator = Union[AnyUnitsOperator, SameUnitsOperator]


[docs] def take(left: TDataArray, operator: Operator, right: Any, /) -> TDataArray: """Perform an operation between left and right data considering units. Args: left: DataArray with units on the left side of the operator. operator: Name of the operator (e.g. ``"add"``, ``"gt"``). right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Units are the same as ``left`` in a numerical operation (e.g. ``"add"``) or nothing in a relational operation (e.g. ``"gt"``). Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ left_units = unitsof(left, strict=True) right_units = unitsof(right, strict=False) if operator == "pow": method = f"__{operator}__" args = (Quantity(right, right_units),) elif operator == "matmul": method = "__mul__" args = (Quantity(TESTER, right_units),) elif operator == "eq" or operator == "ne": method = "__lt__" args = (Quantity(TESTER, right_units),) else: method = f"__{operator}__" args = (Quantity(TESTER, right_units),) try: test = apply_any(TESTER, left_units, method, *args) except Exception as error: raise UnitsConversionError(error) if operator in get_args(SameUnitsOperator): if isinstance(right, Quantity): right = right.to(left_units).value # type: ignore if isinstance(right, DataArray): right = to(right, left_units) try: result = getattr(opr, operator)(left, right) except Exception as error: raise UnitsConversionError(error) if (units := unitsof(test)) is None: return unset(result) else: return set(result, units, overwrite=True)
[docs] def mul(left: TDataArray, right: Any) -> TDataArray: """Perform ``left * right`` considering units. When called from an accessor, it runs ``accessed * right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "mul", right)
[docs] def pow(left: TDataArray, right: Any) -> TDataArray: """Perform ``left ** right`` considering units. When called from an accessor, it runs ``accessed ** right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "pow", right)
[docs] def matmul(left: TDataArray, right: Any) -> TDataArray: """Perform ``left @ right`` considering units. When called from an accessor, it runs ``accessed @ right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "matmul", right)
[docs] def truediv(left: TDataArray, right: Any) -> TDataArray: """Perform ``left / right`` considering units. When called from an accessor, it runs ``accessed / right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "truediv", right)
[docs] def add(left: TDataArray, right: Any) -> TDataArray: """Perform ``left + right`` considering units. When called from an accessor, it runs ``accessed + right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "add", right)
[docs] def sub(left: TDataArray, right: Any) -> TDataArray: """Perform ``left - right`` considering units. When called from an accessor, it runs ``accessed - right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "sub", right)
[docs] def floordiv(left: TDataArray, right: Any) -> TDataArray: """Perform ``left // right`` considering units. When called from an accessor, it runs ``accessed // right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "floordiv", right)
[docs] def mod(left: TDataArray, right: Any) -> TDataArray: """Perform ``left % right`` considering units. When called from an accessor, it runs ``accessed % right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "mod", right)
[docs] def lt(left: TDataArray, right: Any) -> TDataArray: """Perform ``left < right`` considering units. When called from an accessor, it runs ``accessed < right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "lt", right)
[docs] def le(left: TDataArray, right: Any) -> TDataArray: """Perform ``left <= right`` considering units. When called from an accessor, it runs ``accessed <= right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "le", right)
[docs] def eq(left: TDataArray, right: Any) -> TDataArray: """Perform ``left == right`` considering units. When called from an accessor, it runs ``accessed == right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "eq", right)
[docs] def ne(left: TDataArray, right: Any) -> TDataArray: """Perform ``left != right`` considering units. When called from an accessor, it runs ``accessed != right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "ne", right)
[docs] def ge(left: TDataArray, right: Any) -> TDataArray: """Perform ``left >= right`` considering units. When called from an accessor, it runs ``accessed >= right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "ge", right)
[docs] def gt(left: TDataArray, right: Any) -> TDataArray: """Perform ``left > right`` considering units. When called from an accessor, it runs ``accessed > right``. Args: left: DataArray with units on the left side of the operator. right: Any data on the right side of the operator. Returns: DataArray of the result of the operation. Raises: UnitsConversionError: Raised if units cannot be converted. UnitsNotFoundError: Raised if units are not found. UnitsNotValidError: Raised if units are not valid. """ return take(left, "gt", right)