Source code for ax.modelbridge.transforms.int_to_float
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Dict, List, Optional, Set
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.parameter import Parameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.rounding import randomized_round
[docs]class IntToFloat(Transform):
"""Convert a RangeParameter of type int to type float.
Uses either randomized_rounding or default python rounding,
depending on 'rounding' flag.
Transform is done in-place.
"""
def __init__(
self,
search_space: SearchSpace,
observation_features: List[ObservationFeatures],
observation_data: List[ObservationData],
config: Optional[TConfig] = None,
) -> None:
self.rounding = "strict"
if config is not None:
self.rounding = config.get("rounding", "strict")
# Identify parameters that should be transformed
self.transform_parameters: Set[str] = {
p_name
for p_name, p in search_space.parameters.items()
if isinstance(p, RangeParameter) and p.parameter_type == ParameterType.INT
}
[docs] def transform_observation_features(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationFeatures]:
for obsf in observation_features:
for p_name in self.transform_parameters:
if p_name in obsf.parameters:
# pyre: param is declared to have type `int` but is used
# pyre-fixme[9]: as type `Optional[typing.Union[bool, float, str]]`.
param: int = obsf.parameters[p_name]
obsf.parameters[p_name] = float(param)
return observation_features
[docs] def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
transformed_parameters: Dict[str, Parameter] = {}
for p in search_space.parameters.values():
# Refine type, since we've only added RangeParameters above.
if p.name in self.transform_parameters:
# pyre: p_cast is declared to have type `RangeParameter` but
# pyre-fixme[9]: is used as type `Parameter`.
p_cast: RangeParameter = p
transformed_parameters[p.name] = RangeParameter(
name=p_cast.name,
parameter_type=ParameterType.FLOAT,
lower=p_cast.lower,
upper=p_cast.upper,
log_scale=p_cast.log_scale,
digits=p_cast.digits,
)
else:
transformed_parameters[p.name] = p
return SearchSpace(
parameters=list(transformed_parameters.values()),
parameter_constraints=[
pc.clone_with_transformed_parameters(
transformed_parameters=transformed_parameters
)
for pc in search_space.parameter_constraints
],
)
[docs] def untransform_observation_features(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationFeatures]:
for obsf in observation_features:
for p_name in self.transform_parameters:
# pyre: param is declared to have type `float` but is used as
# pyre-fixme[9]: type `Optional[typing.Union[bool, float, str]]`.
param: float = obsf.parameters.get(p_name)
if self.rounding == "strict":
obsf.parameters[p_name] = int(round(param)) # TODO: T41938776
else:
obsf.parameters[p_name] = randomized_round(param)
return observation_features