Source code for ax.models.discrete.full_factorial

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import itertools
import logging
from functools import reduce
from operator import mul

import numpy.typing as npt
from ax.core.types import TGenMetadata, TParamValue, TParamValueList
from ax.models.discrete_base import DiscreteModel
from ax.models.types import TConfig
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger


logger: logging.Logger = get_logger(__name__)


[docs] class FullFactorialGenerator(DiscreteModel): """Generator for full factorial designs. Generates arms for all possible combinations of parameter values, each with weight 1. The value of n supplied to `gen` will be ignored, as the number of arms generated is determined by the list of parameter values. To suppress this warning, use n = -1. """ def __init__( self, max_cardinality: int = 100, check_cardinality: bool = True ) -> None: """ Args: max_cardinality: maximum number of arms allowed if check_cardinality == True. Default is 100. check_cardinality: if True, throw if number of arms exceeds max_cardinality. """ super().__init__() self.max_cardinality = max_cardinality self.check_cardinality = check_cardinality
[docs] @copy_doc(DiscreteModel.gen) def gen( self, n: int, parameter_values: list[TParamValueList], objective_weights: npt.NDArray | None, outcome_constraints: tuple[npt.NDArray, npt.NDArray] | None = None, fixed_features: dict[int, TParamValue] | None = None, pending_observations: list[list[TParamValueList]] | None = None, model_gen_options: TConfig | None = None, ) -> tuple[list[TParamValueList], list[float], TGenMetadata]: if n != -1: logger.warning( "FullFactorialGenerator will ignore the specified value of n. " "The generator automatically determines how many arms to " "generate." ) if fixed_features: for fixed_feature_index, fixed_feature_value in fixed_features.items(): parameter_values[fixed_feature_index] = [fixed_feature_value] num_arms = reduce(mul, [len(values) for values in parameter_values], 1) if self.check_cardinality and num_arms > self.max_cardinality: raise ValueError( f"FullFactorialGenerator generated {num_arms} arms, " f"but the maximum number of arms allowed is " f"{self.max_cardinality}." ) points = [list(x) for x in itertools.product(*parameter_values)] return (points, [1.0 for _ in range(len(points))], {})