ax.utils¶
Common¶
Base¶
Constants¶
- class ax.utils.common.constants.Keys(value)[source]¶
-
Enum of reserved keys in options dicts etc, alphabetized.
NOTE: Useful for keys in dicts that correspond to kwargs to classes or functions and/or are used in multiple places.
- ACQF_KWARGS = 'acquisition_function_kwargs'¶
- AUTOSET_SURROGATE = 'autoset_surrogate'¶
- BATCH_INIT_CONDITIONS = 'batch_initial_conditions'¶
- CANDIDATE_SET = 'candidate_set'¶
- CANDIDATE_SIZE = 'candidate_size'¶
- COST_AWARE_UTILITY = 'cost_aware_utility'¶
- COST_INTERCEPT = 'cost_intercept'¶
- CURRENT_VALUE = 'current_value'¶
- EXPAND = 'expand'¶
- EXPECTED_ACQF_VAL = 'expected_acquisition_value'¶
- FIDELITY_FEATURES = 'fidelity_features'¶
- FIDELITY_WEIGHTS = 'fidelity_weights'¶
- FRAC_RANDOM = 'frac_random'¶
- FULL_PARAMETERIZATION = 'full_parameterization'¶
- IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF = 'immutable_search_space_and_opt_config'¶
- MAXIMIZE = 'maximize'¶
- METADATA = 'metadata'¶
- METRIC_NAMES = 'metric_names'¶
- NUM_FANTASIES = 'num_fantasies'¶
- NUM_INNER_RESTARTS = 'num_inner_restarts'¶
- NUM_RESTARTS = 'num_restarts'¶
- NUM_TRACE_OBSERVATIONS = 'num_trace_observations'¶
- OBJECTIVE = 'objective'¶
- ONLY_SURROGATE = 'only_surrogate'¶
- OPTIMIZER_KWARGS = 'optimizer_kwargs'¶
- PAIRWISE_PREFERENCE_QUERY = 'pairwise_pref_query'¶
- PREFERENCE_DATA = 'preference_data'¶
- PRIMARY_SURROGATE = 'primary'¶
- PROJECT = 'project'¶
- QMC = 'qmc'¶
- RAW_INNER_SAMPLES = 'raw_inner_samples'¶
- RAW_SAMPLES = 'raw_samples'¶
- REFIT_ON_UPDATE = 'refit_on_update'¶
- SAMPLER = 'sampler'¶
- SEED_INNER = 'seed_inner'¶
- SEQUENTIAL = 'sequential'¶
- STATE_DICT = 'state_dict'¶
- SUBCLASS = 'subclass'¶
- SUBSET_MODEL = 'subset_model'¶
- TASK_FEATURES = 'task_features'¶
- TRIAL_COMPLETION_TIMESTAMP = 'trial_completion_timestamp'¶
- WARM_START_REFITTING = 'warm_start_refitting'¶
- X_BASELINE = 'X_baseline'¶
Decorator¶
- class ax.utils.common.decorator.ClassDecorator[source]¶
Bases:
ABC
Template for making a decorator work as a class level decorator. That decorator should extend ClassDecorator. It must implement __init__ and decorate_callable. See disable_logger.decorate_callable for an example. decorate_callable should call self._call_func() instead of directly calling func to handle static functions. Note: _call_func is still imperfect and unit tests should be used to ensure everything is working properly. There is a lot of complexity in detecting classmethods and staticmethods and removing the self argument in the right situations. For best results always use keyword args in the decorated class.
DECORATE_PRIVATE can be set to determine whether private methods should be decorated. In the case of a logging decorator, you may only want to decorate things the user calls. But in the case of a disable logging decorator, you may want to decorate everything to ensure no logs escape.
- DECORATE_PRIVATE = True¶
Docutils¶
Support functions for sphinx et. al
- ax.utils.common.docutils.copy_doc(src: Callable[[...], Any]) Callable[[_T], _T] [source]¶
A decorator that copies the docstring of another object
Since
sphinx
actually loads the python modules to grab the docstrings this works with bothsphinx
and thehelp
function.class Cat(Mamal): @property @copy_doc(Mamal.is_feline) def is_feline(self) -> true: ...
Equality¶
- ax.utils.common.equality.dataframe_equals(df1: pandas.DataFrame, df2: pandas.DataFrame) bool [source]¶
Compare equality of two pandas dataframes.
- ax.utils.common.equality.datetime_equals(dt1: Optional[datetime], dt2: Optional[datetime]) bool [source]¶
Compare equality of two datetimes, ignoring microseconds.
- ax.utils.common.equality.equality_typechecker(eq_func: Callable) Callable [source]¶
A decorator to wrap all __eq__ methods to ensure that the inputs are of the right type.
- ax.utils.common.equality.object_attribute_dicts_equal(one_dict: Dict[str, Any], other_dict: Dict[str, Any], skip_db_id_check: bool = False) bool [source]¶
Utility to check if all items in attribute dicts of two Ax objects are the same.
NOTE: Special-cases some Ax object attributes, like “_experiment” or “_model”, where full equality is hard to check.
- Parameters:
one_dict – First object’s attribute dict (
obj.__dict__
).other_dict – Second object’s attribute dict (
obj.__dict__
).skip_db_id_check – If
True
, will exclude thedb_id
attributes from the equality check. Useful for ensuring that all attributes of an object are equal except the ids, with which one or both of them are saved to the database (e.g. if confirming an object before it was saved, to the version reloaded from the DB).
- ax.utils.common.equality.object_attribute_dicts_find_unequal_fields(one_dict: Dict[str, Any], other_dict: Dict[str, Any], fast_return: bool = True, skip_db_id_check: bool = False) Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Tuple[Any, Any]]] [source]¶
Utility for finding out what attributes of two objects’ attribute dicts are unequal.
- Parameters:
one_dict – First object’s attribute dict (
obj.__dict__
).other_dict – Second object’s attribute dict (
obj.__dict__
).fast_return – Boolean representing whether to return as soon as a single unequal attribute was found or to iterate over all attributes and collect all unequal ones.
skip_db_id_check – If
True
, will exclude thedb_id
attributes from the equality check. Useful for ensuring that all attributes of an object are equal except the ids, with which one or both of them are saved to the database (e.g. if confirming an object before it was saved, to the version reloaded from the DB).
- Returns:
attribute name to attribute values of unequal type (as a tuple),
attribute name to attribute values of unequal value (as a tuple).
- Return type:
Two dictionaries
- ax.utils.common.equality.same_elements(list1: List[Any], list2: List[Any]) bool [source]¶
Compare equality of two lists of core Ax objects.
- Assumptions:
– The contents of each list are types that implement __eq__ – The lists do not contain duplicates
Checking equality is then the same as checking that the lists are the same length, and that one is a subset of the other.
Equality¶
- ax.utils.common.executils.handle_exceptions_in_retries(no_retry_exceptions: Tuple[Type[Exception], ...], retry_exceptions: Tuple[Type[Exception], ...], suppress_errors: bool, check_message_contains: Optional[str], last_retry: bool, logger: Optional[Logger], wrap_error_message_in: Optional[str]) Generator[None, None, None] [source]¶
- ax.utils.common.executils.retry_on_exception(exception_types: Optional[Tuple[Type[Exception], ...]] = None, no_retry_on_exception_types: Optional[Tuple[Type[Exception], ...]] = None, check_message_contains: Optional[List[str]] = None, retries: int = 3, suppress_all_errors: bool = False, logger: Optional[Logger] = None, default_return_on_suppression: Optional[Any] = None, wrap_error_message_in: Optional[str] = None, initial_wait_seconds: Optional[int] = None) Optional[Any] [source]¶
A decorator for instance methods or standalone functions that makes them retry on failure and allows to specify on which types of exceptions the function should and should not retry.
NOTE: If the argument suppress_all_errors is supplied and set to True, the error will be suppressed and default value returned.
- Parameters:
exception_types – A tuple of exception(s) types to catch in the decorated function. If none is provided, baseclass Exception will be used.
no_retry_on_exception_types – Exception types to consider non-retryable even if their supertype appears in exception_types or the only exceptions to not retry on if no exception_types are specified.
check_message_contains – A list of strings, against which to match error messages. If the error message contains any one of these strings, the exception will cause a retry. NOTE: This argument works in addition to exception_types; if those are specified, only the specified types of exceptions will be caught and retried on if they contain the strings provided as check_message_contains.
retries – Number of retries to perform.
suppress_all_errors – If true, after all the retries are exhausted, the error will still be suppressed and default_return_on_suppresion will be returned from the function. NOTE: If using this argument, the decorated function may not actually get fully executed, if it consistently raises an exception.
logger – A handle for the logger to be used.
default_return_on_suppression – If the error is suppressed after all the retries, then this default value will be returned from the function. Defaults to None.
wrap_error_message_in – If raising the error message after all the retries, a string wrapper for the error message (useful for making error messages more user-friendly). NOTE: Format of resulting error will be: “<wrap_error_message_in>: <original_error_type>: <original_error_msg>”, with the stack trace of the original message.
initial_wait_seconds – Initial length of time to wait between failures, doubled after each failure up to a maximum of 10 minutes. If unspecified then there is no wait between retries.
Kwargs¶
- ax.utils.common.kwargs.consolidate_kwargs(kwargs_iterable: Iterable[Optional[Dict[str, Any]]], keywords: Iterable[str]) Dict[str, Any] [source]¶
Combine an iterable of kwargs into a single dict of kwargs, where kwargs by duplicate keys that appear later in the iterable get priority over the ones that appear earlier and only kwargs referenced in keywords will be used. This allows to combine somewhat redundant sets of kwargs, where a user-set kwarg, for instance, needs to override a default kwarg.
>>> consolidate_kwargs( ... kwargs_iterable=[{'a': 1, 'b': 2}, {'b': 3, 'c': 4, 'd': 5}], ... keywords=['a', 'b', 'd'] ... ) {'a': 1, 'b': 3, 'd': 5}
- ax.utils.common.kwargs.filter_kwargs(function: Callable, **kwargs: Any) Any [source]¶
Filter out kwargs that are not applicable for a given function. Return a copy of given kwargs dict with only the required kwargs.
- ax.utils.common.kwargs.get_function_argument_names(function: Callable, omit: Optional[List[str]] = None) List[str] [source]¶
Extract parameter names from function signature.
- ax.utils.common.kwargs.get_function_default_arguments(function: Callable) Dict[str, Any] [source]¶
Extract default arguments from function signature.
- ax.utils.common.kwargs.validate_kwarg_typing(typed_callables: List[Callable], **kwargs: Any) None [source]¶
Check if keywords in kwargs exist in any of the typed_callables and if the type of each keyword value matches the type of corresponding arg in one of the callables
Note: this function expects the typed callables to have unique keywords for the arguments and will raise an error if repeat keywords are found.
- ax.utils.common.kwargs.warn_on_kwargs(callable_with_kwargs: Callable, **kwargs: Any) None [source]¶
Log a warning when a decoder function receives unexpected kwargs.
NOTE: This mainly caters to the use case where an older version of Ax is used to decode objects, serialized to JSON by a newer version of Ax (and therefore potentially containing new fields). In that case, the decoding function should not fail when encountering those additional fields, but rather just ignore them and log a warning using this function.
Logger¶
- class ax.utils.common.logger.AxOutputNameFilter(name='')[source]¶
Bases:
Filter
This is a filter which sets the record’s output_name, if not configured
- ax.utils.common.logger.build_file_handler(filepath: str, level: int = 20) StreamHandler [source]¶
Build a file handle that logs entries to the given file, using the same formatting as the stream handler.
- Parameters:
filepath – Location of the file to log output to. If the file exists, output will be appended. If it does not exist, a new file will be created.
level – The log level. By default, sets level to INFO
- Returns:
A logging.FileHandler instance
- ax.utils.common.logger.build_stream_handler(level: int = 20) StreamHandler [source]¶
Build the default stream handler used for most Ax logging. Sets default level to INFO, instead of WARNING.
- Parameters:
level – The log level. By default, sets level to INFO
- Returns:
A logging.StreamHandler instance
- class ax.utils.common.logger.disable_logger(name: str, level: int = 40)[source]¶
Bases:
ClassDecorator
- class ax.utils.common.logger.disable_loggers(names: List[str], level: int = 40)[source]¶
Bases:
ClassDecorator
- ax.utils.common.logger.get_logger(name: str, level: int = 20, force_name: bool = False) Logger [source]¶
Get an Axlogger.
To set a human-readable “output_name” that appears in logger outputs, add {“output_name”: “[MY_OUTPUT_NAME]”} to the logger’s contextual information. By default, we use the logger’s name
NOTE: To change the log level on particular outputs (e.g. STDERR logs), set the proper log level on the relevant handler, instead of the logger e.g. logger.handers[0].setLevel(INFO)
- Parameters:
name – The name of the logger.
level – The level at which to actually log. Logs below this level of importance will be discarded
force_name – If set to false and the module specified is not ultimately a descendent of the ax module specified by name, “ax.” will be prepended to name
- Returns:
The logging.Logger object.
Mock Torch¶
- ax.utils.common.mock.mock_patch_method_original(mock_path: str, original_method: Callable[[...], T]) MagicMock [source]¶
Context manager for patching a method returning type T on class C, to track calls to it while still executing the original method. There is not a native way to do this with mock.patch.
Result¶
- class ax.utils.common.result.Err(value: E)[source]¶
Bases:
Generic
[T
,E
],Result
[T
,E
]Contains the error value.
- property err: E¶
- map(op: Callable[[T], U]) Result[U, E] [source]¶
Maps a Result[T, E] to Result[U, E] by applying a function to a contained Ok value, leaving an Err value untouched. This function can be used to compose the results of two functions.
- map_err(op: Callable[[E], F]) Result[T, F] [source]¶
Maps a Result[T, E] to Result[T, F] by applying a function to a contained Err value, leaving an Ok value untouched. This function can be used to pass through a successful result while handling an error.
- map_or(default: U, op: Callable[[T], U]) U [source]¶
Returns the provided default (if Err), or applies a function to the contained value (if Ok).
- map_or_else(default_op: Callable[[], U], op: Callable[[T], U]) U [source]¶
Maps a Result[T, E] to U by applying fallback function default to a contained Err value, or function op to a contained Ok value. This function can be used to unpack a successful result while handling an error.
- unwrap() NoReturn [source]¶
Returns the contained Ok value.
Because this function may raise an UnwrapError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- unwrap_err() E [source]¶
Returns the contained Err value.
Because this function may raise an UnwrapError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- unwrap_or_else(op: Callable[[E], T]) T [source]¶
Returns the contained Ok value or computes it from a Callable.
- property value: E¶
- class ax.utils.common.result.Ok(value: T)[source]¶
Bases:
Generic
[T
,E
],Result
[T
,E
]Contains the success value.
- map(op: Callable[[T], U]) Result[U, E] [source]¶
Maps a Result[T, E] to Result[U, E] by applying a function to a contained Ok value, leaving an Err value untouched. This function can be used to compose the results of two functions.
- map_err(op: Callable[[E], F]) Result[T, F] [source]¶
Maps a Result[T, E] to Result[T, F] by applying a function to a contained Err value, leaving an Ok value untouched. This function can be used to pass through a successful result while handling an error.
- map_or(default: U, op: Callable[[T], U]) U [source]¶
Returns the provided default (if Err), or applies a function to the contained value (if Ok).
- map_or_else(default_op: Callable[[], U], op: Callable[[T], U]) U [source]¶
Maps a Result[T, E] to U by applying fallback function default to a contained Err value, or function op to a contained Ok value. This function can be used to unpack a successful result while handling an error.
- property ok: T¶
- unwrap() T [source]¶
Returns the contained Ok value.
Because this function may raise an UnwrapError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- unwrap_err() NoReturn [source]¶
Returns the contained Err value.
Because this function may raise an UnwrapError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- unwrap_or_else(op: Callable[[E], T]) T [source]¶
Returns the contained Ok value or computes it from a Callable.
- property value: T¶
- class ax.utils.common.result.Result[source]¶
-
A minimal implementation of a rusty Result monad. See https://doc.rust-lang.org/std/result/enum.Result.html for more information.
- abstract map(op: Callable[[T], U]) Result[U, E] [source]¶
Maps a Result[T, E] to Result[U, E] by applying a function to a contained Ok value, leaving an Err value untouched. This function can be used to compose the results of two functions.
- abstract map_err(op: Callable[[E], F]) Result[T, F] [source]¶
Maps a Result[T, E] to Result[T, F] by applying a function to a contained Err value, leaving an Ok value untouched. This function can be used to pass through a successful result while handling an error.
- abstract map_or(default: U, op: Callable[[T], U]) U [source]¶
Returns the provided default (if Err), or applies a function to the contained value (if Ok).
- abstract map_or_else(default_op: Callable[[], U], op: Callable[[T], U]) U [source]¶
Maps a Result[T, E] to U by applying fallback function default to a contained Err value, or function op to a contained Ok value. This function can be used to unpack a successful result while handling an error.
- abstract unwrap() T [source]¶
Returns the contained Ok value.
Because this function may raise an UnwrapError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
- abstract unwrap_err() E [source]¶
Returns the contained Err value.
Because this function may raise an UnwrapError, its use is generally discouraged. Instead, prefer to handle the Err case explicitly, or call unwrap_or, unwrap_or_else, or unwrap_or_default.
Serialization¶
- class ax.utils.common.serialization.SerializationMixin[source]¶
Bases:
object
- classmethod deserialize_init_args(args: Dict[str, Any], decoder_registry: Optional[Dict[str, Type]] = None, class_decoder_registry: Optional[Dict[str, Callable[[Dict[str, Any]], Any]]] = None) Dict[str, Any] [source]¶
Given a dictionary, deserialize the properties needed to initialize the object. Used for storage.
- ax.utils.common.serialization.callable_from_reference(path: str) Callable [source]¶
Retrieves a callable by its path.
- ax.utils.common.serialization.callable_to_reference(callable: Callable) str [source]¶
Obtains path to the callable of form <module>.<name>.
- ax.utils.common.serialization.extract_init_args(args: Dict[str, Any], class_: Type) Dict[str, Any] [source]¶
Given a dictionary, extract the arguments required for the given class’s constructor.
Testutils¶
Support functions for tests
- class ax.utils.common.testutils.TestCase(methodName: str = 'runTest')[source]¶
Bases:
TestCase
The base Ax test case, contains various helper functions to write unittests.
- assertAxBaseEqual(first: Base, second: Base, msg: Optional[str] = None, skip_db_id_check: bool = False) None [source]¶
Check that two Ax objects that subclass
Base
are equal or raise assertion error otherwise.- Parameters:
first –
Base
-subclassing object to compare tosecond
.second –
Base
-subclassing object to compare tofirst
.msg – Message to put into the assertion error raised on inequality; if not specified, a default message is used.
skip_db_id_check –
If
True
, will exclude thedb_id
attributes from the equality check. Useful for ensuring that all attributes of an object are equal except the ids, with which one or both of them are saved to the database (e.g. if confirming an object before it was saved, to theversion reloaded from the DB).
- assertDictsAlmostEqual(a: Dict[str, Any], b: Dict[str, Any], consider_nans_equal: bool = False) None [source]¶
Testing utility that checks that 1) the keys of a and b are identical, and that 2) the values of a and b are almost equal if they have a floating point type, considering NaNs as equal, and otherwise just equal.
- Parameters:
test – The test case object.
a – A dictionary.
b – Another dictionary.
consider_nans_equal – Whether to consider NaNs equal when comparing floating point numbers.
- assertEqual(first: Any, second: Any, msg: Optional[str] = None) None [source]¶
Fail if the two objects are unequal as determined by the ‘==’ operator.
- assertRaisesOn(exc: Type[Exception], line: Optional[str] = None, regex: Optional[str] = None) ContextManager[None] [source]¶
Assert that an exception is raised on a specific line.
Timeutils¶
- ax.utils.common.timeutils.current_timestamp_in_millis() int [source]¶
Grab current timestamp in milliseconds as an int.
Typeutils¶
- ax.utils.common.typeutils.checked_cast(typ: Type[T], val: V, exception: Optional[Exception] = None) T [source]¶
Cast a value to a type (with a runtime safety check).
Returns the value unchanged and checks its type at runtime. This signals to the typechecker that the value has the designated type.
Like typing.cast
check_cast
performs no runtime conversion on its argument, but, unliketyping.cast
,checked_cast
will throw an error if the value is not of the expected type. The type passed as an argument should be a python class.- Parameters:
typ – the type to cast to
val – the value that we are casting
exception – override exception to raise if typecheck fails
- Returns:
the
val
argument, unchanged
- ax.utils.common.typeutils.checked_cast_dict(key_typ: Type[K], value_typ: Type[V], d: Dict[X, Y]) Dict[K, V] [source]¶
Calls checked_cast on all keys and values in the dictionary.
- ax.utils.common.typeutils.checked_cast_list(typ: Type[T], old_l: List[V]) List[T] [source]¶
Calls checked_cast on all items in a list.
- ax.utils.common.typeutils.checked_cast_optional(typ: Type[T], val: Optional[V]) Optional[T] [source]¶
Calls checked_cast only if value is not None.
- ax.utils.common.typeutils.checked_cast_to_tuple(typ: Tuple[Type[V], ...], val: V) T [source]¶
Cast a value to a union of multiple types (with a runtime safety check). This function is similar to checked_cast, but allows for the type to be defined as a tuple of types, in which case the value is cast as a union of the types in the tuple.
- Parameters:
typ – the tuple of types to cast to
val – the value that we are casting
- Returns:
the
val
argument, unchanged
- ax.utils.common.typeutils.not_none(val: Optional[T], message: Optional[str] = None) T [source]¶
Unbox an optional type.
- Parameters:
val – the value to cast to a non
None
typemessage – optional override of the default error message
- Returns:
val
whenval
is notNone
- Return type:
V
- Throws:
ValueError if
val
isNone
- ax.utils.common.typeutils.numpy_type_to_python_type(value: Any) Any [source]¶
If value is a Numpy int or float, coerce to a Python int or float. This is necessary because some of our transforms return Numpy values.
- ax.utils.common.typeutils.version_safe_check_type(argname: str, value: T, expected_type: Type[T]) None [source]¶
Excecute the check_type function if it has the expected signature, otherwise warn. This is done to support newer versions of typeguard with minimal loss of functionality for users that have dependency conflicts
Typeutils Torch¶
Flake8 Plugins¶
Docstring Checker¶
- class ax.utils.flake8_plugins.docstring_checker.DocstringChecker(tree, filename)[source]¶
Bases:
object
A flake8 plug-in that makes sure all public functions have a docstring
- class ax.utils.flake8_plugins.docstring_checker.DocstringCheckerVisitor[source]¶
Bases:
NodeVisitor
- visit_FunctionDef(node: FunctionDef) None [source]¶
- class ax.utils.flake8_plugins.docstring_checker.Error(lineno, col, message, type)[source]¶
Bases:
NamedTuple
- ax.utils.flake8_plugins.docstring_checker.is_copy_doc_call(c)[source]¶
Tries to guess if this is a call to the
copy_doc
decorator. This is a purely syntactic check so if the decorator was aliased as another name] or wrapped in another function we will fail.
Measurement¶
Synthetic Functions¶
- class ax.utils.measurement.synthetic_functions.Aug_Branin[source]¶
Bases:
SyntheticFunction
Augmented Branin function (3-dimensional with infinitely many global minima).
- class ax.utils.measurement.synthetic_functions.Aug_Hartmann6[source]¶
Bases:
Hartmann6
Augmented Hartmann6 function (7-dimensional with 1 global minimum).
- class ax.utils.measurement.synthetic_functions.Branin[source]¶
Bases:
SyntheticFunction
Branin function (2-dimensional with 3 global minima).
- class ax.utils.measurement.synthetic_functions.FromBotorch(botorch_synthetic_function: SyntheticTestFunction)[source]¶
Bases:
SyntheticFunction
- class ax.utils.measurement.synthetic_functions.Hartmann6[source]¶
Bases:
SyntheticFunction
Hartmann6 function (6-dimensional with 1 global minimum).
- class ax.utils.measurement.synthetic_functions.SyntheticFunction[source]¶
Bases:
ABC
- ax.utils.measurement.synthetic_functions.from_botorch(botorch_synthetic_function: SyntheticTestFunction) SyntheticFunction [source]¶
Utility to generate Ax synthetic functions from BoTorch synthetic functions.
Notebook¶
Plotting¶
- ax.utils.notebook.plotting.init_notebook_plotting(offline: bool = False) None [source]¶
Initialize plotting in notebooks, either in online or offline mode.
- ax.utils.notebook.plotting.render(plot_config: AxPlotConfig, inject_helpers: bool = False) None [source]¶
Render plot config.
Report¶
Render¶
- ax.utils.report.render.link_html(text: str, href: str) str [source]¶
Embed text and reference address into link tag.
- ax.utils.report.render.render_report_elements(experiment_name: str, html_elements: List[str], header: bool = True, offline: bool = False, notebook_env: bool = False) str [source]¶
Generate Ax HTML report for a given experiment from HTML elements.
Uses Jinja2 for template. Injects Plotly JS for graph rendering.
Example:
html_elements = [ h2_html("Subsection with plot"), p_html("This is an example paragraph."), plot_html(plot_fitted(gp_model, 'perf_metric')), h2_html("Subsection with table"), pandas_html(data.df), ] html = render_report_elements('My experiment', html_elements)
- Parameters:
experiment_name – the name of the experiment to use for title.
html_elements – list of HTML strings to render in report body.
header – if True, render experiment title as a header. Meant to be used for standalone reports (e.g. via email), as opposed to served on the front-end.
offline – if True, entire Plotly library is bundled with report.
notebook_env – if True, caps the report width to 700px for viewing in a notebook environment.
- Returns:
HTML string.
- Return type:
- ax.utils.report.render.table_cell_html(text: str, width: Optional[str] = None) str [source]¶
Embed text or an HTML element into table cell tag.
- ax.utils.report.render.table_heading_cell_html(text: str) str [source]¶
Embed text or an HTML element into table heading cell tag.
- ax.utils.report.render.table_html(table_rows: List[str]) str [source]¶
Embed list of HTML elements into table tag.
Sensitivity¶
Derivative GP¶
- ax.utils.sensitivity.derivative_gp.get_KXX_inv(gp: Model) Tensor [source]¶
Get the inverse matrix of K(X,X). :param gp: Botorch model.
- Returns:
The inverse of K(X,X).
- ax.utils.sensitivity.derivative_gp.get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = 'rbf') Tensor [source]¶
Computes the analytic derivative of the kernel K(x,X) w.r.t. x. :param gp: Botorch model. :param x: (n x D) Test points. :param kernel_type: Takes “rbf” or “matern”
- Returns:
Tensor (n x D) The derivative of the kernel K(x,X) w.r.t. x.
- ax.utils.sensitivity.derivative_gp.get_Kxx_dx2(gp: Model, kernel_type: str = 'rbf') Tensor [source]¶
Computes the analytic second derivative of the kernel w.r.t. the training data :param gp: Botorch model. :param kernel_type: Takes “rbf” or “matern”
- Returns:
Tensor (n x D x D) The second derivative of the kernel w.r.t. the training data.
- ax.utils.sensitivity.derivative_gp.posterior_derivative(gp: Model, x: Tensor, kernel_type: str = 'rbf') MultivariateNormal [source]¶
Computes the posterior of the derivative of the GP w.r.t. the given test points x. This follows the derivation used by GIBO in Sarah Muller, Alexander von Rohr, Sebastian Trimpe. “Local policy search with Bayesian optimization”, Advances in Neural Information Processing Systems 34, NeurIPS 2021. :param gp: Botorch model :param x: (n x D) Test points. :param kernel_type: Takes “rbf” or “matern”
- Returns:
A Botorch Posterior.
Derivative Measures¶
- class ax.utils.sensitivity.derivative_measures.GpDGSMGpMean(model: Model, bounds: Tensor, derivative_gp: bool = False, kernel_type: Optional[str] = None, Y_scale: float = 1.0, num_mc_samples: int = 10000, input_qmc: bool = False, dtype: dtype = torch.float64, num_bootstrap_samples: int = 1)[source]¶
Bases:
object
- gradient_absolute_measure() Tensor [source]¶
Computes the gradient absolute measure:
- Returns:
- if self.num_bootstrap_samples > 1
Tensor: (values, var_mc, stderr_mc) x dim
- else
Tensor: (values) x dim
- gradient_measure() Tensor [source]¶
Computes the gradient measure:
- Returns:
- if self.num_bootstrap_samples > 1
Tensor: (values, var_mc, stderr_mc) x dim
- else
Tensor: (values) x dim
- class ax.utils.sensitivity.derivative_measures.GpDGSMGpSampling(model: Model, bounds: Tensor, num_gp_samples: int, derivative_gp: bool = False, kernel_type: Optional[str] = None, Y_scale: float = 1.0, num_mc_samples: int = 10000, input_qmc: bool = False, gp_sample_qmc: bool = False, dtype: dtype = torch.float64, num_bootstrap_samples: int = 1)[source]¶
Bases:
GpDGSMGpMean
- ax.utils.sensitivity.derivative_measures.compute_derivatives_from_model_list(model_list: List[Model], bounds: Tensor, **kwargs: Any) Tensor [source]¶
Computes average derivatives of a list of models on a bounded domain. Estimation is according to the GP posterior mean function.
- Parameters:
model_list – A list of m botorch.models.model.Model types for which to compute the average derivative.
bounds – A 2 x d Tensor of lower and upper bounds of the domain of the models.
kwargs – Passed along to GpDGSMGpMean.
- Returns:
A (m x d) tensor of gradient measures.
Sobol Measures¶
- class ax.utils.sensitivity.sobol_measures.SobolSensitivity(bounds: Tensor, input_function: Optional[Callable[[Tensor], Tensor]] = None, num_mc_samples: int = 10000, input_qmc: bool = False, second_order: bool = False, num_bootstrap_samples: int = 1, bootstrap_array: bool = False)[source]¶
Bases:
object
- evalute_function(f_A_B_ABi: Optional[Tensor] = None) None [source]¶
- evaluates the objective function and devides the evaluation into
torch.Tensors needed for the indices computation.
- Parameters:
f_A_B_ABi – Function evaluations on the entire grid of size M(d+2).
- first_order_indices() Tensor [source]¶
Computes the first order Sobol indices:
- Returns:
- if num_bootstrap_samples>1
Tensor: (values,var_mc,stderr_mc)x dim
- else
Tensor: (values)x dim
- second_order_indices(first_order_idxs: Optional[Tensor] = None, first_order_idxs_btsp: Optional[Tensor] = None) Tensor [source]¶
Computes the Second order Sobol indices: :param first_order_idxs: Tensor of first order indices. :param first_order_idxs_btsp: Tensor of all first order indices given by bootstrap.
- Returns:
- if num_bootstrap_samples>1
Tensor: (values,var_mc,stderr_mc)x dim
- else
Tensor: (values)x dim
- class ax.utils.sensitivity.sobol_measures.SobolSensitivityGPMean(model: ~botorch.models.model.Model, bounds: ~torch.Tensor, num_mc_samples: int = 10000, second_order: bool = False, input_qmc: bool = False, num_bootstrap_samples: int = 1, link_function: ~typing.Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor] = <function GaussianLinkMean>, mini_batch_size: int = 128)[source]¶
Bases:
object
- first_order_indices() Tensor [source]¶
Computes the first order Sobol indices:
- Returns:
- if num_bootstrap_samples>1
Tensor: (values,var_mc,stderr_mc)x dim
- else
Tensor: (values)x dim
- class ax.utils.sensitivity.sobol_measures.SobolSensitivityGPSampling(model: Model, bounds: Tensor, num_gp_samples: int = 1000, num_mc_samples: int = 10000, second_order: bool = False, input_qmc: bool = False, gp_sample_qmc: bool = False, num_bootstrap_samples: int = 1)[source]¶
Bases:
object
- first_order_indices() Tensor [source]¶
Computes the first order Sobol indices:
- Returns:
- if num_bootstrap_samples>1
Tensor: (values, var_gp, stderr_gp, var_mc, stderr_mc) x dim
- else
Tensor: (values, var, stderr) x dim
- ax.utils.sensitivity.sobol_measures.ax_parameter_sens(model_bridge: TorchModelBridge, metrics: Optional[List[str]] = None, order: str = 'first', signed: bool = True, **sobol_kwargs: Any) Dict[str, Dict[str, ndarray]] [source]¶
Compute sensitivity for all metrics on an TorchModelBridge.
Sobol measures are always positive regardless of the direction in which the parameter influences f. If signed is set to True, then the Sobol measure for each parameter will be given as its sign the sign of the average gradient with respect to that parameter across the search space. Thus, important parameters that, when increased, decrease f will have large and negative values; unimportant parameters will have values close to 0.
- Parameters:
model_bridge – A ModelBridge object with models that were fit.
metrics – The names of the metrics and outcomes for which to compute sensitivities. This should preferably be metrics with a good model fit. Defaults to model_bridge.outcomes.
order – A string specifying the order of the Sobol indices to be computed. Supports “first” and “total” and defaults to “first”.
signed – A bool for whether the measure should be signed.
sobol_kwargs – keyword arguments passed on to SobolSensitivityGPMean, and if signed, GpDGSMGpMean.
- Returns:
- {‘parameter_name’: sensitivity_value}}, where the
sensitivity value is cast to a Numpy array in order to be compatible with plot_feature_importance_by_feature.
- Return type:
Dictionary {‘metric_name’
- ax.utils.sensitivity.sobol_measures.compute_sobol_indices_from_model_list(model_list: List[Model], bounds: Tensor, order: str = 'first', **sobol_kwargs: Any) Tensor [source]¶
Computes Sobol indices of a list of models on a bounded domain.
- Parameters:
model_list – A list of botorch.models.model.Model types for which to compute the Sobol indices.
bounds – A 2 x d Tensor of lower and upper bounds of the domain of the models.
order – A string specifying the order of the Sobol indices to be computed. Supports “first” and “total” and defaults to “first”.
sobol_kwargs – keyword arguments passed on to SobolSensitivityGPMean.
- Returns:
With m GPs, returns a (m x d) tensor of order-order Sobol indices.
Stats¶
Statstools¶
- ax.utils.stats.statstools.agresti_coull_sem(n_numer: Union[pandas.Series, ndarray, int], n_denom: Union[pandas.Series, ndarray, int], prior_successes: int = 2, prior_failures: int = 2) Union[ndarray, float] [source]¶
Compute the Agresti-Coull style standard error for a binomial proportion.
Reference: Agresti, Alan, and Brent A. Coull. Approximate Is Better than ‘Exact’ for Interval Estimation of Binomial Proportions.” The American Statistician, vol. 52, no. 2, 1998, pp. 119-126. JSTOR, www.jstor.org/stable/2685469.
- ax.utils.stats.statstools.inverse_variance_weight(means: ndarray, variances: ndarray, conflicting_noiseless: str = 'warn') Tuple[float, float] [source]¶
Perform inverse variance weighting.
- Parameters:
means – The means of the observations.
variances – The variances of the observations.
conflicting_noiseless – How to handle the case of multiple observations with zero variance but different means. Options are “warn” (default), “ignore” or “raise”.
- ax.utils.stats.statstools.marginal_effects(df: pandas.DataFrame) pandas.DataFrame [source]¶
This method calculates the relative (in %) change in the outcome achieved by using any individual factor level versus randomizing across all factor levels. It does this by estimating a baseline under the experiment by marginalizing over all factors/levels. For each factor level, then, it conditions on that level for the individual factor and then marginalizes over all levels for all other factors.
- Parameters:
df – Dataframe containing columns named mean and sem. All other columns are assumed to be factors for which to calculate marginal effects.
- Returns:
- A dataframe containing columns “Name”, “Level”, “Beta” and “SE”
corresponding to the factor, level, effect and standard error. Results are relativized as percentage changes.
- ax.utils.stats.statstools.positive_part_james_stein(means: Union[ndarray, List[float]], sems: Union[ndarray, List[float]]) Tuple[ndarray, ndarray] [source]¶
Estimation method for Positive-part James-Stein estimator.
This method takes a vector of K means (y_i) and standard errors (sigma_i) and calculates the positive-part James Stein estimator.
Resulting estimates are the shrunk means and standard errors. The positive part James-Stein estimator shrinks each constituent average to the grand average:
y_i - phi_i * y_i + phi_i * ybar
The variable phi_i determines the amount of shrinkage. For phi_i = 1, mu_hat is equal to ybar (the mean of all y_i), while for phi_i = 0, mu_hat is equal to y_i. It can be shown that restricting phi_i <= 1 dominates the unrestricted estimator, so this method restricts phi_i in this manner. The amount of shrinkage, phi_i, is determined by:
(K - 3) * sigma2_i / s2
That is, less shrinkage is applied when individual means are estimated with greater precision, and more shrinkage is applied when individual means are very tightly clustered together. We also restrict phi_i to never be larger than 1.
The variance of the mean estimator is:
(1 - phi_i) * sigma2_i + phi * sigma2_i / K + 2 * phi_i ** 2 * (y_i - ybar)^2 / (K - 3)
The first term is the variance component from y_i, the second term is the contribution from the mean of all y_i, and the third term is the contribution from the uncertainty in the sum of squared deviations of y_i from the mean of all y_i.
For more information, see https://ax.dev/docs/models.html#empirical-bayes-and-thompson-sampling.
- Parameters:
means – Means of each arm
sems – Standard errors of each arm
- Returns:
Empirical Bayes estimate of each arm’s mean sem_i: Empirical Bayes estimate of each arm’s sem
- Return type:
mu_hat_i
- ax.utils.stats.statstools.relativize(means_t: Union[ndarray, List[float], float], sems_t: Union[ndarray, List[float], float], mean_c: float, sem_c: float, bias_correction: bool = True, cov_means: Union[ndarray, List[float], float] = 0.0, as_percent: bool = False, control_as_constant: bool = False) Tuple[ndarray, ndarray] [source]¶
Ratio estimator based on the delta method.
This uses the delta method (i.e. a Taylor series approximation) to estimate the mean and standard deviation of the sampling distribution of the ratio between test and control – that is, the sampling distribution of an estimator of the true population value under the assumption that the means in test and control have a known covariance:
(mu_t / mu_c) - 1.
Under a second-order Taylor expansion, the sampling distribution of the relative change in empirical means, which is m_t / m_c - 1, is approximately normally distributed with mean
[(mu_t - mu_c) / mu_c] - [(sigma_c)^2 * mu_t] / (mu_c)^3
and variance
(sigma_t / mu_c)^2 - 2 * mu_t _ sigma_tc / mu_c^3 + [(sigma_c * mu_t)^2 / (mu_c)^4]
as the higher terms are assumed to be close to zero in the full Taylor series. To estimate these parameters, we plug in the empirical means and standard errors. This gives us the estimators:
[(m_t - m_c) / m_c] - [(s_c)^2 * m_t] / (m_c)^3
and
(s_t / m_c)^2 - 2 * m_t * s_tc / m_c^3 + [(s_c * m_t)^2 / (m_c)^4]
Note that the delta method does NOT take as input the empirical standard deviation of a metric, but rather the standard error of the mean of that metric – that is, the standard deviation of the metric after division by the square root of the total number of observations.
- Parameters:
means_t – Sample means (test)
sems_t – Sample standard errors of the means (test)
mean_c – Sample mean (control)
sem_c – Sample standard error of the mean (control)
bias_correction – Whether to apply bias correction when computing relativized metric values. Uses a second-order Taylor expansion for approximating the means and standard errors of the ratios.
cov_means – Sample covariance between test and control
as_percent – If true, return results in percent (* 100)
control_as_constant – If true, control is treated as a constant. bias_correction, sem_c, and cov_means are ignored when this is true.
- Returns:
- Inferred means of the sampling distribution of
the relative change (mean_t - mean_c) / abs(mean_c)
- sem_hat: Inferred standard deviation of the sampling
distribution of rel_hat – i.e. the standard error.
- Return type:
rel_hat
- ax.utils.stats.statstools.relativize_data(data: Data, status_quo_name: str = 'status_quo', as_percent: bool = False, include_sq: bool = False, bias_correction: bool = True, control_as_constant: bool = False) Data [source]¶
Relativize a data object w.r.t. a status_quo arm.
- Parameters:
data – The data object to be relativized.
status_quo_name – The name of the status_quo arm.
as_percent – If True, return results as percentage change.
include_sq – Include status quo in final df.
bias_correction – Whether to apply bias correction when computing relativized metric values. Uses a second-order Taylor expansion for approximating the means and standard errors or the ratios, see ax.utils.stats.statstools.relativize for more details.
control_as_constant – If true, control is treated as a constant. bias_correction is ignored when this is true.
- Returns:
- The new data object with the relativized metrics (excluding the
status_quo arm)
- ax.utils.stats.statstools.total_variance(means: ndarray, variances: ndarray, sample_sizes: ndarray) float [source]¶
Compute total variance.
- ax.utils.stats.statstools.unrelativize(means_t: Union[ndarray, List[float], float], sems_t: Union[ndarray, List[float], float], mean_c: float, sem_c: float, bias_correction: bool = True, cov_means: Union[ndarray, List[float], float] = 0.0, as_percent: bool = False, control_as_constant: bool = False) Tuple[ndarray, ndarray] [source]¶
Reverse operation of ax.utils.stats.statstools.relativize.
- Parameters:
means_t – Relativized sample means (test) to be unrelativized
sems_t – Relativized sample SEM of the means (test) to be unrelativized
mean_c – Unrelativized control mean
sem_c – Unrelativized control SEM of the mean
bias_correction – if means_t and sems_t are obtained with bias_correction=True in ax.utils.stats.statstools.relativize
cov_means – Sample covariance between the unrelativized test and control
as_percent – If true, assuming means_t and sems_t are percentages (i.e., 1 means 1%).
control_as_constant – If true, control is treated as a constant. bias_correction, sem_c, and cov_means are ignored when this is true.
- Returns:
Inferred sample (test) means in the unrelativized scale s_t: Inferred SEM of sample (test) means in the unrelativized scale
- Return type:
m_t
Model Fit Metrics¶
- class ax.utils.stats.model_fit_stats.ModelFitMetricProtocol(*args, **kwargs)[source]¶
Bases:
Protocol
Structural type for model fit metrics.
- ax.utils.stats.model_fit_stats.coefficient_of_determination(y_obs: ndarray, y_pred: ndarray, se_pred: Optional[ndarray] = None, eps: float = 1e-12) float [source]¶
Computes coefficient of determination, the proportion of variance in y_obs accounted for by predictions y_pred.
- Parameters:
y_obs – An array of observations for a single metric.
y_pred – An array of the predicted values corresponding to y_obs.
se_pred – Not used, kept for API compatibility.
eps – A small constant to add to the denominator for numerical stability.
- Returns:
The scalar coefficient of determination, “R squared”.
- ax.utils.stats.model_fit_stats.compute_model_fit_metrics(y_obs: Mapping[str, ndarray], y_pred: Mapping[str, ndarray], se_pred: Mapping[str, ndarray], fit_metrics_dict: Mapping[str, ModelFitMetricProtocol]) Dict[str, Dict[str, float]] [source]¶
Computes the model fit metrics for each experimental metric in the input dicts.
- Parameters:
y_obs – A dictionary mapping from experimental metric name to observed values.
y_pred – A dictionary mapping from experimental metric name to predicted values.
se_pred – A dictionary mapping from experimental metric name to predicted standard errors.
fit_metrics_dict – A dictionary mapping from model fit metric name to a ModelFitMetricProtocol function that evaluates a model fit metric.
- Returns:
A nested dictionary mapping from model fit and experimental metric names to their corresponding model fit metrics values.
- ax.utils.stats.model_fit_stats.mean_of_the_standardized_error(y_obs: ndarray, y_pred: ndarray, se_pred: ndarray) float [source]¶
Computes the mean of the error standardized by the predictive standard deviation of the model se_pred. If the model makes good predictions and its uncertainty is quantified well, should be close to 0 and be normally distributed.
NOTE: This assumes that se_pred is the predictive standard deviation of the observations of the objective y, not the predictive standard deviation of the objective f itself. In practice, this will matter for very noisy observations.
- Parameters:
y_obs – An array of observations for a single metric.
y_pred – An array of the predicted values corresponding to y_obs.
se_pred – An array of the standard errors of the predicted values.
- Returns:
The scalar mean of the standardized error.
- ax.utils.stats.model_fit_stats.std_of_the_standardized_error(y_obs: ndarray, y_pred: ndarray, se_pred: ndarray) float [source]¶
Standard deviation of the error standardized by the predictive standard deviation of the model se_pred. If the uncertainty is quantified well, should be close to 1.
NOTE: This assumes that se_pred is the predictive standard deviation of the observations of the objective y, not the predictive standard deviation of the objective f itself. In practice, this will matter for very noisy observations.
- Parameters:
y_obs – An array of observations for a single metric.
y_pred – An array of the predicted values corresponding to y_obs.
se_pred – An array of the standard errors of the predicted values.
- Returns:
The scalar standard deviation of the standardized error.
Storage¶
Deletion¶
Testing¶
Backend Scheduler¶
- class ax.utils.testing.backend_scheduler.AsyncSimulatedBackendScheduler(experiment: Experiment, generation_strategy: GenerationStrategy, max_pending_trials: int, options: SchedulerOptions)[source]¶
Bases:
Scheduler
A Scheduler that uses a simulated backend for Ax asynchronous benchmarks.
- property backend_simulator: BackendSimulator¶
Get the
BackendSimulator
stored on the runner of the experiment.- Returns:
The backend simulator.
- experiment: Experiment¶
- generation_strategy: GenerationStrategyInterface¶
- logger: LoggerAdapter¶
- should_stop_trials_early(trial_indices: Set[int]) Dict[int, Optional[str]] [source]¶
Given a set of trial indices, decide whether or not to early-stop running trials using the
early_stopping_strategy
.- Parameters:
trial_indices – Indices of trials to consider for early stopping.
- Returns:
Dict with new suggested
TrialStatus
as keys and a set of indices of trials to update (subset of initially-passed trials) as values.
Backend Simulator¶
- class ax.utils.testing.backend_simulator.BackendSimulator(options: Optional[BackendSimulatorOptions] = None, queued: Optional[List[SimTrial]] = None, running: Optional[List[SimTrial]] = None, failed: Optional[List[SimTrial]] = None, completed: Optional[List[SimTrial]] = None, verbose_logging: bool = True)[source]¶
Bases:
object
Simulator for a backend deployment with concurrent dispatch and a queue.
- classmethod from_state(state: BackendSimulatorState)[source]¶
Construct a simulator from a state.
- Parameters:
state – A
BackendSimulatorState
to set the simulator to.- Returns:
A
BackendSimulator
with the desired state.
- get_sim_trial_by_index(trial_index: int) Optional[SimTrial] [source]¶
Get a
SimTrial
bytrial_index
.- Parameters:
trial_index – The index of the trial to return.
- Returns:
A
SimTrial
with the indextrial_index
or None if not found.
- lookup_trial_index_status(trial_index: int) Optional[TrialStatus] [source]¶
Lookup the trial status of a
trial_index
.- Parameters:
trial_index – The index of the trial to check.
- Returns:
A
TrialStatus
.
- new_trial(trial: SimTrial, status: TrialStatus) None [source]¶
Register a trial into the simulator.
- Parameters:
trial – A new trial to add.
status – The status of the new trial, either STAGED (add to
self._queued
) or RUNNING (add toself._running
).
- run_trial(trial_index: int, runtime: float) None [source]¶
Run a simulated trial.
- Parameters:
trial_index – The index of the trial (usually the Ax trial index)
runtime – The runtime of the simulation. Typically sampled from the runtime model of a simulation model.
Internally, the runtime is scaled by the time_scaling factor, so that the simulation can run arbitrarily faster than the underlying evaluation.
- state() BackendSimulatorState [source]¶
Return a
BackendSimulatorState
containing the state of the simulator.
- status() SimStatus [source]¶
Return the internal status of the simulator.
- Returns:
A
SimStatus
object representing the current simulator state.
- class ax.utils.testing.backend_simulator.BackendSimulatorOptions(max_concurrency: int = 1, time_scaling: float = 1.0, failure_rate: float = 0.0, internal_clock: Optional[float] = None, use_update_as_start_time: bool = False)[source]¶
Bases:
object
Settings for the BackendSimulator.
- Parameters:
max_concurrency – The maximum number of trials that can be run in parallel.
time_scaling – The factor to scale down the runtime of the tasks by. If
runtime
is the actual runtime of a trial, the simulation time will beruntime / time_scaling
.failure_rate – The rate at which the trials are failing. For now, trials fail independently with at coin flip based on that rate.
internal_clock – The initial state of the internal clock. If None, the simulator uses
time.time()
as the clock.use_update_as_start_time – Whether the start time of a new trial should be logged as the current time (at time of update) or end time of previous trial. This makes sense when using the internal clock and the BackendSimulator is simulated forward by an external process (such as Scheduler).
- class ax.utils.testing.backend_simulator.BackendSimulatorState(options: BackendSimulatorOptions, verbose_logging: bool, queued: List[Dict[str, Optional[float]]], running: List[Dict[str, Optional[float]]], failed: List[Dict[str, Optional[float]]], completed: List[Dict[str, Optional[float]]])[source]¶
Bases:
object
State of the BackendSimulator.
- Parameters:
options – The BackendSimulatorOptions associated with this simulator.
verbose_logging – Whether the simulator is using verbose logging.
queued – Currently queued trials.
running – Currently running trials.
failed – Currently failed trials.
completed – Currently completed trials.
- options: BackendSimulatorOptions¶
- class ax.utils.testing.backend_simulator.SimStatus(queued: List[int], running: List[int], failed: List[int], time_remaining: List[float], completed: List[int])[source]¶
Bases:
object
Container for status of the simulation.
- class ax.utils.testing.backend_simulator.SimTrial(trial_index: int, sim_runtime: float, sim_start_time: Optional[float] = None, sim_queued_time: Optional[float] = None, sim_completed_time: Optional[float] = None)[source]¶
Bases:
object
Container for the simulation tasks.
Benchmark Stubs¶
- ax.utils.testing.benchmark_stubs.get_aggregated_benchmark_result() AggregatedBenchmarkResult [source]¶
- ax.utils.testing.benchmark_stubs.get_benchmark_problem() BenchmarkProblem [source]¶
- ax.utils.testing.benchmark_stubs.get_benchmark_result() BenchmarkResult [source]¶
- ax.utils.testing.benchmark_stubs.get_moo_surrogate() MOOSurrogateBenchmarkProblem [source]¶
- ax.utils.testing.benchmark_stubs.get_multi_objective_benchmark_problem(infer_noise: bool = True, num_trials: int = 4) MultiObjectiveBenchmarkProblem [source]¶
- ax.utils.testing.benchmark_stubs.get_single_objective_benchmark_problem(infer_noise: bool = True, num_trials: int = 4) SingleObjectiveBenchmarkProblem [source]¶
- ax.utils.testing.benchmark_stubs.get_sobol_benchmark_method() BenchmarkMethod [source]¶
- ax.utils.testing.benchmark_stubs.get_sobol_gpei_benchmark_method() BenchmarkMethod [source]¶
- ax.utils.testing.benchmark_stubs.get_soo_surrogate() SOOSurrogateBenchmarkProblem [source]¶
Core Stubs¶
- class ax.utils.testing.core_stubs.CustomTestMetric(name: str, test_attribute: str, lower_is_better: Optional[bool] = None)[source]¶
Bases:
Metric
- class ax.utils.testing.core_stubs.DummyEarlyStoppingStrategy(early_stop_trials: Optional[Dict[int, Optional[str]]] = None)[source]¶
Bases:
BaseEarlyStoppingStrategy
- should_stop_trials_early(trial_indices: Set[int], experiment: Experiment, **kwargs: Dict[str, Any]) Dict[int, Optional[str]] [source]¶
Decide whether to complete trials before evaluation is fully concluded.
Typical examples include stopping a machine learning model’s training, or halting the gathering of samples before some planned number are collected.
- Parameters:
trial_indices – Indices of candidate trials to stop early.
experiment – Experiment that contains the trials and other contextual data.
- Returns:
A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason.
- class ax.utils.testing.core_stubs.DummyGlobalStoppingStrategy(min_trials: int, trial_to_stop: int)[source]¶
Bases:
BaseGlobalStoppingStrategy
A dummy Global Stopping Strategy which stops the optimization after a pre-specified number of trials are completed.
- class ax.utils.testing.core_stubs.SpecialGenerationStrategy[source]¶
Bases:
GenerationStrategyInterface
A subclass of GenerationStrategyInterface to be used for testing how methods respond to subtypes other than GenerationStrategy.
- gen_for_multiple_trials_with_multiple_models(experiment: Experiment, num_generator_runs: int, data: Optional[Data] = None, n: int = 1) List[List[GeneratorRun]] [source]¶
Produce GeneratorRuns for multiple trials at once with the possibility of ensembling, or using multiple models per trial, getting multiple GeneratorRuns per trial.
- Parameters:
experiment – Experiment, for which the generation strategy is producing a new generator run in the course of gen, and to which that generator run will be added as trial(s). Information stored on the experiment (e.g., trial statuses) is used to determine which model will be used to produce the generator run returned from this method.
data – Optional data to be passed to the underlying model’s gen, which is called within this method and actually produces the resulting generator run. By default, data is all data on the experiment.
n – Integer representing how many trials should be in the generator run produced by this method. NOTE: Some underlying models may ignore the
n
and produce a model-determined number of arms. In that case this method will also output a generator run with number of arms that can differ fromn
.pending_observations – A map from metric name to pending observations for that metric, used by some models to avoid resuggesting points that are currently being evaluated.
- Returns:
A list of lists of lists generator runs. Each outer list represents a trial being suggested and each inner list represents a generator run for that trial.
- class ax.utils.testing.core_stubs.TestTrial(experiment: core.experiment.Experiment, trial_type: Optional[str] = None, ttl_seconds: Optional[int] = None, index: Optional[int] = None)[source]¶
Bases:
BaseTrial
Trial class to test unsupported trial type error
- ax.utils.testing.core_stubs.get_abandoned_arm() AbandonedArm [source]¶
- ax.utils.testing.core_stubs.get_acquisition_type() Type[Acquisition] [source]¶
- ax.utils.testing.core_stubs.get_and_early_stopping_strategy() AndEarlyStoppingStrategy [source]¶
- ax.utils.testing.core_stubs.get_arm_weights1() MutableMapping[Arm, float] [source]¶
- ax.utils.testing.core_stubs.get_arm_weights2() MutableMapping[Arm, float] [source]¶
- ax.utils.testing.core_stubs.get_arms_from_dict(arm_weights_dict: MutableMapping[Arm, float]) List[Arm] [source]¶
- ax.utils.testing.core_stubs.get_augmented_branin_metric(name: str = 'aug_branin') AugmentedBraninMetric [source]¶
- ax.utils.testing.core_stubs.get_augmented_branin_optimization_config() OptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_augmented_hartmann_metric(name: str = 'aug_hartmann') AugmentedHartmann6Metric [source]¶
- ax.utils.testing.core_stubs.get_augmented_hartmann_optimization_config() OptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_batch_trial(abandon_arm: bool = True, experiment: Optional[Experiment] = None, constrain_search_space: bool = True) BatchTrial [source]¶
- ax.utils.testing.core_stubs.get_batch_trial_with_repeated_arms(num_repeated_arms: int) BatchTrial [source]¶
Create a batch that contains both new arms and N arms from the last existed trial in the experiment. Where N is equal to the input argument ‘num_repeated_arms’.
- ax.utils.testing.core_stubs.get_botorch_model() BoTorchModel [source]¶
- ax.utils.testing.core_stubs.get_botorch_model_with_default_acquisition_class() BoTorchModel [source]¶
- ax.utils.testing.core_stubs.get_botorch_model_with_surrogate_specs() BoTorchModel [source]¶
- ax.utils.testing.core_stubs.get_branin_data(trial_indices: Optional[Iterable[int]] = None, trials: Optional[Iterable[Trial]] = None) Data [source]¶
- ax.utils.testing.core_stubs.get_branin_data_batch(batch: BatchTrial) Data [source]¶
- ax.utils.testing.core_stubs.get_branin_data_multi_objective(trial_indices: Optional[Iterable[int]] = None, num_objectives: int = 2) Data [source]¶
- ax.utils.testing.core_stubs.get_branin_experiment(has_optimization_config: bool = True, with_batch: bool = False, with_trial: bool = False, with_status_quo: bool = False, with_fidelity_parameter: bool = False, with_choice_parameter: bool = False, with_str_choice_param: bool = False, search_space: Optional[SearchSpace] = None, minimize: bool = False, named: bool = True, with_completed_trial: bool = False) Experiment [source]¶
- ax.utils.testing.core_stubs.get_branin_experiment_with_multi_objective(has_optimization_config: bool = True, has_objective_thresholds: bool = False, with_batch: bool = False, with_status_quo: bool = False, with_fidelity_parameter: bool = False, num_objectives: int = 2) Experiment [source]¶
- ax.utils.testing.core_stubs.get_branin_experiment_with_status_quo_trials(num_sobol_trials: int = 5, multi_objective: bool = False) Tuple[Experiment, ObservationFeatures] [source]¶
- ax.utils.testing.core_stubs.get_branin_experiment_with_timestamp_map_metric(with_status_quo: bool = False, rate: Optional[float] = None, map_tracking_metric: bool = False) Experiment [source]¶
- ax.utils.testing.core_stubs.get_branin_metric(name: str = 'branin', lower_is_better: bool = True) BraninMetric [source]¶
- ax.utils.testing.core_stubs.get_branin_multi_objective_optimization_config(has_objective_thresholds: bool = False, num_objectives: int = 2) MultiObjectiveOptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_branin_objective(name: str = 'branin', minimize: bool = False) Objective [source]¶
- ax.utils.testing.core_stubs.get_branin_optimization_config(minimize: bool = False) OptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_branin_search_space(with_fidelity_parameter: bool = False, with_choice_parameter: bool = False, with_str_choice_param: bool = False) SearchSpace [source]¶
- ax.utils.testing.core_stubs.get_branin_with_multi_task(with_multi_objective: bool = False) Experiment [source]¶
- ax.utils.testing.core_stubs.get_choice_parameter() ChoiceParameter [source]¶
- ax.utils.testing.core_stubs.get_data(metric_name: str = 'ax_test_metric', trial_index: int = 0, num_non_sq_arms: int = 4, include_sq: bool = True) Data [source]¶
- ax.utils.testing.core_stubs.get_dataset(num_samples: int = 2, d: int = 2, m: int = 2, has_observation_noise: bool = False, feature_names: Optional[List[str]] = None, outcome_names: Optional[List[str]] = None, tkwargs: Optional[Dict[str, Any]] = None, seed: Optional[int] = None) SupervisedDataset [source]¶
Constructs a SupervisedDataset based on the given arguments.
- Parameters:
num_samples – The number of samples in the dataset.
d – The dimension of the features.
m – The number of outcomes.
has_observation_noise – If True, includes Yvar in the dataset.
feature_names – A list of feature names. Defaults to x0, x1…
outcome_names – A list of outcome names. Defaults to y0, y1…
tkwargs – Optional dictionary of tensor kwargs, such as dtype and device.
seed – An optional seed used to generate the data.
- ax.utils.testing.core_stubs.get_default_scheduler_options() SchedulerOptions [source]¶
- ax.utils.testing.core_stubs.get_dict_lookup_metric() DictLookupMetric [source]¶
- ax.utils.testing.core_stubs.get_experiment(with_status_quo: bool = True, constrain_search_space: bool = True) Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_batch_and_single_trial() Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_batch_trial(constrain_search_space: bool = True) Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_custom_runner_and_metric(constrain_search_space: bool = True, immutable: bool = False, multi_objective: bool = False) Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_data() Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_map_data() Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_map_data_type() Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_multi_objective() Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_observations(observations: List[List[float]], minimize: bool = False, scalarized: bool = False, constrained: bool = False, with_tracking_metrics: bool = False, search_space: Optional[SearchSpace] = None) Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_repeated_arms(num_repeated_arms: int) Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_scalarized_objective_and_outcome_constraint() Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_trial() Experiment [source]¶
- ax.utils.testing.core_stubs.get_experiment_with_trial_with_ttl() Experiment [source]¶
- ax.utils.testing.core_stubs.get_factorial_experiment(has_optimization_config: bool = True, with_batch: bool = False, with_status_quo: bool = False) Experiment [source]¶
- ax.utils.testing.core_stubs.get_factorial_metric(name: str = 'success_metric') FactorialMetric [source]¶
- ax.utils.testing.core_stubs.get_fixed_parameter() FixedParameter [source]¶
- ax.utils.testing.core_stubs.get_generator_run() GeneratorRun [source]¶
- ax.utils.testing.core_stubs.get_generator_run2() GeneratorRun [source]¶
- ax.utils.testing.core_stubs.get_hartmann_metric(name: str = 'hartmann') Hartmann6Metric [source]¶
- ax.utils.testing.core_stubs.get_hartmann_optimization_config() OptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_hartmann_search_space(with_fidelity_parameter: bool = False) SearchSpace [source]¶
- ax.utils.testing.core_stubs.get_hierarchical_search_space(with_fixed_parameter: bool = False) HierarchicalSearchSpace [source]¶
- ax.utils.testing.core_stubs.get_hierarchical_search_space_experiment(num_observations: int = 0) Experiment [source]¶
- ax.utils.testing.core_stubs.get_high_dimensional_branin_experiment() Experiment [source]¶
- ax.utils.testing.core_stubs.get_hss_trials_with_fixed_parameter(exp: Experiment) Dict[int, BaseTrial] [source]¶
- ax.utils.testing.core_stubs.get_improvement_global_stopping_strategy() ImprovementGlobalStoppingStrategy [source]¶
- ax.utils.testing.core_stubs.get_l2_reg_weight_parameter() RangeParameter [source]¶
- ax.utils.testing.core_stubs.get_large_factorial_search_space(num_levels: int = 10, num_parameters: int = 6) SearchSpace [source]¶
- ax.utils.testing.core_stubs.get_large_ordinal_search_space(n_ordinal_choice_parameters: int, n_continuous_range_parameters: int) SearchSpace [source]¶
- ax.utils.testing.core_stubs.get_lr_parameter() RangeParameter [source]¶
- ax.utils.testing.core_stubs.get_many_branin_objective_opt_config(n_objectives: int) MultiObjectiveOptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_map_optimization_config() OptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_model_parameter(with_fixed_parameter: bool = False) ChoiceParameter [source]¶
- ax.utils.testing.core_stubs.get_model_predictions() Tuple[Dict[str, List[float]], Dict[str, Dict[str, List[float]]]] [source]¶
- ax.utils.testing.core_stubs.get_model_predictions_per_arm() Dict[str, Tuple[Dict[str, float], Optional[Dict[str, Dict[str, float]]]]] [source]¶
- ax.utils.testing.core_stubs.get_multi_objective_optimization_config(custom_metric: bool = False) MultiObjectiveOptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_multi_type_experiment(add_trial_type: bool = True, add_trials: bool = False, num_arms: int = 10) MultiTypeExperiment [source]¶
- ax.utils.testing.core_stubs.get_multi_type_experiment_with_multi_objective(add_trials: bool = False) MultiTypeExperiment [source]¶
- ax.utils.testing.core_stubs.get_num_boost_rounds_parameter() RangeParameter [source]¶
- ax.utils.testing.core_stubs.get_objective_threshold(metric_name: str = 'm1', bound: float = -0.25, comparison_op: ComparisonOp = ComparisonOp.GEQ) ObjectiveThreshold [source]¶
- ax.utils.testing.core_stubs.get_optimization_config() OptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_optimization_config_no_constraints() OptimizationConfig [source]¶
- ax.utils.testing.core_stubs.get_or_early_stopping_strategy() OrEarlyStoppingStrategy [source]¶
- ax.utils.testing.core_stubs.get_order_constraint() OrderConstraint [source]¶
- ax.utils.testing.core_stubs.get_ordered_choice_parameter() ChoiceParameter [source]¶
- ax.utils.testing.core_stubs.get_parameter_constraint(param_x: str = 'x', param_y: str = 'w') ParameterConstraint [source]¶
- ax.utils.testing.core_stubs.get_parameter_distribution() ParameterDistribution [source]¶
- ax.utils.testing.core_stubs.get_percentile_early_stopping_strategy() PercentileEarlyStoppingStrategy [source]¶
- ax.utils.testing.core_stubs.get_percentile_early_stopping_strategy_with_non_objective_metric_name() PercentileEarlyStoppingStrategy [source]¶
- ax.utils.testing.core_stubs.get_percentile_early_stopping_strategy_with_true_objective_metric_name() PercentileEarlyStoppingStrategy [source]¶
- ax.utils.testing.core_stubs.get_range_parameter() RangeParameter [source]¶
- ax.utils.testing.core_stubs.get_range_parameter2() RangeParameter [source]¶
- ax.utils.testing.core_stubs.get_risk_measure() RiskMeasure [source]¶
- ax.utils.testing.core_stubs.get_robust_branin_experiment(risk_measure: Optional[RiskMeasure] = None, optimization_config: Optional[OptimizationConfig] = None, num_sobol_trials: int = 2) Experiment [source]¶
- ax.utils.testing.core_stubs.get_robust_search_space(lb: float = 0.0, ub: float = 5.0, multivariate: bool = False, use_discrete: bool = False, num_samples: int = 4) RobustSearchSpace [source]¶
- ax.utils.testing.core_stubs.get_robust_search_space_environmental(lb: float = 0.0, ub: float = 5.0) RobustSearchSpace [source]¶
- ax.utils.testing.core_stubs.get_scalarized_outcome_constraint() ScalarizedOutcomeConstraint [source]¶
- ax.utils.testing.core_stubs.get_scheduler_options_batch_trial() SchedulerOptions [source]¶
- ax.utils.testing.core_stubs.get_search_space(constrain_search_space: bool = True) SearchSpace [source]¶
- ax.utils.testing.core_stubs.get_search_space_for_range_value(min: float = 3.0, max: float = 6.0) SearchSpace [source]¶
- ax.utils.testing.core_stubs.get_search_space_for_range_values(min: float = 3.0, max: float = 6.0) SearchSpace [source]¶
- ax.utils.testing.core_stubs.get_search_space_with_choice_parameters(num_ordered_parameters: int = 2, num_unordered_choices: int = 5) SearchSpace [source]¶
- ax.utils.testing.core_stubs.get_sebo_acquisition_class() Type[SEBOAcquisition] [source]¶
- ax.utils.testing.core_stubs.get_sum_constraint1() SumConstraint [source]¶
- ax.utils.testing.core_stubs.get_sum_constraint2() SumConstraint [source]¶
- ax.utils.testing.core_stubs.get_synthetic_runner() SyntheticRunner [source]¶
- ax.utils.testing.core_stubs.get_task_choice_parameter() ChoiceParameter [source]¶
- ax.utils.testing.core_stubs.get_test_map_data_experiment(num_trials: int, num_fetches: int, num_complete: int, map_tracking_metric: bool = False) Experiment [source]¶
- ax.utils.testing.core_stubs.get_threshold_early_stopping_strategy() ThresholdEarlyStoppingStrategy [source]¶
- ax.utils.testing.core_stubs.get_weights_from_dict(arm_weights_dict: MutableMapping[Arm, float]) List[float] [source]¶
- ax.utils.testing.core_stubs.get_winsorization_config() WinsorizationConfig [source]¶
- ax.utils.testing.core_stubs.run_branin_experiment_with_generation_strategy(generation_strategy: GenerationStrategy, num_trials: int = 6, kwargs_for_get_branin_experiment: Optional[Dict[str, Any]] = None) Experiment [source]¶
Gets a Branin experiment using any given kwargs and runs num_trials trials using the given generation strategy.
Modeling Stubs¶
- ax.utils.testing.modeling_stubs.get_experiment_for_value() Experiment [source]¶
- ax.utils.testing.modeling_stubs.get_generation_strategy(with_experiment: bool = False, with_callable_model_kwarg: bool = True, with_completion_criteria: int = 0, with_generation_nodes: bool = False) GenerationStrategy [source]¶
- ax.utils.testing.modeling_stubs.get_legacy_list_surrogate_generation_step_as_dict() Dict[str, Any] [source]¶
For use ensuring backwards compatibility loading the now deprecated ListSurrogate.
- ax.utils.testing.modeling_stubs.get_observation(first_metric_name: str = 'a', second_metric_name: str = 'b') Observation [source]¶
- ax.utils.testing.modeling_stubs.get_observation1(first_metric_name: str = 'a', second_metric_name: str = 'b') Observation [source]¶
- ax.utils.testing.modeling_stubs.get_observation1trans(first_metric_name: str = 'a', second_metric_name: str = 'b') Observation [source]¶
- ax.utils.testing.modeling_stubs.get_observation2(first_metric_name: str = 'a', second_metric_name: str = 'b') Observation [source]¶
- ax.utils.testing.modeling_stubs.get_observation2trans(first_metric_name: str = 'a', second_metric_name: str = 'b') Observation [source]¶
- ax.utils.testing.modeling_stubs.get_observation_features() ObservationFeatures [source]¶
- ax.utils.testing.modeling_stubs.get_observation_status_quo0(first_metric_name: str = 'a', second_metric_name: str = 'b') Observation [source]¶
- ax.utils.testing.modeling_stubs.get_observation_status_quo1(first_metric_name: str = 'a', second_metric_name: str = 'b') Observation [source]¶
- ax.utils.testing.modeling_stubs.get_surrogate_as_dict() Dict[str, Any] [source]¶
For use ensuring backwards compatibility when loading Surrogate with input_transform and outcome_transform kwargs.
- ax.utils.testing.modeling_stubs.get_surrogate_generation_step() GenerationStep [source]¶
- ax.utils.testing.modeling_stubs.get_surrogate_spec_as_dict(model_class: Optional[str] = None, with_legacy_input_transform: bool = False) Dict[str, Any] [source]¶
For use ensuring backwards compatibility when loading SurrogateSpec with input_transform and outcome_transform kwargs.
- ax.utils.testing.modeling_stubs.sobol_gpei_generation_node_gs() GenerationStrategy [source]¶
Returns a basic SOBOL +GPEI GS usecase using GenerationNodes for testing
- class ax.utils.testing.modeling_stubs.transform_1(search_space: Optional[SearchSpace] = None, observations: Optional[List[Observation]] = None, modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, config: Optional[TConfig] = None)[source]¶
Bases:
Transform
- config: TConfig¶
- modelbridge: Optional[modelbridge_module.base.ModelBridge]¶
- transform_observation_features(observation_features: List[ObservationFeatures]) List[ObservationFeatures] [source]¶
Transform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters:
observation_features – Observation features
Returns: transformed observation features
- transform_optimization_config(optimization_config: OptimizationConfig, modelbridge: Optional[ModelBridge], fixed_features: Optional[ObservationFeatures]) OptimizationConfig [source]¶
Transform optimization config.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters:
optimization_config – The optimization config
Returns: transformed optimization config.
- transform_search_space(search_space: SearchSpace) SearchSpace [source]¶
Transform search space.
The transforms are typically done in-place. This calls two private methods, _transform_search_space, which transforms the core search space attributes, and _transform_parameter_distributions, which transforms the distributions when using a RobustSearchSpace.
- Parameters:
search_space – The search space
Returns: transformed search space.
- untransform_observation_features(observation_features: List[ObservationFeatures]) List[ObservationFeatures] [source]¶
Untransform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters:
observation_features – Observation features in the transformed space
Returns: observation features in the original space
- class ax.utils.testing.modeling_stubs.transform_2(search_space: Optional[SearchSpace] = None, observations: Optional[List[Observation]] = None, modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, config: Optional[TConfig] = None)[source]¶
Bases:
Transform
- config: TConfig¶
- modelbridge: Optional[modelbridge_module.base.ModelBridge]¶
- transform_observation_features(observation_features: List[ObservationFeatures]) List[ObservationFeatures] [source]¶
Transform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters:
observation_features – Observation features
Returns: transformed observation features
- transform_optimization_config(optimization_config: OptimizationConfig, modelbridge: Optional[ModelBridge], fixed_features: Optional[ObservationFeatures]) OptimizationConfig [source]¶
Transform optimization config.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters:
optimization_config – The optimization config
Returns: transformed optimization config.
- transform_search_space(search_space: SearchSpace) SearchSpace [source]¶
Transform search space.
The transforms are typically done in-place. This calls two private methods, _transform_search_space, which transforms the core search space attributes, and _transform_parameter_distributions, which transforms the distributions when using a RobustSearchSpace.
- Parameters:
search_space – The search space
Returns: transformed search space.
- untransform_observation_features(observation_features: List[ObservationFeatures]) List[ObservationFeatures] [source]¶
Untransform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters:
observation_features – Observation features in the transformed space
Returns: observation features in the original space
Mocking¶
- ax.utils.testing.mock.fast_botorch_optimize(f: Callable) Callable [source]¶
Wraps f in the fast_botorch_optimize_context_manager for use as a decorator.
- ax.utils.testing.mock.fast_botorch_optimize_context_manager(force: bool = False) Generator[None, None, None] [source]¶
A context manager to force botorch to speed up optimization. Currently, the primary tactic is to force the underlying scipy methods to stop after just one iteration.
- force: If True will not raise an AssertionError if no mocks are called.
USE RESPONSIBLY.
Test Init Files¶
Torch Stubs¶
- ax.utils.testing.torch_stubs.get_torch_test_data(dtype: dtype = torch.float32, cuda: bool = False, constant_noise: bool = True, task_features: Optional[List[int]] = None, offset: float = 0.0) Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tuple[float, float]], List[int], List[str], List[str]] [source]¶
Unittest Conventions¶
Utils¶
Test Metrics¶
Backend Simulator Map¶
- class ax.utils.testing.metrics.backend_simulator_map.BackendSimulatorTimestampMapMetric(name: str, param_names: Iterable[str], noise_sd: float = 0.0, lower_is_better: Optional[bool] = None, cache_evaluations: bool = True)[source]¶
Bases:
NoisyFunctionMapMetric
A metric that interfaces with an underlying
BackendSimulator
and returns timestamp map data.- convert_to_timestamps(start_time: float, end_time: float) List[float] [source]¶
Given a starting and current time, get the list of intermediate timestamps at which we have observations.
Branin Backend Map¶
- class ax.utils.testing.metrics.branin_backend_map.BraninBackendMapMetric(name: str, param_names: List[str], noise_sd: float = 0.0, lower_is_better: Optional[bool] = True, cache_evaluations: bool = True, rate: float = 0.5, delta_t: float = 1.0)[source]¶
Bases:
BackendSimulatorTimestampMapMetric
,BraninTimestampMapMetric
A Branin
BackendSimulatorTimestampMapMetric
with a multiplicative factor of1 - exp(-rate * t)
wheret
is the runtime of the trial.
Tutorials¶
Neural Net¶
- class ax.utils.tutorials.cnn_utils.CNN[source]¶
Bases:
Module
Convolutional Neural Network.
- forward(x: Tensor) Tensor [source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- ax.utils.tutorials.cnn_utils.evaluate(net: Module, data_loader: DataLoader, dtype: dtype, device: device) float [source]¶
Compute classification accuracy on provided dataset.
- Parameters:
net – trained model
data_loader – DataLoader containing the evaluation set
dtype – torch dtype
device – torch device
- Returns:
classification accuracy
- Return type:
- ax.utils.tutorials.cnn_utils.get_partition_data_loaders(train_valid_set: Dataset, test_set: Dataset, downsample_pct: float = 0.5, train_pct: float = 0.8, batch_size: int = 128, num_workers: int = 0, deterministic_partitions: bool = False, downsample_pct_test: Optional[float] = None) Tuple[DataLoader, DataLoader, DataLoader] [source]¶
- Helper function for partitioning training data into training and validation sets,
downsampling data, and initializing DataLoaders for each partition.
- Parameters:
train_valid_set – torch.dataset
downsample_pct – the proportion of the dataset to use for training, and validation
train_pct – the proportion of the downsampled data to use for training
batch_size – how many samples per batch to load
num_workers – number of workers (subprocesses) for loading data
deterministic_partitions – whether to partition data in a deterministic fashion
downsample_pct_test – the proportion of the dataset to use for test, default to be equal to downsample_pct
- Returns:
training data DataLoader: validation data DataLoader: test data
- Return type:
DataLoader
- ax.utils.tutorials.cnn_utils.load_mnist(downsample_pct: float = 0.5, train_pct: float = 0.8, data_path: str = './data', batch_size: int = 128, num_workers: int = 0, deterministic_partitions: bool = False, downsample_pct_test: Optional[float] = None) Tuple[DataLoader, DataLoader, DataLoader] [source]¶
- Load MNIST dataset (download if necessary) and split data into training,
validation, and test sets.
- Parameters:
downsample_pct – the proportion of the dataset to use for training, validation, and test
train_pct – the proportion of the downsampled data to use for training
data_path – Root directory of dataset where MNIST/processed/training.pt and MNIST/processed/test.pt exist.
batch_size – how many samples per batch to load
num_workers – number of workers (subprocesses) for loading data
deterministic_partitions – whether to partition data in a deterministic fashion
downsample_pct_test – the proportion of the dataset to use for test, default to be equal to downsample_pct
- Returns:
training data DataLoader: validation data DataLoader: test data
- Return type:
DataLoader
- ax.utils.tutorials.cnn_utils.split_dataset(dataset: Dataset, lengths: List[int], deterministic_partitions: bool = False) List[Dataset] [source]¶
Split a dataset either randomly or deterministically.
- Parameters:
dataset – the dataset to split
lengths – the lengths of each partition
deterministic_partitions – deterministic_partitions: whether to partition data in a deterministic fashion
- Returns:
split datasets
- Return type:
List[Dataset]
- ax.utils.tutorials.cnn_utils.train(net: Module, train_loader: DataLoader, parameters: Dict[str, float], dtype: dtype, device: device) Module [source]¶
Train CNN on provided data set.
- Parameters:
net – initialized neural network
train_loader – DataLoader containing training set
parameters – dictionary containing parameters to be passed to the optimizer. - lr: default (0.001) - momentum: default (0.0) - weight_decay: default (0.0) - num_epochs: default (1)
dtype – torch dtype
device – torch device
- Returns:
trained CNN.
- Return type:
nn.Module