from __future__ import annotations

import enum
import math
from typing import Any

from sqlalchemy import asc
from sqlalchemy import case
from sqlalchemy import CheckConstraint
from sqlalchemy import DateTime
from sqlalchemy import desc
from sqlalchemy import Enum
from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import orm
from sqlalchemy import String
from sqlalchemy import Text
from sqlalchemy import UniqueConstraint

from optuna import distributions
from optuna.study._study_direction import StudyDirection
from optuna.trial import TrialState


try:
    from sqlalchemy.orm import declarative_base
except ImportError:
    # TODO(c-bata): Remove this after dropping support for SQLAlchemy v1.3 or prior.
    from sqlalchemy.ext.declarative import declarative_base

try:
    from sqlalchemy.orm import mapped_column

    _Column = mapped_column
except ImportError:
    # TODO(Shinichi): Remove this after dropping support for SQLAlchemy<2.0.
    from sqlalchemy import Column as _Column  # type: ignore[assignment]

# Don't modify this version number anymore.
# The schema management functionality has been moved to alembic.
SCHEMA_VERSION = 12

MAX_INDEXED_STRING_LENGTH = 512
MAX_VERSION_LENGTH = 256

NOT_FOUND_MSG = "Record does not exist."

FLOAT_PRECISION = 53

BaseModel: Any = declarative_base()


class StudyModel(BaseModel):
    __tablename__ = "studies"
    study_id = _Column(Integer, primary_key=True)
    study_name = _Column(
        String(MAX_INDEXED_STRING_LENGTH), index=True, unique=True, nullable=False
    )

    @classmethod
    def find_or_raise_by_id(
        cls, study_id: int, session: orm.Session, for_update: bool = False
    ) -> "StudyModel":
        query = session.query(cls).filter(cls.study_id == study_id)

        if for_update:
            query = query.with_for_update()

        study = query.one_or_none()
        if study is None:
            raise KeyError(NOT_FOUND_MSG)

        return study

    @classmethod
    def find_by_name(cls, study_name: str, session: orm.Session) -> "StudyModel" | None:
        study = session.query(cls).filter(cls.study_name == study_name).one_or_none()

        return study

    @classmethod
    def find_or_raise_by_name(cls, study_name: str, session: orm.Session) -> "StudyModel":
        study = cls.find_by_name(study_name, session)
        if study is None:
            raise KeyError(NOT_FOUND_MSG)

        return study


class StudyDirectionModel(BaseModel):
    __tablename__ = "study_directions"
    __table_args__: Any = (UniqueConstraint("study_id", "objective"),)
    study_direction_id = _Column(Integer, primary_key=True)
    direction = _Column(Enum(StudyDirection), nullable=False)
    study_id = _Column(Integer, ForeignKey("studies.study_id"), nullable=False)
    objective = _Column(Integer, nullable=False)

    study = orm.relationship(
        StudyModel, backref=orm.backref("directions", cascade="all, delete-orphan")
    )

    @classmethod
    def where_study_id(cls, study_id: int, session: orm.Session) -> list["StudyDirectionModel"]:
        return session.query(cls).filter(cls.study_id == study_id).all()


class StudyUserAttributeModel(BaseModel):
    __tablename__ = "study_user_attributes"
    __table_args__: Any = (UniqueConstraint("study_id", "key"),)
    study_user_attribute_id = _Column(Integer, primary_key=True)
    study_id = _Column(Integer, ForeignKey("studies.study_id"))
    key = _Column(String(MAX_INDEXED_STRING_LENGTH))
    value_json = _Column(Text())

    study = orm.relationship(
        StudyModel, backref=orm.backref("user_attributes", cascade="all, delete-orphan")
    )

    @classmethod
    def find_by_study_and_key(
        cls, study: StudyModel, key: str, session: orm.Session
    ) -> "StudyUserAttributeModel" | None:
        attribute = (
            session.query(cls)
            .filter(cls.study_id == study.study_id)
            .filter(cls.key == key)
            .one_or_none()
        )

        return attribute

    @classmethod
    def where_study_id(
        cls, study_id: int, session: orm.Session
    ) -> list["StudyUserAttributeModel"]:
        return session.query(cls).filter(cls.study_id == study_id).all()


