import numpy as np

from .base_restricted_module import BaseRestrictedModule


class RestrictedNumpy(BaseRestrictedModule):
    def __init__(self):
        self.allowed_attributes = [
            # Array creation
            "array",
            "zeros",
            "ones",
            "empty",
            "full",
            "zeros_like",
            "ones_like",
            "empty_like",
            "full_like",
            "eye",
            "identity",
            "diag",
            "arange",
            "linspace",
            "logspace",
            "geomspace",
            "fromfunction",
            "fromiter",
            # Array manipulation
            "reshape",
            "ravel",
            "flatten",
            "moveaxis",
            "rollaxis",
            "swapaxes",
            "transpose",
            "split",
            "hsplit",
            "vsplit",
            "dsplit",
            "stack",
            "column_stack",
            "dstack",
            "row_stack",
            "concatenate",
            "vstack",
            "hstack",
            "tile",
            "repeat",
            # Mathematical operations
            "add",
            "subtract",
            "multiply",
            "divide",
            "power",
            "mod",
            "remainder",
            "divmod",
            "negative",
            "positive",
            "absolute",
            "fabs",
            "rint",
            "floor",
            "ceil",
            "trunc",
            "exp",
            "expm1",
            "exp2",
            "log",
            "log10",
            "log2",
            "log1p",
            "sqrt",
            "square",
            "cbrt",
            "reciprocal",
            # Trigonometric functions
            "sin",
            "cos",
            "tan",
            "arcsin",
            "arccos",
            "arctan",
            "arctan2",
            "hypot",
            "sinh",
            "cosh",
            "tanh",
            "arcsinh",
            "arccosh",
            "arctanh",
            "deg2rad",
            "rad2deg",
            # Statistical functions
            "mean",
            "average",
            "median",
            "std",
            "var",
            "min",
            "max",
            "argmin",
            "argmax",
            "sum",
            "prod",
            "percentile",
            "quantile",
            "histogram",
            "histogram2d",
            "histogramdd",
            "bincount",
            "digitize",
            # Linear algebra
            "dot",
            "vdot",
            "inner",
            "outer",
            "matmul",
            "tensordot",
            "einsum",
            "trace",
            "diagonal",
            # Sorting and searching
            "sort",
            "argsort",
            "partition",
            "argpartition",
            "searchsorted",
            "nonzero",
            "where",
            "extract",
            # Logic functions
            "all",
            "any",
            "greater",
            "greater_equal",
            "less",
            "less_equal",
            "equal",
            "not_equal",
            "logical_and",
            "logical_or",
            "logical_not",
            "logical_xor",
            "isfinite",
            "isinf",
            "isnan",
            "isneginf",
            "isposinf",
            # Set operations
            "unique",
            "intersect1d",
            "union1d",
            "setdiff1d",
            "setxor1d",
            # Basic array information
            "shape",
            "size",
            "ndim",
            "dtype",
            # Utility functions
            "clip",
            "round",
            "sign",
            "conj",
            "real",
            "imag",
            "copy",
            "asarray",
            "asanyarray",
            "ascontiguousarray",
            "asfortranarray",
        ]

        for attr in self.allowed_attributes:
            if hasattr(np, attr):
                setattr(self, attr, self._wrap_function(getattr(np, attr)))

    def __getattr__(self, name):
        if name not in self.allowed_attributes:
            raise AttributeError(f"'{name}' is not allowed in RestrictedNumPy")
        return getattr(np, name)
