Source code for ax.utils.flake8_plugins.docstring_checker
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import itertools
from pathlib import Path
from typing import Callable, NamedTuple
[docs]class Error(NamedTuple):
lineno: int
col: int
message: str
type: type
[docs]def should_check(filename):
# Getting options for plugins in flake8 is a bit of a hassle so we just hardcode
# our conventions.
path = Path(filename)
if path.parent.name not in ("tests", "experimental", "flake8_plugins"):
return True
with path.open() as fd:
for line in itertools.islice(fd, 5):
if line == "# check-docstrings\n":
return True
return False
[docs]class DocstringChecker:
"""
A flake8 plug-in that makes sure all public functions have a docstring
"""
name: str = "docstring checker"
version: str = "1.0.0"
fikename: str
tree: ast.Module
def __init__(self, tree, filename):
self.filename = filename
self.tree = tree
[docs] def run(self):
if not should_check(self.filename):
return
visitor = DocstringCheckerVisitor()
visitor.visit(self.tree)
yield from visitor.errors
[docs]def is_copy_doc_call(c):
"""Tries to guess if this is a call to the ``copy_doc`` decorator. This is
a purely syntactic check so if the decorator was aliased as another name]
or wrapped in another function we will fail.
"""
if not isinstance(c, ast.Call):
return False
func = c.func
if isinstance(func, ast.Attribute):
name = func.attr
elif isinstance(func, ast.Name):
name = func.id
else:
return False
return name == "copy_doc"
[docs]class DocstringCheckerVisitor(ast.NodeVisitor):
errors: list[Error]
def __init__(self) -> None:
self.errors = []
[docs] def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.check_A000(node)
self.generic_visit(node)
[docs] def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.check_A000(node)
self.generic_visit(node)
[docs] def visit_AsyncFunctionDef(self, node: ast.ClassDef) -> None:
self.check_A000(node)
self.generic_visit(node)
[docs] def check_A000(self, node: ast.AST) -> None:
if node.name.startswith("_"):
return
docstring = ast.get_docstring(node)
if docstring is None and not any(
is_copy_doc_call(dec) for dec in node.decorator_list
):
self.errors.append(A000(node))
# Error classes E, C, W and F are used by flake8. T by mypy and B by bugbear
[docs]def new_error(errorid: str, msg: str) -> Callable[[ast.AST], Error]:
full_message = f"{errorid} {msg}"
def mk_error(node: ast.AST) -> Error:
return Error(
lineno=node.lineno,
col=node.col_offset,
message=full_message,
type=DocstringChecker,
)
mk_error.__name__ = errorid
return mk_error
A000 = new_error(
"A000",
"Missing docstring. All public classes, functions and methods should have "
"docstrings (cf https://fburl.com/wiki/wbcrsoeo).",
)