Source code for ax.utils.common.random
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import random
from collections.abc import Generator
from contextlib import contextmanager
from typing import Optional
import numpy as np
import torch
[docs]def set_rng_seed(seed: int) -> None:
"""Sets seeds for random number generators from numpy, pytorch,
and the native random module.
Args:
seed: The random number generator seed.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
[docs]@contextmanager
def with_rng_seed(seed: Optional[int]) -> Generator[None, None, None]:
"""Context manager that sets the random number generator seeds
to a given value and restores the previous state on exit.
If the seed is None, the context manager does nothing. This makes
it possible to use the context manager without having to change
the code based on whether the seed is specified.
Args:
seed: The random number generator seed.
"""
if seed is None:
yield
else:
old_state_native = random.getstate()
old_state_numpy = np.random.get_state()
try:
with torch.random.fork_rng():
set_rng_seed(seed)
yield
finally:
random.setstate(old_state_native)
np.random.set_state(old_state_numpy)