Source code for ax.modelbridge.transforms.cast
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, TYPE_CHECKING
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401 # pragma: no cover
[docs]class Cast(Transform):
"""Cast each param value to the respective parameter's type/format and
to a flattened version of the hierarchical search space, if applicable.
This is a default transform that should run across all models.
NOTE: In case where searh space is hierarchical and this transform is
configured to flatten it:
* All calls to `Cast.transform_...` transform Ax objects defined in
terms of hierarchical search space, to their definitions in terms of
flattened search space.
* All calls to `Cast.untransform_...` cast Ax objects back to a
hierarchical search space.
* The hierarchical search space is seen as the "original" search space,
and the flattened search space –– as "transformed".
Transform is done in-place for casting types, but objects are copied
during flattening of- and casting to the hierarchical search space.
"""
def __init__(
self,
search_space: SearchSpace,
observation_features: Optional[List[ObservationFeatures]] = None,
observation_data: Optional[List[ObservationData]] = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: Optional[TConfig] = None,
) -> None:
self.search_space = search_space.clone()
self.flatten_hss: bool = (
config is None or checked_cast(bool, config.get("flatten_hss", True))
) and isinstance(search_space, HierarchicalSearchSpace)
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
"""Flattens the hierarchical search space and returns the flat
``SearchSpace`` if this transform is configured to flatten hierarchical
search spaces. Does nothing if the search space is not hierarchical.
NOTE: All calls to `Cast.transform_...` transform Ax objects defined in
terms of hierarchical search space, to their definitions in terms of
flattened search space. All calls to `Cast.untransform_...` cast Ax
objects back to a hierarchical search space.
Args:
search_space: The search space to flatten.
Returns: transformed search space.
"""
if not self.flatten_hss:
return search_space
return checked_cast(HierarchicalSearchSpace, search_space).flatten()