#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
"""
Classes for optimizing yields from chemical reactions.
References
.. [Perera2018]
D. Perera, J. W. Tucker, S. Brahmbhatt, C. Helal, A. Chong, W. Farrell,
P. Richardson, N. W. Sach. A platform for automated nanomole-scale
reaction screening and micromole-scale synthesis in flow. Science, 26.
2018.
.. [Shields2021]
B. J. Shields, J. Stevens, J. Li, et al. Bayesian reaction optimization
as a tool for chemical synthesis. Nature 590, 89–96 (2021).
"SUZUKI" involves optimization solvent, ligand, and base combinations
in a Suzuki-Miyaura coupling to optimize carbon-carbon bond formation.
See _[Perera2018] for details.
"DIRECT_ARYLATION" involves optimizing the solvent, base, and ligand chemicals
as well as the temperature and concentration for a direct arylation reaction.
See _[Shields2021] for details.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Any
from zipfile import ZipFile
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.types import TParameterization, TParamValue
from ax.utils.common.result import Err, Ok
from pyre_extensions import none_throws
[docs]
class ChemistryProblemType(Enum):
# pyre-fixme[35]: Target cannot be annotated.
SUZUKI: str = "suzuki"
# pyre-fixme[35]: Target cannot be annotated.
DIRECT_ARYLATION: str = "direct_arylation"
[docs]
@dataclass(frozen=True)
class ChemistryData:
param_names: list[str]
objective_dict: dict[tuple[TParamValue, ...], float]
[docs]
def evaluate(self, params: TParameterization) -> float:
k = tuple(params[pname] for pname in self.param_names)
return self.objective_dict[k]
@lru_cache(maxsize=8)
def _get_data(problem_type: ChemistryProblemType) -> ChemistryData:
file_path = Path(__file__).parent.joinpath("chemistry_data.zip").absolute()
with ZipFile(file_path) as zf:
with zf.open(f"{problem_type.value}.csv") as f:
df = pd.read_csv(f, index_col=0)
param_names = sorted(col for col in df.columns if col != "yield")
return ChemistryData(
param_names=param_names,
objective_dict=df.set_index(param_names)["yield"].to_dict(),
)
[docs]
class ChemistryMetric(Metric):
"""Metric for modeling chemical reactions.
Metric describing the outcomes of chemical reactions. Based on tabulate data.
Problems typically contain many discrete and categorical parameters.
Args:
name: The name of the metric.
noiseless: If True, consider observations noiseless, otherwise
sume unknown Gaussian observation noise.
problem_type: The problem type.
Attributes:
noiseless: If True, consider observations noiseless, otherwise
assume unknown Gaussian observation noise.
lower_is_better: If True, the metric should be minimized.
"""
def __init__(
self,
name: str,
noiseless: bool = False,
problem_type: ChemistryProblemType = ChemistryProblemType.SUZUKI,
lower_is_better: bool = False,
) -> None:
self.noiseless = noiseless
self.problem_type = problem_type
super().__init__(name=name, lower_is_better=lower_is_better)
[docs]
def clone(self) -> ChemistryMetric:
return self.__class__(
name=self._name,
noiseless=self.noiseless,
problem_type=self.problem_type,
lower_is_better=none_throws(self.lower_is_better),
)
[docs]
def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
try:
noise_sd = 0.0 if self.noiseless else float("nan")
data = _get_data(self.problem_type)
arm_names = []
mean = []
for name, arm in trial.arms_by_name.items():
arm_names.append(name)
val = data.evaluate(params=arm.parameters)
mean.append(val)
df = pd.DataFrame(
{
"arm_name": arm_names,
"metric_name": self.name,
"mean": mean,
"sem": noise_sd,
"trial_index": trial.index,
}
)
return Ok(value=Data(df=df))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)