class StudySystemAttributeModel(BaseModel):
    __tablename__ = "study_system_attributes"
    __table_args__: Any = (UniqueConstraint("study_id", "key"),)
    study_system_attribute_id = _Column(Integer, primary_key=True)
    study_id = _Column(Integer, ForeignKey("studies.study_id"))
    key = _Column(String(MAX_INDEXED_STRING_LENGTH))
    value_json = _Column(Text())

    study = orm.relationship(
        StudyModel, backref=orm.backref("system_attributes", cascade="all, delete-orphan")
    )

    @classmethod
    def find_by_study_and_key(
        cls, study: StudyModel, key: str, session: orm.Session
    ) -> "StudySystemAttributeModel" | None:
        attribute = (
            session.query(cls)
            .filter(cls.study_id == study.study_id)
            .filter(cls.key == key)
            .one_or_none()
        )

        return attribute

    @classmethod
    def where_study_id(
        cls, study_id: int, session: orm.Session
    ) -> list["StudySystemAttributeModel"]:
        return session.query(cls).filter(cls.study_id == study_id).all()


class TrialModel(BaseModel):
    __tablename__ = "trials"
    trial_id = _Column(Integer, primary_key=True)
    # No `UniqueConstraint` is put on the `number` columns although it in practice is constrained
    # to be unique. This is to reduce code complexity as table-level locking would be required
    # otherwise. See https://github.com/optuna/optuna/pull/939#discussion_r387447632.
    number = _Column(Integer)
    study_id = _Column(Integer, ForeignKey("studies.study_id"), index=True)
    state = _Column(Enum(TrialState), nullable=False)
    datetime_start = _Column(DateTime)
    datetime_complete = _Column(DateTime)

    study = orm.relationship(
        StudyModel, backref=orm.backref("trials", cascade="all, delete-orphan")
    )

    @classmethod
    def find_max_value_trial_id(cls, study_id: int, objective: int, session: orm.Session) -> int:
        trial = (
            session.query(cls)
            .with_entities(cls.trial_id)
            .filter(cls.study_id == study_id)
            .filter(cls.state == TrialState.COMPLETE)
            .join(TrialValueModel)
            .filter(TrialValueModel.objective == objective)
            .order_by(
                desc(
                    case(
                        (
                            TrialValueModel.value_type == TrialValueModel.TrialValueType.INF_NEG,
                            -1,
                        ),
                        (
                            TrialValueModel.value_type == TrialValueModel.TrialValueType.FINITE,
                            0,
                        ),
                        (
                            TrialValueModel.value_type == TrialValueModel.TrialValueType.INF_POS,
                            1,
                        ),
                    )
                ),
                desc(TrialValueModel.value),
            )
            .limit(1)
            .one_or_none()
        )
        if trial is None:
            raise ValueError(NOT_FOUND_MSG)
        return trial[0]

    @classmethod
    def find_min_value_trial_id(cls, study_id: int, objective: int, session: orm.Session) -> int:
        trial = (
            session.query(cls)
            .with_entities(cls.trial_id)
            .filter(cls.study_id == study_id)
            .filter(cls.state == TrialState.COMPLETE)
            .join(TrialValueModel)
            .filter(TrialValueModel.objective == objective)
            .order_by(
                asc(
                    case(
                        (
                            TrialValueModel.value_type == TrialValueModel.TrialValueType.INF_NEG,
                            -1,
                        ),
                        (
                            TrialValueModel.value_type == TrialValueModel.TrialValueType.FINITE,
                            0,
                        ),
                        (
                            TrialValueModel.value_type == TrialValueModel.TrialValueType.INF_POS,
                            1,
                        ),
                    )
                ),
                asc(TrialValueModel.value),  # Note: asc here
            )
            .limit(1)
            .one_or_none()
        )
        if trial is None:
            raise ValueError(NOT_FOUND_MSG)
        return trial[0]

    @classmethod
    def find_or_raise_by_id(
        cls, trial_id: int, session: orm.Session, for_update: bool = False
    ) -> "TrialModel":
        query = session.query(cls).filter(cls.trial_id == trial_id)

        # "FOR UPDATE" clause is used for row-level locking.
        # Please note that SQLite3 doesn't support this clause.
        if for_update:
            query = query.with_for_update()

        trial = query.one_or_none()
        if trial is None:
            raise KeyError(NOT_FOUND_MSG)

        return trial

    @classmethod
    def count(
        cls, session: orm.Session, study: StudyModel | None = None, state: TrialState | None = None
    ) -> int:
        trial_count = session.query(func.count(cls.trial_id))
        if study is not None:
            trial_count = trial_count.filter(cls.study_id == study.study_id)
        if state is not None:
            trial_count = trial_count.filter(cls.state == state)

        return trial_count.scalar()

    def count_past_trials(self, session: orm.Session) -> int:
        trial_count = session.query(func.count(TrialModel.trial_id)).filter(
            TrialModel.study_id == self.study_id, TrialModel.trial_id < self.trial_id
        )
        return trial_count.scalar()


