Source code for ax.utils.common.logger

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging
import os
from typing import Any, Optional


[docs]def get_logger( name: str, filepath: Optional[str] = None, level: int = logging.INFO, output_name: Optional[str] = None, ) -> logging.Logger: """Get an Axlogger. Sets default level to INFO, instead of WARNING. Adds timestamps to logger messages. Args: name: The name of the logger. filepath: Location of the file to log output to. If the file exists, output will be appended. If it does not exist, a new file will be created. level: The log level. output_name: The name of the logger to appear in the logged output. Useful to abbreviate long logger names. Returns: The logging.Logger object. """ if output_name is None: output_name = name formatter = logging.Formatter( fmt=f"[%(levelname)s %(asctime)s] {output_name}: %(message)s", datefmt="%m-%d %H:%M:%S", ) logger = logging.getLogger(name) logger.setLevel(level=level) # Add timestamps to log messages. if not logger.handlers: console = logging.StreamHandler() console.setLevel(level=level) console.setFormatter(formatter) logger.addHandler(console) logger.propagate = False if filepath is None: return logger if os.path.isfile(filepath): logger.warning(f"Log file ({filepath}) already exists, appending logs.") logfile = logging.FileHandler(filepath) logfile.setLevel(level=level) logfile.setFormatter(formatter) logger.addHandler(logfile) return logger
# pyre-ignore (ignoring Any in argument and output typing) def _round_floats_for_logging(item: Any, decimal_places: int = 2) -> Any: """Round a number or numbers in a mapping to a given number of decimal places. If item or values in dictionary is not a number, returns it as it. """ if isinstance(item, float): return round(item, 2) elif isinstance(item, dict): return { k: _round_floats_for_logging(item=v, decimal_places=decimal_places) for k, v in item.items() } elif isinstance(item, list): return [ _round_floats_for_logging(item=i, decimal_places=decimal_places) for i in item ] elif isinstance(item, tuple): return tuple( _round_floats_for_logging(item=i, decimal_places=decimal_places) for i in item ) return item