Source code for ndtools.comparison

__all__ = ["TotalEquality", "TotalOrdering"]


# standard library
from collections.abc import Callable
from typing import Any, TypeVar


# dependencies
from . import operators as op
from .abc import Equatable, Orderable


# type hints
T = TypeVar("T")


# constants
MISSINGS_EQUALITY = {
    "__eq__": {
        "__ne__": op.ne_by_eq,
    },
    "__ne__": {
        "__eq__": op.eq_by_ne,
    },
}
MISSINGS_ORDERING = {
    "__ge__": {
        "__gt__": op.gt_by_ge,
        "__le__": op.le_by_ge,
        "__lt__": op.lt_by_ge,
    },
    "__gt__": {
        "__ge__": op.ge_by_gt,
        "__le__": op.le_by_gt,
        "__lt__": op.lt_by_gt,
    },
    "__le__": {
        "__gt__": op.gt_by_le,
        "__ge__": op.ge_by_le,
        "__lt__": op.lt_by_le,
    },
    "__lt__": {
        "__gt__": op.gt_by_lt,
        "__ge__": op.ge_by_lt,
        "__le__": op.le_by_lt,
    },
}


[docs] class TotalEquality(Equatable): """Implement missing equality operations for multidimensional arrays. Raises: ValueError: Raised if none of the equality operators (==, !=) is defined. Examples: :: import numpy as np from ndtools import TotalEquality class Even(TotalEquality): def __eq__(self, array): return array % 2 == 0 Even() == np.arange(3) # -> array([True, False, True]) np.arange(3) == Even() # -> array([True, False, True]) Even() != np.arange(3) # -> array([False, True, False]) np.arange(3) != Even() # -> array([False, True, False]) """ __eq__: Callable[..., Any] __ne__: Callable[..., Any] __array_ufunc__: Callable[..., Any] def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) defined = [name for name in MISSINGS_EQUALITY if has_usermethod(cls, name)] if not defined: raise ValueError("Define at least one equality operator (==, !=).") for name, operator in MISSINGS_EQUALITY[defined[0]].items(): if not has_usermethod(cls, name): setattr(cls, name, operator)
[docs] class TotalOrdering(Orderable): """Implement missing ordering operations for multidimensional arrays. Raises: ValueError: Raise if none of the ordering operator (>=, >, <=, <) is defined. Examples: :: import numpy as np from dataclasses import dataclass from ndtools import TotalOrdering @dataclass class Range(TotalOrdering): lower: float upper: float def __eq__(self, array): return (array >= self.lower) & (array < self.upper) def __ge__(self, array): return array < self.upper Range(1, 2) == np.arange(3) # -> array([False, True, False]) np.arange(3) == Range(1, 2) # -> array([False, True, False]) Range(1, 2) >= np.arange(3) # -> array([True, True, False]) np.arange(3) <= Range(1, 2) # -> array([True, True, False]) """ __eq__: Callable[..., Any] __ge__: Callable[..., Any] __gt__: Callable[..., Any] __le__: Callable[..., Any] __lt__: Callable[..., Any] __ne__: Callable[..., Any] __array_ufunc__: Callable[..., Any] def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) defined = [name for name in MISSINGS_EQUALITY if has_usermethod(cls, name)] if not defined: raise ValueError("Define at least one equality operator (==, !=).") for name, operator in MISSINGS_EQUALITY[defined[0]].items(): if not has_usermethod(cls, name): setattr(cls, name, operator) defined = [name for name in MISSINGS_ORDERING if has_usermethod(cls, name)] if not defined: raise ValueError("Define at least one ordering operator (>=, >, <=, <).") for name, operator in MISSINGS_ORDERING[defined[0]].items(): if not has_usermethod(cls, name): setattr(cls, name, operator)
def has_usermethod(obj: Any, name: str, /) -> bool: """Check if an object has a user-defined method with given name.""" return ( hasattr(obj, name) and not is_abstractmethod(getattr(obj, name)) and not is_objectmethod(getattr(obj, name)) ) def is_abstractmethod(method: Any, /) -> bool: """Check if given method is an abstract method.""" return bool(getattr(method, "__isabstractmethod__", None)) def is_objectmethod(method: Any, /) -> bool: """Check if given method is defined in the object class.""" return method is getattr(object, method.__name__, None)