Source code for ax.modelbridge.transforms.rounding

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import math
import random
from copy import copy
from typing import Set

import numpy as np
from ax.core.parameter_constraint import OrderConstraint
from ax.core.search_space import SearchSpace
from ax.core.types import TParameterization

[docs]def randomized_round(x: float) -> int: """Randomized round of x""" z = math.floor(x) return int(z + float(random.random() <= (x - z)))
[docs]def randomized_onehot_round(x: np.ndarray) -> np.ndarray: """Randomized rounding of x to a one-hot vector. x should be 0 <= x <= 1. If x includes negative values, they will be rounded to zero. """ neg_x = x < 0 x[neg_x] = 0 if len(x) == 1: return np.array([randomized_round(x[0])]) if sum(x) == 0: x = np.ones_like(x) x[neg_x] = 0 w = x / sum(x) hot = np.random.choice(len(w), size=1, p=w)[0] z = np.zeros_like(x) z[hot] = 1 return z
[docs]def strict_onehot_round(x: np.ndarray) -> np.ndarray: """Round x to a one-hot vector by selecting the max element. Ties broken randomly.""" if len(x) == 1: return np.round(x) argmax = x == max(x) x[argmax] = 1 x[~argmax] = 0 return randomized_onehot_round(x)
[docs]def contains_constrained_integer( search_space: SearchSpace, transform_parameters: Set[str] ) -> bool: """Check if any integer parameters are present in parameter_constraints. Order constraints are ignored since strict rounding preserves ordering. """ for constraint in search_space.parameter_constraints: if isinstance(constraint, OrderConstraint): continue constraint_params = set(constraint.constraint_dict.keys()) if constraint_params.intersection(transform_parameters): return True return False
[docs]def randomized_round_parameters( parameters: TParameterization, transform_parameters: Set[str] ) -> TParameterization: rounded_parameters = copy(parameters) for p_name in transform_parameters: # pyre: param is declared to have type `float` but is used as # pyre-fixme[9]: type `Optional[typing.Union[bool, float, str]]`. param: float = parameters.get(p_name) rounded_parameters[p_name] = randomized_round(param) return rounded_parameters