__all__ = ["asframe", "aspandas", "asseries"]
# standard library
from types import FunctionType
from typing import Any, Callable, Hashable, Iterable, Optional, overload
# dependencies
import numpy as np
import pandas as pd
from pandas.api.types import is_list_like
from typing_extensions import get_origin
from .specs import Field, Fields, Spec
from .tagging import Tag
from .typing import DataClass, DataClassOf, PAny, TFrame, TPandas, TSeries
@overload
def aspandas(obj: DataClassOf[TPandas, PAny], *, factory: None = None) -> TPandas: ...
@overload
def aspandas(obj: DataClass[PAny], *, factory: Callable[..., TPandas]) -> TPandas: ...
[docs]
def aspandas(obj: Any, *, factory: Any = None) -> Any:
    """Create a DataFrame or Series object from a dataclass object.
    Which data structure is created will be determined by a factory
    defined as the ``__pandas_factory__`` attribute in the original
    dataclass of ``obj`` or the ``factory`` argument. If a factory is
    a function, it must have an annotation of the return type.
    Args:
        obj: Dataclass object that should have attribute, column, data,
            and/or index fields. If the original dataclass has the
            ``__pandas_factory__`` attribute, it will be used as a
            factory for the data creation.
    Keyword Args:
        factory: Class or function for the DataFrame or Series creation.
            It must take the same parameters as ``pandas.DataFrame``
            or ``pandas.Series``, and return an object of it or its
            subclass. If it is a function, it must have an annotation
            of the return type. If passed, it will be preferentially
            used even if the original dataclass of ``obj`` has the
            ``__pandas_factory__`` attribute.
    Returns:
        DataFrame or Series object that complies with the original dataclass.
    Raises:
        ValueError: Raised if no factory is found or the return type
            cannot be inferred from a factory when it is a function.
    """
    spec = Spec.from_dataclass(type(obj)) @ obj
    if factory is None:
        factory = spec.factory
    if factory is None:
        raise ValueError("Could not find any factory.")
    if isinstance(factory, FunctionType):
        return_ = factory.__annotations__["return"]
    else:
        return_ = factory
    origin = get_origin(return_) or return_
    if issubclass(origin, pd.DataFrame):
        return asframe(obj, factory=factory)
    elif issubclass(origin, pd.Series):
        return asseries(obj, factory=factory)
    else:
        raise ValueError("Could not infer an object type.") 
@overload
def asframe(obj: DataClassOf[TFrame, PAny], *, factory: None = None) -> TFrame: ...
@overload
def asframe(obj: DataClass[PAny], *, factory: Callable[..., TFrame]) -> TFrame: ...
@overload
def asframe(obj: DataClass[PAny], *, factory: None = None) -> pd.DataFrame: ...
[docs]
def asframe(obj: Any, *, factory: Any = None) -> Any:
    """Create a DataFrame object from a dataclass object.
    The return type will be determined by a factory defined as the
    ``__pandas_factory__`` attribute in the original dataclass of
    ``obj`` or the ``factory`` argument. If neither is specified,
    it defaults to ``pandas.DataFrame``.
    Args:
        obj: Dataclass object that should have attribute, column, data,
            and/or index fields. If the original dataclass has the
            ``__pandas_factory__`` attribute, it will be used as a
            factory for the DataFrame creation.
    Keyword Args:
        factory: Class or function for the DataFrame creation.
            It must take the same parameters as ``pandas.DataFrame``,
            and return an object of it or its subclass. If passed, it
            will be preferentially used even if the original dataclass
            of ``obj`` has the ``__pandas_factory__`` attribute.
    Returns:
        DataFrame object that complies with the original dataclass.
    """
    spec = Spec.from_dataclass(type(obj)) @ obj
    if factory is None:
        factory = spec.factory or pd.DataFrame
    dataframe = factory(
        data=get_data(spec),
        index=get_index(spec),
        columns=get_columns(spec),
    )
    dataframe.attrs.update(get_attrs(spec))
    return squeeze(dataframe) 