class TrialUserAttributeModel(BaseModel):
    __tablename__ = "trial_user_attributes"
    __table_args__: Any = (UniqueConstraint("trial_id", "key"),)
    trial_user_attribute_id = _Column(Integer, primary_key=True)
    trial_id = _Column(Integer, ForeignKey("trials.trial_id"))
    key = _Column(String(MAX_INDEXED_STRING_LENGTH))
    value_json = _Column(Text())

    trial = orm.relationship(
        TrialModel, backref=orm.backref("user_attributes", cascade="all, delete-orphan")
    )

    @classmethod
    def find_by_trial_and_key(
        cls, trial: TrialModel, key: str, session: orm.Session
    ) -> "TrialUserAttributeModel" | None:
        attribute = (
            session.query(cls)
            .filter(cls.trial_id == trial.trial_id)
            .filter(cls.key == key)
            .one_or_none()
        )

        return attribute

    @classmethod
    def where_trial_id(
        cls, trial_id: int, session: orm.Session
    ) -> list["TrialUserAttributeModel"]:
        return session.query(cls).filter(cls.trial_id == trial_id).all()


class TrialSystemAttributeModel(BaseModel):
    __tablename__ = "trial_system_attributes"
    __table_args__: Any = (UniqueConstraint("trial_id", "key"),)
    trial_system_attribute_id = _Column(Integer, primary_key=True)
    trial_id = _Column(Integer, ForeignKey("trials.trial_id"))
    key = _Column(String(MAX_INDEXED_STRING_LENGTH))
    value_json = _Column(Text())

    trial = orm.relationship(
        TrialModel, backref=orm.backref("system_attributes", cascade="all, delete-orphan")
    )

    @classmethod
    def find_by_trial_and_key(
        cls, trial: TrialModel, key: str, session: orm.Session
    ) -> "TrialSystemAttributeModel" | None:
        attribute = (
            session.query(cls)
            .filter(cls.trial_id == trial.trial_id)
            .filter(cls.key == key)
            .one_or_none()
        )

        return attribute

    @classmethod
    def where_trial_id(
        cls, trial_id: int, session: orm.Session
    ) -> list["TrialSystemAttributeModel"]:
        return session.query(cls).filter(cls.trial_id == trial_id).all()


