From 8767302ab39cafa9568e000b40e5c733573f1983 Mon Sep 17 00:00:00 2001 From: "alexander.varga" Date: Thu, 18 Jul 2024 13:50:40 -0400 Subject: [PATCH] enable writing module code directly to file --- libcst/_nodes/internal.py | 25 +++++++++++++++++++++++++ libcst/_nodes/module.py | 26 +++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/libcst/_nodes/internal.py b/libcst/_nodes/internal.py index 35d897435..67c3f2fd8 100644 --- a/libcst/_nodes/internal.py +++ b/libcst/_nodes/internal.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import io from contextlib import contextmanager from dataclasses import dataclass, field from typing import Iterable, Iterator, List, Optional, Sequence, TYPE_CHECKING, Union @@ -70,6 +71,30 @@ def record_syntactic_position( yield +@add_slots +@dataclass(frozen=False) +class CodegenWriter(CodegenState): + """ + A CodegenState that writes to a file-like object. + """ + + writer: io.TextIOBase = None # need a default value for dataclass + + def __post_init__(self) -> None: + if self.writer is None: + raise TypeError("writer must be provided") + + def add_indent_tokens(self) -> None: + for token in self.indent_tokens: + self.writer.write(token) + + def add_token(self, value: str) -> None: + self.writer.write(value) + + def pop_trailing_newline(self) -> None: + pass + + def visit_required( parent: "CSTNode", fieldname: str, node: CSTNodeT, visitor: "CSTVisitorT" ) -> CSTNodeT: diff --git a/libcst/_nodes/module.py b/libcst/_nodes/module.py index 9ed45716a..d22e77841 100644 --- a/libcst/_nodes/module.py +++ b/libcst/_nodes/module.py @@ -3,12 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import io from dataclasses import dataclass from typing import cast, Optional, Sequence, TYPE_CHECKING, TypeVar, Union from libcst._add_slots import add_slots from libcst._nodes.base import CSTNode -from libcst._nodes.internal import CodegenState, visit_body_sequence, visit_sequence +from libcst._nodes.internal import ( + CodegenState, + CodegenWriter, + visit_body_sequence, + visit_sequence, +) from libcst._nodes.statement import ( BaseCompoundStatement, get_docstring_impl, @@ -136,6 +142,24 @@ def code_for_node(self, node: CSTNode) -> str: node._codegen(state) return "".join(state.tokens) + def write_code(self, writer: io.TextIOBase): + """ + Like :meth:`code`, but writes the code to the given file-like object. + """ + self.write_code_for_node(self, writer) + + def write_code_for_node(self, node: CSTNode, writer: io.TextIOBase): + """ + Like :meth:`code_for_node`, but writes the code to the given file-like object. + """ + + state = CodegenWriter( + default_indent=self.default_indent, + default_newline=self.default_newline, + writer=writer, + ) + node._codegen(state) + @property def config_for_parsing(self) -> "PartialParserConfig": """