Source code for ax.benchmark.utils
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple, Union, cast
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.modelbridge.generation_strategy import GenerationStrategy
[docs]def get_problems_and_methods(
problems: Optional[Union[List[BenchmarkProblem], List[str]]] = None,
methods: Optional[Union[List[GenerationStrategy], List[str]]] = None,
) -> Tuple[List[BenchmarkProblem], List[GenerationStrategy]]:
"""Validate problems and methods; find them by string keys if passed as
strings.
"""
if (
problems is None
or methods is None
or not all(isinstance(p, BenchmarkProblem) for p in problems)
or not all(isinstance(m, GenerationStrategy) for m in methods)
):
raise NotImplementedError # TODO (done in D18009570)
return (
cast(List[BenchmarkProblem], problems),
cast(List[GenerationStrategy], methods),
)
[docs]def get_corresponding(
value_or_matrix: Union[int, List[List[int]]], row: int, col: int
) -> int:
"""If `value_or_matrix` is a matrix, extract the value in cell specified by
`row` and `col`. If `value_or_matrix` is a scalar, just return it.
"""
if isinstance(value_or_matrix, list):
assert all(isinstance(x, list) for x in value_or_matrix)
return value_or_matrix[row][col]
assert isinstance(value_or_matrix, int)
return value_or_matrix