class TrialParamModel(BaseModel):
    __tablename__ = "trial_params"
    __table_args__: Any = (UniqueConstraint("trial_id", "param_name"),)
    param_id = _Column(Integer, primary_key=True)
    trial_id = _Column(Integer, ForeignKey("trials.trial_id"))
    param_name = _Column(String(MAX_INDEXED_STRING_LENGTH))
    param_value = _Column(Float(precision=FLOAT_PRECISION))
    distribution_json = _Column(Text())

    trial = orm.relationship(
        TrialModel, backref=orm.backref("params", cascade="all, delete-orphan")
    )

    def check_and_add(self, session: orm.Session, study_id: int) -> None:
        self._check_compatibility_with_previous_trial_param_distributions(session, study_id)
        session.add(self)

    def _check_compatibility_with_previous_trial_param_distributions(
        self, session: orm.Session, study_id: int
    ) -> None:
        previous_record = (
            session.query(TrialParamModel)
            .join(TrialModel)
            .filter(TrialModel.study_id == study_id)
            .filter(TrialParamModel.param_name == self.param_name)
            .first()
        )
        if previous_record is not None:
            distributions.check_distribution_compatibility(
                distributions.json_to_distribution(previous_record.distribution_json),
                distributions.json_to_distribution(self.distribution_json),
            )

    @classmethod
    def find_by_trial_and_param_name(
        cls, trial: TrialModel, param_name: str, session: orm.Session
    ) -> "TrialParamModel" | None:
        param_distribution = (
            session.query(cls)
            .filter(cls.trial_id == trial.trial_id)
            .filter(cls.param_name == param_name)
            .one_or_none()
        )

        return param_distribution

    @classmethod
    def find_or_raise_by_trial_and_param_name(
        cls, trial: TrialModel, param_name: str, session: orm.Session
    ) -> "TrialParamModel":
        param_distribution = cls.find_by_trial_and_param_name(trial, param_name, session)

        if param_distribution is None:
            raise KeyError(NOT_FOUND_MSG)

        return param_distribution

    @classmethod
    def where_trial_id(cls, trial_id: int, session: orm.Session) -> list["TrialParamModel"]:
        trial_params = session.query(cls).filter(cls.trial_id == trial_id).all()

        return trial_params


class TrialValueModel(BaseModel):
    class TrialValueType(enum.Enum):
        FINITE = 1
        INF_POS = 2
        INF_NEG = 3

    __tablename__ = "trial_values"
    __table_args__: Any = (UniqueConstraint("trial_id", "objective"),)
    trial_value_id = _Column(Integer, primary_key=True)
    trial_id = _Column(Integer, ForeignKey("trials.trial_id"), nullable=False)
    objective = _Column(Integer, nullable=False)
    value = _Column(Float(precision=FLOAT_PRECISION), nullable=True)
    value_type = _Column(Enum(TrialValueType), nullable=False)

    trial = orm.relationship(
        TrialModel, backref=orm.backref("values", cascade="all, delete-orphan")
    )

    @classmethod
    def value_to_stored_repr(cls, value: float) -> tuple[float | None, TrialValueType]:
        if value == float("inf"):
            return None, cls.TrialValueType.INF_POS
        elif value == float("-inf"):
            return None, cls.TrialValueType.INF_NEG
        else:
            return value, cls.TrialValueType.FINITE

    @classmethod
    def stored_repr_to_value(cls, value: float | None, float_type: TrialValueType) -> float:
        if float_type == cls.TrialValueType.INF_POS:
            assert value is None
            return float("inf")
        elif float_type == cls.TrialValueType.INF_NEG:
            assert value is None
            return float("-inf")
        else:
            assert float_type == cls.TrialValueType.FINITE
            assert value is not None
            return value

    @classmethod
    def find_by_trial_and_objective(
        cls, trial: TrialModel, objective: int, session: orm.Session
    ) -> "TrialValueModel" | None:
        trial_value = (
            session.query(cls)
            .filter(cls.trial_id == trial.trial_id)
            .filter(cls.objective == objective)
            .one_or_none()
        )

        return trial_value

    @classmethod
    def where_trial_id(cls, trial_id: int, session: orm.Session) -> list["TrialValueModel"]:
        trial_values = (
            session.query(cls).filter(cls.trial_id == trial_id).order_by(asc(cls.objective)).all()
        )

        return trial_values


