Source code for ax.modelbridge.transforms.int_to_float
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from logging import Logger
from typing import Optional, TYPE_CHECKING
from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import Parameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.rounding import (
contains_constrained_integer,
randomized_round_parameters,
)
from ax.modelbridge.transforms.utils import construct_new_search_space
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401
logger: Logger = get_logger(__name__)
DEFAULT_MAX_ROUND_ATTEMPTS = 10_000
[docs]
class IntToFloat(Transform):
"""Convert a RangeParameter of type int to type float.
Uses either randomized_rounding or default python rounding,
depending on 'rounding' flag.
The `min_choices` config can be used to transform only the parameters
with cardinality greater than or equal to `min_choices`; with the exception
of `log_scale` parameters, which are always transformed.
Transform is done in-place.
"""
def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
self.search_space: SearchSpace = none_throws(
search_space, "IntToFloat requires search space"
)
config = config or {}
self.rounding: str = checked_cast(str, config.get("rounding", "strict"))
self.max_round_attempts: int = checked_cast(
int, config.get("max_round_attempts", DEFAULT_MAX_ROUND_ATTEMPTS)
)
self.min_choices: int = checked_cast(int, config.get("min_choices", 0))
# Identify parameters that should be transformed
self.transform_parameters: set[str] = {
p_name
for p_name, p in self.search_space.parameters.items()
if isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and ((p.cardinality() >= self.min_choices) or p.log_scale)
}
if contains_constrained_integer(self.search_space, self.transform_parameters):
self.rounding = "randomized"
self.contains_constrained_integer: bool = True
else:
self.contains_constrained_integer: bool = False
[docs]
def transform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
for p_name in self.transform_parameters:
if p_name in obsf.parameters:
# pyre: param is declared to have type `int` but is used
# pyre-fixme[9]: as type `Optional[typing.Union[bool, float, str]]`.
param: int = obsf.parameters[p_name]
obsf.parameters[p_name] = float(param)
return observation_features
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
transformed_parameters: dict[str, Parameter] = {}
for p_name, p in search_space.parameters.items():
if p_name in self.transform_parameters and isinstance(p, RangeParameter):
transformed_parameters[p_name] = RangeParameter(
name=p_name,
parameter_type=ParameterType.FLOAT,
# +/- 0.5 ensures that sampling
# 1) floating point numbers from (quasi-)Uniform(0,1)
# 2) unnormalizing to the raw search space
# 3) rounding
# results in uniform (quasi-)random integers
lower=p.lower - 0.49999,
upper=p.upper + 0.49999,
log_scale=p.log_scale,
digits=p.digits,
is_fidelity=p.is_fidelity,
target_value=p.target_value, # casting happens in constructor
)
else:
transformed_parameters[p.name] = p
return construct_new_search_space(
search_space=search_space,
parameters=list(transformed_parameters.values()),
parameter_constraints=[
pc.clone_with_transformed_parameters(
transformed_parameters=transformed_parameters
)
for pc in search_space.parameter_constraints
],
)
[docs]
def untransform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
present_params = self.transform_parameters.intersection(
obsf.parameters.keys()
)
if self.rounding == "strict":
for p_name in present_params:
# pyre: param is declared to have type `float` but is used as
# pyre-fixme[9]: type `Optional[typing.Union[bool, float, str]]`.
param: float = obsf.parameters.get(p_name)
obsf.parameters[p_name] = int(round(param)) # TODO: T41938776
else:
if self.contains_constrained_integer:
if len(present_params) == 0:
continue
elif len(present_params) != len(self.transform_parameters):
# no parameters being present is allowed to handle fixed
# features, but all parameters must be present if there
# are parameter constraints involving integers.
raise ValueError(
"Either all parameters must be provided or no parameters"
" should be provided, when there are parameter"
" constraints involving integers."
)
round_attempts = 0
rounded_parameters = randomized_round_parameters(
obsf.parameters, self.transform_parameters
)
# Try to round up to max_round_attempt times)
while (
not self.search_space.check_membership(
rounded_parameters, check_all_parameters_present=False
)
and round_attempts < self.max_round_attempts
):
rounded_parameters = randomized_round_parameters(
obsf.parameters, present_params
)
round_attempts += 1
if not self.search_space.check_membership(
rounded_parameters, check_all_parameters_present=False
):
logger.warning(
f"Unable to round {obsf.parameters}"
f"to meet parameter constraints of {self.search_space}"
)
# This means we failed to randomly round the observation to
# something that satisfies the search space bounds and parameter
# constraints. We use strict rounding in order to get a candidate
# that satisfies the search space bounds, but this candidate may
# not satisfy the parameter constraints.
for p_name in present_params:
param = obsf.parameters.get(p_name)
obsf.parameters[p_name] = int(round(param)) # pyre-ignore
else: # Update observation if rounding was successful
for p_name in present_params:
obsf.parameters[p_name] = rounded_parameters[p_name]
return observation_features