from __future__ import annotations

import copy
from typing import TYPE_CHECKING

from optuna.distributions import BaseDistribution
from optuna.trial import TrialState


if TYPE_CHECKING:
    from optuna.study import Study


class _SearchSpaceGroup:
    def __init__(self) -> None:
        self._search_spaces: list[dict[str, BaseDistribution]] = []

    @property
    def search_spaces(self) -> list[dict[str, BaseDistribution]]:
        return self._search_spaces

    def add_distributions(self, distributions: dict[str, BaseDistribution]) -> None:
        dist_keys = set(distributions.keys())
        next_search_spaces = []

        for search_space in self._search_spaces:
            keys = set(search_space.keys())

            next_search_spaces.append({name: search_space[name] for name in keys & dist_keys})
            next_search_spaces.append({name: search_space[name] for name in keys - dist_keys})

            dist_keys -= keys

        next_search_spaces.append({name: distributions[name] for name in dist_keys})
        self._search_spaces = list(
            filter(lambda search_space: len(search_space) > 0, next_search_spaces)
        )


class _GroupDecomposedSearchSpace:
    def __init__(self, include_pruned: bool = False) -> None:
        self._search_space = _SearchSpaceGroup()
        self._study_id: int | None = None
        self._include_pruned = include_pruned

    def calculate(self, study: Study, use_cache: bool = False) -> _SearchSpaceGroup:
        if self._study_id is None:
            self._study_id = study._study_id
        else:
            # Note that the check below is meaningless when
            # :class:`~optuna.storages.InMemoryStorage` is used because
            # :func:`~optuna.storages.InMemoryStorage.create_new_study`
            # always returns the same study ID.
            if self._study_id != study._study_id:
                raise ValueError("`_GroupDecomposedSearchSpace` cannot handle multiple studies.")

        states_of_interest: tuple[TrialState, ...]
        if self._include_pruned:
            states_of_interest = (TrialState.COMPLETE, TrialState.PRUNED)
        else:
            states_of_interest = (TrialState.COMPLETE,)

        for trial in study._get_trials(
            deepcopy=False, states=states_of_interest, use_cache=use_cache
        ):
            self._search_space.add_distributions(trial.distributions)

        return copy.deepcopy(self._search_space)
