from __future__ import annotations

from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
import copy
import datetime
import gc
import itertools
import os
import sys
from typing import TYPE_CHECKING
import warnings

import optuna
from optuna import exceptions
from optuna import logging
from optuna import progress_bar as pbar_module
from optuna.exceptions import ExperimentalWarning
from optuna.storages._heartbeat import get_heartbeat_thread
from optuna.storages._heartbeat import is_heartbeat_enabled
from optuna.study._tell import _tell_with_warning
from optuna.trial import TrialState


if TYPE_CHECKING:
    from collections.abc import Callable
    from collections.abc import Iterable
    from collections.abc import Sequence
    from concurrent.futures import Future
    from typing import Any

    from optuna.trial import FrozenTrial

_logger = logging.get_logger(__name__)


def _optimize(
    study: "optuna.Study",
    func: "optuna.study.study.ObjectiveFuncType",
    n_trials: int | None = None,
    timeout: float | None = None,
    n_jobs: int = 1,
    catch: tuple[type[Exception], ...] = (),
    callbacks: Iterable[Callable[["optuna.Study", FrozenTrial], None]] | None = None,
    gc_after_trial: bool = False,
    show_progress_bar: bool = False,
) -> None:
    if not isinstance(catch, tuple):
        raise TypeError(
            "The catch argument is of type '{}' but must be a tuple.".format(type(catch).__name__)
        )

    if study._thread_local.in_optimize_loop:
        raise RuntimeError("Nested invocation of `Study.optimize` method isn't allowed.")

    if show_progress_bar and n_trials is None and timeout is not None and n_jobs != 1:
        warnings.warn("The timeout-based progress bar is not supported with n_jobs != 1.")
        show_progress_bar = False

    progress_bar = pbar_module._ProgressBar(show_progress_bar, n_trials, timeout)

    study._stop_flag = False

    try:
        if n_jobs == 1:
            _optimize_sequential(
                study,
                func,
                n_trials,
                timeout,
                catch,
                callbacks,
                gc_after_trial,
                reseed_sampler_rng=False,
                time_start=None,
                progress_bar=progress_bar,
            )
        else:
            if n_jobs == -1:
                n_jobs = os.cpu_count() or 1

            time_start = datetime.datetime.now()
            futures: set[Future] = set()

            with ThreadPoolExecutor(max_workers=n_jobs) as executor:
                for n_submitted_trials in itertools.count():
                    if study._stop_flag:
                        break

                    if (
                        timeout is not None
                        and (datetime.datetime.now() - time_start).total_seconds() > timeout
                    ):
                        break

                    if n_trials is not None and n_submitted_trials >= n_trials:
                        break

                    if len(futures) >= n_jobs:
                        completed, futures = wait(futures, return_when=FIRST_COMPLETED)
                        # Raise if exception occurred in executing the completed futures.
                        for f in completed:
                            f.result()

                    futures.add(
                        executor.submit(
                            _optimize_sequential,
                            study,
                            func,
                            1,
                            timeout,
                            catch,
                            callbacks,
                            gc_after_trial,
                            True,
                            time_start,
                            progress_bar,
                        )
                    )
    finally:
        study._thread_local.in_optimize_loop = False
        progress_bar.close()


