Source code for stubpy.ast_pass

"""
stubpy.ast_pass
===============

AST pre-pass — harvests structural metadata from source *without executing*
the module.

This module runs a read-only walk over the source file's AST before (or
instead of) importing the module.  Because no code is executed, this pass
is free from import-time side effects.

The harvested data is stored in :class:`ASTSymbols` and fed into
:func:`~stubpy.symbols.build_symbol_table` to construct the
:class:`~stubpy.symbols.SymbolTable`.

What is harvested
-----------------

* **Classes** — name, source line, base class expressions (as strings),
  decorator names, and directly-defined methods.
* **Module-level functions** — name, line, ``async`` flag, decorator names,
  and a flag for ``@overload``-decorated variants.
* **Annotated variables** — ``name: Type = value`` at module scope.
* **``__all__``** — the explicit public API list, when present.
* **Type alias declarations** (all forms):

  - ``Name: TypeAlias = <rhs>``  — explicit PEP 613 annotation
  - ``Name = int | float``        — bare PEP 604 union
  - ``Name = Union[str, int]``    — subscripted generic
  - ``Name = int``                — known built-in or typing type name
  - ``type Name = <rhs>``         — Python 3.12+ PEP 695 soft keyword
  - ``type Stack[T] = list[T]``   — generic alias (PEP 695)

* **TypeVar / ParamSpec / TypeVarTuple / NewType** call-expression declarations.

Ignore directive
----------------

If the source file begins (before any code) with a comment containing
``# stubpy: ignore`` (case-insensitive), the harvester returns an empty
:class:`ASTSymbols` and the caller should skip stub generation for that
file.  Check :attr:`ASTSymbols.skip_file` to detect this.

What is *not* harvested
-----------------------

* Nested functions or classes inside other functions.
* Import statements (handled by :mod:`stubpy.imports`).
* Runtime values — those require the module to be executed.

Examples
--------
>>> from stubpy.ast_pass import ast_harvest
>>> syms = ast_harvest("x: int = 1\\nclass Foo: pass\\n")
>>> syms.variables[0].name
'x'
>>> syms.classes[0].name
'Foo'
"""
from __future__ import annotations

import ast
from dataclasses import dataclass, field


# ---------------------------------------------------------------------------
# Data containers for harvested metadata
# ---------------------------------------------------------------------------

