Source code for ax.modelbridge.transforms.remove_fixed
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Optional, TYPE_CHECKING
from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ChoiceParameter, FixedParameter, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.utils import construct_new_search_space
from ax.models.types import TConfig
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401
[docs]
class RemoveFixed(Transform):
"""Remove fixed parameters.
Fixed parameters should not be included in the SearchSpace.
This transform removes these parameters, leaving only tunable parameters.
Transform is done in-place for observation features.
"""
def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
assert search_space is not None, "RemoveFixed requires search space"
# Identify parameters that should be transformed
self.fixed_parameters: dict[str, FixedParameter] = {
p_name: p
for p_name, p in search_space.parameters.items()
if isinstance(p, FixedParameter)
}
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
tunable_parameters: list[ChoiceParameter | RangeParameter] = []
for p in search_space.parameters.values():
if p.name not in self.fixed_parameters:
# If it's not in fixed_parameters, it must be a tunable param.
# pyre: p_ is declared to have type `Union[ChoiceParameter,
# pyre: RangeParameter]` but is used as type `ax.core.
# pyre-fixme[9]: parameter.Parameter`.
p_: ChoiceParameter | RangeParameter = p
tunable_parameters.append(p_)
return construct_new_search_space(
search_space=search_space,
# pyre-ignore Incompatible parameter type [6]
parameters=tunable_parameters,
parameter_constraints=[
pc.clone() for pc in search_space.parameter_constraints
],
)