@overload
def asseries(obj: DataClassOf[TSeries, PAny], *, factory: None = None) -> TSeries: ...
@overload
def asseries(obj: DataClass[PAny], *, factory: Callable[..., TSeries]) -> TSeries: ...
@overload
def asseries(obj: DataClass[PAny], *, factory: None = None) -> "pd.Series[Any]": ...
[docs]
def asseries(obj: Any, *, factory: Any = None) -> Any:
    """Create a Series object from a dataclass object.
    The return type will be determined by a factory defined as the
    ``__pandas_factory__`` attribute in the original dataclass of
    ``obj`` or the ``factory`` argument. If neither is specified,
    it defaults to ``pandas.Series``.
    Args:
        obj: Dataclass object that should have attribute, column, data,
            and/or index fields. If the original dataclass has the
            ``__pandas_factory__`` attribute, it will be used as a
            factory for the Series creation.
    Keyword Args:
        factory: Class or function for the Series creation.
            It must take the same parameters as ``pandas.Series``,
            and return an object of it or its subclass. If passed, it
            will be preferentially used even if the original dataclass
            of ``obj`` has the ``__pandas_factory__`` attribute.
    Returns:
        Series object that complies with the original dataclass.
    """
    spec = Spec.from_dataclass(type(obj)) @ obj
    if factory is None:
        factory = spec.factory or pd.Series
    data = get_data(spec)
    index = get_index(spec)
    if not data:
        series = factory(index=index)
    else:
        name, data = next(iter(data.items()))
        series = factory(data=data, index=index, name=name)
    series.attrs.update(get_attrs(spec))
    return squeeze(series) 
def get_attrs(spec: Spec) -> dict[Hashable, Any]:
    """Derive attributes from a specification."""
    data: dict[Hashable, Any] = {}
    for field in spec.fields.of(Tag.ATTR):
        data.update(items(field))
    return data
def get_columns(spec: Spec) -> Optional[pd.MultiIndex]:
    """Derive columns from a specification."""
    if not (fields := spec.fields.of(Tag.DATA)):
        return None
    if (names := name(fields)) is None:
        return None
    return pd.MultiIndex.from_tuples(
        map(name, fields),
        names=names,
    )
def get_data(spec: Spec) -> dict[Hashable, Any]:
    """Derive data from a specification."""
    data: dict[Hashable, Any] = {}
    for field in spec.fields.of(Tag.DATA):
        for key, val in items(field):
            data[key] = ensure(val, field.dtype)
    return data
def get_index(spec: Spec) -> Optional[pd.MultiIndex]:
    """Derive index from a specification."""
    if not (fields := spec.fields.of(Tag.INDEX)):
        return None
    data: dict[Hashable, Any] = {}
    for field in fields:
        for key, val in items(field):
            data[key] = ensure(val, field.dtype)
    return pd.MultiIndex.from_arrays(
        np.broadcast_arrays(*data.values()),
        names=data.keys(),
    )
def ensure(data: Any, dtype: Optional[str]) -> Any:
    """Ensure data to be 1D and have given data type."""
    if not is_list_like(data):
        data = [data]
    if isinstance(data, (pd.Index, pd.Series)):
        return type(data)(data, dtype=dtype, copy=False)  # type: ignore
    else:
        return pd.array(data, dtype=dtype, copy=False)
def items(field: Field) -> Iterable[tuple[Hashable, Any]]:
    """Generate default(s) of a field specification."""
    if field.has(Tag.MULTIPLE):
        yield from field.default.items()
    else:
        yield (name(field), field.default)
@overload
def name(fields: Field) -> Hashable: ...
@overload
def name(fields: Fields) -> Optional[Hashable]: ...
def name(fields: Any) -> Any:
    """Derive name of a field(s) specification."""
    if isinstance(fields, Field):
        if isinstance(name := fields.name, dict):
            return tuple(name.values())
        else:
            return name
    if isinstance(fields, Fields):
        for field in fields:
            if isinstance(name := field.name, dict):
                return tuple(name.keys())
def squeeze(data: TPandas) -> TPandas:
    """Drop levels of an index and columns if possible."""
    if data.index.nlevels == 1:
        data.index = data.index.get_level_values(0)
    if isinstance(data, pd.Series):
        return data  # type: ignore
    if data.columns.nlevels == 1:
        data.columns = data.columns.get_level_values(0)
    return data