from __future__ import annotations

from collections.abc import Callable
from collections.abc import Container
from collections.abc import Sequence
import copy
import threading
from typing import Any

import optuna
from optuna import distributions
from optuna._typing import JSONSerializable
from optuna.storages import BaseStorage
from optuna.storages._heartbeat import BaseHeartbeat
from optuna.storages._rdb.storage import RDBStorage
from optuna.study._frozen import FrozenStudy
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import TrialState


class _StudyInfo:
    def __init__(self) -> None:
        # Trial number to corresponding FrozenTrial.
        self.trials: dict[int, FrozenTrial] = {}
        # A list of trials and the last trial number which require storage access to read latest
        # attributes.
        self.unfinished_trial_ids: set[int] = set()
        self.last_finished_trial_id: int = -1
        self.directions: list[StudyDirection] | None = None
        self.name: str | None = None


class _CachedStorage(BaseStorage, BaseHeartbeat):
    """A wrapper class of storage backends.

    This class is used in :func:`~optuna.get_storage` function and automatically
    wraps :class:`~optuna.storages.RDBStorage` class.

    :class:`~optuna.storages._CachedStorage` meets the following **Data persistence** requirements.

    **Data persistence**

    :class:`~optuna.storages._CachedStorage` does not guarantee that write operations are logged
    into a persistent storage, even when write methods succeed.
    Thus, when process failure occurs, some writes might be lost.
    As exceptions, when a persistent storage is available, any writes on any attributes
    of `Study` and writes on `state` of `Trial` are guaranteed to be persistent.
    Additionally, any preceding writes on any attributes of `Trial` are guaranteed to
    be written into a persistent storage before writes on `state` of `Trial` succeed.
    The same applies for `param`, `user_attrs', 'system_attrs' and 'intermediate_values`
    attributes.

    Args:
        backend:
            :class:`~optuna.storages.RDBStorage` class instance to wrap.
    """

    def __init__(self, backend: RDBStorage) -> None:
        self._backend = backend
        self._studies: dict[int, _StudyInfo] = {}
        self._trial_id_to_study_id_and_number: dict[int, tuple[int, int]] = {}
        self._study_id_and_number_to_trial_id: dict[tuple[int, int], int] = {}
        self._lock = threading.Lock()

    def __getstate__(self) -> dict[Any, Any]:
        state = self.__dict__.copy()
        del state["_lock"]
        return state

    def __setstate__(self, state: dict[Any, Any]) -> None:
        self.__dict__.update(state)
        self._lock = threading.Lock()

    def create_new_study(
        self, directions: Sequence[StudyDirection], study_name: str | None = None
    ) -> int:
        study_id = self._backend.create_new_study(directions=directions, study_name=study_name)
        with self._lock:
            study = _StudyInfo()
            study.name = study_name
            study.directions = list(directions)
            self._studies[study_id] = study

        return study_id

    def delete_study(self, study_id: int) -> None:
        with self._lock:
            if study_id in self._studies:
                for trial_number in self._studies[study_id].trials:
                    trial_id = self._study_id_and_number_to_trial_id.get((study_id, trial_number))
                    if trial_id in self._trial_id_to_study_id_and_number:
                        del self._trial_id_to_study_id_and_number[trial_id]
                    if (study_id, trial_number) in self._study_id_and_number_to_trial_id:
                        del self._study_id_and_number_to_trial_id[(study_id, trial_number)]
                del self._studies[study_id]

        self._backend.delete_study(study_id)

    def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None:
        self._backend.set_study_user_attr(study_id, key, value)

    def set_study_system_attr(self, study_id: int, key: str, value: JSONSerializable) -> None:
        self._backend.set_study_system_attr(study_id, key, value)

    def get_study_id_from_name(self, study_name: str) -> int:
        return self._backend.get_study_id_from_name(study_name)

    def get_study_name_from_id(self, study_id: int) -> str:
        with self._lock:
            if study_id in self._studies:
                name = self._studies[study_id].name
                if name is not None:
                    return name

        name = self._backend.get_study_name_from_id(study_id)
        with self._lock:
            if study_id not in self._studies:
                self._studies[study_id] = _StudyInfo()
            self._studies[study_id].name = name
        return name

    def get_study_directions(self, study_id: int) -> list[StudyDirection]:
        with self._lock:
            if study_id in self._studies:
                directions = self._studies[study_id].directions
                if directions is not None:
                    return directions

        directions = self._backend.get_study_directions(study_id)
        with self._lock:
            if study_id not in self._studies:
                self._studies[study_id] = _StudyInfo()
            self._studies[study_id].directions = directions
        return directions

    def get_study_user_attrs(self, study_id: int) -> dict[str, Any]:
        return self._backend.get_study_user_attrs(study_id)

    def get_study_system_attrs(self, study_id: int) -> dict[str, Any]:
        return self._backend.get_study_system_attrs(study_id)

    def get_all_studies(self) -> list[FrozenStudy]:
        return self._backend.get_all_studies()

    def create_new_trial(self, study_id: int, template_trial: FrozenTrial | None = None) -> int:
        frozen_trial = self._backend._create_new_trial(study_id, template_trial)
        trial_id = frozen_trial._trial_id
        with self._lock:
            if study_id not in self._studies:
                self._studies[study_id] = _StudyInfo()
            self._add_trials_to_cache(study_id, [frozen_trial])
        return trial_id

    def set_trial_param(
        self,
        trial_id: int,
        param_name: str,
        param_value_internal: float,
        distribution: distributions.BaseDistribution,
    ) -> None:
        self._backend.set_trial_param(trial_id, param_name, param_value_internal, distribution)

    def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int:
        key = (study_id, trial_number)
        with self._lock:
            if key in self._study_id_and_number_to_trial_id:
                return self._study_id_and_number_to_trial_id[key]

        return self._backend.get_trial_id_from_study_id_trial_number(study_id, trial_number)

    def get_best_trial(self, study_id: int) -> FrozenTrial:
        _directions = self.get_study_directions(study_id)
        if len(_directions) > 1:
            raise RuntimeError(
                "Best trial can be obtained only for single-objective optimization."
            )
        direction = _directions[0]
        trial_id = self._backend._get_best_trial_id(study_id, direction)
        return self.get_trial(trial_id)

    def set_trial_state_values(
        self, trial_id: int, state: TrialState, values: Sequence[float] | None = None
    ) -> bool:
        return self._backend.set_trial_state_values(trial_id, state=state, values=values)

    def set_trial_intermediate_value(
        self, trial_id: int, step: int, intermediate_value: float
    ) -> None:
        self._backend.set_trial_intermediate_value(trial_id, step, intermediate_value)

    def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None:
        self._backend.set_trial_user_attr(trial_id, key=key, value=value)

    def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None:
        self._backend.set_trial_system_attr(trial_id, key=key, value=value)

    def _get_cached_trial(self, trial_id: int) -> FrozenTrial | None:
        if trial_id not in self._trial_id_to_study_id_and_number:
            return None
        study_id, number = self._trial_id_to_study_id_and_number[trial_id]
        study = self._studies[study_id]
        trial = study.trials[number]
        if not trial.state.is_finished():
            return None
        return trial

    def get_trial(self, trial_id: int) -> FrozenTrial:
        with self._lock:
            trial = self._get_cached_trial(trial_id)
            if trial is not None:
                return trial

        return self._backend.get_trial(trial_id)

    def get_all_trials(
        self,
        study_id: int,
        deepcopy: bool = True,
        states: Container[TrialState] | None = None,
    ) -> list[FrozenTrial]:
        self._read_trials_from_remote_storage(study_id)

        with self._lock:
            study = self._studies[study_id]
            # We need to sort trials by their number because some samplers assume this behavior.
            # The following two lines are latency-sensitive.

            trials: dict[int, FrozenTrial] | list[FrozenTrial]

            if states is not None:
                trials = {number: t for number, t in study.trials.items() if t.state in states}
            else:
                trials = study.trials
            trials = list(sorted(trials.values(), key=lambda t: t.number))
            return copy.deepcopy(trials) if deepcopy else trials

    def _read_trials_from_remote_storage(self, study_id: int) -> None:
        with self._lock:
            if study_id not in self._studies:
                self._studies[study_id] = _StudyInfo()
            study = self._studies[study_id]
            trials = self._backend._get_trials(
                study_id,
                states=None,
                included_trial_ids=study.unfinished_trial_ids,
                trial_id_greater_than=study.last_finished_trial_id,
            )
            if not trials:
                return

            self._add_trials_to_cache(study_id, trials)
            for trial in trials:
                if not trial.state.is_finished():
                    study.unfinished_trial_ids.add(trial._trial_id)
                    continue

                # Updates to last_finished_trial_id should only be performed here because they must
                # be executed only when all trials have been considered.
                study.last_finished_trial_id = max(study.last_finished_trial_id, trial._trial_id)
                if trial._trial_id in study.unfinished_trial_ids:
                    study.unfinished_trial_ids.remove(trial._trial_id)

    def _add_trials_to_cache(self, study_id: int, trials: list[FrozenTrial]) -> None:
        study = self._studies[study_id]
        for trial in trials:
            self._trial_id_to_study_id_and_number[trial._trial_id] = (
                study_id,
                trial.number,
            )
            self._study_id_and_number_to_trial_id[(study_id, trial.number)] = trial._trial_id
            study.trials[trial.number] = trial

    def record_heartbeat(self, trial_id: int) -> None:
        self._backend.record_heartbeat(trial_id)

    def _get_stale_trial_ids(self, study_id: int) -> list[int]:
        return self._backend._get_stale_trial_ids(study_id)

    def get_heartbeat_interval(self) -> int | None:
        return self._backend.get_heartbeat_interval()

    def get_failed_trial_callback(self) -> Callable[["optuna.Study", FrozenTrial], None] | None:
        return self._backend.get_failed_trial_callback()
