Source code for ax.analysis.plotly.surface.utils
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import math
import numpy as np
from ax.core.parameter import (
ChoiceParameter,
FixedParameter,
Parameter,
RangeParameter,
TParamValue,
)
[docs]
def get_parameter_values(parameter: Parameter, density: int = 100) -> list[TParamValue]:
"""
Get a list of parameter values to predict over for a given parameter.
"""
# For RangeParameter use linspace for the range of the parameter
if isinstance(parameter, RangeParameter):
if parameter.log_scale:
return np.logspace(
math.log10(parameter.lower), math.log10(parameter.upper), density
).tolist()
return np.linspace(parameter.lower, parameter.upper, density).tolist()
# For ChoiceParameter use the values of the parameter directly
if isinstance(parameter, ChoiceParameter) and parameter.is_ordered:
return parameter.values
raise ValueError(
f"Parameter {parameter.name} must be a RangeParameter or "
"ChoiceParameter with is_ordered=True to be used in surface plot."
)
[docs]
def select_fixed_value(parameter: Parameter) -> TParamValue:
"""
Select a fixed value for a parameter. Use mean for RangeParameter, "middle" value
for ChoiceParameter, and value for FixedParameter.
"""
if isinstance(parameter, RangeParameter):
return (parameter.lower * 1.0 + parameter.upper) / 2
elif isinstance(parameter, ChoiceParameter):
return parameter.values[len(parameter.values) // 2]
elif isinstance(parameter, FixedParameter):
return parameter.value
else:
raise ValueError(f"Got unexpected parameter type {parameter}.")
[docs]
def is_axis_log_scale(parameter: Parameter) -> bool:
"""
Check if the parameter is log scale.
"""
return isinstance(parameter, RangeParameter) and parameter.log_scale