def _optimize_sequential(
    study: "optuna.Study",
    func: "optuna.study.study.ObjectiveFuncType",
    n_trials: int | None,
    timeout: float | None,
    catch: tuple[type[Exception], ...],
    callbacks: Iterable[Callable[["optuna.Study", FrozenTrial], None]] | None,
    gc_after_trial: bool,
    reseed_sampler_rng: bool,
    time_start: datetime.datetime | None,
    progress_bar: pbar_module._ProgressBar | None,
) -> None:
    # Here we set `in_optimize_loop = True`, not at the beginning of the `_optimize()` function.
    # Because it is a thread-local object and `n_jobs` option spawns new threads.
    study._thread_local.in_optimize_loop = True
    if reseed_sampler_rng:
        study.sampler.reseed_rng()

    i_trial = 0

    if time_start is None:
        time_start = datetime.datetime.now()

    while True:
        if study._stop_flag:
            break

        if n_trials is not None:
            if i_trial >= n_trials:
                break
            i_trial += 1

        if timeout is not None:
            elapsed_seconds = (datetime.datetime.now() - time_start).total_seconds()
            if elapsed_seconds >= timeout:
                break

        try:
            frozen_trial_id = _run_trial(study, func, catch)
        finally:
            # The following line mitigates memory problems that can be occurred in some
            # environments (e.g., services that use computing containers such as GitHub Actions).
            # Please refer to the following PR for further details:
            # https://github.com/optuna/optuna/pull/325.
            if gc_after_trial:
                gc.collect()

        if callbacks is not None:
            frozen_trial = study._storage.get_trial(frozen_trial_id)
            for callback in callbacks:
                callback(study, copy.deepcopy(frozen_trial))

        if progress_bar is not None:
            elapsed_seconds = (datetime.datetime.now() - time_start).total_seconds()
            progress_bar.update(elapsed_seconds, study)

    study._storage.remove_session()


def _run_trial(
    study: "optuna.Study",
    func: "optuna.study.study.ObjectiveFuncType",
    catch: tuple[type[Exception], ...],
) -> int:
    if is_heartbeat_enabled(study._storage):
        with warnings.catch_warnings():
            # Ignore ExperimentalWarning when using fail_stale_trials internally.
            warnings.simplefilter("ignore", ExperimentalWarning)
            optuna.storages.fail_stale_trials(study)

    trial = study.ask()

    state: TrialState | None = None
    value_or_values: float | Sequence[float] | None = None
    func_err: Exception | KeyboardInterrupt | None = None
    func_err_fail_exc_info: Any | None = None

    with get_heartbeat_thread(trial._trial_id, study._storage):
        try:
            value_or_values = func(trial)
        except exceptions.TrialPruned as e:
            # TODO(mamu): Handle multi-objective cases.
            state = TrialState.PRUNED
            func_err = e
        except (Exception, KeyboardInterrupt) as e:
            state = TrialState.FAIL
            func_err = e
            func_err_fail_exc_info = sys.exc_info()

    # `_tell_with_warning` may raise during trial post-processing.
    try:
        updated_state, values, warning_message = _tell_with_warning(
            study=study,
            trial=trial,
            value_or_values=value_or_values,
            state=state,
            suppress_warning=True,
        )
    except Exception:
        frozen_trial = study._storage.get_trial(trial._trial_id)
        updated_state = frozen_trial.state
        values = frozen_trial.values
        warning_message = None
        raise
    finally:
        if updated_state == TrialState.COMPLETE:
            assert values is not None
            study._log_completed_trial(values, trial.number, trial.params)
        elif updated_state == TrialState.PRUNED:
            _logger.info("Trial {} pruned. {}".format(trial.number, str(func_err)))
        elif updated_state == TrialState.FAIL:
            if func_err is not None:
                _log_failed_trial(
                    trial.number,
                    trial.params,
                    repr(func_err),
                    exc_info=func_err_fail_exc_info,
                    value_or_values=value_or_values,
                )
            elif warning_message is not None:
                _log_failed_trial(
                    trial.number,
                    trial.params,
                    warning_message,
                    value_or_values=value_or_values,
                )
            else:
                assert False, "Should not reach."
        else:
            assert False, "Should not reach."

    if (
        updated_state == TrialState.FAIL
        and func_err is not None
        and not isinstance(func_err, catch)
    ):
        raise func_err
    return trial._trial_id


def _log_failed_trial(
    trial_number: int,
    trial_params: dict[str, Any],
    message: str | Warning,
    exc_info: Any = None,
    value_or_values: Any = None,
) -> None:
    _logger.warning(
        "Trial {} failed with parameters: {} because of the following error: {}.".format(
            trial_number, trial_params, message
        ),
        exc_info=exc_info,
    )

    _logger.warning("Trial {} failed with value {}.".format(trial_number, repr(value_or_values)))