[docs] @dataclass class FunctionInfo: """ Metadata for a single function or method definition from the AST. Parameters ---------- name : str lineno : int is_async : bool ``True`` for ``async def`` definitions. decorators : list of str Plain names of all decorators (e.g. ``["classmethod"]``). is_overload : bool ``True`` when ``overload`` appears in *decorators*. raw_arg_annotations : dict Maps parameter name → unparsed annotation string for every annotated parameter. Variadic names are prefixed: ``"*args"``, ``"**kwargs"``. raw_return_annotation : str or None Unparsed return-annotation string, or ``None`` when absent. kwargs_forwarded_to : list of str Names of callables to which ``**kwargs`` is forwarded in the body. Populated by the body scanner in :meth:`ASTHarvester._harvest_function`. Used by :func:`~stubpy.resolver.resolve_function_params` to expand variadic parameters into their concrete counterparts. args_forwarded_to : list of str Names of callables to which ``*args`` is forwarded in the body. Same purpose as *kwargs_forwarded_to* for positional variadics. Examples -------- >>> info = FunctionInfo(name="greet", lineno=5, is_async=False) >>> info.is_overload False >>> info.kwargs_forwarded_to [] """ name: str lineno: int is_async: bool = False decorators: list[str] = field(default_factory=list) is_overload: bool = False raw_arg_annotations: dict[str, str] = field(default_factory=dict) raw_return_annotation: str | None = None kwargs_forwarded_to: list[str] = field(default_factory=list) args_forwarded_to: list[str] = field(default_factory=list)
[docs] @dataclass class ClassInfo: """ Metadata for a single class definition from the AST. Parameters ---------- name : str lineno : int bases : list of str Base class expressions as unparsed strings (e.g. ``["Element"]``). decorators : list of str Plain decorator names. methods : list of FunctionInfo Methods defined directly in the class body. Examples -------- >>> info = ClassInfo(name="Widget", lineno=10, bases=["Element"]) >>> info.decorators [] """ name: str lineno: int bases: list[str] = field(default_factory=list) decorators: list[str] = field(default_factory=list) methods: list[FunctionInfo] = field(default_factory=list)
[docs] @dataclass class VariableInfo: """ Metadata for a module-level variable assignment. Covers both annotated assignments (``name: Type = value``) and plain assignments without annotations (``name = value``). Parameters ---------- name : str lineno : int annotation_str : str or None Unparsed annotation expression, or ``None`` for unannotated assignments. value_repr : str or None Unparsed right-hand side expression, or ``None`` when absent. """ name: str lineno: int annotation_str: str | None = None value_repr: str | None = None
[docs] @dataclass class TypeVarInfo: """ Metadata for a ``TypeVar``, ``ParamSpec``, ``TypeVarTuple``, ``TypeAlias``, or ``NewType`` declaration. Parameters ---------- name : str lineno : int kind : str One of ``"TypeVar"``, ``"ParamSpec"``, ``"TypeVarTuple"``, ``"TypeAlias"``, ``"NewType"``. source_str : str Unparsed right-hand side expression (for TypeVar/NewType) or the aliased type expression (for TypeAlias). """ name: str lineno: int kind: str # "TypeVar" | "ParamSpec" | "TypeVarTuple" | "TypeAlias" | "NewType" source_str: str
[docs] @dataclass class ASTSymbols: """ Container for all metadata harvested from a single source file's AST. Created by :func:`ast_harvest` and consumed by :func:`~stubpy.symbols.build_symbol_table`. Attributes ---------- classes : list of ClassInfo All top-level class definitions, in source order. functions : list of FunctionInfo All top-level function definitions, in source order. variables : list of VariableInfo All top-level annotated (and plain) variable assignments. typevar_decls : list of TypeVarInfo TypeVar / ParamSpec / TypeVarTuple / TypeAlias / NewType declarations. all_exports : list of str or None Contents of ``__all__``, or ``None`` when the module has no ``__all__`` declaration. """ classes: list[ClassInfo] = field(default_factory=list) functions: list[FunctionInfo] = field(default_factory=list) variables: list[VariableInfo] = field(default_factory=list) typevar_decls: list[TypeVarInfo] = field(default_factory=list) all_exports: list[str] | None = None # None = no __all__ found skip_file: bool = False # True when # stubpy: ignore found
# --------------------------------------------------------------------------- # Private helpers # --------------------------------------------------------------------------- _TYPEVAR_CALL_NAMES: frozenset[str] = frozenset( {"TypeVar", "ParamSpec", "TypeVarTuple", "NewType"} ) _OVERLOAD_NAMES: frozenset[str] = frozenset({"overload"}) # Known built-in type names that are always types, never plain values. # Used to recognise implicit type-alias assignments like ``Color = str``. _BUILTIN_TYPE_NAMES: frozenset[str] = frozenset({ "int", "float", "complex", "bool", "str", "bytes", "bytearray", "list", "tuple", "set", "frozenset", "dict", "type", "object", "memoryview", "range", "slice", }) # typing.__all__ names that function as types (not values or decorators). # Populated once at import time; conservative — only the names that would # never be used as plain variable values. try: import typing as _typing_mod _TYPING_TYPE_NAMES: frozenset[str] = frozenset( n for n in _typing_mod.__all__ if not n.startswith("_") and n[0].isupper() ) except Exception: _TYPING_TYPE_NAMES = frozenset({ "Any", "Callable", "ClassVar", "Dict", "Final", "FrozenSet", "Generic", "Iterator", "List", "Literal", "Optional", "Protocol", "Sequence", "Set", "Tuple", "Type", "Union", }) # Combined set of names that are "always a type" and therefore safe to # treat as implicit type aliases when they appear as a bare assignment RHS. _KNOWN_TYPE_NAMES: frozenset[str] = _BUILTIN_TYPE_NAMES | _TYPING_TYPE_NAMES def _is_implicit_alias(node: ast.expr | None) -> bool: """Return ``True`` when *node* looks like an implicit type alias RHS. Three patterns are recognised as unambiguous type alias expressions: 1. **PEP 604 union** — ``int | float`` (``ast.BinOp`` with ``BitOr``). 2. **Subscripted generic** — ``Union[str, int]``, ``list[int]``, ``Literal["a"]``, etc. (any ``ast.Subscript``). 3. **Known-type bare name** — ``int``, ``str``, ``list``, ``Any``, etc. Only names in :data:`_KNOWN_TYPE_NAMES` qualify; arbitrary names such as ``SomeClass`` or ``logger`` do not, to avoid false positives. ``ast.Constant`` (numbers, strings), ``ast.Call`` (function calls), and unrecognised ``ast.Name`` nodes are intentionally excluded. .. note:: ``Name = SomeArbitraryName`` is NOT treated as a TypeAlias because we cannot determine at parse time whether ``SomeArbitraryName`` is a type or a value without executing the module. Use ``Name: TypeAlias = SomeArbitraryName`` or the Python 3.12+ ``type Name = SomeArbitraryName`` form for unambiguous declaration. """ if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): return True if isinstance(node, ast.Subscript): return True if isinstance(node, ast.Name) and node.id in _KNOWN_TYPE_NAMES: return True return False def _decorator_name(node: ast.expr) -> str: """ Return the simple name of a decorator node. Handles both ``@name`` (:class:`ast.Name`) and ``@module.name`` (:class:`ast.Attribute`) forms. Returns ``""`` for arbitrary expressions. """ if isinstance(node, ast.Name): return node.id if isinstance(node, ast.Attribute): return node.attr return "" def _unparse(node: ast.expr | None) -> str | None: """Safely unparse an AST expression to its source string, or ``None``.""" if node is None: return None try: return ast.unparse(node) except Exception: return None def _extract_all_list(node: ast.Assign) -> list[str] | None: """ Return the string elements of ``__all__ = [...]`` / ``__all__ = (...)``. Returns ``None`` if *node* is not an ``__all__`` assignment or the right-hand side is not a literal list/tuple of strings. """ for target in node.targets: if isinstance(target, ast.Name) and target.id == "__all__": if isinstance(node.value, (ast.List, ast.Tuple)): result: list[str] = [] for elt in node.value.elts: if isinstance(elt, ast.Constant) and isinstance(elt.value, str): result.append(elt.value) return result return None def _call_func_name(call: ast.Call) -> str: """Extract the bare function name from a :class:`ast.Call` node, or ``""``.""" func = call.func if isinstance(func, ast.Name): return func.id if isinstance(func, ast.Attribute): return func.attr return "" def _is_typevar_call(node: ast.expr | None) -> str | None: """ Return the TypeVar/ParamSpec/etc. kind if *node* is a call to one of those constructors, or ``None`` otherwise. """ if not isinstance(node, ast.Call): return None name = _call_func_name(node) return name if name in _TYPEVAR_CALL_NAMES else None def _has_ignore_directive(source: str) -> bool: """Return ``True`` if the source begins with a ``# stubpy: ignore`` directive. Only lines that are blank, comment-only, or a module docstring (before the first non-trivial code statement) are inspected. The check is case-insensitive and tolerates extra whitespace around the colon. Parameters ---------- source : str Raw Python source text. Returns ------- bool Examples -------- >>> _has_ignore_directive("# stubpy: ignore\\nclass Foo: pass\\n") True >>> _has_ignore_directive("# STUBPY: IGNORE\\nclass Foo: pass\\n") True >>> _has_ignore_directive("# regular comment\\nclass Foo: pass\\n") False >>> _has_ignore_directive("class Foo: pass\\n# stubpy: ignore") False """ import re as _re _IGNORE_RE = _re.compile(r"#\s*stubpy\s*:\s*ignore\b", _re.IGNORECASE) for line in source.splitlines(): stripped = line.strip() if not stripped: continue # blank line if stripped.startswith("#"): if _IGNORE_RE.search(stripped): return True continue # other comment — keep scanning # First non-blank, non-comment line reached — stop break return False # --------------------------------------------------------------------------- # Harvester # ---------------------------------------------------------------------------
[docs] class ASTHarvester(ast.NodeVisitor): """ Walk the top-level AST of a Python source file and collect structural metadata without executing any code. Only **top-level** definitions are collected (class/function/variable statements that are direct children of the module body). Statements nested inside ``if``, ``with``, or ``try`` blocks at the module level are visited transitively so that patterns like ``if TYPE_CHECKING: ...`` are still partially harvested. Parameters ---------- source : str Raw Python source text. Examples -------- >>> h = ASTHarvester("async def foo(): pass") >>> syms = h.harvest() >>> syms.functions[0].is_async True """
[docs] def __init__(self, source: str) -> None: self._source = source self.result = ASTSymbols()
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def harvest(self) -> ASTSymbols: """ Parse the source and return the populated :class:`ASTSymbols`. Returns an empty (but valid) :class:`ASTSymbols` on :exc:`SyntaxError` without raising. If the source begins (before any code) with a ``# stubpy: ignore`` comment, :attr:`~ASTSymbols.skip_file` is set to ``True`` and the returned :class:`ASTSymbols` is otherwise empty. """ # Check for the ignore directive in leading comments/blank lines. if _has_ignore_directive(self._source): self.result.skip_file = True return self.result try: tree = ast.parse(self._source) except SyntaxError: return self.result # Only visit immediate children of the module node so we don't # accidentally recurse into class bodies from visit_Module itself. for child in ast.iter_child_nodes(tree): self.visit(child) return self.result
# ------------------------------------------------------------------ # Top-level node visitors # ------------------------------------------------------------------
[docs] def visit_ClassDef(self, node: ast.ClassDef) -> None: """Harvest a class definition and its directly-defined methods.""" bases = [_unparse(b) or "" for b in node.bases] decorators = [_decorator_name(d) for d in node.decorator_list] info = ClassInfo( name=node.name, lineno=node.lineno, bases=[b for b in bases if b], decorators=[d for d in decorators if d], ) # Harvest methods defined directly in the class body only for child in node.body: if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): info.methods.append(self._harvest_function(child)) self.result.classes.append(info)
# Do NOT recurse further — nested classes stay out of scope here
[docs] def visit_FunctionDef(self, node: ast.FunctionDef) -> None: """Harvest a top-level synchronous function.""" self.result.functions.append(self._harvest_function(node))
[docs] def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: """Harvest a top-level asynchronous function.""" self.result.functions.append(self._harvest_function(node, is_async=True))
[docs] def visit_Assign(self, node: ast.Assign) -> None: """ Handle: 1. ``__all__ = [...]`` — populates :attr:`~ASTSymbols.all_exports`. 2. ``X = TypeVar(...)`` / ``X = NewType(...)`` — explicit TypeVar declarations. 3. ``X = int | float`` / ``X = Union[int, str]`` — implicit TypeAlias (bare union or subscripted generic RHS without an annotation). 4. Plain ``name = value`` assignments — recorded as :class:`VariableInfo`. """ # 1. __all__ all_names = _extract_all_list(node) if all_names is not None: self.result.all_exports = all_names return if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): target_name = node.targets[0].id # 2. TypeVar / ParamSpec / TypeVarTuple / NewType (X = TypeVar("X")) kind = _is_typevar_call(node.value) if kind: self.result.typevar_decls.append(TypeVarInfo( name=target_name, lineno=node.lineno, kind=kind, source_str=_unparse(node.value) or "", )) return # 3. Bare union / subscripted generic — treat as implicit TypeAlias. # e.g. ``Color = str | tuple[float, ...]`` or # ``Length = Union[str, float, int]`` if _is_implicit_alias(node.value): self.result.typevar_decls.append(TypeVarInfo( name=target_name, lineno=node.lineno, kind="TypeAlias", source_str=_unparse(node.value) or "", )) return # 4. Plain variable assignment (no annotation) for target in node.targets: if isinstance(target, ast.Name): self.result.variables.append(VariableInfo( name=target.id, lineno=node.lineno, annotation_str=None, value_repr=_unparse(node.value), ))
[docs] def visit_TypeAlias(self, node: ast.AST) -> None: """Handle Python 3.12+ ``type Name = ...`` soft-keyword statement (PEP 695). The AST node is ``ast.TypeAlias`` (available from Python 3.12). We access fields by attribute so the code compiles on Python 3.10/3.11 where the class does not exist but the method will never be called. Examples -------- The following source:: type Vector = list[float] produces a ``TypeVarInfo`` with ``kind="TypeAlias"`` and ``source_str="list[float"]``. """ # ast.TypeAlias has: .name (ast.Name), .type_params (list), .value (expr) name_node = getattr(node, "name", None) value_node = getattr(node, "value", None) if name_node is None or not hasattr(name_node, "id"): return self.result.typevar_decls.append(TypeVarInfo( name=name_node.id, lineno=getattr(node, "lineno", 0), kind="TypeAlias", source_str=_unparse(value_node) or "", ))
[docs] def visit_AnnAssign(self, node: ast.AnnAssign) -> None: """ Handle annotated assignments: * ``name: TypeAlias = int | str`` → :class:`TypeVarInfo` * ``name: Type = value`` → :class:`VariableInfo` """ if not isinstance(node.target, ast.Name): return name = node.target.id ann_str = _unparse(node.annotation) # Detect MyType: TypeAlias = ... if ann_str in ("TypeAlias", "typing.TypeAlias"): rhs = _unparse(node.value) if node.value else "" self.result.typevar_decls.append(TypeVarInfo( name=name, lineno=node.lineno, kind="TypeAlias", source_str=rhs or "", )) return self.result.variables.append(VariableInfo( name=name, lineno=node.lineno, annotation_str=ann_str, value_repr=_unparse(node.value) if node.value else None, ))
# ------------------------------------------------------------------ # Transitively visit common wrapper nodes so that top-level # definitions inside ``if TYPE_CHECKING:`` etc. are still harvested. # ------------------------------------------------------------------
[docs] def visit_If(self, node: ast.If) -> None: """Recurse into if/else bodies (handles ``if TYPE_CHECKING:`` blocks).""" for child in node.body + node.orelse: self.visit(child)
[docs] def visit_Try(self, node: ast.Try) -> None: """Recurse into try/except/else/finally bodies.""" for child in node.body + node.orelse + node.finalbody: # type: ignore[attr-defined] self.visit(child) for handler in node.handlers: for child in handler.body: self.visit(child)
[docs] def visit_TryStar(self, node: ast.AST) -> None: # Python 3.11+ ExceptionGroup self.generic_visit(node)
[docs] def visit_With(self, node: ast.With) -> None: """Recurse into ``with`` blocks.""" for child in node.body: self.visit(child)
# Suppress generic recursion for everything else (Import, Expr, etc.)
[docs] def generic_visit(self, node: ast.AST) -> None: # type: ignore[override] pass
# ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ def _harvest_function( self, node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool | None = None, ) -> FunctionInfo: """Build a :class:`FunctionInfo` from a function definition AST node. Also scans the function body to detect variadic-forwarding patterns: any call expression where ``**kwargs_name`` or ``*args_name`` is passed through becomes an entry in :attr:`~stubpy.ast_pass.FunctionInfo.kwargs_forwarded_to` / :attr:`~stubpy.ast_pass.FunctionInfo.args_forwarded_to`. These fields are consumed by :func:`~stubpy.resolver.resolve_function_params` at stub-emission time to expand variadic parameters into their concrete counterparts. """ if is_async is None: is_async = isinstance(node, ast.AsyncFunctionDef) decorator_names = [_decorator_name(d) for d in node.decorator_list if _decorator_name(d)] is_overload = any(n in _OVERLOAD_NAMES for n in decorator_names) # ── Collect annotated parameters ────────────────────────────────── raw_arg_anns: dict[str, str] = {} all_args = ( node.args.posonlyargs + node.args.args + node.args.kwonlyargs ) for arg in all_args: if arg.annotation: s = _unparse(arg.annotation) if s: raw_arg_anns[arg.arg] = s if node.args.vararg and node.args.vararg.annotation: s = _unparse(node.args.vararg.annotation) if s: raw_arg_anns[f"*{node.args.vararg.arg}"] = s if node.args.kwarg and node.args.kwarg.annotation: s = _unparse(node.args.kwarg.annotation) if s: raw_arg_anns[f"**{node.args.kwarg.arg}"] = s # ── Scan body for **kwargs / *args forwarding targets ───────────── kwargs_name = node.args.kwarg.arg if node.args.kwarg else None varargs_name = node.args.vararg.arg if node.args.vararg else None kw_targets: list[str] = [] pos_targets: list[str] = [] if kwargs_name or varargs_name: for body_node in ast.walk(node): if body_node is node: # skip the definition itself continue if not isinstance(body_node, ast.Call): continue fname = _call_func_name(body_node) if not fname: continue if kwargs_name: has_kw_fwd = any( kw.arg is None and isinstance(kw.value, ast.Name) and kw.value.id == kwargs_name for kw in body_node.keywords ) if has_kw_fwd and fname not in kw_targets: kw_targets.append(fname) if varargs_name: has_pos_fwd = any( isinstance(arg, ast.Starred) and isinstance(arg.value, ast.Name) and arg.value.id == varargs_name for arg in body_node.args ) if has_pos_fwd and fname not in pos_targets: pos_targets.append(fname) return FunctionInfo( name=node.name, lineno=node.lineno, is_async=is_async, decorators=decorator_names, is_overload=is_overload, raw_arg_annotations=raw_arg_anns, raw_return_annotation=_unparse(node.returns), kwargs_forwarded_to=kw_targets, args_forwarded_to=pos_targets, )
# --------------------------------------------------------------------------- # Public convenience function # ---------------------------------------------------------------------------
[docs] def ast_harvest(source: str) -> ASTSymbols: """ Parse *source* and return structural metadata without executing any code. This is the main entry point for the AST pre-pass stage. A fresh :class:`ASTHarvester` is created for each call, making this function fully re-entrant. Parameters ---------- source : str Raw Python source text. Returns ------- ASTSymbols Populated container of all harvested metadata. On a :exc:`SyntaxError` the result will be empty but valid — no exception is raised. Examples -------- >>> syms = ast_harvest("") >>> syms.classes [] >>> syms = ast_harvest("class Foo(Bar): pass") >>> syms.classes[0].name, syms.classes[0].bases ('Foo', ['Bar']) >>> syms = ast_harvest("async def fetch(url: str) -> None: ...") >>> fn = syms.functions[0] >>> fn.is_async, fn.name (True, 'fetch') >>> syms = ast_harvest("X = TypeVar('X')") >>> syms.typevar_decls[0].kind 'TypeVar' """ return ASTHarvester(source).harvest()