#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Set, TYPE_CHECKING
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.parameter import Parameter, ParameterType, RangeParameter
from ax.core.search_space import RobustSearchSpace, SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.rounding import (
contains_constrained_integer,
randomized_round_parameters,
)
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401 # pragma: no cover
logger = get_logger(__name__)
DEFAULT_MAX_ROUND_ATTEMPTS = 10000
[docs]class IntToFloat(Transform):
"""Convert a RangeParameter of type int to type float.
Uses either randomized_rounding or default python rounding,
depending on 'rounding' flag.
Transform is done in-place.
"""
def __init__(
self,
search_space: SearchSpace,
observation_features: List[ObservationFeatures],
observation_data: List[ObservationData],
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: Optional[TConfig] = None,
) -> None:
self.search_space = search_space
self.rounding = "strict"
if config is not None:
self.rounding = config.get("rounding", "strict")
self.max_round_attempts = config.get(
"max_round_attempts", DEFAULT_MAX_ROUND_ATTEMPTS
)
else:
self.max_round_attempts = DEFAULT_MAX_ROUND_ATTEMPTS
# Identify parameters that should be transformed
self.transform_parameters: Set[str] = {
p_name
for p_name, p in search_space.parameters.items()
if isinstance(p, RangeParameter) and p.parameter_type == ParameterType.INT
}
if contains_constrained_integer(self.search_space, self.transform_parameters):
self.rounding = "randomized"
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
transformed_parameters: Dict[str, Parameter] = {}
for p_name, p in search_space.parameters.items():
if p_name in self.transform_parameters and isinstance(p, RangeParameter):
transformed_parameters[p_name] = RangeParameter(
name=p_name,
parameter_type=ParameterType.FLOAT,
# +/- 0.5 ensures that sampling
# 1) floating point numbers from (quasi-)Uniform(0,1)
# 2) unnormalizing to the raw search space
# 3) rounding
# results in uniform (quasi-)random integers
lower=p.lower - 0.49999,
upper=p.upper + 0.49999,
log_scale=p.log_scale,
digits=p.digits,
is_fidelity=p.is_fidelity,
target_value=p.target_value, # casting happens in constructor
)
else:
transformed_parameters[p.name] = p
new_kwargs = {
"parameters": list(transformed_parameters.values()),
"parameter_constraints": [
pc.clone_with_transformed_parameters(
transformed_parameters=transformed_parameters
)
for pc in search_space.parameter_constraints
],
}
if isinstance(search_space, RobustSearchSpace):
new_kwargs["environmental_variables"] = list(
search_space._environmental_variables.values()
)
# pyre-ignore Incompatible parameter type [6]
new_kwargs["parameter_distributions"] = search_space.parameter_distributions
# pyre-ignore Incompatible parameter type [6]
return search_space.__class__(**new_kwargs)