Source code for xarray_dataclasses.datamodel

"""Submodule for data expression inside the package."""

__all__ = ["DataModel"]


# standard library
from dataclasses import dataclass, field, is_dataclass
from typing import (
    Any,
    Dict,
    Hashable,
    List,
    Literal,
    Optional,
    Tuple,
    Type,
    Union,
    cast,
)


# dependencies
import numpy as np
import xarray as xr
from typing_extensions import ParamSpec, get_type_hints


# submodules
from .typing import (
    AnyDType,
    AnyField,
    AnyXarray,
    DataClass,
    Dims,
    Role,
    get_annotated,
    get_dataclass,
    get_dims,
    get_dtype,
    get_name,
    get_role,
)


# type hints
PInit = ParamSpec("PInit")
AnyDataClass = Union[Type[DataClass[PInit]], DataClass[PInit]]
AnyEntry = Union["AttrEntry", "DataEntry"]


# runtime classes
class MissingType:
    """Singleton that indicates missing data."""

    _instance = None

    def __new__(cls) -> "MissingType":
        if cls._instance is None:
            cls._instance = super().__new__(cls)

        return cls._instance

    def __repr__(self) -> str:
        return "<MISSING>"


MISSING = MissingType()


@dataclass(frozen=True)
class AttrEntry:
    """Entry of an attribute (i.e. metadata)."""

    name: Hashable
    """Name that the attribute is accessed by."""

    tag: Literal["attr", "name"]
    """Function of the attribute (either attr or name)."""

    type: Any = None
    """Type or type hint of the attribute."""

    value: Any = MISSING
    """Actual value of the attribute."""

    cast: bool = False
    """Whether the value is cast to the type."""

    def __call__(self) -> Any:
        """Create an object according to the entry."""
        if self.value is MISSING:
            raise ValueError("Value is missing.")

        return self.value


@dataclass(frozen=True)
class DataEntry:
    """Entry of a data variable."""

    name: Hashable
    """Name that the attribute is accessed by."""

    tag: Literal["coord", "data"]
    """Function of the data (either coord or data)."""

    dims: Dims = cast(Dims, None)
    """Dimensions of the DataArray that the data is cast to."""

    dtype: Optional[AnyDType] = None
    """Data type of the DataArray that the data is cast to."""

    base: Optional[Type[DataClass[Any]]] = None
    """Base dataclass that converts the data to a DataArray."""

    value: Any = MISSING
    """Actual value of the data."""

    cast: bool = True
    """Whether the value is cast to the data type."""

    def __post_init__(self) -> None:
        """Update the entry if a base dataclass exists."""
        if self.base is None:
            return

        model = DataModel.from_dataclass(self.base)

        setattr = object.__setattr__
        setattr(self, "dims", model.data_vars[0].dims)
        setattr(self, "dtype", model.data_vars[0].dtype)

        if model.names:
            setattr(self, "name", model.names[0].value)

    def __call__(self, reference: Optional[AnyXarray] = None) -> xr.DataArray:
        """Create a DataArray object according to the entry."""
        from .dataarray import asdataarray

        if self.value is MISSING:
            raise ValueError("Value is missing.")

        if self.base is None:
            return get_typedarray(self.value, self.dims, self.dtype, reference)

        if is_dataclass(self.value):
            return asdataarray(self.value, reference)
        else:
            return asdataarray(self.base(self.value), reference)


[docs]@dataclass(frozen=True) class DataModel: """Data representation (data model) inside the package.""" entries: Dict[str, AnyEntry] = field(default_factory=dict) """Entries of data variable(s) and attribute(s).""" @property def attrs(self) -> List[AttrEntry]: """Return a list of attribute entries.""" return [v for v in self.entries.values() if v.tag == "attr"] @property def coords(self) -> List[DataEntry]: """Return a list of coordinate entries.""" return [v for v in self.entries.values() if v.tag == "coord"] @property def data_vars(self) -> List[DataEntry]: """Return a list of data variable entries.""" return [v for v in self.entries.values() if v.tag == "data"] @property def data_vars_items(self) -> List[Tuple[str, DataEntry]]: """Return a list of data variable entries with keys.""" return [(k, v) for k, v in self.entries.items() if v.tag == "data"] @property def names(self) -> List[AttrEntry]: """Return a list of name entries.""" return [v for v in self.entries.values() if v.tag == "name"]
[docs] @classmethod def from_dataclass(cls, dataclass: AnyDataClass[PInit]) -> "DataModel": """Create a data model from a dataclass or its object.""" model = cls() eval_dataclass(dataclass) for field in dataclass.__dataclass_fields__.values(): value = getattr(dataclass, field.name, MISSING) entry = get_entry(field, value) if entry is not None: model.entries[field.name] = entry return model
# runtime functions def eval_dataclass(dataclass: AnyDataClass[PInit]) -> None: """Evaluate field types of a dataclass.""" if not is_dataclass(dataclass): raise TypeError("Not a dataclass or its object.") fields = dataclass.__dataclass_fields__.values() # do nothing if field types are already evaluated if not any(isinstance(field.type, str) for field in fields): return # otherwise, replace field types with evaluated types if not isinstance(dataclass, type): dataclass = type(dataclass) types = get_type_hints(dataclass, include_extras=True) for field in fields: field.type = types[field.name] def get_entry(field: AnyField, value: Any) -> Optional[AnyEntry]: """Create an entry from a field and its value.""" role = get_role(field.type) name = get_name(field.type, field.name) if role is Role.ATTR or role is Role.NAME: return AttrEntry( name=name, tag=role.value, value=value, type=get_annotated(field.type), ) if role is Role.COORD or role is Role.DATA: try: return DataEntry( name=name, tag=role.value, base=get_dataclass(field.type), value=value, ) except TypeError: return DataEntry( name=name, tag=role.value, dims=get_dims(field.type), dtype=get_dtype(field.type), value=value, ) def get_typedarray( data: Any, dims: Dims, dtype: Optional[AnyDType], reference: Optional[AnyXarray] = None, ) -> xr.DataArray: """Create a DataArray object with given dims and dtype. Args: data: Data to be converted to the DataArray object. dims: Dimensions of the DataArray object. dtype: Data type of the DataArray object. reference: DataArray or Dataset object as a reference of shape. Returns: DataArray object with given dims and dtype. """ try: data.__array__ except AttributeError: data = np.asarray(data) if dtype is not None: data = data.astype(dtype, copy=False) if data.ndim == len(dims): dataarray = xr.DataArray(data, dims=dims) elif data.ndim == 0 and reference is not None: dataarray = xr.DataArray(data) else: raise ValueError( "Could not create a DataArray object from data. " f"Mismatch between shape {data.shape} and dims {dims}." ) if reference is None: return dataarray else: ddims = set(reference.dims) - set(dims) reference = reference.isel({dim: 0 for dim in ddims}) return dataarray.broadcast_like(reference)