__all__ = [
    "apply",
    "decompose",
    "format",
    "like",
    "set",
    "to",
    "unset",
]
# standard library
from types import MethodType, MethodWrapperType
from typing import Any, Optional
# dependencies
from astropy.units import Equivalency, Quantity
from xarray import DataArray
from .utils import (
    TESTER,
    UNITS,
    TDataArray,
    UnitsConversionError,
    UnitsExistError,
    UnitsLike,
    unitsof,
)
[docs]
def apply(
    da: TDataArray,
    method: str,
    /,
    *args: Any,
    **kwargs: Any,
) -> TDataArray:
    """Apply a method of Astropy Quantity to a DataArray.
    When called from an accessor, it runs ``apply(accessed, method, ...)``.
    Args:
        da: Input DataArray with units.
        method: Method (or property) name of Astropy Quantity.
        *args: Positional arguments of the method.
        *kwargs: Keyword arguments of the method.
    Returns:
        DataArray with the method (or property) applied.
    Raises:
        UnitsConversionError: Raised if units cannot be converted.
        UnitsNotFoundError: Raised if units are not found.
        UnitsNotValidError: Raised if units are not valid.
    See Also:
        https://docs.astropy.org/en/stable/units/quantity.html
    """
    units = unitsof(da, strict=True)
    def per_block(block: TDataArray) -> TDataArray:
        data = apply_any(block, units, method, *args, **kwargs)
        return block.copy(data=data)
    try:
        test = apply_any(TESTER, units, method, *args, **kwargs)
    except Exception as error:
        raise UnitsConversionError(error)
    try:
        result = da.map_blocks(per_block)
    except Exception as error:
        raise UnitsConversionError(error)
    return set(result, unitsof(test, strict=True), overwrite=True) 
def apply_any(
    data: Any,
    units: UnitsLike,
    method: str,
    /,
    *args: Any,
    **kwargs: Any,
) -> Any:
    """Apply a method of Astropy Quantity to any data."""
    attr = getattr(Quantity(data, units), method)
    if isinstance(attr, (MethodType, MethodWrapperType)):
        return attr(*args, **kwargs)
    else:
        return attr
[docs]
def decompose(da: TDataArray, /) -> TDataArray:
    """Convert a DataArray with units to decomposed ones.
    When called from an accessor, it runs ``decompose(accessed)``.
    Args:
        da: Input DataArray with units.
    Returns:
        DataArray with the decomposed units.
    Raises:
        UnitsConversionError: Raised if units cannot be converted.
        UnitsNotFoundError: Raised if units are not found.
        UnitsNotValidError: Raised if units are not valid.
    See Also:
        https://docs.astropy.org/en/stable/units/decomposing_and_composing.html
    """
    return apply(da, "decompose") 
[docs]
def like(
    da: TDataArray,
    other: DataArray,
    /,
    equivalencies: Optional[Equivalency] = None,
) -> TDataArray:
    """Convert a DataArray with units to those of the other.
    When called from an accessor, it runs ``like(accessed, other, ...)``.
    Args:
        da: Input DataArray with units.
        other: DataArray with units to which the input is converted.
        equivalencies: Optional Astropy equivalencies.
    Returns:
        DataArray with the converted units.
    Raises:
        UnitsConversionError: Raised if units cannot be converted.
        UnitsNotFoundError: Raised if units are not found.
        UnitsNotValidError: Raised if units are not valid.
    See Also:
        https://docs.astropy.org/en/stable/units/quantity.html
    """
    units = unitsof(other, strict=True)
    return apply(da, "to", units, equivalencies) 
[docs]
def set(
    da: TDataArray,
    units: UnitsLike,
    /,
    *,
    overwrite: bool = False,
) -> TDataArray:
    """Set units to a DataArray.
    When called from an accessor, it runs ``set(accessed, units, ...)``.
    Args:
        da: Input DataArray.
        units: Units to be set to the input.
        overwrite: Whether to overwrite existing units.
    Returns:
        DataArray with given units in ``attrs["units"]``.
    Raises:
        UnitsExistError: Raised if units already exist.
            Not raised when ``overwrite`` is ``True``.
        UnitsNotValidError: Raised if units are not valid.
    """
    if not overwrite and unitsof(da) is not None:
        raise UnitsExistError(repr(da))
    return da.assign_attrs({UNITS: units}) 
[docs]
def to(
    da: TDataArray,
    units: UnitsLike,
    /,
    equivalencies: Optional[Equivalency] = None,
) -> TDataArray:
    """Convert a DataArray with units to other units.
    When called from an accessor, it runs ``to(accessed, units, ...)``.
    Args:
        da: Input DataArray with units.
        units: Units to which the input is converted.
        equivalencies: Optional Astropy equivalencies.
    Returns:
        DataArray with the converted units.
    Raises:
        UnitsConversionError: Raised if units cannot be converted.
        UnitsNotFoundError: Raised if units are not found.
        UnitsNotValidError: Raised if units are not valid.
    See Also:
        https://docs.astropy.org/en/stable/units/quantity.html
    """
    return apply(da, "to", units, equivalencies) 
[docs]
def unset(da: TDataArray, /) -> TDataArray:
    """Remove units from a DataArray.
    When called from an accessor, it runs ``unset(accessed)``.
    Args:
        da: Input DataArray.
    Returns:
        DataArray with units removed.
    """
    da = da.copy(data=da.data)
    da.attrs.pop(UNITS, None)
    return da