Source code for ax.modelbridge.transforms.map_key_to_float

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Optional, TYPE_CHECKING

from ax.core.map_metric import MapMetric
from ax.core.observation import Observation, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.metadata_to_float import MetadataToFloat
from ax.models.types import TConfig
from pyre_extensions import assert_is_instance

if TYPE_CHECKING:
    # import as module to make sphinx-autodoc-typehints happy
    from ax import modelbridge as modelbridge_module  # noqa F401


[docs] class MapKeyToFloat(MetadataToFloat): """ This transform extracts the entry from the metadata field of the observation features corresponding to the default map key (`MapMetric.map_key_info.key`) and inserts it into the parameter field. Inheriting from the `MetadataToFloat` transform, this transform also adds a range (float) parameter to the search space. Similarly, users can override the default behavior by specifying the `config` with `parameters` as the key, where each entry maps a metadata key to a dictionary of keyword arguments for the corresponding RangeParameter constructor. Transform is done in-place. """ DEFAULT_LOG_SCALE: bool = True DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, config: TConfig | None = None, ) -> None: config = config or {} self.parameters: dict[str, dict[str, Any]] = assert_is_instance( config.setdefault("parameters", {}), dict ) self.parameters.setdefault(self.DEFAULT_MAP_KEY, {}) super().__init__( search_space=search_space, observations=observations, modelbridge=modelbridge, config=config, ) def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: if not obsf.parameters: for p in self._parameter_list: obsf.parameters[p.name] = p.upper return super()._transform_observation_feature(obsf)