Source code for ndtools.comparison

__all__ = ["total_equality", "total_ordering"]


# standard library
from typing import Any, TypeVar


# dependencies
import numpy as np
from . import operators as op


# type hints
T = TypeVar("T")


# constants
MISSINGS_EQUALITY = {
    "__eq__": {
        "__ne__": op.ne_by_eq,
    },
    "__ne__": {
        "__eq__": op.eq_by_ne,
    },
}
MISSINGS_ORDERING = {
    "__ge__": {
        "__gt__": op.gt_by_ge,
        "__le__": op.le_by_ge,
        "__lt__": op.lt_by_ge,
    },
    "__gt__": {
        "__ge__": op.ge_by_gt,
        "__le__": op.le_by_gt,
        "__lt__": op.lt_by_gt,
    },
    "__le__": {
        "__gt__": op.gt_by_le,
        "__ge__": op.ge_by_le,
        "__lt__": op.lt_by_le,
    },
    "__lt__": {
        "__gt__": op.gt_by_lt,
        "__ge__": op.ge_by_lt,
        "__le__": op.le_by_lt,
    },
}


def has_userattr(obj: Any, name: str, /) -> bool:
    """Check if an object has a used-defined attribute with given name."""
    return getattr(obj, name, None) is not getattr(object, name, None)


[docs] def total_equality(cls: type[T], /) -> type[T]: """Class decorator that fills in missing multidimensional equality methods. Args: cls: Class to be decorated. Returns: The same class with missing multidimensional equality methods. Examples: :: import numpy as np from ndtools import total_equality @total_equality class Even: def __eq__(self, array): return ~((array % 2).astype(bool)) result = (np.arange(3) == Even()) expected = np.array([True, False, True]) assert (result == expected).all() """ defined = [name for name in MISSINGS_EQUALITY if has_userattr(cls, name)] if not defined: raise ValueError("Define at least one equality operator (==, !=).") for name, operator in MISSINGS_EQUALITY[defined[0]].items(): if not has_userattr(cls, name): setattr(cls, name, operator) def __array_ufunc__( self: Any, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any, ) -> Any: if ufunc is np.equal: return self == inputs[0] elif ufunc is np.not_equal: return self != inputs[0] else: return NotImplemented setattr(cls, "__array_ufunc__", __array_ufunc__) return cls
[docs] def total_ordering(cls: type[T], /) -> type[T]: """Class decorator that fills in missing multidimensional ordering methods. Args: cls: Class to be decorated. Returns: The same class with missing multidimensional ordering methods. Examples: :: import numpy as np from dataclasses import dataclass from ndtools import total_ordering @dataclass @total_ordering class Interval: lower: float upper: float def __eq__(self, array): return (array >= self.lower) & (array < self.upper) def __ge__(self, array): return array < self.upper result = (np.arange(3) == Interval(1, 2)) expected = np.array([False, True, False]) assert (result == expected).all() result = (np.arange(3) < Interval(1, 2)) expected = np.array([True, False, False]) assert (result == expected).all() """ cls = total_equality(cls) defined = [name for name in MISSINGS_ORDERING if has_userattr(cls, name)] if not defined: raise ValueError("Define at least one ordering operator (>=, >, <=, <).") for name, operator in MISSINGS_ORDERING[defined[0]].items(): if not has_userattr(cls, name): setattr(cls, name, operator) def __array_ufunc__( self: Any, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any, ) -> Any: if ufunc is np.equal: return self == inputs[0] elif ufunc is np.greater: return self < inputs[0] elif ufunc is np.greater_equal: return self <= inputs[0] elif ufunc is np.less: return self > inputs[0] elif ufunc is np.less_equal: return self >= inputs[0] elif ufunc is np.not_equal: return self != inputs[0] else: return NotImplemented setattr(cls, "__array_ufunc__", __array_ufunc__) return cls