Source code for ax.modelbridge.transforms.cap_parameter
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, List, Optional
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.parameter import RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401 # pragma: no cover
[docs]class CapParameter(Transform):
"""Cap parameter range(s) to given values. Expects a configuration of form
{ parameter_name -> new_upper_range_value }.
This transform only transforms the search space.
"""
def __init__(
self,
search_space: SearchSpace,
observation_features: List[ObservationFeatures],
observation_data: List[ObservationData],
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: Optional[TConfig] = None,
) -> None:
self.config = config or {}
self.transform_parameters = { # Only transform parameters in config.
p_name for p_name in search_space.parameters if p_name in self.config
}