Source code for ax.modelbridge.transforms.map_unit_x
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from ax.core.observation import Observation, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.unit_x import UnitX
from ax.models.types import TConfig
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401
[docs]class MapUnitX(UnitX):
"""A `UnitX` transform for map parameters in observation_features, identified
as those that are not part of the search space. Since they are not part of the
search space, the bounds are inferred from the set of observation features. Only
observation features are transformed; all other objects undergo identity transform.
"""
target_lb: float = 0.0
target_range: float = 1.0
def __init__(
self,
search_space: Optional[SearchSpace] = None,
observations: Optional[List[Observation]] = None,
modelbridge: Optional[modelbridge_module.base.ModelBridge] = None,
config: Optional[TConfig] = None,
) -> None:
assert observations is not None, "MapUnitX requires observations"
assert search_space is not None, "MapUnitX requires search space"
# Loop through observation features and identify parameters that
# are not part of the search space. Store all observed values to
# infer bounds
map_values = defaultdict(list)
for obs in observations:
for p in obs.features.parameters:
if p not in search_space.parameters:
map_values[p].append(obs.features.parameters[p])
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
# `typing.List` to avoid runtime subscripting errors.
def get_range(values: List) -> Tuple[float, float]:
return (min(values), max(values))
self.bounds: Dict[str, Tuple[float, float]] = {
p: get_range(v) for p, v in map_values.items()
}
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
return search_space
def _transform_parameter_distributions(self, search_space: SearchSpace) -> None:
return super(UnitX, self)._transform_parameter_distributions(
search_space=search_space
)