Source code for ax.models.discrete.full_factorial

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import itertools
import logging
from functools import reduce
from operator import mul
from typing import Optional

import numpy as np
from ax.core.types import TGenMetadata, TParamValue, TParamValueList
from ax.models.discrete_base import DiscreteModel
from ax.models.types import TConfig
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger


logger: logging.Logger = get_logger(__name__)


[docs]class FullFactorialGenerator(DiscreteModel): """Generator for full factorial designs. Generates arms for all possible combinations of parameter values, each with weight 1. The value of n supplied to `gen` will be ignored, as the number of arms generated is determined by the list of parameter values. To suppress this warning, use n = -1. """ def __init__( self, max_cardinality: int = 100, check_cardinality: bool = True ) -> None: """ Args: max_cardinality: maximum number of arms allowed if check_cardinality == True. Default is 100. check_cardinality: if True, throw if number of arms exceeds max_cardinality. """ super().__init__() self.max_cardinality = max_cardinality self.check_cardinality = check_cardinality
[docs] @copy_doc(DiscreteModel.gen) def gen( self, n: int, parameter_values: list[TParamValueList], objective_weights: Optional[np.ndarray], outcome_constraints: Optional[tuple[np.ndarray, np.ndarray]] = None, fixed_features: Optional[dict[int, TParamValue]] = None, pending_observations: Optional[list[list[TParamValueList]]] = None, model_gen_options: Optional[TConfig] = None, ) -> tuple[list[TParamValueList], list[float], TGenMetadata]: if n != -1: raise ValueError( "FullFactorialGenerator will ignore the specified value of n. " "The generator automatically determines how many arms to " "generate." ) if fixed_features: for fixed_feature_index, fixed_feature_value in fixed_features.items(): parameter_values[fixed_feature_index] = [fixed_feature_value] num_arms = reduce(mul, [len(values) for values in parameter_values], 1) if self.check_cardinality and num_arms > self.max_cardinality: raise ValueError( f"FullFactorialGenerator generated {num_arms} arms, " f"but the maximum number of arms allowed is " f"{self.max_cardinality}." ) points = [list(x) for x in itertools.product(*parameter_values)] return (points, [1.0 for _ in range(len(points))], {})