Source code for ax.modelbridge.transforms.logit
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Optional, TYPE_CHECKING
from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from scipy.special import expit, logit
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401
[docs]
class Logit(Transform):
"""Apply logit transform to a float RangeParameter domain.
Transform is done in-place.
"""
def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
assert search_space is not None, "Logit requires search space"
# Identify parameters that should be transformed
self.transform_parameters: set[str] = {
p_name
for p_name, p in search_space.parameters.items()
if isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.FLOAT
and p.logit_scale is True
}
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
for p_name, p in search_space.parameters.items():
if p_name in self.transform_parameters and isinstance(p, RangeParameter):
p.set_logit_scale(False).update_range(
lower=logit(p.lower).item(), upper=logit(p.upper).item()
)
if p.target_value is not None:
p._target_value = logit(p.target_value).item()
return search_space