Source code for ax.modelbridge.transforms.rounding
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import math
import random
import numpy as np
[docs]def randomized_round(x: float) -> int:
"""Randomized round of x"""
z = math.floor(x)
return int(z + float(random.random() <= (x - z)))
[docs]def randomized_onehot_round(x: np.ndarray) -> np.ndarray:
"""Randomized rounding of x to a one-hot vector.
x should be 0 <= x <= 1."""
if len(x) == 1:
return np.array([randomized_round(x[0])])
if sum(x) == 0:
x = np.ones_like(x)
w = x / sum(x)
hot = np.random.choice(len(w), size=1, p=w)[0]
z = np.zeros_like(x)
z[hot] = 1
return z
[docs]def strict_onehot_round(x: np.ndarray) -> np.ndarray:
"""Round x to a one-hot vector by selecting the max element.
Ties broken randomly."""
if len(x) == 1:
return np.round(x)
argmax = x == max(x)
x[argmax] = 1
x[~argmax] = 0
return randomized_onehot_round(x)