class TrialIntermediateValueModel(BaseModel):
    class TrialIntermediateValueType(enum.Enum):
        FINITE = 1
        INF_POS = 2
        INF_NEG = 3
        NAN = 4

    __tablename__ = "trial_intermediate_values"
    __table_args__: Any = (UniqueConstraint("trial_id", "step"),)
    trial_intermediate_value_id = _Column(Integer, primary_key=True)
    trial_id = _Column(Integer, ForeignKey("trials.trial_id"), nullable=False)
    step = _Column(Integer, nullable=False)
    intermediate_value = _Column(Float(precision=FLOAT_PRECISION), nullable=True)
    intermediate_value_type = _Column(Enum(TrialIntermediateValueType), nullable=False)

    trial = orm.relationship(
        TrialModel, backref=orm.backref("intermediate_values", cascade="all, delete-orphan")
    )

    @classmethod
    def intermediate_value_to_stored_repr(
        cls, value: float
    ) -> tuple[float | None, TrialIntermediateValueType]:
        if math.isnan(value):
            return None, cls.TrialIntermediateValueType.NAN
        elif value == float("inf"):
            return None, cls.TrialIntermediateValueType.INF_POS
        elif value == float("-inf"):
            return None, cls.TrialIntermediateValueType.INF_NEG
        else:
            return value, cls.TrialIntermediateValueType.FINITE

    @classmethod
    def stored_repr_to_intermediate_value(
        cls, value: float | None, float_type: TrialIntermediateValueType
    ) -> float:
        if float_type == cls.TrialIntermediateValueType.NAN:
            assert value is None
            return float("nan")
        elif float_type == cls.TrialIntermediateValueType.INF_POS:
            assert value is None
            return float("inf")
        elif float_type == cls.TrialIntermediateValueType.INF_NEG:
            assert value is None
            return float("-inf")
        else:
            assert float_type == cls.TrialIntermediateValueType.FINITE
            assert value is not None
            return value

    @classmethod
    def find_by_trial_and_step(
        cls, trial: TrialModel, step: int, session: orm.Session
    ) -> "TrialIntermediateValueModel" | None:
        trial_intermediate_value = (
            session.query(cls)
            .filter(cls.trial_id == trial.trial_id)
            .filter(cls.step == step)
            .one_or_none()
        )

        return trial_intermediate_value

    @classmethod
    def where_trial_id(
        cls, trial_id: int, session: orm.Session
    ) -> list["TrialIntermediateValueModel"]:
        trial_intermediate_values = session.query(cls).filter(cls.trial_id == trial_id).all()

        return trial_intermediate_values


class TrialHeartbeatModel(BaseModel):
    __tablename__ = "trial_heartbeats"
    __table_args__: Any = (UniqueConstraint("trial_id"),)
    trial_heartbeat_id = _Column(Integer, primary_key=True)
    trial_id = _Column(Integer, ForeignKey("trials.trial_id"), nullable=False)
    heartbeat = _Column(DateTime, nullable=False, default=func.current_timestamp())

    trial = orm.relationship(
        TrialModel, backref=orm.backref("heartbeats", cascade="all, delete-orphan")
    )

    @classmethod
    def where_trial_id(
        cls, trial_id: int, session: orm.Session, for_update: bool = False
    ) -> "TrialHeartbeatModel" | None:

        query = session.query(cls).filter(cls.trial_id == trial_id)

        if for_update:
            query = query.with_for_update()

        return query.one_or_none()


class VersionInfoModel(BaseModel):
    __tablename__ = "version_info"
    # setting check constraint to ensure the number of rows is at most 1
    __table_args__: Any = (CheckConstraint("version_info_id=1"),)
    version_info_id = _Column(Integer, primary_key=True, autoincrement=False, default=1)
    schema_version = _Column(Integer)
    library_version = _Column(String(MAX_VERSION_LENGTH))

    @classmethod
    def find(cls, session: orm.Session) -> "VersionInfoModel" | None:
        version_info = session.query(cls).one_or_none()
        return version_info
