Source code for ax.modelbridge.transforms.cast
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Optional, TYPE_CHECKING
from ax.core.observation import Observation, ObservationFeatures
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401
[docs]
class Cast(Transform):
"""Cast each param value to the respective parameter's type/format and
to a flattened version of the hierarchical search space, if applicable.
This is a default transform that should run across all models.
NOTE: In case where searh space is hierarchical and this transform is
configured to flatten it:
* All calls to `Cast.transform_...` transform Ax objects defined in
terms of hierarchical search space, to their definitions in terms of
flattened search space.
* All calls to `Cast.untransform_...` cast Ax objects back to a
hierarchical search space.
* The hierarchical search space is seen as the "original" search space,
and the flattened search space –– as "transformed".
Transform is done in-place for casting types, but objects are copied
during flattening of- and casting to the hierarchical search space.
"""
def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
self.search_space: SearchSpace = none_throws(search_space).clone()
config = (config or {}).copy()
self.flatten_hss: bool = checked_cast(
bool,
config.pop(
"flatten_hss", isinstance(search_space, HierarchicalSearchSpace)
),
)
self.inject_dummy_values_to_complete_flat_parameterization: bool = checked_cast(
bool,
config.pop("inject_dummy_values_to_complete_flat_parameterization", True),
)
self.use_random_dummy_values: bool = checked_cast(
bool, config.pop("use_random_dummy_values", False)
)
if config:
raise UserInputError(
f"Unexpected config parameters for `Cast` transform: {config}."
)
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
"""Flattens the hierarchical search space and returns the flat
``SearchSpace`` if this transform is configured to flatten hierarchical
search spaces. Does nothing if the search space is not hierarchical.
NOTE: All calls to `Cast.transform_...` transform Ax objects defined in
terms of hierarchical search space, to their definitions in terms of
flattened search space. All calls to `Cast.untransform_...` cast Ax
objects back to a hierarchical search space.
Args:
search_space: The search space to flatten.
Returns: transformed search space.
"""
if not self.flatten_hss:
return search_space
return checked_cast(HierarchicalSearchSpace, search_space).flatten()