Source code for ax.modelbridge.transforms.metadata_to_float

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from logging import Logger
from typing import Any, Iterable, Optional, SupportsFloat, TYPE_CHECKING

from ax.core import ParameterType

from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from pyre_extensions import assert_is_instance, none_throws

if TYPE_CHECKING:
    # import as module to make sphinx-autodoc-typehints happy
    from ax import modelbridge as modelbridge_module  # noqa F401


logger: Logger = get_logger(__name__)


[docs] class MetadataToFloat(Transform): """ This transform converts metadata from observation features into range (float) parameters for a search space. It allows the user to specify the `config` with `parameters` as the key, where each entry maps a metadata key to a dictionary of keyword arguments for the corresponding RangeParameter constructor. Transform is done in-place. """ DEFAULT_LOG_SCALE: bool = False DEFAULT_LOGIT_SCALE: bool = False DEFAULT_IS_FIDELITY: bool = False ENFORCE_BOUNDS: bool = False def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, config: TConfig | None = None, ) -> None: if observations is None or not observations: raise DataRequiredError( "`MetadataToRange` transform requires non-empty data." ) config = config or {} self.parameters: dict[str, dict[str, Any]] = assert_is_instance( config.get("parameters", {}), dict ) self._parameter_list: list[RangeParameter] = [] for name in self.parameters: values: list[float] = [] for obs in observations: obsf_metadata = none_throws(obs.features.metadata) value = float(assert_is_instance(obsf_metadata[name], SupportsFloat)) values.append(value) lower: float = self.parameters[name].get("lower", min(values)) upper: float = self.parameters[name].get("upper", max(values)) log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE) logit_scale = self.parameters[name].get( "logit_scale", self.DEFAULT_LOGIT_SCALE ) digits = self.parameters[name].get("digits") is_fidelity = self.parameters[name].get( "is_fidelity", self.DEFAULT_IS_FIDELITY ) target_value = self.parameters[name].get("target_value") parameter = RangeParameter( name=name, parameter_type=ParameterType.FLOAT, lower=lower, upper=upper, log_scale=log_scale, logit_scale=logit_scale, digits=digits, is_fidelity=is_fidelity, target_value=target_value, ) self._parameter_list.append(parameter) def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: for parameter in self._parameter_list: search_space.add_parameter(parameter.clone()) return search_space
[docs] def transform_observation_features( self, observation_features: list[ObservationFeatures] ) -> list[ObservationFeatures]: for obsf in observation_features: self._transform_observation_feature(obsf) return observation_features
[docs] def untransform_observation_features( self, observation_features: list[ObservationFeatures] ) -> list[ObservationFeatures]: for obsf in observation_features: obsf.metadata = obsf.metadata or {} _transfer( src=obsf.parameters, dst=obsf.metadata, keys=self.parameters.keys(), ) return observation_features
def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: _transfer( src=none_throws(obsf.metadata), dst=obsf.parameters, keys=self.parameters.keys(), )
def _transfer( src: dict[str, Any], dst: dict[str, Any], keys: Iterable[str], ) -> None: """Transfer items in-place from one dictionary to another.""" for key in keys: dst[key] = src.pop(key)