#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import warnings
from typing import Any
from ax.core.experiment import Experiment
from ax.core.search_space import SearchSpace
from ax.exceptions.storage import SQADecodeError
from ax.utils.common.base import Base, SortableBase
JSON_ATTRS = ["baseline_workflow_inputs"]
# Skip over the following attrs in `copy_db_ids`:
# * _experiment (to prevent infinite loops)
# * most generator run and generation strategy metadata
# (since no Base objects are nested in there,
# and we don't have guarantees about the structure of some
# of that data, so the recursion could fail somewhere)
COPY_DB_IDS_ATTRS_TO_SKIP = {
"_best_arm_predictions",
"_bridge_kwargs",
"_candidate_metadata_by_arm_signature",
"_curr",
"_experiment",
"_gen_metadata",
"_memo_df",
"_model_kwargs",
"_model_predictions",
"_model_state_after_gen",
"_model",
"_seen_trial_indices_by_status",
"_steps",
"analysis_scheduler",
"_nodes",
# For auxiliary experiments, we don't expect any updates, so we
# don't need to recur into them during `copy_db_ids`.
"auxiliary_experiments_by_purpose",
}
SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."
[docs]
def is_foreign_key_field(field: str) -> bool:
"""Return true if field name is a foreign key field, i.e. ends in `_id`."""
return len(field) > 3 and field[-3:] == "_id"
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
[docs]
def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None:
"""Takes as input two objects, `source` and `target`, that should be identical,
except that `source` has _db_ids set and `target` doesn't. Recursively copies the
_db_ids from `source` to `target`.
Raise a SQADecodeError when the assumption of equality on `source` and `target`
is violated, since this method is meant to be used when returning a new
user-facing object after saving.
"""
if not path:
path = []
error_message_prefix = (
f"Error encountered while traversing source {path + [str(source)]} and "
f"target {path + [str(target)]}: "
)
if len(path) > 15:
# This shouldn't happen, but is a precaution against accidentally
# introducing infinite loops
raise SQADecodeError(error_message_prefix + "Encountered path of length > 10.")
if type(source) is not type(target):
if not issubclass(type(target), type(source)):
if source is None and isinstance(target, SearchSpace):
warnings.warn(
error_message_prefix + "Encountered two objects of different "
f"types: {type(source)} and {type(target)}. Continuing in the "
"special case that `source is None` and `target` is a "
"`SearchSpace`."
)
else:
raise SQADecodeError(
error_message_prefix + "Encountered two objects of different "
f"types: {type(source)} and {type(target)}."
+ SKIP_ATTRS_ERROR_SUFFIX
)
if isinstance(source, Base):
for attr, val in source.__dict__.items():
if attr.endswith("_db_id"):
# we're at a "leaf" node; copy the db_id and return
setattr(target, attr, val)
continue
# Skip attrs that are doubly private or in COPY_DB_IDS_TO_SKIP.
if attr.startswith("__") or attr in COPY_DB_IDS_ATTRS_TO_SKIP:
continue
# For Json attributes we would like to simply test for equality and not
# recurse through the Json.
# TODO: Add json_attrs as an argument and plumb through to `copy_db_ids`.
if attr in JSON_ATTRS:
source_json = getattr(source, attr)
target_json = getattr(target, attr)
if source_json != target_json:
SQADecodeError(
error_message_prefix + f"Json attribute {attr} not matching "
f"between source: {source_json} and target: {target_json}."
)
continue
# Arms are referenced twice on an Experiment object; once in
# experiment.arms_by_name/signature and once in
# trial.arms_by_name/signature. When copying db_ids, we should
# ignore the former, since it will "collapse" arms of the same
# name/signature that appear in more than one trial.
if isinstance(source, Experiment) and attr in {
"_arms_by_name",
"_arms_by_signature",
}:
continue
copy_db_ids(val, getattr(target, attr), path + [attr])
elif isinstance(source, (list, set)):
source = list(source)
target = list(target)
if len(source) != len(target):
raise SQADecodeError(
error_message_prefix + "Encountered lists of different lengths."
)
if len(source) == 0:
return
if isinstance(source[0], Base) and not isinstance(source[0], SortableBase):
raise SQADecodeError(
error_message_prefix + f"Cannot sort instances of {type(source[0])}; "
"sorting is only defined on instances of SortableBase."
)
try:
source = sorted(source)
target = sorted(target)
except TypeError as e:
if any(isinstance(o, Base) for o in source + target):
raise SQADecodeError(
error_message_prefix + f"TypeError encountered during sorting: {e}"
)
else:
# source and target are not lists of things that need to be saved
return
for index, x in enumerate(source):
copy_db_ids(x, target[index], path + [str(index)])
elif isinstance(source, dict):
for k, v in source.items():
if k not in target:
raise SQADecodeError(
error_message_prefix + "Encountered key only present "
f"in source dictionary: {k}."
)
copy_db_ids(v, target[k], path + [k])
else:
return