From 3ba527958c8908fef729f3af42d7848c7addab9c Mon Sep 17 00:00:00 2001 From: hwrn Date: Tue, 24 Mar 2026 21:41:04 +0800 Subject: [PATCH 01/53] feat: allow # fmt: off --- snakefmt/formatter.py | 32 +++- snakefmt/parser/parser.py | 331 +++++++++++++++++++++++++++++++++++--- snakefmt/parser/syntax.py | 35 +++- tests/test_formatter.py | 244 +++++++++++++++++++++------- 4 files changed, 559 insertions(+), 83 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index a2eb3a3..9193c97 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -65,6 +65,7 @@ def __init__( self.lagging_comments: str = "" self.no_formatting_yet: bool = True self.sort_directives = sort_directives + self.fmt_off_sort_next: bool = False # for # fmt: off[sort] self.previous_result: str = "" self.keyword_spec: list[str] = [] self.keywords: dict[str, str] = {} # cache to sort @@ -89,10 +90,13 @@ def flush_buffer( from_python: bool = False, final_flush: bool = False, in_global_context: bool = False, + exiting_keywords: bool = False, ) -> None: if len(self.buffer) == 0 or self.buffer.isspace(): self.result += self.buffer self.buffer = "" + if exiting_keywords and self.no_formatting_yet and self.result.rstrip("\n"): + self.no_formatting_yet = False return if not from_python: @@ -162,7 +166,7 @@ def process_keyword_context(self, in_global_context: bool): else: # not a PythonCode context, collect keywords to sort self.previous_result += self.result + formatted self.result = "" - self.keyword_spec = self.vocab.ordered() + self.keyword_spec = [] if self.fmt_off_sort_next else self.vocab.ordered() def process_keyword_param( self, param_context: ParameterSyntax, in_global_context: bool @@ -194,6 +198,23 @@ def post_process_keyword(self): ) self.result = self.previous_result + self.result self.previous_result = "" + self.fmt_off_sort_next = False # reset after each rule/context + + def handle_fmt_off_region(self, verbatim: str) -> None: + if self.no_formatting_yet: + self.result = self.result.lstrip("\n") + self.result += self.buffer + self.buffer = "" + if not verbatim: + return + if self.lagging_comments: + self.result += self.lagging_comments + self.lagging_comments = "" + self.result += verbatim + # Treat the verbatim region as transparent to separator logic: + # resume formatting as if nothing preceded (no blank-line separator added). + self.no_formatting_yet = True + self.last_recognised_keyword = "" def run_black_format_str( self, @@ -215,7 +236,9 @@ def run_black_format_str( and len(string.strip().splitlines()) > 1 and not no_nesting ) - + if self.fmt_off and self.fmt_off_applied: + # a `fmt: off` in previous block also affects here, make it work + string = "# fmt: off\n" + string if artificial_nest: string = f"if x:\n{textwrap.indent(string, TAB)}" @@ -267,6 +290,11 @@ def run_black_format_str( lines = fmted.splitlines(keepends=True)[1:] s = "".join(lines).lstrip("\n") fmted = textwrap.dedent(s) + if self.fmt_off: + if self.fmt_off_applied: + fmted = fmted.split("# fmt: off\n", 1)[1] + else: + self.fmt_off_applied = True return fmted def align_strings(self, string: str, target_indent: int) -> str: diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index cd0dc4a..6027ed8 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -1,6 +1,6 @@ import tokenize from abc import ABC, abstractmethod -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Literal from snakefmt.exceptions import UnsupportedSyntax from snakefmt.parser.grammar import PythonCode, SnakeGlobal @@ -16,6 +16,28 @@ from snakefmt.types import TAB, Token, TokenIterator, col_nb +FMT_OFF_REGION = frozenset({"# fmt: off"}) +FMT_OFF_ONE = frozenset({"# fmt: off[one]"}) +FMT_OFF_SORT = frozenset({"# fmt: off[sort]"}) +FMT_OFF = FMT_OFF_REGION | FMT_OFF_ONE | FMT_OFF_SORT +FMT_ON = frozenset({"# fmt: on"}) + + +def split_token_lines(token: tokenize.TokenInfo): + """Token can be multiline. + e.g., `f'''\\nplaintext\\n'''` has these tokens: + + TokenInfo(type=61 (FSTRING_START), string="f'''", start=(21, 0), end=(21, 4), line="f'''\\n") + TokenInfo(type=62 (FSTRING_MIDDLE), string='\\ncccccccc\\n', start=(21, 4), end=(23, 0), line="f'''\\ncccccccc\\n'''\\n") + TokenInfo(type=63 (FSTRING_END), string="'''", start=(23, 0), end=(23, 3), line="'''\\n") + + lines should be split to drop overlapping lines and keep unique ones. + """ + return zip( + range(token.start[0], token.end[0] + 1), token.line.splitlines(keepends=True) + ) + + def not_a_comment_related_token(token: Token): return token.type not in { tokenize.COMMENT, @@ -26,6 +48,14 @@ def not_a_comment_related_token(token: Token): } +def check_indent(line: str, indents: list[str]) -> int: + indents_len = len(indents) + for i, indent in enumerate(reversed(indents), 1): + if line.startswith(indent): + return indents_len - i + raise SyntaxError("Unexpected indent") + + class Snakefile(TokenIterator): """ Adapted from snakemake.parser.Snakefile @@ -97,6 +127,13 @@ def __init__(self, snakefile: Snakefile): self.queriable = True self.in_fstring = False self.last_token: Optional[Token] = None + self.fmt_off_sort_next: bool = False # for `# fmt: off[sort]` + # for `# fmt: off`, (indent, ) + self.fmt_off: Literal[False] | tuple[int] = False + # True if a new block should be formatted as fmt: off due to a preceding fmt directive + self.fmt_off_applied: bool = False + + self.indents: list[str] = [""] status = self.get_next_queriable() self.buffer = status.buffer @@ -112,13 +149,34 @@ def __init__(self, snakefile: Snakefile): break keyword = status.token.string - - if self.vocab.recognises(keyword): + if status.token.string in FMT_ON: + self.fmt_off = False + self.fmt_off_sort_next = False + elif status.token.string in FMT_OFF: + self.fmt_off = (status.cur_indent,) + self.fmt_off_sort_next = False + elif self.fmt_off and status.cur_indent <= self.fmt_off[0]: + self.fmt_off = False + self.fmt_off_applied = False + + if self.vocab.recognises(keyword) and self.fmt_off: + if self.fmt_off: + self.fmt_off_applied = True + self._consume_fmt_off(status.token, min_indent=self.keyword_indent) + status = self.get_next_queriable() + if self.last_block_was_snakecode and not status.eof: + self.block_indent = status.block_indent + self.last_block_was_snakecode = False + self.buffer = status.buffer.lstrip() + elif self.vocab.recognises(keyword): + new_vocab, new_syntax_cls = self.vocab.get(keyword) + is_context_kw = new_vocab is not None and issubclass( + new_syntax_cls, KeywordSyntax + ) if status.cur_indent > self.keyword_indent: - in_if_else = self.buffer.startswith(("if", "else", "elif")) - if self.syntax.from_python or status.pythonable or in_if_else: + if self.syntax.from_python or status.pythonable: self.from_python = True - elif self.from_python: + elif self.from_python and not is_context_kw: # We are exiting python context, so force spacing out keywords self.last_recognised_keyword = "" self.from_python = self.syntax.from_python @@ -135,12 +193,25 @@ def __init__(self, snakefile: Snakefile): f"L{status.token.start[0]}: Unrecognised keyword '{keyword}' " f"in {self.syntax.keyword_name} definition" ) - else: - self.buffer += f"{keyword}" + elif keyword in FMT_OFF_REGION: + self.flush_buffer( + from_python=self.from_python, + in_global_context=self.in_global_context, + ) + self._consume_fmt_off(status.token, min_indent=self.keyword_indent) + self.buffer = "" status = self.get_next_queriable() if self.last_block_was_snakecode and not status.eof: self.block_indent = status.block_indent self.last_block_was_snakecode = False + self.buffer = status.buffer.lstrip() + if self.keyword_indent: + self.last_block_was_snakecode = True + else: + source, status = self._consume_python(status.token) + self.buffer += source + if self.last_block_was_snakecode and not status.eof: + self.last_block_was_snakecode = False self.buffer += status.buffer if ( self.from_python @@ -193,6 +264,7 @@ def flush_buffer( from_python: bool = False, final_flush: bool = False, in_global_context: bool = False, + exiting_keywords: bool = False, ) -> None: """Processes the text in :self.buffer:""" @@ -211,6 +283,141 @@ def post_process_keyword(self) -> None: """Sort params when exiting a keyword context, eg after finishing parsing a 'rule:'""" + def _consume_python( + self, start_token: Token, vocab_recognises=True, added_indent: str = "" + ) -> tuple[str, Status]: + """Collect Python source lines until a snakemake keyword at correct indent, + or dedent below min_indent, or EOF. + Returns (source_text, next_status) where next_status carries the stopping token. + """ + origin_indent = start_token.start[1] + + lines: dict[int, str] = {start_token.start[0]: start_token.line} + # Lines that are interior to a multiline token (string / f-string body). + # Their content must not be reindented. + string_interior_lines: set[int] = set() + self.queriable = False + prev_token = None + last_indent_token = None + min_indent = -1 + + def _init_min_indent(token: Token): + nonlocal min_indent + if token.string.lstrip()[:1] != "#": + while not token.line.startswith(self.indents[-1]): + self.indents.pop() + min_indent = len(self.indents) - 1 + + _init_min_indent(start_token) + while True: + try: + token = next(self.snakefile) + except StopIteration: + eof_token = Token(tokenize.ENDMARKER, "", (0, 0), (0, 0), "") + self.snakefile.denext(eof_token) + break + if min_indent == -1: + _init_min_indent(token) + elif token.line[:origin_indent].strip(): + # non-whitespace before origin indent: stop + self.snakefile.denext(token) + break + self.last_token = token + self.in_fstring = fstring_processing(token, prev_token, self.in_fstring) + prev_token = token + if token.type == tokenize.ENDMARKER: + self.snakefile.denext(token) + break + if token.type == tokenize.INDENT: + self._handle_indent(token) + self.syntax.cur_indent = len(self.indents) - 1 + last_indent_token = token + continue + if token.type == tokenize.DEDENT: + saved_indents = list(self.indents) + self._handle_indent(token) + new_indent = len(self.indents) - 1 + last_indent_token = None + if new_indent < min_indent: + # let get_next_queriable handle dedent below min_indent + self.indents = saved_indents + self.snakefile.denext(token) + break + self.syntax.cur_indent = new_indent + continue + if is_newline(token): + self.queriable = True + lines.update(split_token_lines(token)) + continue + if vocab_recognises: + if ( + (token.type == tokenize.NAME or token.string == "@") + and self.queriable + and not self.in_fstring + ): + if self.vocab.recognises(token.string): + # snakemake keyword: stop, let main loop handle it + self.snakefile.denext(token) + if last_indent_token is not None: + self.snakefile.denext(last_indent_token) + self.indents.pop() + self.syntax.cur_indent = len(self.indents) - 1 + break + else: + if token.type == tokenize.COMMENT and token.string in FMT_ON: + lines.update(split_token_lines(token)) + self.fmt_off = False + self.fmt_off_sort_next = False + break + + self.queriable = False + lines.update(split_token_lines(token)) + # Mark interior lines of any multiline token as string content. + if token.start[0] != token.end[0]: + string_interior_lines.update( + range(token.start[0] + 1, token.end[0] + 1) + ) + + verbatim = self._reindent( + lines, string_interior_lines, origin_indent, added_indent + ) + next_status = self.get_next_queriable() + return verbatim, next_status._replace( + pythonable=next_status.pythonable or bool(verbatim.strip()) + ) + + @abstractmethod + def handle_fmt_off_region(self, verbatim: str) -> None: + """handle unformatted text (just update indent).""" + + def _consume_fmt_off(self, start_token: Token, min_indent: int): + verbatim, next_status = self._consume_python( + start_token, vocab_recognises=False, added_indent=TAB * min_indent + ) + self.handle_fmt_off_region(verbatim) + self.snakefile.denext(next_status.token) + self.queriable = True + + def _reindent( + self, + lines: dict[int, str], + string_interior_lines: set[int], + origin_indent: int, + added_indent: str = "", + ) -> str: + newlines = [] + for i in sorted(lines): + line = lines[i] + if i in string_interior_lines: + newlines.append(line) + elif line.strip(): + newline = line.rsplit("\n", 1) + newline[0] = added_indent + newline[0][origin_indent:] + newlines.append("\n".join(newline)) + else: + newlines.append(line[origin_indent:]) + return "".join(newlines) + def process_keyword(self, status: Status, from_python: bool = False) -> Status: """Called when a snakemake keyword has been found. @@ -261,10 +468,13 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: ) self.process_keyword_param(param_context, self.in_global_context) self.syntax.add_processed_keyword(status.token, status.token.string) + cur_indent = param_context.cur_indent + if param_context.token.type == tokenize.COMMENT and not param_context.eof: + cur_indent = self._determe_comment_indent(param_context.token) return Status( param_context.token, - param_context.cur_indent, - param_context.cur_indent, + cur_indent, + cur_indent, status.buffer, param_context.eof, self.from_python, @@ -278,7 +488,8 @@ def context_exit(self, status: Status) -> None: while self.keyword_indent > status.cur_indent: callback_context: Context = self.context_stack.pop() if callback_context.syntax.accepts_python_code: - self.flush_buffer() # Flushes any code inside 'run' directive + # Flushes any code inside 'run' directive + self.flush_buffer(exiting_keywords=True) else: callback_context.syntax.check_empty() self.context = self.context_stack[-1] @@ -289,6 +500,65 @@ def context_exit(self, status: Status) -> None: if self.keyword_indent > 0: self.syntax.keyword_indent = status.cur_indent + 1 + def _determe_comment_indent(self, token: Token) -> int: + """ + Treat each line of single-line comment separately, + it is determined by the following real code line and previous self.indents. + + follow_indent = indent of the following real code line + if EOF: + follow_indent = 0 + rule 1 (always): + indent of comments >= follow_indent + rule 2 (if follow_indent < self.indents[-1]): + indent of comments = max(i for i in self.indents if i <= comment_indent) + epsilon. + + next(self.snakefile) until follow_indent is determined, then put all peeked tokens back. + """ + # ── Step 1: peek ahead to find follow_indent ──────────────────────── + peeked: list[Token] = [] + saved_indents = list(self.indents) + follow_indent = len(self.indents) - 1 + try: + while True: + t = next(self.snakefile) + peeked.append(t) + if self._handle_indent(t): + pass + elif t.type not in {tokenize.NEWLINE, tokenize.NL, tokenize.COMMENT}: + follow_indent = check_indent(t.line, self.indents) + break + except StopIteration: + follow_indent = 0 + # restore indent stack and token stream unchanged + self.indents = saved_indents + for t in reversed(peeked): + self.snakefile.denext(t) + + # Rule 1 (always): comment must not be indented below following code. + if len(self.indents) - 1 <= follow_indent: + return follow_indent + # Rule 2 (dedent is happening, standalone only): snap comment to the + # highest indent level fitting within the comment's column. + return max(check_indent(token.line, self.indents), follow_indent) + + def _handle_indent(self, token: Token) -> bool: + if token.type == tokenize.INDENT: + line = token.line + indent = line[: len(line) - len(line.lstrip())] + if indent not in self.indents: + self.indents.append(indent) + elif token.type == tokenize.DEDENT: + line = token.line + indent = line[: len(line) - len(line.lstrip())] + while self.indents and self.indents[-1] != indent: + self.indents.pop() + if not self.indents: + raise SyntaxError("Unexpected dedent") + else: + return False + return True + def get_next_queriable(self) -> Status: """Produces the next word that could be a snakemake keyword, and additional information in a :Status: @@ -307,24 +577,37 @@ def get_next_queriable(self) -> Status: self.in_fstring = fstring_processing(token, prev_token, self.in_fstring) if block_indent == -1 and not_a_comment_related_token(token): block_indent = self.cur_indent - if token.type == tokenize.INDENT: - self.syntax.cur_indent += 1 - prev_token = None - newline = True - continue - elif token.type == tokenize.DEDENT: - if self.cur_indent > 0: - self.syntax.cur_indent -= 1 + if self._handle_indent(token): prev_token = None newline = True + self.syntax.cur_indent = len(self.indents) - 1 continue elif token.type == tokenize.ENDMARKER: return Status( token, block_indent, self.cur_indent, buffer, True, pythonable ) elif token.type == tokenize.COMMENT: - if col_nb(token) == 0: - return Status(token, block_indent, 0, buffer, False, pythonable) + if ( + not self.last_block_was_snakecode + and token.string in FMT_OFF + or token.string in FMT_ON + ): + # col-0 comments report cur_indent=0 to trigger context_exit; + # fmt directives at other columns report actual cur_indent. + return Status( + token, block_indent, self.cur_indent, buffer, False, pythonable + ) + # Comments arrive in the token stream *before* any following + # INDENT/DEDENT tokens, so self.cur_indent still reflects the + # previous (potentially higher) level. Delegate to + # _determe_comment_indent which peeks ahead and applies the + # two snapping rules. + effective_indent = self._determe_comment_indent(token) + self.syntax.cur_indent = effective_indent + if effective_indent < max(self.keyword_indent, self.block_indent): + return Status( + token, block_indent, effective_indent, buffer, False, pythonable + ) elif is_newline(token): self.queriable, newline = True, True @@ -343,7 +626,11 @@ def get_next_queriable(self) -> Status: else: buffer += TAB * self.effective_indent - if (token.type == tokenize.NAME or token.string == "@") and self.queriable: + if ( + (token.type == tokenize.NAME or token.string == "@") + and self.queriable + and not self.in_fstring + ): self.queriable = False return Status( token, block_indent, self.cur_indent, buffer, False, pythonable diff --git a/snakefmt/parser/syntax.py b/snakefmt/parser/syntax.py index 3cee140..3e0900f 100644 --- a/snakefmt/parser/syntax.py +++ b/snakefmt/parser/syntax.py @@ -516,21 +516,46 @@ def parse_params(self, snakefile: TokenIterator): self.flush_param(cur_param, skip_empty=True) self.eof = True break - if self.check_exit(cur_param): + if self.check_exit(cur_param, snakefile): break if self.num_params() == 0: raise NoParametersError(f"{self.line_nb}In {self.keyword_name} definition.") - def check_exit(self, cur_param: Parameter): + def check_exit(self, cur_param: Parameter, snakefile: TokenIterator): exit = False - if not self.found_newline: + if not self.found_newline or not self.token: return exit if not_empty(self.token): - # Special condition for comments: they appear before indents/dedents. if self.token.type == tokenize.COMMENT: if not cur_param.is_empty() and col_nb(self.token) < cur_param.col_nb: - exit = True + # comment appears before INDENT/DEDENT in the token stream; + # peek ahead with a temp counter so self.cur_indent stays + # untouched — the real processing will update it once tokens + # are put back. + temp_indent = self.cur_indent + cached_tokens: list[Token] = [] + try: + while True: + t = next(snakefile) + cached_tokens.append(t) + if t.type == tokenize.INDENT: + temp_indent += 1 + elif t.type == tokenize.DEDENT: + temp_indent = max(temp_indent - 1, 0) + elif t.type not in { + tokenize.NEWLINE, + tokenize.NL, + tokenize.COMMENT, + }: + # stop here; this token stays in cached_tokens + # and will be put back below — no double-denext + break + except StopIteration: + pass + for t in reversed(cached_tokens): + snakefile.denext(t) # type: ignore[attr-defined] + exit = temp_indent < self.keyword_indent else: exit = self.cur_indent < self.keyword_indent if exit: diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 05d6c0c..c150e47 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -45,60 +45,57 @@ def test_single_param_keyword_stays_on_same_line(self): assert actual == expected + example_shell_newline = ( + "rule a:\n" + f'{TAB * 1}shell: "for i in $(seq 1 5);"\n' + f'{TAB * 2}"do echo $i;"\n' + f'{TAB * 2}"done"', + "rule a:\n" + f"{TAB * 1}shell:\n" + f'{TAB * 2}"for i in $(seq 1 5);"\n' + f'{TAB * 2}"do echo $i;"\n' + f'{TAB * 2}"done"\n', + ) + def test_shell_param_newline_indented(self): - formatter = setup_formatter( - "rule a:\n" - f'{TAB * 1}shell: "for i in $(seq 1 5);"\n' - f'{TAB * 2}"do echo $i;"\n' - f'{TAB * 2}"done"' - ) - expected = ( - "rule a:\n" - f"{TAB * 1}shell:\n" - f'{TAB * 2}"for i in $(seq 1 5);"\n' - f'{TAB * 2}"do echo $i;"\n' - f'{TAB * 2}"done"\n' - ) - assert formatter.get_formatted() == expected + formatter = setup_formatter(self.example_shell_newline[0]) + assert formatter.get_formatted() == self.example_shell_newline[1] + + example_params_newline = ( + f"rule a: \n" + f'{TAB * 1}input: "a", "b",\n' + f'{TAB * 4}"c"\n' + f'{TAB * 1}wrapper: "mywrapper"', + "rule a:\n" + f"{TAB * 1}input:\n" + f'{TAB * 2}"a",\n' + f'{TAB * 2}"b",\n' + f'{TAB * 2}"c",\n' + f"{TAB * 1}wrapper:\n" + f'{TAB * 2}"mywrapper"\n', + ) def test_single_param_keyword_in_rule_gets_newline_indented(self): - formatter = setup_formatter( - f"rule a: \n" - f'{TAB * 1}input: "a", "b",\n' - f'{TAB * 4}"c"\n' - f'{TAB * 1}wrapper: "mywrapper"' - ) - - actual = formatter.get_formatted() - expected = ( - "rule a:\n" - f"{TAB * 1}input:\n" - f'{TAB * 2}"a",\n' - f'{TAB * 2}"b",\n' - f'{TAB * 2}"c",\n' - f"{TAB * 1}wrapper:\n" - f'{TAB * 2}"mywrapper"\n' - ) - - assert actual == expected + formatter = setup_formatter(self.example_params_newline[0]) + assert formatter.get_formatted() == self.example_params_newline[1] + + example_input_threads_newline = ( + "rule a: \n" + f'{TAB * 1}input: "c"\n' + f"{TAB * 1}threads:\n" + f"{TAB * 2}20\n" + f"{TAB * 1}default_target:\n" + f"{TAB * 2}True\n", + f"rule a:\n" + f"{TAB * 1}input:\n" + f'{TAB * 2}"c",\n' + f"{TAB * 1}threads: 20\n" + f"{TAB * 1}default_target: True\n", + ) def test_single_numeric_param_keyword_in_rule_stays_on_same_line(self): - formatter = setup_formatter( - "rule a: \n" - f'{TAB * 1}input: "c"\n' - f"{TAB * 1}threads:\n" - f"{TAB * 2}20\n" - f"{TAB * 1}default_target:\n" - f"{TAB * 2}True\n" - ) - - actual = formatter.get_formatted() - expected = ( - f'rule a:\n{TAB * 1}input:\n{TAB * 2}"c",\n{TAB * 1}threads: 20\n' - f"{TAB * 1}default_target: True\n" - ) - - assert actual == expected + formatter = setup_formatter(self.example_input_threads_newline[0]) + assert formatter.get_formatted() == self.example_input_threads_newline[1] class TestModuleFormatting: @@ -566,7 +563,8 @@ def test_python_code_after_nested_snakecode_gets_formatted(self): setup_formatter(snakecode) assert mock_m.call_count == 3 assert mock_m.call_args_list[1] == mock.call('"a"', 0, 0, no_nesting=True) - assert mock_m.call_args_list[2] == mock.call("b = 2\n", 0) + # now python codes parsed as-is + assert mock_m.call_args_list[2] == mock.call("b=2\n", 0) formatter = setup_formatter(snakecode) expected = ( @@ -1662,17 +1660,24 @@ def test_storage(self): class TestRunBlockFormatting: - def test_issue_267_comment_indentation_in_run_block(self): + def test_comment_indentation_in_run_block(self): """https://github.com/snakemake/snakefmt/issues/267""" - snakecode = ( + expected = ( "rule fmt_bug_repro:\n" f"{TAB * 1}run:\n" f'{TAB * 2}if "something nested":\n' f"{TAB * 3}pass\n" f"{TAB * 2}# Comment gets indented\n" ) - formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == snakecode + assert setup_formatter(expected).get_formatted() == expected + snakecode = ( + "rule fmt_bug_repro:\n" + f" run:\n" + f' if "something nested":\n' + f" pass\n" + f" # Comment gets indented\n" + ) + assert setup_formatter(snakecode).get_formatted() == expected def test_double_block_comment(self): """https://github.com/snakemake/snakefmt/issues/196""" @@ -1937,3 +1942,134 @@ def test_use_parameters_with(self): ) formatter = setup_formatter(snakecode) assert formatter.get_formatted() == snakecode + + +class TestFmtOffOn: + """Tests for # fmt: off / # fmt: on directives.""" + + def test_fmt_off_at_start(self): + for code, formatted in ( + TestSimpleParamFormatting.example_shell_newline, + TestSimpleParamFormatting.example_params_newline, + TestSimpleParamFormatting.example_input_threads_newline, + ): + expected = "# fmt: off\n" + code + assert setup_formatter(expected).get_formatted() == expected + + def test_fmt_off_at_middle(self): + for code, formatted in ( + TestSimpleParamFormatting.example_shell_newline, + TestSimpleParamFormatting.example_params_newline, + TestSimpleParamFormatting.example_input_threads_newline, + ): + code1 = code + "\n\n\n# fmt: off\n" + code + expected = formatted.strip() + "\n# fmt: off\n" + code + assert setup_formatter(code1).get_formatted() == expected + + def test_fmt_off_on(self): + # TODO: the action after `# fmt: on` should be consistent, should be fixed in the future. + for code, formatted in ( + TestSimpleParamFormatting.example_shell_newline, + TestSimpleParamFormatting.example_params_newline, + TestSimpleParamFormatting.example_input_threads_newline, + ): + code1 = "\n# fmton\n" + code + expected = "# fmton\n" + formatted + assert setup_formatter(code1).get_formatted() == expected + code1 = "\n\n# fmt: on\n" + code + expected = "# fmt: on\n" + formatted + assert setup_formatter(code1).get_formatted() == expected + # TODO: trailing comments like `# fmt: off # comment` are not currently supported, but should be in the future + code1 = "\n# fmt: off\n" + code + "\n# fmt: on\n" + code + expected = "# fmt: off\n" + code + "\n# fmt: on\n" + formatted + assert setup_formatter(code1).get_formatted() == expected + + def test_fmt_off_on_in_run(self): + """# fmt: off inside Python code is handled by Black.""" + code = ( + "# ?\n" + "x = [1,2,3]\n" + "# fmt: off\n" + "y = [ 1, 2]\n" + "s = f'''\n" + " {y} \n" + " '''\n" + "# fmt: on\n" + "z = [4,5,6]\n" + ) + expected = ( + "# ?\n" + "x = [1, 2, 3]\n" + "# fmt: off\n" + "y = [ 1, 2]\n" + "s = f'''\n" + " {y} \n" + " '''\n" + "# fmt: on\n" + "z = [4, 5, 6]\n" + ) + assert setup_formatter(code).get_formatted() == expected + snakecode = "rule:\n" f" run:\n" + ( + "".join(f" {i}\n" for i in code.splitlines()) + ) + snakexpected = "rule:\n" f"{TAB * 1}run:\n" + ( + f"{TAB * 2}# ?\n" + f"{TAB * 2}x = [1, 2, 3]\n" + f"{TAB * 2}# fmt: off\n" + f"{TAB * 2}y = [ 1, 2]\n" + f"{TAB * 2}s = f'''\n" + f"{' '} {{y}} \n" + f"{' '} '''\n" + f"{TAB * 2}# fmt: on\n" + f"{TAB * 2}z = [4, 5, 6]\n" + ) + assert setup_formatter(snakecode).get_formatted() == snakexpected + + def test_fmt_off_on_in_run_complex(self): + code, formatted = TestSimpleParamFormatting.example_shell_newline + formatter = setup_formatter( + f"rule:\n" + f" run:\n" + f" # fmt: off\n" + f" x = [ 1,2,3]\n" + f" # fmt: on\n" + f"\n" + f"sth=1\n" + f"{code}" + ) + expected = ( + "rule:\n" + f"{TAB * 1}run:\n" + f"{TAB * 2}# fmt: off\n" + f"{TAB * 2}x = [ 1,2,3]\n" + f"{TAB * 2}# fmt: on\n" + f"\n" + f"\n" + f"sth = 1\n" + f"\n" + f"\n" + f"{formatted}" + ) + assert formatter.get_formatted() == expected + formatter = setup_formatter( + f"rule:\n" + f" run:\n" + f" # fmt: off\n" + f" x = [ 1,2,3]\n" + f"\n" + f"sth=1\n" + f"{code}" + ) + expected = ( + "rule:\n" + f"{TAB * 1}run:\n" + f"{TAB * 2}# fmt: off\n" + f"{TAB * 2}x = [ 1,2,3]\n" + f"\n" + f"\n" + f"sth = 1\n" + f"\n" + f"\n" + f"{formatted}" + ) + assert formatter.get_formatted() == expected From 25ffce6c9e10d95b6a9ecc78763b952cc9b463a8 Mon Sep 17 00:00:00 2001 From: hwrn Date: Tue, 24 Mar 2026 22:32:27 +0800 Subject: [PATCH 02/53] fix: tests --- tests/test_formatter.py | 1 + uv.lock | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 01c2194..732774d 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2015,6 +2015,7 @@ def side_effect(*args, **kwargs): formatter.snakefile = smk formatter.black_mode = black.Mode() formatter.from_python = False + formatter.fmt_off = False from snakefmt.parser.parser import Context from snakefmt.parser.syntax import KeywordSyntax diff --git a/uv.lock b/uv.lock index 0db0086..d659540 100644 --- a/uv.lock +++ b/uv.lock @@ -996,7 +996,7 @@ wheels = [ [[package]] name = "snakefmt" -version = "0.11.5" +version = "1.0.0" source = { editable = "." } dependencies = [ { name = "black" }, From ed1c25800e8c013e6ad2ecb32f928b45fd0a4be3 Mon Sep 17 00:00:00 2001 From: hwrn Date: Tue, 24 Mar 2026 23:54:56 +0800 Subject: [PATCH 03/53] fix: strict --- snakefmt/formatter.py | 2 + snakefmt/parser/parser.py | 22 ++++--- tests/test_formatter.py | 126 ++++++++++++++++++++++++++++++++------ 3 files changed, 125 insertions(+), 25 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index 43166bb..99da96c 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -200,6 +200,8 @@ def post_process_keyword(self): self.result = self.previous_result + self.result self.previous_result = "" self.fmt_off_sort_next = False # reset after each rule/context + if self.no_formatting_yet and self.result.rstrip("\n"): + self.no_formatting_yet = False def handle_fmt_off_region(self, verbatim: str) -> None: if self.no_formatting_yet: diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 6027ed8..1682299 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -149,7 +149,7 @@ def __init__(self, snakefile: Snakefile): break keyword = status.token.string - if status.token.string in FMT_ON: + if self._check_fmt_off_on(status.token): self.fmt_off = False self.fmt_off_sort_next = False elif status.token.string in FMT_OFF: @@ -198,6 +198,8 @@ def __init__(self, snakefile: Snakefile): from_python=self.from_python, in_global_context=self.in_global_context, ) + if self.keyword_indent > 0: + self.syntax.add_processed_keyword(status.token, keyword) self._consume_fmt_off(status.token, min_indent=self.keyword_indent) self.buffer = "" status = self.get_next_queriable() @@ -303,7 +305,7 @@ def _consume_python( def _init_min_indent(token: Token): nonlocal min_indent - if token.string.lstrip()[:1] != "#": + if not comment_start(token.string): while not token.line.startswith(self.indents[-1]): self.indents.pop() min_indent = len(self.indents) - 1 @@ -364,10 +366,10 @@ def _init_min_indent(token: Token): self.syntax.cur_indent = len(self.indents) - 1 break else: - if token.type == tokenize.COMMENT and token.string in FMT_ON: - lines.update(split_token_lines(token)) + if self._check_fmt_off_on(token): self.fmt_off = False self.fmt_off_sort_next = False + lines.update(split_token_lines(token)) break self.queriable = False @@ -542,6 +544,13 @@ def _determe_comment_indent(self, token: Token) -> int: # highest indent level fitting within the comment's column. return max(check_indent(token.line, self.indents), follow_indent) + def _check_fmt_off_on(self, token: Token) -> bool: + if token.type == tokenize.COMMENT and self.fmt_off: + if token.string in FMT_ON: + if self._determe_comment_indent(token) == self.fmt_off[0]: + return True + return False + def _handle_indent(self, token: Token) -> bool: if token.type == tokenize.INDENT: line = token.line @@ -589,9 +598,8 @@ def get_next_queriable(self) -> Status: elif token.type == tokenize.COMMENT: if ( not self.last_block_was_snakecode - and token.string in FMT_OFF - or token.string in FMT_ON - ): + and (token.string in FMT_OFF or token.string in FMT_ON) + ) and col_nb(token) == 0: # col-0 comments report cur_indent=0 to trigger context_exit; # fmt directives at other columns report actual cur_indent. return Status( diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 732774d..7438b5b 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1790,7 +1790,6 @@ def test_sorting_of_params(self): def test_sorting_comprehensive(self): snakecode = ( "rule all:\n" - f"{TAB}shell: 'echo done'\n" f"{TAB}params: p=1\n" f"{TAB}resources: mem_mb=100\n" f"{TAB}threads: 4\n" @@ -1801,6 +1800,7 @@ def test_sorting_comprehensive(self): f"{TAB}# Important input\n" f"{TAB}input: 'in.txt'\n" f"{TAB}name: 'myrule'\n" + f"{TAB}shell: 'echo done'\n" ) formatter = setup_formatter(snakecode, sort_params=True) expected = ( @@ -2128,7 +2128,6 @@ def test_fmt_off_at_middle(self): assert setup_formatter(code1).get_formatted() == expected def test_fmt_off_on(self): - # TODO: the action after `# fmt: on` should be consistent, should be fixed in the future. for code, formatted in ( TestSimpleParamFormatting.example_shell_newline, TestSimpleParamFormatting.example_params_newline, @@ -2140,11 +2139,30 @@ def test_fmt_off_on(self): code1 = "\n\n# fmt: on\n" + code expected = "# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected - # TODO: trailing comments like `# fmt: off # comment` are not currently supported, but should be in the future code1 = "\n# fmt: off\n" + code + "\n# fmt: on\n" + code expected = "# fmt: off\n" + code + "\n# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected + def test_fmt_off_not_on(self): + for code, formatted in ( + TestSimpleParamFormatting.example_shell_newline, + TestSimpleParamFormatting.example_params_newline, + TestSimpleParamFormatting.example_input_threads_newline, + ): + code1 = ( + "\n# fmt: off\n" + + code + + "\nif 1:\n a=1\n # fmt: on\n b=2\n" + + code + ) + expected = ( + "# fmt: off\n" + + code + + "\nif 1:\n a=1\n # fmt: on\n b=2\n" + + code + ) + assert setup_formatter(code1).get_formatted() == expected + def test_fmt_off_on_in_run(self): """# fmt: off inside Python code is handled by Black.""" code = ( @@ -2193,8 +2211,7 @@ def test_fmt_off_on_in_run_complex(self): f" run:\n" f" # fmt: off\n" f" x = [ 1,2,3]\n" - f" # fmt: on\n" - f"\n" + f" # fmt: on\n\n" f"sth=1\n" f"{code}" ) @@ -2203,12 +2220,8 @@ def test_fmt_off_on_in_run_complex(self): f"{TAB * 1}run:\n" f"{TAB * 2}# fmt: off\n" f"{TAB * 2}x = [ 1,2,3]\n" - f"{TAB * 2}# fmt: on\n" - f"\n" - f"\n" - f"sth = 1\n" - f"\n" - f"\n" + f"{TAB * 2}# fmt: on\n\n\n" + f"sth = 1\n\n\n" f"{formatted}" ) assert formatter.get_formatted() == expected @@ -2216,21 +2229,98 @@ def test_fmt_off_on_in_run_complex(self): f"rule:\n" f" run:\n" f" # fmt: off\n" + f" x = [ 1,2,3]\n\n" + f"sth=1\n" + f"{code}" + ) + expected = ( + "rule:\n" + f"{TAB * 1}run:\n" + f"{TAB * 2}# fmt: off\n" + f"{TAB * 2}x = [ 1,2,3]\n\n\n" + f"sth = 1\n\n\n" + f"{formatted}" + ) + assert formatter.get_formatted() == expected + + def test_fmt_off_on_in_rule(self): + code, formatted = TestSimpleParamFormatting.example_shell_newline + formatter = setup_formatter( + f"rule:\n" + f" # fmt: off\n" + f" run:\n" f" x = [ 1,2,3]\n" - f"\n" f"sth=1\n" f"{code}" ) expected = ( "rule:\n" + f"{TAB * 1}# fmt: off\n" + f"{TAB * 1}run:\n" + f"{TAB * 2}x = [ 1,2,3]\n\n\n" + f"sth = 1\n\n\n" + f"{formatted}" + ) + assert formatter.get_formatted() == expected + formatter = setup_formatter( + f"rule:\n" + f" message: 'finishing'\n" + f" # Important input\n" + f" input: 'in.txt'\n" + f" # fmt: off\n" + f" log: 'log.txt'\n" + f" name: 'myrule'\n" + f" # fmt: on\n" + f" output: 'out.txt'\n" + f" run:\n" + f" # fmt: off\n" + f" x = [ 1,2,3]\n\n" + f"sth=1\n" + f"{code}" + ) + expected = ( + "rule:\n" + f"{TAB}message:\n" + f'{TAB}{TAB}"finishing"\n' + f"{TAB}# Important input\n" + f"{TAB}input:\n" + f'{TAB}{TAB}"in.txt",\n' + f"{TAB}# fmt: off\n" + f"{TAB}log: 'log.txt'\n" + f"{TAB}name: 'myrule'\n" + f"{TAB}# fmt: on\n" + f"{TAB}output:\n" + f'{TAB}{TAB}"out.txt",\n' f"{TAB * 1}run:\n" f"{TAB * 2}# fmt: off\n" - f"{TAB * 2}x = [ 1,2,3]\n" - f"\n" - f"\n" - f"sth = 1\n" - f"\n" - f"\n" + f"{TAB * 2}x = [ 1,2,3]\n\n\n" + f"sth = 1\n\n\n" f"{formatted}" ) assert formatter.get_formatted() == expected + + def test_fmt_off_on_in_other(self): + formatter = setup_formatter( + "module a: \n" + f'{TAB * 1}snakefile: "other.smk"\n' + f"{TAB * 1}# fmt: off\n" + f"{TAB * 1}config: config\n" + f'{TAB * 1}prefix: "testmodule"\n' + f"{TAB * 1}# fmt: on\n" + f'{TAB * 1}replace_prefix: {{"results/": "results/testmodule/"}}\n' + f'{TAB * 1}meta_wrapper: "0.72.0/meta/bio/bwa_mapping"\n' + ) + expected = ( + "module a:\n" + f"{TAB * 1}snakefile:\n" + f'{TAB * 2}"other.smk"\n' + f"{TAB * 1}# fmt: off\n" + f"{TAB * 1}config: config\n" + f'{TAB * 1}prefix: "testmodule"\n' + f"{TAB * 1}# fmt: on\n" + f"{TAB * 1}replace_prefix:\n" + f'{TAB * 2}{{"results/": "results/testmodule/"}}\n' + f"{TAB * 1}meta_wrapper:\n" + f'{TAB * 2}"0.72.0/meta/bio/bwa_mapping"\n' + ) + assert formatter.get_formatted() == expected From df1eda6f9136ad2a38e3fd832d01cb0f8a7d1169 Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 25 Mar 2026 21:57:27 +0800 Subject: [PATCH 04/53] feat: off[sort] --- snakefmt/formatter.py | 22 +-- snakefmt/parser/parser.py | 117 ++++++++++----- tests/test_formatter.py | 296 +++++++++++++++++++++----------------- 3 files changed, 256 insertions(+), 179 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index 99da96c..abb82ee 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -65,8 +65,7 @@ def __init__( self.result: str = "" self.lagging_comments: str = "" self.no_formatting_yet: bool = True - self.sort_directives = sort_directives - self.fmt_off_sort_next: bool = False # for # fmt: off[sort] + self.fmt_sort_off = None if sort_directives else -1 self.previous_result: str = "" self.keyword_spec: list[str] = [] self.keywords: dict[str, str] = {} # cache to sort @@ -167,7 +166,7 @@ def process_keyword_context(self, in_global_context: bool): else: # not a PythonCode context, collect keywords to sort self.previous_result += self.result + formatted self.result = "" - self.keyword_spec = [] if self.fmt_off_sort_next else self.vocab.ordered() + self.keyword_spec = self.vocab.ordered() def process_keyword_param( self, param_context: ParameterSyntax, in_global_context: bool @@ -178,7 +177,7 @@ def process_keyword_param( context=param_context, ) param_formatted = self.format_params(param_context) - if self.sort_directives and not in_global_context and self.keyword_spec: + if self.fmt_sort_off is None and not in_global_context and self.keyword_spec: self.keywords[param_context.keyword_name] = self.result + param_formatted self.result = "" else: @@ -192,14 +191,15 @@ def post_process_keyword(self): for keyword in self.keyword_spec: res = self.keywords.pop(keyword, "") self.previous_result += res - if self.keywords: - raise InvalidParameterSyntax( - "Unexpected keywords when sorted keywords: " - + (", ".join(self.keywords)) - ) + assert not self.keywords, ( + "All directives should have been consumed; " + "if not, this is a bug in snakefmt's handling of snakemake syntax. " + "It must be the coder's fault, not the user's. " + "So please report this to the developers with the code so we can fix it: " + "https://github.com/snakemake/snakefmt/issues" + ) self.result = self.previous_result + self.result self.previous_result = "" - self.fmt_off_sort_next = False # reset after each rule/context if self.no_formatting_yet and self.result.rstrip("\n"): self.no_formatting_yet = False @@ -293,7 +293,7 @@ def run_black_format_str( "\n\n(Note reported line number may be incorrect, as" " snakefmt could not determine the true line number)" ) - err_msg = f"Black error:\n```\n{str(err_msg)}\n```\n" + err_msg = f"Black error:\n```\n{str(err_msg)}\n``` from\n```\n{str(string)}\n```\n" raise InvalidPython(err_msg) from None if artificial_nest: diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 1682299..d4b5b37 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -1,6 +1,7 @@ +import re import tokenize from abc import ABC, abstractmethod -from typing import NamedTuple, Optional, Literal +from typing import Literal, NamedTuple, Optional from snakefmt.exceptions import UnsupportedSyntax from snakefmt.parser.grammar import PythonCode, SnakeGlobal @@ -16,11 +17,34 @@ from snakefmt.types import TAB, Token, TokenIterator, col_nb -FMT_OFF_REGION = frozenset({"# fmt: off"}) -FMT_OFF_ONE = frozenset({"# fmt: off[one]"}) -FMT_OFF_SORT = frozenset({"# fmt: off[sort]"}) -FMT_OFF = FMT_OFF_REGION | FMT_OFF_ONE | FMT_OFF_SORT -FMT_ON = frozenset({"# fmt: on"}) +_FMT_DIRECTIVE_RE = re.compile( + r"^# fmt: (off|on)(?:\[(\w+(?:,\s*\w+)*)\])?(?=$|\s{2}|\s#)" +) + + +class FMT_DIRECTIVE(NamedTuple): + disable: bool + modifiers: list[str] + + @classmethod + def from_token(cls, token: Token): + if token.type != tokenize.COMMENT: + return None + return cls.from_str(token.string) + + @classmethod + def from_str(cls, token_string: str): + """Parse a fmt directive comment. + Returns (disable, modifiers) or None if not a fmt directive. + disable: True | False + modifiers: e.g. [] | ['sort'] | ['next'] | ['sort', 'next'] + """ + m = _FMT_DIRECTIVE_RE.match(token_string) + if m is None: + return None + disable = m.group(1) == "off" + mods = [s.strip() for s in m.group(2).split(",")] if m.group(2) else [] + return cls(disable, mods) # type: ignore[arg-type] def split_token_lines(token: tokenize.TokenInfo): @@ -127,9 +151,9 @@ def __init__(self, snakefile: Snakefile): self.queriable = True self.in_fstring = False self.last_token: Optional[Token] = None - self.fmt_off_sort_next: bool = False # for `# fmt: off[sort]` - # for `# fmt: off`, (indent, ) - self.fmt_off: Literal[False] | tuple[int] = False + self.fmt_sort_off: Optional[int] + # for `# fmt: off`, (indent, kind); kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" + self.fmt_off: Optional[tuple[int, Literal["next", "region", "sort"]]] = None # True if a new block should be formatted as fmt: off due to a preceding fmt directive self.fmt_off_applied: bool = False @@ -149,20 +173,31 @@ def __init__(self, snakefile: Snakefile): break keyword = status.token.string - if self._check_fmt_off_on(status.token): - self.fmt_off = False - self.fmt_off_sort_next = False - elif status.token.string in FMT_OFF: - self.fmt_off = (status.cur_indent,) - self.fmt_off_sort_next = False + if fmt_label := FMT_DIRECTIVE.from_token(status.token): + if fmt_label.disable: + if not fmt_label.modifiers: + self.fmt_off = (status.cur_indent, "region") + elif "next" in fmt_label.modifiers: + self.fmt_off = (status.cur_indent, "next") + elif "sort" in fmt_label.modifiers: + self.fmt_sort_off = status.cur_indent + elif fmt_on := self._check_fmt_on(status.token): + if fmt_on == "region": + self.fmt_off = None + self.fmt_off_applied = False + elif fmt_on == "sort": + self.fmt_sort_off = None + continue elif self.fmt_off and status.cur_indent <= self.fmt_off[0]: - self.fmt_off = False + self.fmt_off = None self.fmt_off_applied = False if self.vocab.recognises(keyword) and self.fmt_off: - if self.fmt_off: - self.fmt_off_applied = True + self.fmt_off_applied = True self._consume_fmt_off(status.token, min_indent=self.keyword_indent) + if self.fmt_off and self.fmt_off[1] == "next": + self.fmt_off = None + self.fmt_off_applied = False status = self.get_next_queriable() if self.last_block_was_snakecode and not status.eof: self.block_indent = status.block_indent @@ -193,7 +228,11 @@ def __init__(self, snakefile: Snakefile): f"L{status.token.start[0]}: Unrecognised keyword '{keyword}' " f"in {self.syntax.keyword_name} definition" ) - elif keyword in FMT_OFF_REGION: + elif ( + (fmt_label := FMT_DIRECTIVE.from_token(status.token)) + and fmt_label.disable + and ("sort" not in fmt_label.modifiers) + ): self.flush_buffer( from_python=self.from_python, in_global_context=self.in_global_context, @@ -365,12 +404,15 @@ def _init_min_indent(token: Token): self.indents.pop() self.syntax.cur_indent = len(self.indents) - 1 break - else: - if self._check_fmt_off_on(token): - self.fmt_off = False - self.fmt_off_sort_next = False + elif fmt_on := self._check_fmt_on(token): + if fmt_on == "region": + self.fmt_off = None + self.fmt_off_applied = False lines.update(split_token_lines(token)) break + elif fmt_on == "sort": + self.fmt_sort_off = None + continue self.queriable = False lines.update(split_token_lines(token)) @@ -544,12 +586,24 @@ def _determe_comment_indent(self, token: Token) -> int: # highest indent level fitting within the comment's column. return max(check_indent(token.line, self.indents), follow_indent) - def _check_fmt_off_on(self, token: Token) -> bool: - if token.type == tokenize.COMMENT and self.fmt_off: - if token.string in FMT_ON: - if self._determe_comment_indent(token) == self.fmt_off[0]: - return True - return False + def _check_fmt_on(self, token: Token): + """Return True if token ends the current fmt:off region.""" + if not (fmt_dir := FMT_DIRECTIVE.from_token(token)) or fmt_dir.disable: + return + if self.fmt_off: + # `# fmt: on[sort]` no effect + if "sort" in fmt_dir.modifiers: + return + token_indent = self._determe_comment_indent(token) + if token_indent == self.fmt_off[0]: + return "region" + return + if self.fmt_sort_off is not None: + if "sort" not in (fmt_dir.modifiers or ["sort"]): + return + token_indent = self._determe_comment_indent(token) + if token_indent == self.fmt_sort_off: + return "sort" def _handle_indent(self, token: Token) -> bool: if token.type == tokenize.INDENT: @@ -596,10 +650,7 @@ def get_next_queriable(self) -> Status: token, block_indent, self.cur_indent, buffer, True, pythonable ) elif token.type == tokenize.COMMENT: - if ( - not self.last_block_was_snakecode - and (token.string in FMT_OFF or token.string in FMT_ON) - ) and col_nb(token) == 0: + if FMT_DIRECTIVE.from_token(token) and col_nb(token) == 0: # col-0 comments report cur_indent=0 to trigger context_exit; # fmt directives at other columns report actual cur_indent. return Status( diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 7438b5b..0ca5156 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1729,15 +1729,29 @@ def test_double_block_comment_mid_run(self): class TestSortFormatting: + sort_simple = ( + "rule a:\n" + f"{TAB * 1}# annots\n" + f"{TAB * 1}threads: 1\n" + f'{TAB * 1}log: "b",\n' + f'{TAB * 1}output: "a", "fsdfdsdfd", "ccc"\n' + f"{TAB * 1}run:\n" + f'{TAB * 2}print("hello world")\n', + "rule a:\n" + f"{TAB * 1}output:\n" + f'{TAB * 2}"a",\n' + f'{TAB * 2}"fsdfdsdfd",\n' + f'{TAB * 2}"ccc",\n' + f"{TAB * 1}log:\n" + f'{TAB * 2}"b",\n' + f"{TAB * 1}# annots\n" + f"{TAB * 1}threads: 1\n" + f"{TAB * 1}run:\n" + f'{TAB * 2}print("hello world")\n', + ) + def test_sorting_of_params(self): - snakecode = ( - "rule a:\n" - f"{TAB * 1}# annots\n" - f"{TAB * 1}threads: 1\n" - f'{TAB * 1}log: "b",\n' - f'{TAB * 1}output: "a", "fsdfdsdfd", "ccc"\n' - f"{TAB * 1}run:\n" - f'{TAB * 2}print("hello world")\n' + snakecode = self.sort_simple[0] + ( "if 2:\n" f"{TAB * 1}rule b:\n" f'{TAB * 2}output: "b",\n' @@ -1753,18 +1767,8 @@ def test_sorting_of_params(self): f'{TAB * 1}print("error")\n' ) formatter = setup_formatter(snakecode, sort_params=True) - expected = ( - "rule a:\n" - f"{TAB * 1}output:\n" - f'{TAB * 2}"a",\n' - f'{TAB * 2}"fsdfdsdfd",\n' - f'{TAB * 2}"ccc",\n' - f"{TAB * 1}log:\n" - f'{TAB * 2}"b",\n' - f"{TAB * 1}# annots\n" - f"{TAB * 1}threads: 1\n" - f"{TAB * 1}run:\n" - f'{TAB * 2}print("hello world")\n\n\n' + expected = self.sort_simple[1] + ( + f"\n\n" "if 2:\n" "\n" f"{TAB * 1}rule b:\n" @@ -1787,127 +1791,121 @@ def test_sorting_of_params(self): ) assert formatter.get_formatted() == expected + sorting_comprehensive = ( + "rule all:\n" + f"{TAB}params: p=1\n" + f"{TAB}resources: mem_mb=100\n" + f"{TAB}threads: 4\n" + f"{TAB}conda: 'env.yaml'\n" + f"{TAB}message: 'finishing'\n" + f"{TAB}log: 'log.txt'\n" + f"{TAB}output: 'out.txt'\n" + f"{TAB}# Important input\n" + f"{TAB}input: 'in.txt'\n" + f"{TAB}name: 'myrule'\n" + f"{TAB}shell: 'echo done'\n", + "rule all:\n" + f"{TAB}name:\n" + f'{TAB*2}"myrule"\n' + f"{TAB}# Important input\n" + f"{TAB}input:\n" + f'{TAB*2}"in.txt",\n' + f"{TAB}output:\n" + f'{TAB*2}"out.txt",\n' + f"{TAB}log:\n" + f'{TAB*2}"log.txt",\n' + f"{TAB}conda:\n" + f'{TAB*2}"env.yaml"\n' + f"{TAB}threads: 4\n" + f"{TAB}resources:\n" + f"{TAB*2}mem_mb=100,\n" + f"{TAB}params:\n" + f"{TAB*2}p=1,\n" + f"{TAB}message:\n" + f'{TAB*2}"finishing"\n' + f"{TAB}shell:\n" + f'{TAB*2}"echo done"\n', + ) + def test_sorting_comprehensive(self): - snakecode = ( - "rule all:\n" - f"{TAB}params: p=1\n" - f"{TAB}resources: mem_mb=100\n" - f"{TAB}threads: 4\n" - f"{TAB}conda: 'env.yaml'\n" - f"{TAB}message: 'finishing'\n" - f"{TAB}log: 'log.txt'\n" - f"{TAB}output: 'out.txt'\n" - f"{TAB}# Important input\n" - f"{TAB}input: 'in.txt'\n" - f"{TAB}name: 'myrule'\n" - f"{TAB}shell: 'echo done'\n" - ) - formatter = setup_formatter(snakecode, sort_params=True) - expected = ( - "rule all:\n" - f"{TAB}name:\n" - f'{TAB*2}"myrule"\n' - f"{TAB}# Important input\n" - f"{TAB}input:\n" - f'{TAB*2}"in.txt",\n' - f"{TAB}output:\n" - f'{TAB*2}"out.txt",\n' - f"{TAB}log:\n" - f'{TAB*2}"log.txt",\n' - f"{TAB}conda:\n" - f'{TAB*2}"env.yaml"\n' - f"{TAB}threads: 4\n" - f"{TAB}resources:\n" - f"{TAB*2}mem_mb=100,\n" - f"{TAB}params:\n" - f"{TAB*2}p=1,\n" - f"{TAB}message:\n" - f'{TAB*2}"finishing"\n' - f"{TAB}shell:\n" - f'{TAB*2}"echo done"\n' - ) - assert formatter.get_formatted() == expected + formatter = setup_formatter(self.sorting_comprehensive[0], sort_params=True) + assert formatter.get_formatted() == self.sorting_comprehensive[1] + + sort_with_coments = ( + "rule complex:\n" + f"{TAB}# Action comment\n" + f"{TAB}shell: 'do something'\n" + f"{TAB}# Resource comment\n" + f"{TAB}resources: res=1\n" + f"{TAB}# Input comment\n" + f"{TAB}input: 'i'\n", + "rule complex:\n" + f"{TAB}# Input comment\n" + f"{TAB}input:\n" + f'{TAB*2}"i",\n' + f"{TAB}# Resource comment\n" + f"{TAB}resources:\n" + f"{TAB*2}res=1,\n" + f"{TAB}# Action comment\n" + f"{TAB}shell:\n" + f'{TAB*2}"do something"\n', + ) def test_sorting_with_comments_preservation(self): - snakecode = ( - "rule complex:\n" - f"{TAB}# Action comment\n" - f"{TAB}shell: 'do something'\n" - f"{TAB}# Resource comment\n" - f"{TAB}resources: res=1\n" - f"{TAB}# Input comment\n" - f"{TAB}input: 'i'\n" - ) - formatter = setup_formatter(snakecode, sort_params=True) - # Comments stay with their keywords - expected = ( - "rule complex:\n" - f"{TAB}# Input comment\n" - f"{TAB}input:\n" - f'{TAB*2}"i",\n' - f"{TAB}# Resource comment\n" - f"{TAB}resources:\n" - f"{TAB*2}res=1,\n" - f"{TAB}# Action comment\n" - f"{TAB}shell:\n" - f'{TAB*2}"do something"\n' - ) - actual = formatter.get_formatted() - assert actual == expected + """Comments stay with their keywords""" + formatter = setup_formatter(self.sort_with_coments[0], sort_params=True) + assert formatter.get_formatted() == self.sort_with_coments[1] + + sort_inline_comments = ( + "rule inline_comments:\n" + f"{TAB}shell: 'echo'\n" + f"{TAB}params:\n" + f"{TAB*2}p=1, # parameter comment\n" + f"{TAB}input: 'i'\n", + "rule inline_comments:\n" + f"{TAB}input:\n" + f'{TAB*2}"i",\n' + f"{TAB}params:\n" + f"{TAB*2}p=1, # parameter comment\n" + f"{TAB}shell:\n" + f'{TAB*2}"echo"\n', + ) def test_sorting_with_inline_parameter_comments(self): - snakecode = ( - "rule inline_comments:\n" - f"{TAB}shell: 'echo'\n" - f"{TAB}params:\n" - f"{TAB*2}p=1, # parameter comment\n" - f"{TAB}input: 'i'\n" - ) - formatter = setup_formatter(snakecode, sort_params=True) - expected = ( - "rule inline_comments:\n" - f"{TAB}input:\n" - f'{TAB*2}"i",\n' - f"{TAB}params:\n" - f"{TAB*2}p=1, # parameter comment\n" - f"{TAB}shell:\n" - f'{TAB*2}"echo"\n' - ) - actual = formatter.get_formatted() - assert actual == expected + formatter = setup_formatter(self.sort_inline_comments[0], sort_params=True) + assert formatter.get_formatted() == self.sort_inline_comments[1] + + sort_module = ( + "module other:\n" + f"{TAB}meta_wrapper: 'wrapper'\n" + f"{TAB}replace_prefix: 'rp'\n" + f"{TAB}prefix: 'p'\n" + f"{TAB}skip_validation: True\n" + f"{TAB}config: 'c'\n" + f"{TAB}snakefile: 's'\n" + f"{TAB}pathvars: ['pv']\n" + f"{TAB}name: 'n'\n", + "module other:\n" + f'{TAB}name: "n"\n' + f"{TAB}pathvars:\n" + f'{TAB*2}["pv"],\n' + f"{TAB}snakefile:\n" + f'{TAB*2}"s"\n' + f"{TAB}config:\n" + f'{TAB*2}"c"\n' + f"{TAB}skip_validation:\n" + f"{TAB*2}True\n" + f"{TAB}prefix:\n" + f'{TAB*2}"p"\n' + f"{TAB}replace_prefix:\n" + f'{TAB*2}"rp"\n' + f"{TAB}meta_wrapper:\n" + f'{TAB*2}"wrapper"\n', + ) def test_sorting_module(self): - snakecode = ( - "module other:\n" - f"{TAB}meta_wrapper: 'wrapper'\n" - f"{TAB}replace_prefix: 'rp'\n" - f"{TAB}prefix: 'p'\n" - f"{TAB}skip_validation: True\n" - f"{TAB}config: 'c'\n" - f"{TAB}snakefile: 's'\n" - f"{TAB}pathvars: ['pv']\n" - f"{TAB}name: 'n'\n" - ) - formatter = setup_formatter(snakecode, sort_params=True) - expected = ( - "module other:\n" - f'{TAB}name: "n"\n' - f"{TAB}pathvars:\n" - f'{TAB*2}["pv"],\n' - f"{TAB}snakefile:\n" - f'{TAB*2}"s"\n' - f"{TAB}config:\n" - f'{TAB*2}"c"\n' - f"{TAB}skip_validation:\n" - f"{TAB*2}True\n" - f"{TAB}prefix:\n" - f'{TAB*2}"p"\n' - f"{TAB}replace_prefix:\n" - f'{TAB*2}"rp"\n' - f"{TAB}meta_wrapper:\n" - f'{TAB*2}"wrapper"\n' - ) - assert formatter.get_formatted() == expected + formatter = setup_formatter(self.sort_module[0], sort_params=True) + assert formatter.get_formatted() == self.sort_module[1] def test_sorting_checkpoint(self): snakecode = ( @@ -2015,7 +2013,7 @@ def side_effect(*args, **kwargs): formatter.snakefile = smk formatter.black_mode = black.Mode() formatter.from_python = False - formatter.fmt_off = False + formatter.fmt_off = None from snakefmt.parser.parser import Context from snakefmt.parser.syntax import KeywordSyntax @@ -2139,6 +2137,12 @@ def test_fmt_off_on(self): code1 = "\n\n# fmt: on\n" + code expected = "# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected + code1 = code + "\n\n# fmt: on\n" + code + expected = formatted + "\n\n# fmt: on\n" + formatted + assert setup_formatter(code1).get_formatted() == expected + code1 = code + "\n\n# fmt: on\n" + code + expected = formatted + "\n\n# fmt: on\n" + formatted + assert setup_formatter(code1).get_formatted() == expected code1 = "\n# fmt: off\n" + code + "\n# fmt: on\n" + code expected = "# fmt: off\n" + code + "\n# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected @@ -2324,3 +2328,25 @@ def test_fmt_off_on_in_other(self): f'{TAB * 2}"0.72.0/meta/bio/bwa_mapping"\n' ) assert formatter.get_formatted() == expected + + +class TestFmtOffSort: + def test_fmt_off_sort(self): + for code, formatted in ( + TestSortFormatting.sorting_comprehensive, + TestSortFormatting.sort_with_coments, + TestSortFormatting.sort_inline_comments, + TestSortFormatting.sort_module, + ): + code1 = code + "\n\n# fmt: on\n" + code + expected = formatted + "\n\n# fmt: on\n" + formatted + assert setup_formatter(code1, sort_params=True).get_formatted() == expected + code1 = "# fmt: off[sort]\n" + code + expected = "# fmt: off[sort]\n" + setup_formatter(code).get_formatted() + assert setup_formatter(code1, sort_params=True).get_formatted() == expected + code2 = code1 + "\n\n# fmt: on[sort]\n" + code + expected2 = expected + "\n\n# fmt: on[sort]\n" + formatted + assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 + code2 = code1 + "\n\n# fmt: on\n" + code + expected2 = expected + "\n\n# fmt: on\n" + formatted + assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 From 3c79aa50bed444f3447127b3e655c3e7cbb585a7 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 26 Mar 2026 11:51:24 +0800 Subject: [PATCH 05/53] fix: ugly reset indents before process keywords --- snakefmt/parser/parser.py | 9 ++- tests/test_formatter.py | 162 +++++++++++++++++++++++++++++++++++++- 2 files changed, 166 insertions(+), 5 deletions(-) diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index d4b5b37..d5ffa1b 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -153,7 +153,7 @@ def __init__(self, snakefile: Snakefile): self.last_token: Optional[Token] = None self.fmt_sort_off: Optional[int] # for `# fmt: off`, (indent, kind); kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" - self.fmt_off: Optional[tuple[int, Literal["next", "region", "sort"]]] = None + self.fmt_off: Optional[tuple[int, Literal["next", "region"]]] = None # True if a new block should be formatted as fmt: off due to a preceding fmt directive self.fmt_off_applied: bool = False @@ -486,6 +486,11 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: accepts_py=new_vocab is PythonCode, ), ) + # should reset index here + line = status.token.line + indent = line[: len(line) - len(line.lstrip())] + while self.indents and self.indents[-1] != indent: + self.indents.pop() self.process_keyword_context(in_global_context) if self.syntax.enter_context: self.context_stack.append(self.context) @@ -610,6 +615,8 @@ def _handle_indent(self, token: Token) -> bool: line = token.line indent = line[: len(line) - len(line.lstrip())] if indent not in self.indents: + if len(indent) <= len(self.indents[-1]): + breakpoint() self.indents.append(indent) elif token.type == tokenize.DEDENT: line = token.line diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 0ca5156..1b1872c 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -63,11 +63,11 @@ def test_shell_param_newline_indented(self): assert formatter.get_formatted() == self.example_shell_newline[1] example_params_newline = ( - f"rule a: \n" + f"rule b: \n" f'{TAB * 1}input: "a", "b",\n' f'{TAB * 4}"c"\n' f'{TAB * 1}wrapper: "mywrapper"', - "rule a:\n" + f"rule b:\n" f"{TAB * 1}input:\n" f'{TAB * 2}"a",\n' f'{TAB * 2}"b",\n' @@ -81,13 +81,13 @@ def test_single_param_keyword_in_rule_gets_newline_indented(self): assert formatter.get_formatted() == self.example_params_newline[1] example_input_threads_newline = ( - "rule a: \n" + f"rule c: \n" f'{TAB * 1}input: "c"\n' f"{TAB * 1}threads:\n" f"{TAB * 2}20\n" f"{TAB * 1}default_target:\n" f"{TAB * 2}True\n", - f"rule a:\n" + f"rule c:\n" f"{TAB * 1}input:\n" f'{TAB * 2}"c",\n' f"{TAB * 1}threads: 20\n" @@ -2350,3 +2350,157 @@ def test_fmt_off_sort(self): code2 = code1 + "\n\n# fmt: on\n" + code expected2 = expected + "\n\n# fmt: on\n" + formatted assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 + + +class TestFmtOffNext: + # def test_fmt_off_next(self): + # for code, formatted in ( + # TestSimpleParamFormatting.example_shell_newline, + # TestSimpleParamFormatting.example_params_newline, + # TestSimpleParamFormatting.example_input_threads_newline, + # ): + # code1 = "\n\n# fmt: off[next]\n" + code + "\n" + code + # expected = "# fmt: off[next]\n" + code.strip("\n") + "\n\n\n" + formatted + # assert setup_formatter(code1).get_formatted() == expected + # code1 = code + "\n# fmt: off[next]\n" + code + "\n\n\n" + code + # expected = ( + # formatted + # + "# fmt: off[next]\n" + # + code.strip("\n") + # + "\n\n\n" + # + formatted + # ) + # assert setup_formatter(code1).get_formatted() == expected + # code1 = code + "\n# fmt: off[next]\n" + code + # expected = formatted + "# fmt: off[next]\n" + code + # assert setup_formatter(code1).get_formatted() == expected + # code1 = code + "\n# fmt: off[next]\n" + code + "\n\n" + # expected = formatted + "# fmt: off[next]\n" + code + "\n\n" + # assert setup_formatter(code1).get_formatted() == expected + + def test_rule_if_rule(self): + code1, format1 = TestSimpleParamFormatting.example_shell_newline + code2, format2 = TestSimpleParamFormatting.example_params_newline + code3, format3 = TestSimpleParamFormatting.example_input_threads_newline + formatter = setup_formatter( + code1 + f"\n" + f"if 1:\n" + + "".join(f" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + f"{code3}" + ) + expected = ( + format1 + + f"\n\n" + + f"if 1:\n\n" + + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + + f"\n\n" + + format3 + ) + assert formatter.get_formatted() == expected + + def test_rule_if2_rule(self): + code1, format1 = TestSimpleParamFormatting.example_shell_newline + code2, format2 = TestSimpleParamFormatting.example_params_newline + code3, format3 = TestSimpleParamFormatting.example_input_threads_newline + formatter = setup_formatter( + code1 + f"\n" + f"if 1:\n" + f" if 2:\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + f"{code3}" + ) + expected = ( + format1 + + f"\n\n" + + f"if 1:\n" + + f"{TAB * 1}if 2:\n\n" + + "".join(f"{TAB * 2}" + i for i in format2.splitlines(keepends=True)) + + f"\n" + + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + + f"\n\n" + + format3 + ) + assert formatter.get_formatted() == expected + + # def test_fmt_off_next_in_if(self): + # code1, format1 = TestSimpleParamFormatting.example_shell_newline + # code2, format2 = TestSimpleParamFormatting.example_params_newline + # code3, format3 = TestSimpleParamFormatting.example_input_threads_newline + # formatter = setup_formatter( + # code1 + f"\n# fmt: \n" + # f"if 1:\n" + # + "".join(" " + i for i in code2.splitlines(keepends=True)) + # + f"\n" + # f"{code3}" + # ) + # expected = ( + # format1 + # + f"\n\n# fmt:\n" + # + f"if 1:\n\n" + # + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + # + f"\n" + # + f"\n" + # + format3 + # ) + # assert formatter.get_formatted() == expected + # formatter = setup_formatter( + # code1 + f"\n# fmt: off[next]\n" + # f"if 1:\n" + # + "".join(" " + i for i in code2.splitlines(keepends=True)) + # + f"\n" + # f"{code3}" + # ) + # expected = ( + # format1 + f"\n\n# fmt: off[next]\n" + # f"if 1:\n" + # + "".join(" " + i for i in code2.splitlines(keepends=True)) + # + f"\n" + # + f"\n" + # + format3 + # ) + # assert formatter.get_formatted() == expected + + # def test_fmt_off_next_in_2if(self): + # code1, format1 = TestSimpleParamFormatting.example_shell_newline + # code2, format2 = TestSimpleParamFormatting.example_params_newline + # code3, format3 = TestSimpleParamFormatting.example_input_threads_newline + # formatter = setup_formatter( + # code1 + f"\n" + # f"if 1:\n" + # f" \n# fmt:\n" + # + "".join(" " + i for i in code2.splitlines(keepends=True)) + # + f"\n" + # + "".join(" " + i for i in code3.splitlines(keepends=True)) + # ) + # expected = ( + # format1 + # + f"\n\n" + # + f"if 1:\n\n" + # + f"{TAB * 1}# fmt:\n" + # + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + # + f"\n" + # + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) + # ) + # assert formatter.get_formatted() == expected + # formatter = setup_formatter( + # code1 + f"\n" + # f"if 1:\n" + # f" \n# fmt: off[next]\n" + # + "".join(" " + i for i in code2.splitlines(keepends=True)) + # + f"\n" + # + "".join(" " + i for i in code3.splitlines(keepends=True)) + # ) + # expected = ( + # format1 + # + f"\n\n" + # + f"if 1:\n\n" + # + f"{TAB * 1}# fmt: off[next]\n" + # + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)) + # + f"\n" + # + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) + # ) + # assert formatter.get_formatted() == expected From 0b355d0ea5faaaf7ce8bf751688b292165c7e258 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 26 Mar 2026 16:58:14 +0800 Subject: [PATCH 06/53] feat: off[next] --- snakefmt/formatter.py | 18 ++- snakefmt/parser/parser.py | 45 +++++-- tests/test_formatter.py | 259 +++++++++++++++++++++++--------------- 3 files changed, 209 insertions(+), 113 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index abb82ee..c793fc8 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -210,13 +210,25 @@ def handle_fmt_off_region(self, verbatim: str) -> None: self.buffer = "" if not verbatim: return + # When fmt:off[next] is inside a Python block (e.g. `if 1:`), the + # directive ends up as a lagging_comment after flushing that block. + is_nested_next = self.fmt_off and self.fmt_off[1] == "next" if self.lagging_comments: + # For nested fmt:off[next], add the same \n separator that + # process_keyword_context/add_newlines would normally provide + # before the first keyword inside the Python block. + if is_nested_next and not self.no_formatting_yet: + self.result += "\n" self.result += self.lagging_comments self.lagging_comments = "" self.result += verbatim - # Treat the verbatim region as transparent to separator logic: - # resume formatting as if nothing preceded (no blank-line separator added). - self.no_formatting_yet = True + # For fmt: off[next], mark that we've emitted content so the following + # block gets its normal blank-line separator. + # For fmt: off regions, treat verbatim as transparent to separator logic. + if is_nested_next: + self.no_formatting_yet = bool(self.lagging_comments) + else: + self.no_formatting_yet = True self.last_recognised_keyword = "" def run_black_format_str( diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index d5ffa1b..aeea051 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -193,8 +193,16 @@ def __init__(self, snakefile: Snakefile): self.fmt_off_applied = False if self.vocab.recognises(keyword) and self.fmt_off: + if self.keyword_indent < status.cur_indent and ( + self.syntax.from_python or status.pythonable + ): + self.from_python = True + self.flush_buffer( + from_python=self.from_python, + in_global_context=self.in_global_context, + ) self.fmt_off_applied = True - self._consume_fmt_off(status.token, min_indent=self.keyword_indent) + self._consume_fmt_off(status.token, min_indent=status.cur_indent) if self.fmt_off and self.fmt_off[1] == "next": self.fmt_off = None self.fmt_off_applied = False @@ -239,7 +247,10 @@ def __init__(self, snakefile: Snakefile): ) if self.keyword_indent > 0: self.syntax.add_processed_keyword(status.token, keyword) - self._consume_fmt_off(status.token, min_indent=self.keyword_indent) + self._consume_fmt_off(status.token, min_indent=status.cur_indent) + if self.fmt_off and self.fmt_off[1] == "next": + self.fmt_off = None + self.fmt_off_applied = False self.buffer = "" status = self.get_next_queriable() if self.last_block_was_snakecode and not status.eof: @@ -341,6 +352,10 @@ def _consume_python( prev_token = None last_indent_token = None min_indent = -1 + # If stop_at_min is True, also stop when dedenting back to min_indent level + # (used for fmt: off[next] to consume exactly one block). + to_consume_next = self.fmt_off and self.fmt_off[1] == "next" + consuming_next = False # used with stop_at_min def _init_min_indent(token: Token): nonlocal min_indent @@ -373,13 +388,18 @@ def _init_min_indent(token: Token): self._handle_indent(token) self.syntax.cur_indent = len(self.indents) - 1 last_indent_token = token + if to_consume_next and len(self.indents) - 1 > min_indent: + consuming_next = True + to_consume_next = False continue if token.type == tokenize.DEDENT: saved_indents = list(self.indents) self._handle_indent(token) new_indent = len(self.indents) - 1 last_indent_token = None - if new_indent < min_indent: + if new_indent < min_indent or ( + consuming_next and new_indent == min_indent + ): # let get_next_queriable handle dedent below min_indent self.indents = saved_indents self.snakefile.denext(token) @@ -404,6 +424,10 @@ def _init_min_indent(token: Token): self.indents.pop() self.syntax.cur_indent = len(self.indents) - 1 break + # `# fmt: off` may within Python code, apply it to the next snakemake keyword. + if fmt_label := FMT_DIRECTIVE.from_token(token): + if fmt_label.disable and "next" in fmt_label.modifiers: + self.fmt_off = (self.syntax.cur_indent, "next") elif fmt_on := self._check_fmt_on(token): if fmt_on == "region": self.fmt_off = None @@ -426,6 +450,11 @@ def _init_min_indent(token: Token): lines, string_interior_lines, origin_indent, added_indent ) next_status = self.get_next_queriable() + if consuming_next and verbatim: + # Strip extra trailing blank lines; the following block's separator + # logic (add_newlines) will provide the correct spacing. + while verbatim.endswith("\n\n"): + verbatim = verbatim[:-1] return verbatim, next_status._replace( pythonable=next_status.pythonable or bool(verbatim.strip()) ) @@ -486,11 +515,6 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: accepts_py=new_vocab is PythonCode, ), ) - # should reset index here - line = status.token.line - indent = line[: len(line) - len(line.lstrip())] - while self.indents and self.indents[-1] != indent: - self.indents.pop() self.process_keyword_context(in_global_context) if self.syntax.enter_context: self.context_stack.append(self.context) @@ -548,6 +572,11 @@ def context_exit(self, status: Status) -> None: self.block_indent = self.cur_indent if self.keyword_indent > 0: self.syntax.keyword_indent = status.cur_indent + 1 + # ParameterSyntax consumes INDENT/DEDENT tokens without updating + # Parser.indents, leaving stale deeper-level entries. Trim them now + # so get_next_queriable computes the correct cur_indent for the next block. + while len(self.indents) - 1 > status.cur_indent: + self.indents.pop() def _determe_comment_indent(self, token: Token) -> int: """ diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 1b1872c..f52058a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2353,30 +2353,30 @@ def test_fmt_off_sort(self): class TestFmtOffNext: - # def test_fmt_off_next(self): - # for code, formatted in ( - # TestSimpleParamFormatting.example_shell_newline, - # TestSimpleParamFormatting.example_params_newline, - # TestSimpleParamFormatting.example_input_threads_newline, - # ): - # code1 = "\n\n# fmt: off[next]\n" + code + "\n" + code - # expected = "# fmt: off[next]\n" + code.strip("\n") + "\n\n\n" + formatted - # assert setup_formatter(code1).get_formatted() == expected - # code1 = code + "\n# fmt: off[next]\n" + code + "\n\n\n" + code - # expected = ( - # formatted - # + "# fmt: off[next]\n" - # + code.strip("\n") - # + "\n\n\n" - # + formatted - # ) - # assert setup_formatter(code1).get_formatted() == expected - # code1 = code + "\n# fmt: off[next]\n" + code - # expected = formatted + "# fmt: off[next]\n" + code - # assert setup_formatter(code1).get_formatted() == expected - # code1 = code + "\n# fmt: off[next]\n" + code + "\n\n" - # expected = formatted + "# fmt: off[next]\n" + code + "\n\n" - # assert setup_formatter(code1).get_formatted() == expected + def test_fmt_off_next(self): + for code, formatted in ( + TestSimpleParamFormatting.example_shell_newline, + TestSimpleParamFormatting.example_params_newline, + TestSimpleParamFormatting.example_input_threads_newline, + ): + code1 = "\n\n# fmt: off[next]\n" + code + "\n" + code + expected = "# fmt: off[next]\n" + code.strip("\n") + "\n\n\n" + formatted + assert setup_formatter(code1).get_formatted() == expected + code1 = code + "\n# fmt: off[next]\n" + code + "\n\n\n" + code + expected = ( + formatted + + "# fmt: off[next]\n" + + code.strip("\n") + + "\n\n\n" + + formatted + ) + assert setup_formatter(code1).get_formatted() == expected + code1 = code + "\n# fmt: off[next]\n" + code + expected = formatted + "# fmt: off[next]\n" + code + assert setup_formatter(code1).get_formatted() == expected + code1 = code + "\n# fmt: off[next]\n" + code + "\n\n" + expected = formatted + "# fmt: off[next]\n" + code.rstrip("\n") + "\n" + assert setup_formatter(code1).get_formatted() == expected def test_rule_if_rule(self): code1, format1 = TestSimpleParamFormatting.example_shell_newline @@ -2425,82 +2425,137 @@ def test_rule_if2_rule(self): + format3 ) assert formatter.get_formatted() == expected + formatter = setup_formatter( + code1 + f"\n" + f"if 1:\n" + f" if 2:\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)).rstrip("\n") + + f"\n" + f" # fmt: off[next]\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + f"{code3}" + ) + expected = ( + format1 + + f"\n\n" + + f"if 1:\n" + + f"{TAB * 1}if 2:\n\n" + + "".join( + f"{TAB * 2}" + i for i in format2.splitlines(keepends=True) + ).rstrip("\n") + + f"\n" + f"{TAB * 1}# fmt: off[next]\n" + + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)) + + "\n" + + f"\n\n" + + format3 + ) + assert formatter.get_formatted() == expected + formatter = setup_formatter( + code1 + f"\n" + f"if 1:\n" + f" if 2:\n" + f" # fmt: off[next]\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)).rstrip("\n") + + f"\n" + f"{code3}" + ) + expected = ( + format1 + f"\n\n" + f"if 1:\n" + f"{TAB * 1}if 2:\n\n" + f"{TAB * 2}# fmt: off[next]\n" + + "".join(f"{TAB * 2}" + i for i in code2.splitlines(keepends=True)).rstrip( + "\n" + ) + + "\n" + + "\n" + + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + + f"\n\n" + + format3 + ) + assert formatter.get_formatted() == expected - # def test_fmt_off_next_in_if(self): - # code1, format1 = TestSimpleParamFormatting.example_shell_newline - # code2, format2 = TestSimpleParamFormatting.example_params_newline - # code3, format3 = TestSimpleParamFormatting.example_input_threads_newline - # formatter = setup_formatter( - # code1 + f"\n# fmt: \n" - # f"if 1:\n" - # + "".join(" " + i for i in code2.splitlines(keepends=True)) - # + f"\n" - # f"{code3}" - # ) - # expected = ( - # format1 - # + f"\n\n# fmt:\n" - # + f"if 1:\n\n" - # + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) - # + f"\n" - # + f"\n" - # + format3 - # ) - # assert formatter.get_formatted() == expected - # formatter = setup_formatter( - # code1 + f"\n# fmt: off[next]\n" - # f"if 1:\n" - # + "".join(" " + i for i in code2.splitlines(keepends=True)) - # + f"\n" - # f"{code3}" - # ) - # expected = ( - # format1 + f"\n\n# fmt: off[next]\n" - # f"if 1:\n" - # + "".join(" " + i for i in code2.splitlines(keepends=True)) - # + f"\n" - # + f"\n" - # + format3 - # ) - # assert formatter.get_formatted() == expected - - # def test_fmt_off_next_in_2if(self): - # code1, format1 = TestSimpleParamFormatting.example_shell_newline - # code2, format2 = TestSimpleParamFormatting.example_params_newline - # code3, format3 = TestSimpleParamFormatting.example_input_threads_newline - # formatter = setup_formatter( - # code1 + f"\n" - # f"if 1:\n" - # f" \n# fmt:\n" - # + "".join(" " + i for i in code2.splitlines(keepends=True)) - # + f"\n" - # + "".join(" " + i for i in code3.splitlines(keepends=True)) - # ) - # expected = ( - # format1 - # + f"\n\n" - # + f"if 1:\n\n" - # + f"{TAB * 1}# fmt:\n" - # + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) - # + f"\n" - # + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) - # ) - # assert formatter.get_formatted() == expected - # formatter = setup_formatter( - # code1 + f"\n" - # f"if 1:\n" - # f" \n# fmt: off[next]\n" - # + "".join(" " + i for i in code2.splitlines(keepends=True)) - # + f"\n" - # + "".join(" " + i for i in code3.splitlines(keepends=True)) - # ) - # expected = ( - # format1 - # + f"\n\n" - # + f"if 1:\n\n" - # + f"{TAB * 1}# fmt: off[next]\n" - # + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)) - # + f"\n" - # + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) - # ) - # assert formatter.get_formatted() == expected + def test_fmt_off_next_in_if(self): + code1, format1 = TestSimpleParamFormatting.example_shell_newline + code2, format2 = TestSimpleParamFormatting.example_params_newline + code3, format3 = TestSimpleParamFormatting.example_input_threads_newline + formatter = setup_formatter( + code1 + f"\n# fmt: \n" + f"if 1:\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + f"{code3}" + ) + expected = ( + format1 + + f"\n\n# fmt:\n" + + f"if 1:\n\n" + + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + + f"\n" + + f"\n" + + format3 + ) + assert formatter.get_formatted() == expected + formatter = setup_formatter( + code1.rstrip("\n") + f"\n# fmt: off[next]\n" + f"if 1:\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + + code3 + ) + expected = ( + format1.rstrip("\n") + f"\n# fmt: off[next]\n" + f"if 1:\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n\n\n" + + format3 + ) + assert formatter.get_formatted() == expected + + def test_fmt_off_next_in_2if(self): + code1, format1 = TestSimpleParamFormatting.example_shell_newline + code2, format2 = TestSimpleParamFormatting.example_params_newline + code3, format3 = TestSimpleParamFormatting.example_input_threads_newline + formatter = setup_formatter( + code1.rstrip("\n") + f"\n" + f"if 1:\n" + f" \n# fmt:\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + + "".join(" " + i for i in code3.splitlines(keepends=True)) + ) + expected = ( + format1.rstrip("\n") + f"\n" + f"\n\n" + f"if 1:\n\n" + f"{TAB * 1}# fmt:\n" + + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + + f"\n" + + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) + ) + assert formatter.get_formatted() == expected + formatter = setup_formatter( + code1.rstrip("\n") + f"\n" + f"if 1:\n" + f" \n# fmt: off[next]\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)) + + f"\n" + + "".join(" " + i for i in code3.splitlines(keepends=True)) + ) + expected = ( + format1.rstrip("\n") + f"\n" + f"\n\n" + f"if 1:\n\n" + f"{TAB * 1}# fmt: off[next]\n" + + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)).strip( + "\n" + ) + + f"\n" + + f"\n" + + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) + ) + assert formatter.get_formatted() == expected From b0746178e0391fcdd2fe358f2111e74df0d5191c Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 02:27:37 +0800 Subject: [PATCH 07/53] fix: clean --- snakefmt/formatter.py | 24 +++--- snakefmt/parser/parser.py | 176 +++++++++++++++++++++----------------- snakefmt/parser/syntax.py | 2 +- tests/test_formatter.py | 54 +++++++++--- 4 files changed, 151 insertions(+), 105 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index c793fc8..f601725 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -106,6 +106,9 @@ def flush_buffer( else: # Invalid python syntax, eg lone 'else:' between two rules, can occur. # Below constructs valid code statements and formats them. + if self.fmt_off_expected_index: + self.buffer += self.fmt_off_expected_index + self.fmt_off_expected_index = "" re_match = contextual_matcher.match(self.buffer) if re_match is not None: callback_keyword = re_match.group(2) @@ -122,11 +125,12 @@ def flush_buffer( ) formatted = self.run_black_format_str(to_format, self.block_indent) re_rematch = contextual_matcher.match(formatted) - if re_rematch is None: - raise ValueError( - "contextual_matcher failed to match for the given " - f"formatted string: {formatted}" - ) + assert re_rematch, ( + "This should always match as we just formatted it with the same regex. " + "If this error is raised, it's a bug in snakefmt's handling of snakemake syntax. " + "Please report this to the developers with the code so we can fix it: " + "https://github.com/snakemake/snakefmt/issues" + ) if condition != "": callback_keyword += re_rematch.group(3) formatted = ( @@ -251,9 +255,6 @@ def run_black_format_str( and len(string.strip().splitlines()) > 1 and not no_nesting ) - if self.fmt_off and self.fmt_off_applied: - # a `fmt: off` in previous block also affects here, make it work - string = "# fmt: off\n" + string if artificial_nest: string = f"if x:\n{textwrap.indent(string, TAB)}" @@ -305,18 +306,13 @@ def run_black_format_str( "\n\n(Note reported line number may be incorrect, as" " snakefmt could not determine the true line number)" ) - err_msg = f"Black error:\n```\n{str(err_msg)}\n``` from\n```\n{str(string)}\n```\n" + err_msg = f"Black error:\n```\n{str(err_msg)}\n```\n" raise InvalidPython(err_msg) from None if artificial_nest: lines = fmted.splitlines(keepends=True)[1:] s = "".join(lines).lstrip("\n") fmted = textwrap.dedent(s) - if self.fmt_off: - if self.fmt_off_applied: - fmted = fmted.split("# fmt: off\n", 1)[1] - else: - self.fmt_off_applied = True return fmted def align_strings(self, string: str, target_indent: int) -> str: diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index aeea051..9d1a3db 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -154,8 +154,7 @@ def __init__(self, snakefile: Snakefile): self.fmt_sort_off: Optional[int] # for `# fmt: off`, (indent, kind); kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" self.fmt_off: Optional[tuple[int, Literal["next", "region"]]] = None - # True if a new block should be formatted as fmt: off due to a preceding fmt directive - self.fmt_off_applied: bool = False + self.fmt_off_expected_index: str = "" self.indents: list[str] = [""] @@ -177,39 +176,51 @@ def __init__(self, snakefile: Snakefile): if fmt_label.disable: if not fmt_label.modifiers: self.fmt_off = (status.cur_indent, "region") + self.fmt_off_expected_index = status.token.line[ + : col_nb(status.token) + ] elif "next" in fmt_label.modifiers: self.fmt_off = (status.cur_indent, "next") + self.fmt_off_expected_index = status.token.line[ + : col_nb(status.token) + ] elif "sort" in fmt_label.modifiers: self.fmt_sort_off = status.cur_indent - elif fmt_on := self._check_fmt_on(status.token): - if fmt_on == "region": - self.fmt_off = None - self.fmt_off_applied = False - elif fmt_on == "sort": - self.fmt_sort_off = None - continue + elif self._check_fmt_on(status.token) == "sort": + self.fmt_sort_off = None + continue elif self.fmt_off and status.cur_indent <= self.fmt_off[0]: self.fmt_off = None - self.fmt_off_applied = False - if self.vocab.recognises(keyword) and self.fmt_off: - if self.keyword_indent < status.cur_indent and ( - self.syntax.from_python or status.pythonable - ): - self.from_python = True - self.flush_buffer( - from_python=self.from_python, - in_global_context=self.in_global_context, - ) - self.fmt_off_applied = True - self._consume_fmt_off(status.token, min_indent=status.cur_indent) - if self.fmt_off and self.fmt_off[1] == "next": - self.fmt_off = None - self.fmt_off_applied = False - status = self.get_next_queriable() - if self.last_block_was_snakecode and not status.eof: - self.block_indent = status.block_indent - self.last_block_was_snakecode = False + if self.fmt_off: + if self.vocab.recognises(keyword): + if self.keyword_indent < status.cur_indent and ( + self.syntax.from_python or status.pythonable + ): + self.from_python = True + self.flush_buffer( + from_python=self.from_python, + in_global_context=self.in_global_context, + ) + status = self._consume_fmt_off( + status.token, min_indent=status.cur_indent + ) + else: + self.flush_buffer( + from_python=True, + in_global_context=self.in_global_context, + ) + if self.keyword_indent > 0: + self.syntax.add_processed_keyword(status.token, keyword) + status = self._consume_fmt_off( + status.token, min_indent=status.cur_indent + ) + self.buffer = "" + if self.last_block_was_snakecode and not status.eof: + self.block_indent = status.block_indent + self.last_block_was_snakecode = False + if self.keyword_indent: + self.last_block_was_snakecode = True self.buffer = status.buffer.lstrip() elif self.vocab.recognises(keyword): new_vocab, new_syntax_cls = self.vocab.get(keyword) @@ -236,29 +247,6 @@ def __init__(self, snakefile: Snakefile): f"L{status.token.start[0]}: Unrecognised keyword '{keyword}' " f"in {self.syntax.keyword_name} definition" ) - elif ( - (fmt_label := FMT_DIRECTIVE.from_token(status.token)) - and fmt_label.disable - and ("sort" not in fmt_label.modifiers) - ): - self.flush_buffer( - from_python=self.from_python, - in_global_context=self.in_global_context, - ) - if self.keyword_indent > 0: - self.syntax.add_processed_keyword(status.token, keyword) - self._consume_fmt_off(status.token, min_indent=status.cur_indent) - if self.fmt_off and self.fmt_off[1] == "next": - self.fmt_off = None - self.fmt_off_applied = False - self.buffer = "" - status = self.get_next_queriable() - if self.last_block_was_snakecode and not status.eof: - self.block_indent = status.block_indent - self.last_block_was_snakecode = False - self.buffer = status.buffer.lstrip() - if self.keyword_indent: - self.last_block_was_snakecode = True else: source, status = self._consume_python(status.token) self.buffer += source @@ -354,8 +342,9 @@ def _consume_python( min_indent = -1 # If stop_at_min is True, also stop when dedenting back to min_indent level # (used for fmt: off[next] to consume exactly one block). - to_consume_next = self.fmt_off and self.fmt_off[1] == "next" + is_next_mode = self.fmt_off and self.fmt_off[1] == "next" consuming_next = False # used with stop_at_min + seen_next_block_keyword = False def _init_min_indent(token: Token): nonlocal min_indent @@ -384,15 +373,14 @@ def _init_min_indent(token: Token): if token.type == tokenize.ENDMARKER: self.snakefile.denext(token) break - if token.type == tokenize.INDENT: + elif token.type == tokenize.INDENT: self._handle_indent(token) self.syntax.cur_indent = len(self.indents) - 1 last_indent_token = token - if to_consume_next and len(self.indents) - 1 > min_indent: + if is_next_mode and len(self.indents) - 1 > min_indent: consuming_next = True - to_consume_next = False continue - if token.type == tokenize.DEDENT: + elif token.type == tokenize.DEDENT: saved_indents = list(self.indents) self._handle_indent(token) new_indent = len(self.indents) - 1 @@ -406,37 +394,52 @@ def _init_min_indent(token: Token): break self.syntax.cur_indent = new_indent continue - if is_newline(token): + elif is_newline(token): self.queriable = True lines.update(split_token_lines(token)) continue - if vocab_recognises: - if ( - (token.type == tokenize.NAME or token.string == "@") - and self.queriable - and not self.in_fstring - ): - if self.vocab.recognises(token.string): - # snakemake keyword: stop, let main loop handle it + elif ( + (token.type == tokenize.NAME or token.string == "@") + and self.queriable + and not self.in_fstring + and self.vocab.recognises(token.string) + ): + if is_next_mode: + if seen_next_block_keyword: + # fmt: off[next] consumed one whole keyword block; + # hand the next same-level block back to main loop. self.snakefile.denext(token) if last_indent_token is not None: self.snakefile.denext(last_indent_token) self.indents.pop() self.syntax.cur_indent = len(self.indents) - 1 break - # `# fmt: off` may within Python code, apply it to the next snakemake keyword. - if fmt_label := FMT_DIRECTIVE.from_token(token): - if fmt_label.disable and "next" in fmt_label.modifiers: - self.fmt_off = (self.syntax.cur_indent, "next") - elif fmt_on := self._check_fmt_on(token): - if fmt_on == "region": + else: + seen_next_block_keyword = True + if vocab_recognises: + # snakemake keyword: stop, let main loop handle it + self.snakefile.denext(token) + if last_indent_token is not None: + self.snakefile.denext(last_indent_token) + self.indents.pop() + self.syntax.cur_indent = len(self.indents) - 1 + break + # `# fmt: off[next]` within Python code: stop and let main loop handle it. + elif fmt_label := FMT_DIRECTIVE.from_token(token): + if fmt_label.disable: + if fmt_label.modifiers: + # `# fmt: off[` is not actual format diabler, it affects limited + if not self.fmt_off or ( + # two following [next] + self.fmt_off[1] != "region" + and self._determe_comment_indent(token) == self.fmt_off[0] + ): + self.snakefile.denext(token) + break + elif self._check_fmt_on(token) == "region": self.fmt_off = None - self.fmt_off_applied = False lines.update(split_token_lines(token)) break - elif fmt_on == "sort": - self.fmt_sort_off = None - continue self.queriable = False lines.update(split_token_lines(token)) @@ -470,6 +473,9 @@ def _consume_fmt_off(self, start_token: Token, min_indent: int): self.handle_fmt_off_region(verbatim) self.snakefile.denext(next_status.token) self.queriable = True + if self.fmt_off and self.fmt_off[1] == "next": + self.fmt_off = None + return self.get_next_queriable() def _reindent( self, @@ -485,7 +491,10 @@ def _reindent( newlines.append(line) elif line.strip(): newline = line.rsplit("\n", 1) - newline[0] = added_indent + newline[0][origin_indent:] + if newline[0][:origin_indent].strip(): + newline[0] = added_indent + newline[0].lstrip() + else: + newline[0] = added_indent + newline[0][origin_indent:] newlines.append("\n".join(newline)) else: newlines.append(line[origin_indent:]) @@ -644,8 +653,6 @@ def _handle_indent(self, token: Token) -> bool: line = token.line indent = line[: len(line) - len(line.lstrip())] if indent not in self.indents: - if len(indent) <= len(self.indents[-1]): - breakpoint() self.indents.append(indent) elif token.type == tokenize.DEDENT: line = token.line @@ -686,7 +693,12 @@ def get_next_queriable(self) -> Status: token, block_indent, self.cur_indent, buffer, True, pythonable ) elif token.type == tokenize.COMMENT: - if FMT_DIRECTIVE.from_token(token) and col_nb(token) == 0: + fmt_dir = FMT_DIRECTIVE.from_token(token) + if ( + fmt_dir + and col_nb(token) == 0 + and not (fmt_dir.disable and "next" in (fmt_dir.modifiers or [])) + ): # col-0 comments report cur_indent=0 to trigger context_exit; # fmt directives at other columns report actual cur_indent. return Status( @@ -703,6 +715,12 @@ def get_next_queriable(self) -> Status: return Status( token, block_indent, effective_indent, buffer, False, pythonable ) + # A `# fmt: off[next]` directive at any indent always triggers verbatim + # mode for the next snakemake block — return it so the main loop can act. + if fmt_dir and fmt_dir.disable and "next" in (fmt_dir.modifiers or []): + return Status( + token, block_indent, effective_indent, buffer, False, pythonable + ) elif is_newline(token): self.queriable, newline = True, True diff --git a/snakefmt/parser/syntax.py b/snakefmt/parser/syntax.py index 3e0900f..ae90fc5 100644 --- a/snakefmt/parser/syntax.py +++ b/snakefmt/parser/syntax.py @@ -309,7 +309,7 @@ def __init__( self.keyword_indent = keyword_indent self.cur_indent = max(self.keyword_indent - 1, 0) self.comment = "" - self.token = None + self.token: Token if snakefile is not None: self.validate_keyword_line(snakefile) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index f52058a..3da638c 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2463,10 +2463,8 @@ def test_rule_if2_rule(self): + f"\n" f"{code3}" ) - expected = ( - format1 + f"\n\n" - f"if 1:\n" - f"{TAB * 1}if 2:\n\n" + expected1 = format1 + f"\n\n" f"if 1:\n" f"{TAB * 1}if 2:\n" + expected2 = ( f"{TAB * 2}# fmt: off[next]\n" + "".join(f"{TAB * 2}" + i for i in code2.splitlines(keepends=True)).rstrip( "\n" @@ -2477,7 +2475,8 @@ def test_rule_if2_rule(self): + f"\n\n" + format3 ) - assert formatter.get_formatted() == expected + formatted = formatter.get_formatted() + assert formatted.startswith(expected1) and formatted.endswith(expected2) def test_fmt_off_next_in_if(self): code1, format1 = TestSimpleParamFormatting.example_shell_newline @@ -2541,15 +2540,13 @@ def test_fmt_off_next_in_2if(self): formatter = setup_formatter( code1.rstrip("\n") + f"\n" f"if 1:\n" - f" \n# fmt: off[next]\n" + f"\n # fmt: off[next]\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) + f"\n" + "".join(" " + i for i in code3.splitlines(keepends=True)) ) - expected = ( - format1.rstrip("\n") + f"\n" - f"\n\n" - f"if 1:\n\n" + expected1 = format1.rstrip("\n") + f"\n" f"\n\n" f"if 1:\n" + expected2 = ( f"{TAB * 1}# fmt: off[next]\n" + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)).strip( "\n" @@ -2558,4 +2555,39 @@ def test_fmt_off_next_in_2if(self): + f"\n" + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) ) - assert formatter.get_formatted() == expected + formatted = formatter.get_formatted() + assert formatted.startswith(expected1) and formatted.endswith(expected2) + + def test_fmt_off_2(self): + fomatter = setup_formatter( + "if 1:\n" + " rule a:\n" + ' input: "foo"\n' + " # fmt: off[next]\n" + " rule b:\n" + ' input: "bar"\n' + "\n" + " # fmt: off[next]\n" + " rule c:\n" + ' input: "baz"\n' + "rule d:\n" + ' input: "qux"\n' + ) + assert fomatter.get_formatted() == ( + f"if 1:\n" + f"\n" + f"{TAB}rule a:\n" + f"{TAB}{TAB}input:\n" + f'{TAB}{TAB}{TAB}"foo",\n' + f"{TAB}# fmt: off[next]\n" + f"{TAB}rule b:\n" + f'{TAB} input: "bar"\n' + f"{TAB}# fmt: off[next]\n" + f"{TAB}rule c:\n" + f'{TAB} input: "baz"\n' + f"\n" + f"\n" + f"rule d:\n" + f"{TAB}input:\n" + f'{TAB}{TAB}"qux",\n' + ) From a24b5bfd83fcadec59312eef6b2c2a8128b80f43 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 09:52:59 +0800 Subject: [PATCH 08/53] docs: fmt: off --- README.md | 123 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 6588309..03cc2eb 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,10 @@ design and specifications of [Black][black]. > `--diff` or `--check` options. See [Usage](#usage) for more details. > [!IMPORTANT] -> **Recent Changes:** +> **Recent Changes:** > 1. **Rule and module directives are now sorted by default:** `snakefmt` will automatically sort the order of directives inside rules (e.g. `input`, `output`, `shell`) and modules into a consistent order. You can opt out of this by using the `--no-sort` CLI flag. > 2. **Black upgraded to v26:** The underlying `black` formatter has been upgraded to v26. You will see changes in how implicitly concatenated strings are wrapped (they are now collapsed onto a single line if they fit within the line limit) and other minor adjustments compared to previous versions. -> +> > **Example of expected differences:** > ```python > # Before (Snakefmt older versions) @@ -33,7 +33,7 @@ design and specifications of [Black][black]. > "b.txt", > input: > "a.txt", -> +> > # After (Directives sorted, strings collapsed by Black 26) > rule example: > input: @@ -47,25 +47,34 @@ design and specifications of [Black][black]. [TOC]: # # Table of Contents -- [Install](#install) - - [PyPi](#pypi) - - [Conda](#conda) - - [Containers](#containers) - - [Local](#local) -- [Example File](#example-file) -- [Usage](#usage) - - [Basic Usage](#basic-usage) - - [Full Usage](#full-usage) -- [Configuration](#configuration) - - [Directive Sorting](#directive-sorting) -- [Integration](#integration) - - [Editor Integration](#editor-integration) - - [Version Control Integration](#version-control-integration) - - [Github Actions](#github-actions) -- [Plug Us](#plug-us) -- [Changes](#changes) -- [Contributing](#contributing) -- [Cite](#cite) +1. [Install](#install) + 1. [PyPi](#pypi) + 2. [Conda](#conda) + 3. [Containers](#containers) + 1. [Docker](#docker) + 2. [Singularity](#singularity) + 4. [Local](#local) +2. [Example File](#example-file) +3. [Usage](#usage) + 1. [Basic Usage](#basic-usage) + 2. [Full Usage](#full-usage) +4. [Configuration](#configuration) + 1. [Directive Sorting](#directive-sorting) + 2. [Format Directives](#format-directives) + 1. [`# fmt: off` / `# fmt: on`](#-fmt-off---fmt-on) + 2. [`# fmt: off[sort]`](#-fmt-offsort) + 3. [`# fmt: off[next]`](#-fmt-offnext) + 4. [Example](#example) +5. [Integration](#integration) + 1. [Editor Integration](#editor-integration) + 2. [Version Control Integration](#version-control-integration) + 3. [GitHub Actions](#github-actions) +6. [Plug Us](#plug-us) + 1. [Markdown](#markdown) + 2. [ReStructuredText](#restructuredtext) +7. [Changes](#changes) +8. [Contributing](#contributing) +9. [Cite](#cite) ## Install @@ -313,6 +322,76 @@ This ordering ensures that the directives most frequently used in execution bloc You can disable this feature using the `--no-sort` flag. +### Format Directives + +`snakefmt` supports inline comment directives to control formatting behaviour for specific regions of code. + +#### `# fmt: off` / `# fmt: on` + +Disables all formatting for the region between the two directives. The directives must appear at the same indentation level. A `# fmt: on` at a deeper indent than the matching `# fmt: off` has no effect. + +```python +rule a: + input: + "a.txt", + + +# fmt: off +rule b: + input: "b.txt" + output: + "c.txt" +# fmt: on + + +rule c: + input: + "d.txt", +``` + +Note: inside `run:` blocks and other Python code, `# fmt: off` / `# fmt: on` is passed through to [Black][black] which handles it natively. + +#### `# fmt: off[sort]` + +Disables only directive sorting for the region, while still applying all other formatting. Useful when you want to preserve a custom directive order for a specific rule. + +```python +# fmt: off[sort] +rule keep_my_order: + output: + "result.txt", + input: + "source.txt", + shell: + "cp {input} {output}" +# fmt: on[sort] +``` + +A plain `# fmt: on` (without `[sort]`) also ends a `# fmt: off[sort]` region. + +#### `# fmt: off[next]` + +Disables formatting for the single next Snakemake keyword block (e.g. `rule`, `checkpoint`, `use rule`). Only that one block is left unformatted; subsequent blocks are formatted normally. + +```python +rule formatted: + input: + "a.txt", + output: + "b.txt", + + +# fmt: off[next] +rule unformatted: + input: "a.txt" + output: "b.txt" + + +rule also_formatted: + input: + "a.txt", +``` + #### Example `pyproject.toml` From c89dcde7f403aa8f53ebf4ad2757f902b8c81bee Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 09:58:16 +0800 Subject: [PATCH 09/53] style: sort imports --- snakefmt/formatter.py | 7 +- snakefmt/parser/parser.py | 21 +++--- tests/test_formatter.py | 133 +++++++++++++++++++------------------- 3 files changed, 82 insertions(+), 79 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index f601725..44fab0d 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -126,9 +126,10 @@ def flush_buffer( formatted = self.run_black_format_str(to_format, self.block_indent) re_rematch = contextual_matcher.match(formatted) assert re_rematch, ( - "This should always match as we just formatted it with the same regex. " - "If this error is raised, it's a bug in snakefmt's handling of snakemake syntax. " - "Please report this to the developers with the code so we can fix it: " + "This should always match as we just formatted it with the same " + "regex. If this error is raised, it's a bug in snakefmt's " + "handling of snakemake syntax. Please report this to the " + "developers with the code so we can fix it: " "https://github.com/snakemake/snakefmt/issues" ) if condition != "": diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 9d1a3db..925cad2 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -16,7 +16,6 @@ ) from snakefmt.types import TAB, Token, TokenIterator, col_nb - _FMT_DIRECTIVE_RE = re.compile( r"^# fmt: (off|on)(?:\[(\w+(?:,\s*\w+)*)\])?(?=$|\s{2}|\s#)" ) @@ -51,9 +50,12 @@ def split_token_lines(token: tokenize.TokenInfo): """Token can be multiline. e.g., `f'''\\nplaintext\\n'''` has these tokens: - TokenInfo(type=61 (FSTRING_START), string="f'''", start=(21, 0), end=(21, 4), line="f'''\\n") - TokenInfo(type=62 (FSTRING_MIDDLE), string='\\ncccccccc\\n', start=(21, 4), end=(23, 0), line="f'''\\ncccccccc\\n'''\\n") - TokenInfo(type=63 (FSTRING_END), string="'''", start=(23, 0), end=(23, 3), line="'''\\n") + TokenInfo(type=61 (FSTRING_START), string="f'''", + start=(21, 0), end=(21, 4), line="f'''\\n") + TokenInfo(type=62 (FSTRING_MIDDLE), string='\\ncccccccc\\n', + start=(21, 4), end=(23, 0), line="f'''\\ncccccccc\\n'''\\n") + TokenInfo(type=63 (FSTRING_END), string="'''", + start=(23, 0), end=(23, 3), line="'''\\n") lines should be split to drop overlapping lines and keep unique ones. """ @@ -152,7 +154,8 @@ def __init__(self, snakefile: Snakefile): self.in_fstring = False self.last_token: Optional[Token] = None self.fmt_sort_off: Optional[int] - # for `# fmt: off`, (indent, kind); kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" + # for `# fmt: off`, (indent, kind) + # kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" self.fmt_off: Optional[tuple[int, Literal["next", "region"]]] = None self.fmt_off_expected_index: str = "" @@ -598,9 +601,11 @@ def _determe_comment_indent(self, token: Token) -> int: rule 1 (always): indent of comments >= follow_indent rule 2 (if follow_indent < self.indents[-1]): - indent of comments = max(i for i in self.indents if i <= comment_indent) + epsilon. + indent of comments = max(i for i in self.indents + if i <= comment_indent) + epsilon. - next(self.snakefile) until follow_indent is determined, then put all peeked tokens back. + next(self.snakefile) until follow_indent is determined, + then put all peeked tokens back. """ # ── Step 1: peek ahead to find follow_indent ──────────────────────── peeked: list[Token] = [] @@ -716,7 +721,7 @@ def get_next_queriable(self) -> Status: token, block_indent, effective_indent, buffer, False, pythonable ) # A `# fmt: off[next]` directive at any indent always triggers verbatim - # mode for the next snakemake block — return it so the main loop can act. + # mode for the next snakemake block, return it so the main loop can act. if fmt_dir and fmt_dir.disable and "next" in (fmt_dir.modifiers or []): return Status( token, block_indent, effective_indent, buffer, False, pythonable diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 3da638c..952854c 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1695,10 +1695,10 @@ def test_comment_indentation_in_run_block(self): assert setup_formatter(expected).get_formatted() == expected snakecode = ( "rule fmt_bug_repro:\n" - f" run:\n" - f' if "something nested":\n' - f" pass\n" - f" # Comment gets indented\n" + " run:\n" + ' if "something nested":\n' + " pass\n" + " # Comment gets indented\n" ) assert setup_formatter(snakecode).get_formatted() == expected @@ -2192,7 +2192,7 @@ def test_fmt_off_on_in_run(self): "z = [4, 5, 6]\n" ) assert setup_formatter(code).get_formatted() == expected - snakecode = "rule:\n" f" run:\n" + ( + snakecode = "rule:\n" " run:\n" + ( "".join(f" {i}\n" for i in code.splitlines()) ) snakexpected = "rule:\n" f"{TAB * 1}run:\n" + ( @@ -2383,18 +2383,16 @@ def test_rule_if_rule(self): code2, format2 = TestSimpleParamFormatting.example_params_newline code3, format3 = TestSimpleParamFormatting.example_input_threads_newline formatter = setup_formatter( - code1 + f"\n" - f"if 1:\n" - + "".join(f" " + i for i in code2.splitlines(keepends=True)) - + f"\n" + code1 + "\n" + "if 1:\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) + "\n" f"{code3}" ) expected = ( format1 - + f"\n\n" - + f"if 1:\n\n" + + "\n\n" + + "if 1:\n\n" + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) - + f"\n\n" + + "\n\n" + format3 ) assert formatter.get_formatted() == expected @@ -2404,66 +2402,66 @@ def test_rule_if2_rule(self): code2, format2 = TestSimpleParamFormatting.example_params_newline code3, format3 = TestSimpleParamFormatting.example_input_threads_newline formatter = setup_formatter( - code1 + f"\n" - f"if 1:\n" - f" if 2:\n" + code1 + "\n" + "if 1:\n" + " if 2:\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) - + f"\n" + + "\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) - + f"\n" + + "\n" f"{code3}" ) expected = ( format1 - + f"\n\n" - + f"if 1:\n" - + f"{TAB * 1}if 2:\n\n" + + "\n\n" + + "if 1:\n" + + "{TAB * 1}if 2:\n\n" + "".join(f"{TAB * 2}" + i for i in format2.splitlines(keepends=True)) - + f"\n" + + "\n" + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) - + f"\n\n" + + "\n\n" + format3 ) assert formatter.get_formatted() == expected formatter = setup_formatter( - code1 + f"\n" - f"if 1:\n" - f" if 2:\n" + code1 + "\n" + "if 1:\n" + " if 2:\n" + "".join(" " + i for i in code2.splitlines(keepends=True)).rstrip("\n") - + f"\n" - f" # fmt: off[next]\n" + + "\n" + " # fmt: off[next]\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) - + f"\n" - f"{code3}" + + "\n" + + code3 ) expected = ( format1 - + f"\n\n" - + f"if 1:\n" + + "\n\n" + + "if 1:\n" + f"{TAB * 1}if 2:\n\n" + "".join( f"{TAB * 2}" + i for i in format2.splitlines(keepends=True) ).rstrip("\n") - + f"\n" + + "\n" f"{TAB * 1}# fmt: off[next]\n" + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)) + "\n" - + f"\n\n" + + "\n\n" + format3 ) assert formatter.get_formatted() == expected formatter = setup_formatter( - code1 + f"\n" - f"if 1:\n" - f" if 2:\n" - f" # fmt: off[next]\n" + code1 + "\n" + "if 1:\n" + " if 2:\n" + " # fmt: off[next]\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) - + f"\n" + + "\n" + "".join(" " + i for i in code2.splitlines(keepends=True)).rstrip("\n") - + f"\n" + + "\n" f"{code3}" ) - expected1 = format1 + f"\n\n" f"if 1:\n" f"{TAB * 1}if 2:\n" + expected1 = format1 + "\n\n" "if 1:\n" "{TAB * 1}if 2:\n" expected2 = ( f"{TAB * 2}# fmt: off[next]\n" + "".join(f"{TAB * 2}" + i for i in code2.splitlines(keepends=True)).rstrip( @@ -2472,7 +2470,7 @@ def test_rule_if2_rule(self): + "\n" + "\n" + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) - + f"\n\n" + + "\n\n" + format3 ) formatted = formatter.get_formatted() @@ -2483,34 +2481,34 @@ def test_fmt_off_next_in_if(self): code2, format2 = TestSimpleParamFormatting.example_params_newline code3, format3 = TestSimpleParamFormatting.example_input_threads_newline formatter = setup_formatter( - code1 + f"\n# fmt: \n" - f"if 1:\n" + code1 + "\n# fmt: \n" + "if 1:\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) + f"\n" f"{code3}" ) expected = ( format1 - + f"\n\n# fmt:\n" - + f"if 1:\n\n" + + "\n\n# fmt:\n" + + "if 1:\n\n" + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) - + f"\n" - + f"\n" + + "\n" + + "\n" + format3 ) assert formatter.get_formatted() == expected formatter = setup_formatter( - code1.rstrip("\n") + f"\n# fmt: off[next]\n" - f"if 1:\n" + code1.rstrip("\n") + "\n# fmt: off[next]\n" + "if 1:\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) - + f"\n" + + "\n" + code3 ) expected = ( - format1.rstrip("\n") + f"\n# fmt: off[next]\n" - f"if 1:\n" + format1.rstrip("\n") + "\n# fmt: off[next]\n" + "if 1:\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) - + f"\n\n\n" + + "\n\n\n" + format3 ) assert formatter.get_formatted() == expected @@ -2520,39 +2518,38 @@ def test_fmt_off_next_in_2if(self): code2, format2 = TestSimpleParamFormatting.example_params_newline code3, format3 = TestSimpleParamFormatting.example_input_threads_newline formatter = setup_formatter( - code1.rstrip("\n") + f"\n" - f"if 1:\n" - f" \n# fmt:\n" + code1.rstrip("\n") + "\n" + "if 1:\n" + " \n# fmt:\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) - + f"\n" + + "\n" + "".join(" " + i for i in code3.splitlines(keepends=True)) ) expected = ( - format1.rstrip("\n") + f"\n" - f"\n\n" - f"if 1:\n\n" + format1.rstrip("\n") + "\n" + "\n\n" + "if 1:\n\n" f"{TAB * 1}# fmt:\n" + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) - + f"\n" + + "\n" + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) ) assert formatter.get_formatted() == expected formatter = setup_formatter( - code1.rstrip("\n") + f"\n" - f"if 1:\n" - f"\n # fmt: off[next]\n" + code1.rstrip("\n") + "\n" + "if 1:\n" + "\n # fmt: off[next]\n" + "".join(" " + i for i in code2.splitlines(keepends=True)) - + f"\n" + + "\n" + "".join(" " + i for i in code3.splitlines(keepends=True)) ) - expected1 = format1.rstrip("\n") + f"\n" f"\n\n" f"if 1:\n" + expected1 = format1.rstrip("\n") + "\n" "\n\n" "if 1:\n" expected2 = ( f"{TAB * 1}# fmt: off[next]\n" + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)).strip( "\n" ) - + f"\n" - + f"\n" + + "\n\n" + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) ) formatted = formatter.get_formatted() From 6b96b3097000c02497928242ccfc080c4d6fbb8c Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 10:22:47 +0800 Subject: [PATCH 10/53] fix --- README.md | 52 ++++++++++++++++++--------------------- snakefmt/parser/parser.py | 4 +-- tests/test_formatter.py | 22 ++++++----------- 3 files changed, 33 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 03cc2eb..9900502 100644 --- a/README.md +++ b/README.md @@ -47,34 +47,30 @@ design and specifications of [Black][black]. [TOC]: # # Table of Contents -1. [Install](#install) - 1. [PyPi](#pypi) - 2. [Conda](#conda) - 3. [Containers](#containers) - 1. [Docker](#docker) - 2. [Singularity](#singularity) - 4. [Local](#local) -2. [Example File](#example-file) -3. [Usage](#usage) - 1. [Basic Usage](#basic-usage) - 2. [Full Usage](#full-usage) -4. [Configuration](#configuration) - 1. [Directive Sorting](#directive-sorting) - 2. [Format Directives](#format-directives) - 1. [`# fmt: off` / `# fmt: on`](#-fmt-off---fmt-on) - 2. [`# fmt: off[sort]`](#-fmt-offsort) - 3. [`# fmt: off[next]`](#-fmt-offnext) - 4. [Example](#example) -5. [Integration](#integration) - 1. [Editor Integration](#editor-integration) - 2. [Version Control Integration](#version-control-integration) - 3. [GitHub Actions](#github-actions) -6. [Plug Us](#plug-us) - 1. [Markdown](#markdown) - 2. [ReStructuredText](#restructuredtext) -7. [Changes](#changes) -8. [Contributing](#contributing) -9. [Cite](#cite) +- [Install](#install) + - [PyPi](#pypi) + - [Conda](#conda) + - [Containers](#containers) + - [Docker](#docker) + - [Singularity](#singularity) + - [Local](#local) +- [Example File](#example-file) +- [Usage](#usage) + - [Basic Usage](#basic-usage) + - [Full Usage](#full-usage) +- [Configuration](#configuration) + - [Directive Sorting](#directive-sorting) + - [Format Directives](#format-directives) +- [Integration](#integration) + - [Editor Integration](#editor-integration) + - [Version Control Integration](#version-control-integration) + - [GitHub Actions](#github-actions) +- [Plug Us](#plug-us) + - [Markdown](#markdown) + - [ReStructuredText](#restructuredtext) +- [Changes](#changes) +- [Contributing](#contributing) +- [Cite](#cite) ## Install diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 925cad2..cb39b95 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -706,9 +706,7 @@ def get_next_queriable(self) -> Status: ): # col-0 comments report cur_indent=0 to trigger context_exit; # fmt directives at other columns report actual cur_indent. - return Status( - token, block_indent, self.cur_indent, buffer, False, pythonable - ) + return Status(token, block_indent, 0, buffer, False, pythonable) # Comments arrive in the token stream *before* any following # INDENT/DEDENT tokens, so self.cur_indent still reflects the # previous (potentially higher) level. Delegate to diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 952854c..87aeb9a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1374,8 +1374,7 @@ def test_repeated_parameter_keyword_code_in_between_spacing(self): assert setup_formatter(snakecode).get_formatted() == snakecode def test_double_spacing_for_rules(self): - formatter = setup_formatter( - f"""above_rule = "2spaces" + formatter = setup_formatter(f"""above_rule = "2spaces" rule a: {TAB * 1}threads: 1 @@ -1384,8 +1383,7 @@ def test_double_spacing_for_rules(self): rule b: {TAB * 1}threads: 2 below_rule = "2spaces" -""" - ) +""") expected = f"""above_rule = "2spaces" @@ -1484,13 +1482,11 @@ def test_comment_inside_python_code_sticks_to_rule(self): assert setup_formatter(snakecode).get_formatted() == expected def test_comment_below_keyword_gets_spaced(self): - formatter = setup_formatter( - f"""# Rules + formatter = setup_formatter(f"""# Rules rule all: {TAB * 1}input: output_files # Comment -""" - ) +""") actual = formatter.get_formatted() expected = f"""# Rules @@ -1670,13 +1666,11 @@ def test_shell_indention_long_line(self): class TestStorage: def test_storage(self): - code = textwrap.dedent( - """ + code = textwrap.dedent(""" storage http_local: provider="http", keep_local=True, - """ - ) + """) formatter = setup_formatter(code) assert formatter.get_formatted() == code @@ -2415,7 +2409,7 @@ def test_rule_if2_rule(self): format1 + "\n\n" + "if 1:\n" - + "{TAB * 1}if 2:\n\n" + + f"{TAB * 1}if 2:\n\n" + "".join(f"{TAB * 2}" + i for i in format2.splitlines(keepends=True)) + "\n" + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) @@ -2461,7 +2455,7 @@ def test_rule_if2_rule(self): + "\n" f"{code3}" ) - expected1 = format1 + "\n\n" "if 1:\n" "{TAB * 1}if 2:\n" + expected1 = format1 + "\n\n" "if 1:\n" f"{TAB * 1}if 2:\n" expected2 = ( f"{TAB * 2}# fmt: off[next]\n" + "".join(f"{TAB * 2}" + i for i in code2.splitlines(keepends=True)).rstrip( From 41bb18117d4add6a62ccc6fc9fdfc5ad168d9bfb Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 11:23:33 +0800 Subject: [PATCH 11/53] fix: prompt --- snakefmt/parser/parser.py | 59 +++++++++++++++++---------------------- tests/test_formatter.py | 44 +++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 34 deletions(-) diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index cb39b95..6ffa948 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -195,37 +195,7 @@ def __init__(self, snakefile: Snakefile): elif self.fmt_off and status.cur_indent <= self.fmt_off[0]: self.fmt_off = None - if self.fmt_off: - if self.vocab.recognises(keyword): - if self.keyword_indent < status.cur_indent and ( - self.syntax.from_python or status.pythonable - ): - self.from_python = True - self.flush_buffer( - from_python=self.from_python, - in_global_context=self.in_global_context, - ) - status = self._consume_fmt_off( - status.token, min_indent=status.cur_indent - ) - else: - self.flush_buffer( - from_python=True, - in_global_context=self.in_global_context, - ) - if self.keyword_indent > 0: - self.syntax.add_processed_keyword(status.token, keyword) - status = self._consume_fmt_off( - status.token, min_indent=status.cur_indent - ) - self.buffer = "" - if self.last_block_was_snakecode and not status.eof: - self.block_indent = status.block_indent - self.last_block_was_snakecode = False - if self.keyword_indent: - self.last_block_was_snakecode = True - self.buffer = status.buffer.lstrip() - elif self.vocab.recognises(keyword): + if self.vocab.recognises(keyword): new_vocab, new_syntax_cls = self.vocab.get(keyword) is_context_kw = new_vocab is not None and issubclass( new_syntax_cls, KeywordSyntax @@ -244,6 +214,23 @@ def __init__(self, snakefile: Snakefile): status = self.process_keyword(status, self.from_python) self.block_indent = status.cur_indent self.last_block_was_snakecode = True + elif self.fmt_off: + self.flush_buffer( + from_python=True, + in_global_context=self.in_global_context, + ) + if self.keyword_indent > 0: + self.syntax.add_processed_keyword(status.token, keyword) + status = self._consume_fmt_off( + status.token, min_indent=status.cur_indent + ) + self.buffer = "" + if self.last_block_was_snakecode and not status.eof: + self.block_indent = status.block_indent + self.last_block_was_snakecode = False + if self.keyword_indent: + self.last_block_was_snakecode = True + self.buffer = status.buffer.lstrip() else: if not self.syntax.accepts_python_code and not comment_start(keyword): raise SyntaxError( @@ -439,9 +426,13 @@ def _init_min_indent(token: Token): ): self.snakefile.denext(token) break - elif self._check_fmt_on(token) == "region": - self.fmt_off = None - lines.update(split_token_lines(token)) + elif fmt_on := self._check_fmt_on(token): + if fmt_on == "region": + self.fmt_off = None + lines.update(split_token_lines(token)) + elif fmt_on == "sort": + self.fmt_sort_off = None + self.snakefile.denext(token) break self.queriable = False diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 87aeb9a..b2468c3 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2345,6 +2345,50 @@ def test_fmt_off_sort(self): expected2 = expected + "\n\n# fmt: on\n" + formatted assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 + def test_fmt_off_sort_between_directive(self): + code = ( + "rule all:\n" + f"{TAB}params: p=1\n" + f"{TAB}resources: mem_mb=100\n" + f"{TAB}threads: 4\n" + f"{TAB}conda: 'env.yaml'\n" + f"{TAB}message: 'finishing'\n" + f"{TAB}log: 'log.txt'\n" + f"{TAB}# fmt: off[sort]\n" + f"{TAB}output: 'out.txt'\n" + f"{TAB}# fmt: on[sort]\n" + f"{TAB}# Important input\n" + f"{TAB}input: 'in.txt'\n" + f"{TAB}name: 'myrule'\n" + f"{TAB}shell: 'echo done'\n" + ) + expected = ( + "rule all:\n" + f"{TAB}name:\n" + f'{TAB*2}"myrule"\n' + f"{TAB}# fmt: off[sort]\n" + f"{TAB}output:\n" + f'{TAB*2}"out.txt",\n' + f"{TAB}# fmt: on[sort]\n" + f"{TAB}# Important input\n" + f"{TAB}input:\n" + f'{TAB*2}"in.txt",\n' + f"{TAB}log:\n" + f'{TAB*2}"log.txt",\n' + f"{TAB}conda:\n" + f'{TAB*2}"env.yaml"\n' + f"{TAB}threads: 4\n" + f"{TAB}resources:\n" + f"{TAB*2}mem_mb=100,\n" + f"{TAB}params:\n" + f"{TAB*2}p=1,\n" + f"{TAB}message:\n" + f'{TAB*2}"finishing"\n' + f"{TAB}shell:\n" + f'{TAB*2}"echo done"\n' + ) + assert setup_formatter(code, sort_params=True).get_formatted() == expected + class TestFmtOffNext: def test_fmt_off_next(self): From afd273c8a99da3f207181177b1caf6c6fac3d4e8 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 15:08:06 +0800 Subject: [PATCH 12/53] fix: more tests --- snakefmt/formatter.py | 14 ++++++--- snakefmt/parser/parser.py | 64 ++++++++++++++++++++++++--------------- tests/test_formatter.py | 55 +++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 30 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index 44fab0d..e87da52 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -205,7 +205,10 @@ def post_process_keyword(self): ) self.result = self.previous_result + self.result self.previous_result = "" - if self.no_formatting_yet and self.result.rstrip("\n"): + # Keep no_formatting_yet when there is pending buffered content. + # This prevents premature separator insertion after fmt: off/on + # verbatim regions before the next flush occurs. + if self.no_formatting_yet and self.result.rstrip("\n") and not self.buffer: self.no_formatting_yet = False def handle_fmt_off_region(self, verbatim: str) -> None: @@ -226,14 +229,15 @@ def handle_fmt_off_region(self, verbatim: str) -> None: self.result += "\n" self.result += self.lagging_comments self.lagging_comments = "" + if self.fmt_off_preceded_by_blank_line: + if self.result and not self.result.endswith("\n\n"): + self.result += "\n" + self.fmt_off_preceded_by_blank_line = False self.result += verbatim # For fmt: off[next], mark that we've emitted content so the following # block gets its normal blank-line separator. # For fmt: off regions, treat verbatim as transparent to separator logic. - if is_nested_next: - self.no_formatting_yet = bool(self.lagging_comments) - else: - self.no_formatting_yet = True + self.no_formatting_yet = not is_nested_next self.last_recognised_keyword = "" def run_black_format_str( diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 6ffa948..f2e2929 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -158,6 +158,7 @@ def __init__(self, snakefile: Snakefile): # kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" self.fmt_off: Optional[tuple[int, Literal["next", "region"]]] = None self.fmt_off_expected_index: str = "" + self.fmt_off_preceded_by_blank_line: bool = False self.indents: list[str] = [""] @@ -189,11 +190,14 @@ def __init__(self, snakefile: Snakefile): ] elif "sort" in fmt_label.modifiers: self.fmt_sort_off = status.cur_indent - elif self._check_fmt_on(status.token) == "sort": - self.fmt_sort_off = None + elif self._check_fmt_on(fmt_label, status.token) == "sort": continue elif self.fmt_off and status.cur_indent <= self.fmt_off[0]: self.fmt_off = None + elif ( + self.fmt_sort_off is not None and status.cur_indent < self.fmt_sort_off + ): + self.fmt_sort_off = None if self.vocab.recognises(keyword): new_vocab, new_syntax_cls = self.vocab.get(keyword) @@ -426,13 +430,17 @@ def _init_min_indent(token: Token): ): self.snakefile.denext(token) break - elif fmt_on := self._check_fmt_on(token): + elif self.in_global_context: + # In global Python context, plain `# fmt: off` starts a parser + # verbatim region. In non-global Python contexts (e.g. run:), it + # stays inside Python and is handled by Black. + last_line = lines[max(lines)] if lines else "" + self.fmt_off_preceded_by_blank_line = not last_line.strip() + self.snakefile.denext(token) + break + elif fmt_on := self._check_fmt_on(fmt_label, token): if fmt_on == "region": - self.fmt_off = None lines.update(split_token_lines(token)) - elif fmt_on == "sort": - self.fmt_sort_off = None - self.snakefile.denext(token) break self.queriable = False @@ -625,24 +633,21 @@ def _determe_comment_indent(self, token: Token) -> int: # highest indent level fitting within the comment's column. return max(check_indent(token.line, self.indents), follow_indent) - def _check_fmt_on(self, token: Token): + def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): """Return True if token ends the current fmt:off region.""" - if not (fmt_dir := FMT_DIRECTIVE.from_token(token)) or fmt_dir.disable: - return if self.fmt_off: # `# fmt: on[sort]` no effect - if "sort" in fmt_dir.modifiers: - return - token_indent = self._determe_comment_indent(token) - if token_indent == self.fmt_off[0]: - return "region" - return - if self.fmt_sort_off is not None: - if "sort" not in (fmt_dir.modifiers or ["sort"]): - return - token_indent = self._determe_comment_indent(token) - if token_indent == self.fmt_sort_off: - return "sort" + if "sort" not in fmt_label.modifiers: + token_indent = self._determe_comment_indent(token) + if token_indent == self.fmt_off[0]: + self.fmt_off = None + return "region" + elif self.fmt_sort_off is not None: + if "sort" in (fmt_label.modifiers or ["sort"]): + token_indent = self._determe_comment_indent(token) + if token_indent == self.fmt_sort_off: + self.fmt_sort_off = None + return "sort" def _handle_indent(self, token: Token) -> bool: if token.type == tokenize.INDENT: @@ -709,9 +714,18 @@ def get_next_queriable(self) -> Status: return Status( token, block_indent, effective_indent, buffer, False, pythonable ) - # A `# fmt: off[next]` directive at any indent always triggers verbatim - # mode for the next snakemake block, return it so the main loop can act. - if fmt_dir and fmt_dir.disable and "next" in (fmt_dir.modifiers or []): + # `# fmt: off[next]` always needs parser-level handling. + # Plain `# fmt: off` is parser-level only in global context; in other + # Python contexts it is handled by Black. + if ( + fmt_dir + and fmt_dir.disable + and ( + "next" in fmt_dir.modifiers + or "sort" in fmt_dir.modifiers + or (not fmt_dir.modifiers and self.in_global_context) + ) + ): return Status( token, block_indent, effective_indent, buffer, False, pythonable ) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index b2468c3..e6342b8 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2323,6 +2323,29 @@ def test_fmt_off_on_in_other(self): ) assert formatter.get_formatted() == expected + def test_fmt_off_lagging_comments(self): + expected = ( + "if 1:\n" + " lagging_comments\n" + "\n" + " # fmt: off\n" + " rule a:\n" + ' input: "sth"\n' + ' name: "sth"\n' + " # fmt: on\n" + ) + assert setup_formatter(expected).get_formatted() == expected + expected = ( + "if 1:\n" + " # lagging_comments\n" + " # fmt: off\n" + " rule a:\n" + ' input: "sth"\n' + ' name: "sth"\n' + " # fmt: on\n" + ) + assert setup_formatter(expected).get_formatted() == expected + class TestFmtOffSort: def test_fmt_off_sort(self): @@ -2345,6 +2368,38 @@ def test_fmt_off_sort(self): expected2 = expected + "\n\n# fmt: on\n" + formatted assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 + def test_fmt_off_sort_dedent(self): + code1, formatted1 = TestSortFormatting.sorting_comprehensive + code2, formatted2 = TestSortFormatting.sort_with_coments + formatted2 = setup_formatter(code2).get_formatted() + code3, formatted3 = TestSortFormatting.sort_inline_comments + code = ( + code1.rstrip() + "\n" + "\n" + "if 1:\n" + " # fmt: off[sort]\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)).rstrip() + + "\n" + "\n\n" + code3 + ) + expected = ( + formatted1 + "\n\n" + "if 1:\n" + "\n" + f"{TAB}# fmt: off[sort]\n" + + "".join(TAB + i for i in formatted2.splitlines(keepends=True)) + + "\n\n" + + formatted3 + ) + assert setup_formatter(code, sort_params=True).get_formatted() == expected + + def test_fmt_off_sort_nothing(self): + code1, formatted1 = TestSortFormatting.sorting_comprehensive + code3, formatted3 = TestSortFormatting.sort_inline_comments + code = code1.rstrip() + "\n" "\n" "if 1:\n" " pass\n" "\n\n" + code3 + expected = formatted1 + "\n\n" "if 1:\n" f"{TAB}pass\n" "\n\n" + formatted3 + assert setup_formatter(code, sort_params=True).get_formatted() == expected + def test_fmt_off_sort_between_directive(self): code = ( "rule all:\n" From 5814007c58ca893822a16d7fbca30aad70b7498d Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 15:18:15 +0800 Subject: [PATCH 13/53] typo --- snakefmt/parser/parser.py | 16 ++++++++-------- tests/test_formatter.py | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index f2e2929..da83250 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -422,11 +422,11 @@ def _init_min_indent(token: Token): elif fmt_label := FMT_DIRECTIVE.from_token(token): if fmt_label.disable: if fmt_label.modifiers: - # `# fmt: off[` is not actual format diabler, it affects limited + # `# fmt: off[` is not actual format disabler, it affects limited if not self.fmt_off or ( # two following [next] self.fmt_off[1] != "region" - and self._determe_comment_indent(token) == self.fmt_off[0] + and self._determine_comment_indent(token) == self.fmt_off[0] ): self.snakefile.denext(token) break @@ -554,7 +554,7 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: self.syntax.add_processed_keyword(status.token, status.token.string) cur_indent = param_context.cur_indent if param_context.token.type == tokenize.COMMENT and not param_context.eof: - cur_indent = self._determe_comment_indent(param_context.token) + cur_indent = self._determine_comment_indent(param_context.token) return Status( param_context.token, cur_indent, @@ -589,7 +589,7 @@ def context_exit(self, status: Status) -> None: while len(self.indents) - 1 > status.cur_indent: self.indents.pop() - def _determe_comment_indent(self, token: Token) -> int: + def _determine_comment_indent(self, token: Token) -> int: """ Treat each line of single-line comment separately, it is determined by the following real code line and previous self.indents. @@ -638,13 +638,13 @@ def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): if self.fmt_off: # `# fmt: on[sort]` no effect if "sort" not in fmt_label.modifiers: - token_indent = self._determe_comment_indent(token) + token_indent = self._determine_comment_indent(token) if token_indent == self.fmt_off[0]: self.fmt_off = None return "region" elif self.fmt_sort_off is not None: if "sort" in (fmt_label.modifiers or ["sort"]): - token_indent = self._determe_comment_indent(token) + token_indent = self._determine_comment_indent(token) if token_indent == self.fmt_sort_off: self.fmt_sort_off = None return "sort" @@ -706,9 +706,9 @@ def get_next_queriable(self) -> Status: # Comments arrive in the token stream *before* any following # INDENT/DEDENT tokens, so self.cur_indent still reflects the # previous (potentially higher) level. Delegate to - # _determe_comment_indent which peeks ahead and applies the + # _determine_comment_indent which peeks ahead and applies the # two snapping rules. - effective_indent = self._determe_comment_indent(token) + effective_indent = self._determine_comment_indent(token) self.syntax.cur_indent = effective_indent if effective_indent < max(self.keyword_indent, self.block_indent): return Status( diff --git a/tests/test_formatter.py b/tests/test_formatter.py index e6342b8..1e45477 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1825,7 +1825,7 @@ def test_sorting_comprehensive(self): formatter = setup_formatter(self.sorting_comprehensive[0], sort_params=True) assert formatter.get_formatted() == self.sorting_comprehensive[1] - sort_with_coments = ( + sort_with_comments = ( "rule complex:\n" f"{TAB}# Action comment\n" f"{TAB}shell: 'do something'\n" @@ -1847,8 +1847,8 @@ def test_sorting_comprehensive(self): def test_sorting_with_comments_preservation(self): """Comments stay with their keywords""" - formatter = setup_formatter(self.sort_with_coments[0], sort_params=True) - assert formatter.get_formatted() == self.sort_with_coments[1] + formatter = setup_formatter(self.sort_with_comments[0], sort_params=True) + assert formatter.get_formatted() == self.sort_with_comments[1] sort_inline_comments = ( "rule inline_comments:\n" @@ -2351,7 +2351,7 @@ class TestFmtOffSort: def test_fmt_off_sort(self): for code, formatted in ( TestSortFormatting.sorting_comprehensive, - TestSortFormatting.sort_with_coments, + TestSortFormatting.sort_with_comments, TestSortFormatting.sort_inline_comments, TestSortFormatting.sort_module, ): @@ -2370,7 +2370,7 @@ def test_fmt_off_sort(self): def test_fmt_off_sort_dedent(self): code1, formatted1 = TestSortFormatting.sorting_comprehensive - code2, formatted2 = TestSortFormatting.sort_with_coments + code2, formatted2 = TestSortFormatting.sort_with_comments formatted2 = setup_formatter(code2).get_formatted() code3, formatted3 = TestSortFormatting.sort_inline_comments code = ( @@ -2649,7 +2649,7 @@ def test_fmt_off_next_in_2if(self): assert formatted.startswith(expected1) and formatted.endswith(expected2) def test_fmt_off_2(self): - fomatter = setup_formatter( + formatter = setup_formatter( "if 1:\n" " rule a:\n" ' input: "foo"\n' @@ -2663,7 +2663,7 @@ def test_fmt_off_2(self): "rule d:\n" ' input: "qux"\n' ) - assert fomatter.get_formatted() == ( + assert formatter.get_formatted() == ( f"if 1:\n" f"\n" f"{TAB}rule a:\n" From b02aaabedb589bb6820af9a81da513903ac470a8 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 15:35:07 +0800 Subject: [PATCH 14/53] style: flake8 --- snakefmt/parser/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index da83250..2e0f4b2 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -422,7 +422,7 @@ def _init_min_indent(token: Token): elif fmt_label := FMT_DIRECTIVE.from_token(token): if fmt_label.disable: if fmt_label.modifiers: - # `# fmt: off[` is not actual format disabler, it affects limited + # `# fmt: off[` isn't actual format disabler, affects limited if not self.fmt_off or ( # two following [next] self.fmt_off[1] != "region" From a418503fa35a08855d8eddc368ef08c83b2514f2 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 18:46:23 +0800 Subject: [PATCH 15/53] fix: docs --- snakefmt/formatter.py | 12 ++- snakefmt/parser/parser.py | 201 ++++++++++++++++++++++---------------- 2 files changed, 123 insertions(+), 90 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index e87da52..fa807bc 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -65,7 +65,9 @@ def __init__( self.result: str = "" self.lagging_comments: str = "" self.no_formatting_yet: bool = True - self.fmt_sort_off = None if sort_directives else -1 + # sorting can be initially disabled, + # but will be enabled in contexts with `# fmt: on[sort]` + self.sort_off_indent = None if sort_directives else -1 self.previous_result: str = "" self.keyword_spec: list[str] = [] self.keywords: dict[str, str] = {} # cache to sort @@ -106,9 +108,9 @@ def flush_buffer( else: # Invalid python syntax, eg lone 'else:' between two rules, can occur. # Below constructs valid code statements and formats them. - if self.fmt_off_expected_index: - self.buffer += self.fmt_off_expected_index - self.fmt_off_expected_index = "" + if self.fmt_off_expected_indent: + self.buffer += self.fmt_off_expected_indent + self.fmt_off_expected_indent = "" re_match = contextual_matcher.match(self.buffer) if re_match is not None: callback_keyword = re_match.group(2) @@ -182,7 +184,7 @@ def process_keyword_param( context=param_context, ) param_formatted = self.format_params(param_context) - if self.fmt_sort_off is None and not in_global_context and self.keyword_spec: + if self.sort_off_indent is None and not in_global_context and self.keyword_spec: self.keywords[param_context.keyword_name] = self.result + param_formatted self.result = "" else: diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 2e0f4b2..b62de49 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -82,6 +82,24 @@ def check_indent(line: str, indents: list[str]) -> int: raise SyntaxError("Unexpected indent") +def token_indents_updated(token: Token, indents: list[str]) -> bool: + if token.type == tokenize.INDENT: + line = token.line + indent = line[: len(line) - len(line.lstrip())] + if indent not in indents: + indents.append(indent) + elif token.type == tokenize.DEDENT: + line = token.line + indent = line[: len(line) - len(line.lstrip())] + while indents and indents[-1] != indent: + indents.pop() + if not indents: + raise SyntaxError("Unexpected dedent") + else: + return False + return True + + class Snakefile(TokenIterator): """ Adapted from snakemake.parser.Snakefile @@ -153,11 +171,13 @@ def __init__(self, snakefile: Snakefile): self.queriable = True self.in_fstring = False self.last_token: Optional[Token] = None - self.fmt_sort_off: Optional[int] + # None: sorting enabled (no active off[sort]). + # >=0 : disabled at that indent level and below due to active off[sort] + self.sort_off_indent: Optional[int] # for `# fmt: off`, (indent, kind) # kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" self.fmt_off: Optional[tuple[int, Literal["next", "region"]]] = None - self.fmt_off_expected_index: str = "" + self.fmt_off_expected_indent: str = "" self.fmt_off_preceded_by_blank_line: bool = False self.indents: list[str] = [""] @@ -180,24 +200,25 @@ def __init__(self, snakefile: Snakefile): if fmt_label.disable: if not fmt_label.modifiers: self.fmt_off = (status.cur_indent, "region") - self.fmt_off_expected_index = status.token.line[ + self.fmt_off_expected_indent = status.token.line[ : col_nb(status.token) ] elif "next" in fmt_label.modifiers: self.fmt_off = (status.cur_indent, "next") - self.fmt_off_expected_index = status.token.line[ + self.fmt_off_expected_indent = status.token.line[ : col_nb(status.token) ] elif "sort" in fmt_label.modifiers: - self.fmt_sort_off = status.cur_indent + self.sort_off_indent = status.cur_indent elif self._check_fmt_on(fmt_label, status.token) == "sort": continue elif self.fmt_off and status.cur_indent <= self.fmt_off[0]: self.fmt_off = None elif ( - self.fmt_sort_off is not None and status.cur_indent < self.fmt_sort_off + self.sort_off_indent is not None + and status.cur_indent < self.sort_off_indent ): - self.fmt_sort_off = None + self.sort_off_indent = None if self.vocab.recognises(keyword): new_vocab, new_syntax_cls = self.vocab.get(keyword) @@ -368,7 +389,7 @@ def _init_min_indent(token: Token): self.snakefile.denext(token) break elif token.type == tokenize.INDENT: - self._handle_indent(token) + token_indents_updated(token, self.indents) self.syntax.cur_indent = len(self.indents) - 1 last_indent_token = token if is_next_mode and len(self.indents) - 1 > min_indent: @@ -376,7 +397,7 @@ def _init_min_indent(token: Token): continue elif token.type == tokenize.DEDENT: saved_indents = list(self.indents) - self._handle_indent(token) + token_indents_updated(token, self.indents) new_indent = len(self.indents) - 1 last_indent_token = None if new_indent < min_indent or ( @@ -400,48 +421,18 @@ def _init_min_indent(token: Token): ): if is_next_mode: if seen_next_block_keyword: - # fmt: off[next] consumed one whole keyword block; - # hand the next same-level block back to main loop. - self.snakefile.denext(token) - if last_indent_token is not None: - self.snakefile.denext(last_indent_token) - self.indents.pop() - self.syntax.cur_indent = len(self.indents) - 1 + # fmt: off[next] consumed one whole keyword block. + self._detent_last_indent(token, last_indent_token) break else: seen_next_block_keyword = True if vocab_recognises: # snakemake keyword: stop, let main loop handle it - self.snakefile.denext(token) - if last_indent_token is not None: - self.snakefile.denext(last_indent_token) - self.indents.pop() - self.syntax.cur_indent = len(self.indents) - 1 + self._detent_last_indent(token, last_indent_token) break # `# fmt: off[next]` within Python code: stop and let main loop handle it. - elif fmt_label := FMT_DIRECTIVE.from_token(token): - if fmt_label.disable: - if fmt_label.modifiers: - # `# fmt: off[` isn't actual format disabler, affects limited - if not self.fmt_off or ( - # two following [next] - self.fmt_off[1] != "region" - and self._determine_comment_indent(token) == self.fmt_off[0] - ): - self.snakefile.denext(token) - break - elif self.in_global_context: - # In global Python context, plain `# fmt: off` starts a parser - # verbatim region. In non-global Python contexts (e.g. run:), it - # stays inside Python and is handled by Black. - last_line = lines[max(lines)] if lines else "" - self.fmt_off_preceded_by_blank_line = not last_line.strip() - self.snakefile.denext(token) - break - elif fmt_on := self._check_fmt_on(fmt_label, token): - if fmt_on == "region": - lines.update(split_token_lines(token)) - break + elif self._comsume_fmt_off_in_python(token, lines): + break self.queriable = False lines.update(split_token_lines(token)) @@ -464,6 +455,54 @@ def _init_min_indent(token: Token): pythonable=next_status.pythonable or bool(verbatim.strip()) ) + def _detent_last_indent(self, token: Token, last_indent_token: Optional[Token]): + """ + A whole keyword block consumed, + hand the next same-level block back to main loop. + """ + self.snakefile.denext(token) + if last_indent_token is not None: + self.snakefile.denext(last_indent_token) + self.indents.pop() + self.syntax.cur_indent = len(self.indents) - 1 + + def _comsume_fmt_off_in_python(self, token: Token, lines: dict[int, str]): + """ + Consume `# fmt: off/on` directives within Python code. + lines is needed to: + 1. determine the effective indent of the comment token + (when fmt: off in global context, or fmt: off[next] in any context) + 2. record the lines of a fmt: off region (when fmt: on[region]) + Returns True if a fmt directive was consumed, + which should be handled by the main loop (and break there) + """ + fmt_label = FMT_DIRECTIVE.from_token(token) + if not fmt_label: + return False + if fmt_label.disable: + if fmt_label.modifiers: + # `# fmt: off[` isn't actual format disabler, affects limited + if not self.fmt_off or ( + # two following [next] + self.fmt_off[1] != "region" + and self._determine_comment_indent(token) == self.fmt_off[0] + ): + self.snakefile.denext(token) + return True + elif self.in_global_context: + # In global Python context, plain `# fmt: off` starts a parser + # verbatim region. In non-global Python contexts (e.g. run:), it + # stays inside Python and is handled by Black. + last_line = lines[max(lines)] if lines else "" + self.fmt_off_preceded_by_blank_line = not last_line.strip() + self.snakefile.denext(token) + return True + elif fmt_on := self._check_fmt_on(fmt_label, token): + if fmt_on == "region": + lines.update(split_token_lines(token)) + return True + return False + @abstractmethod def handle_fmt_off_region(self, verbatim: str) -> None: """handle unformatted text (just update indent).""" @@ -591,17 +630,23 @@ def context_exit(self, status: Status) -> None: def _determine_comment_indent(self, token: Token) -> int: """ - Treat each line of single-line comment separately, - it is determined by the following real code line and previous self.indents. - - follow_indent = indent of the following real code line - if EOF: - follow_indent = 0 - rule 1 (always): - indent of comments >= follow_indent - rule 2 (if follow_indent < self.indents[-1]): - indent of comments = max(i for i in self.indents - if i <= comment_indent) + epsilon. + This function returns the real indent level of a comment token and + update self.indents if needed, + which is determined by the following real code line and previous indents. + + Durning parsing self.snakefile, when a comment token is encountered, + its effective indent level is not directly knowable. + + principles: + follow_indent = indent of the following real code line + if EOF: + follow_indent = 0 + rule 1 (always): + indent of comments >= follow_indent + rule 2 (if follow_indent < self.indents[-1]): + indent of comments = epsilon + max( + i for i in self.indents if i <= comment_indent + ) next(self.snakefile) until follow_indent is determined, then put all peeked tokens back. @@ -612,20 +657,23 @@ def _determine_comment_indent(self, token: Token) -> int: follow_indent = len(self.indents) - 1 try: while True: - t = next(self.snakefile) - peeked.append(t) - if self._handle_indent(t): + token = next(self.snakefile) + peeked.append(token) + if token_indents_updated(token, self.indents): pass - elif t.type not in {tokenize.NEWLINE, tokenize.NL, tokenize.COMMENT}: - follow_indent = check_indent(t.line, self.indents) + elif token.type not in { + tokenize.NEWLINE, + tokenize.NL, + tokenize.COMMENT, + }: + follow_indent = check_indent(token.line, self.indents) break except StopIteration: follow_indent = 0 # restore indent stack and token stream unchanged self.indents = saved_indents - for t in reversed(peeked): - self.snakefile.denext(t) - + for token in reversed(peeked): + self.snakefile.denext(token) # Rule 1 (always): comment must not be indented below following code. if len(self.indents) - 1 <= follow_indent: return follow_indent @@ -634,7 +682,7 @@ def _determine_comment_indent(self, token: Token) -> int: return max(check_indent(token.line, self.indents), follow_indent) def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): - """Return True if token ends the current fmt:off region.""" + """Determine which fmt: on can turn on formatting""" if self.fmt_off: # `# fmt: on[sort]` no effect if "sort" not in fmt_label.modifiers: @@ -642,30 +690,13 @@ def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): if token_indent == self.fmt_off[0]: self.fmt_off = None return "region" - elif self.fmt_sort_off is not None: + elif self.sort_off_indent is not None: if "sort" in (fmt_label.modifiers or ["sort"]): token_indent = self._determine_comment_indent(token) - if token_indent == self.fmt_sort_off: - self.fmt_sort_off = None + if token_indent == self.sort_off_indent: + self.sort_off_indent = None return "sort" - def _handle_indent(self, token: Token) -> bool: - if token.type == tokenize.INDENT: - line = token.line - indent = line[: len(line) - len(line.lstrip())] - if indent not in self.indents: - self.indents.append(indent) - elif token.type == tokenize.DEDENT: - line = token.line - indent = line[: len(line) - len(line.lstrip())] - while self.indents and self.indents[-1] != indent: - self.indents.pop() - if not self.indents: - raise SyntaxError("Unexpected dedent") - else: - return False - return True - def get_next_queriable(self) -> Status: """Produces the next word that could be a snakemake keyword, and additional information in a :Status: @@ -684,7 +715,7 @@ def get_next_queriable(self) -> Status: self.in_fstring = fstring_processing(token, prev_token, self.in_fstring) if block_indent == -1 and not_a_comment_related_token(token): block_indent = self.cur_indent - if self._handle_indent(token): + if token_indents_updated(token, self.indents): prev_token = None newline = True self.syntax.cur_indent = len(self.indents) - 1 From 2ed3f9dc93a80a9deab92eef476d2c795df896af Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 27 Mar 2026 21:29:56 +0800 Subject: [PATCH 16/53] fix: typo --- snakefmt/formatter.py | 5 +---- snakefmt/parser/parser.py | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index fa807bc..f239870 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -65,9 +65,6 @@ def __init__( self.result: str = "" self.lagging_comments: str = "" self.no_formatting_yet: bool = True - # sorting can be initially disabled, - # but will be enabled in contexts with `# fmt: on[sort]` - self.sort_off_indent = None if sort_directives else -1 self.previous_result: str = "" self.keyword_spec: list[str] = [] self.keywords: dict[str, str] = {} # cache to sort @@ -77,7 +74,7 @@ def __init__( if line_length is not None: self.black_mode.line_length = line_length - super().__init__(snakefile) # Call to parse snakefile + super().__init__(snakefile, sort_directives=sort_directives) def get_formatted(self) -> str: return self.result diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index b62de49..3363466 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -158,7 +158,7 @@ class Parser(ABC): and the alternation in `:self.last_block_was_snakecode`. """ - def __init__(self, snakefile: Snakefile): + def __init__(self, snakefile: Snakefile, sort_directives=False): self.context = Context( SnakeGlobal(), KeywordSyntax("Global", keyword_indent=0, accepts_py=True) ) @@ -171,14 +171,16 @@ def __init__(self, snakefile: Snakefile): self.queriable = True self.in_fstring = False self.last_token: Optional[Token] = None - # None: sorting enabled (no active off[sort]). - # >=0 : disabled at that indent level and below due to active off[sort] - self.sort_off_indent: Optional[int] # for `# fmt: off`, (indent, kind) # kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" self.fmt_off: Optional[tuple[int, Literal["next", "region"]]] = None self.fmt_off_expected_indent: str = "" self.fmt_off_preceded_by_blank_line: bool = False + # None: sorting enabled (no active off[sort]). + # >=0 : disabled at that indent level and below due to active off[sort] + # sorting can be initially disabled (-1), + # but will be enabled in contexts with `# fmt: on[sort]` + self.sort_off_indent = None if sort_directives else -1 self.indents: list[str] = [""] @@ -431,7 +433,7 @@ def _init_min_indent(token: Token): self._detent_last_indent(token, last_indent_token) break # `# fmt: off[next]` within Python code: stop and let main loop handle it. - elif self._comsume_fmt_off_in_python(token, lines): + elif self._consume_fmt_off_in_python(token, lines): break self.queriable = False @@ -466,7 +468,7 @@ def _detent_last_indent(self, token: Token, last_indent_token: Optional[Token]): self.indents.pop() self.syntax.cur_indent = len(self.indents) - 1 - def _comsume_fmt_off_in_python(self, token: Token, lines: dict[int, str]): + def _consume_fmt_off_in_python(self, token: Token, lines: dict[int, str]): """ Consume `# fmt: off/on` directives within Python code. lines is needed to: @@ -628,7 +630,7 @@ def context_exit(self, status: Status) -> None: while len(self.indents) - 1 > status.cur_indent: self.indents.pop() - def _determine_comment_indent(self, token: Token) -> int: + def _determine_comment_indent(self, comment_token: Token) -> int: """ This function returns the real indent level of a comment token and update self.indents if needed, @@ -679,7 +681,7 @@ def _determine_comment_indent(self, token: Token) -> int: return follow_indent # Rule 2 (dedent is happening, standalone only): snap comment to the # highest indent level fitting within the comment's column. - return max(check_indent(token.line, self.indents), follow_indent) + return max(check_indent(comment_token.line, self.indents), follow_indent) def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): """Determine which fmt: on can turn on formatting""" @@ -691,7 +693,7 @@ def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): self.fmt_off = None return "region" elif self.sort_off_indent is not None: - if "sort" in (fmt_label.modifiers or ["sort"]): + if not fmt_label.modifiers or "sort" in fmt_label.modifiers: token_indent = self._determine_comment_indent(token) if token_indent == self.sort_off_indent: self.sort_off_indent = None @@ -729,7 +731,7 @@ def get_next_queriable(self) -> Status: if ( fmt_dir and col_nb(token) == 0 - and not (fmt_dir.disable and "next" in (fmt_dir.modifiers or [])) + and not (fmt_dir.disable and "next" in fmt_dir.modifiers) ): # col-0 comments report cur_indent=0 to trigger context_exit; # fmt directives at other columns report actual cur_indent. From 41ea8c306fdbcf92b991235fbea8b2872be14926 Mon Sep 17 00:00:00 2001 From: hwrn Date: Sat, 28 Mar 2026 16:58:03 +0800 Subject: [PATCH 17/53] fix: improve handling of fmt: off directives and update related tests --- README.md | 12 +++--- snakefmt/formatter.py | 45 +++++++++++++++++----- snakefmt/parser/parser.py | 16 +++----- tests/test_formatter.py | 78 +++++++++++++++++++++++++-------------- 4 files changed, 98 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 9900502..bafa4d2 100644 --- a/README.md +++ b/README.md @@ -51,8 +51,6 @@ design and specifications of [Black][black]. - [PyPi](#pypi) - [Conda](#conda) - [Containers](#containers) - - [Docker](#docker) - - [Singularity](#singularity) - [Local](#local) - [Example File](#example-file) - [Usage](#usage) @@ -66,8 +64,6 @@ design and specifications of [Black][black]. - [Version Control Integration](#version-control-integration) - [GitHub Actions](#github-actions) - [Plug Us](#plug-us) - - [Markdown](#markdown) - - [ReStructuredText](#restructuredtext) - [Changes](#changes) - [Contributing](#contributing) - [Cite](#cite) @@ -321,6 +317,10 @@ You can disable this feature using the `--no-sort` flag. ### Format Directives `snakefmt` supports inline comment directives to control formatting behaviour for specific regions of code. +Format directives are scope-local. +The design principle is: +- Only the region selected by `# fmt: off`/`# fmt: on` (or the single block selected by `# fmt: off[next]`) is left untouched. +- Code before and after that region follows normal `snakefmt` formatting and spacing behavior, equivalent to replacing the directive with a regular comment line. #### `# fmt: off` / `# fmt: on` @@ -490,13 +490,13 @@ in your project. [![Code style: snakefmt](https://img.shields.io/badge/code%20style-snakefmt-000000.svg)](https://github.com/snakemake/snakefmt) -#### Markdown +### Markdown ```md [![Code style: snakefmt](https://img.shields.io/badge/code%20style-snakefmt-000000.svg)](https://github.com/snakemake/snakefmt) ``` -#### ReStructuredText +### ReStructuredText ```rst .. image:: https://img.shields.io/badge/code%20style-snakefmt-000000.svg diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index f239870..188133a 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -210,16 +210,40 @@ def post_process_keyword(self): if self.no_formatting_yet and self.result.rstrip("\n") and not self.buffer: self.no_formatting_yet = False - def handle_fmt_off_region(self, verbatim: str) -> None: + def flush_fmt_off_region(self, verbatim: str): + """Blank-line rules: + + applied before the verbatim block: + - At global indent (fmt_off[0] == 0) and result not empty: + result should end with exactly 2 blank lines (``\\n\\n\\n``) + (standard separation between top-level constructs). + - When the preceding Python code had a blank line before ``# fmt: off`` + (``fmt_off_preceded_by_blank_line``): + result should end with >= 1 blank line. + - ``# fmt: off[next]`` nested inside a Python block: + another ``\\n`` is prepended to any lagging comment + so the following keyword gets its normal blank-line separator. + + applied after the verbatim block: + - ``# fmt: off[next]``: sets ``no_formatting_yet := False``, + so the next formatted block gets its normal blank-line separator. + - Plain ``# fmt: off`` regions: sets ``no_formatting_yet := True``, + suppressing blank-line insertion in the next ``add_newlines`` call. + """ + if self.no_formatting_yet: self.result = self.result.lstrip("\n") self.result += self.buffer self.buffer = "" - if not verbatim: - return - # When fmt:off[next] is inside a Python block (e.g. `if 1:`), the - # directive ends up as a lagging_comment after flushing that block. - is_nested_next = self.fmt_off and self.fmt_off[1] == "next" + if self.fmt_off: + if self.fmt_off[0] == 0 and not self.no_formatting_yet: + if self.fmt_off and not self.result.endswith("\n\n\n"): + self.result += "\n\n" + # When fmt:off[next] is inside a Python block (e.g. `if 1:`), the + # directive ends up as a lagging_comment after flushing that block. + is_nested_next = self.fmt_off[1] == "next" + else: + is_nested_next = False if self.lagging_comments: # For nested fmt:off[next], add the same \n separator that # process_keyword_context/add_newlines would normally provide @@ -228,15 +252,12 @@ def handle_fmt_off_region(self, verbatim: str) -> None: self.result += "\n" self.result += self.lagging_comments self.lagging_comments = "" + self.no_formatting_yet = not is_nested_next if self.fmt_off_preceded_by_blank_line: if self.result and not self.result.endswith("\n\n"): self.result += "\n" self.fmt_off_preceded_by_blank_line = False self.result += verbatim - # For fmt: off[next], mark that we've emitted content so the following - # block gets its normal blank-line separator. - # For fmt: off regions, treat verbatim as transparent to separator logic. - self.no_formatting_yet = not is_nested_next self.last_recognised_keyword = "" def run_black_format_str( @@ -515,6 +536,10 @@ def add_newlines( if comment_matches > 0: self.lagging_comments = "\n".join(all_lines[comment_break:]) + "\n" if final_flush: + # Preserve one intentional blank line before trailing + # comments at EOF (e.g. indented # fmt-like comments). + if comment_break > 0 and all_lines[comment_break - 1] == "": + self.result += "\n" self.result += self.lagging_comments else: self.result += formatted_string diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 3363466..866661f 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -251,12 +251,9 @@ def __init__(self, snakefile: Snakefile, sort_directives=False): status = self._consume_fmt_off( status.token, min_indent=status.cur_indent ) - self.buffer = "" if self.last_block_was_snakecode and not status.eof: self.block_indent = status.block_indent - self.last_block_was_snakecode = False - if self.keyword_indent: - self.last_block_was_snakecode = True + self.last_block_was_snakecode = self.keyword_indent > 0 self.buffer = status.buffer.lstrip() else: if not self.syntax.accepts_python_code and not comment_start(keyword): @@ -506,14 +503,14 @@ def _consume_fmt_off_in_python(self, token: Token, lines: dict[int, str]): return False @abstractmethod - def handle_fmt_off_region(self, verbatim: str) -> None: - """handle unformatted text (just update indent).""" + def flush_fmt_off_region(self, verbatim: str) -> None: + """Flush unformatted region introduced by a fmt: off directive into result""" def _consume_fmt_off(self, start_token: Token, min_indent: int): verbatim, next_status = self._consume_python( start_token, vocab_recognises=False, added_indent=TAB * min_indent ) - self.handle_fmt_off_region(verbatim) + self.flush_fmt_off_region(verbatim) self.snakefile.denext(next_status.token) self.queriable = True if self.fmt_off and self.fmt_off[1] == "next": @@ -554,7 +551,7 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: new_vocab, new_syntax = self.vocab.get(keyword) if new_vocab is not None and issubclass(new_syntax, KeywordSyntax): in_global_context = self.in_global_context - saved_context = self.context + saved_context: Context = self.context # 'use' keyword can not enter a new context self.context = Context( new_vocab(), @@ -576,8 +573,7 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: self.queriable = True self.block_indent = self.syntax.keyword_indent + 1 status = self.get_next_queriable() - # lstrip forces the formatter deal with newlines - if self.context.syntax.accepts_python_code: # type: ignore + if self.context.syntax.accepts_python_code: self.buffer += status.buffer.lstrip("\n\r") else: self.buffer += status.buffer.lstrip() diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 1e45477..f92b61f 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2115,8 +2115,11 @@ def test_fmt_off_at_middle(self): TestSimpleParamFormatting.example_params_newline, TestSimpleParamFormatting.example_input_threads_newline, ): + code1 = code + "\n\n\n# fmtoff\n" + code + expected = formatted.strip() + "\n\n\n# fmtoff\n" + formatted + assert setup_formatter(code1).get_formatted() == expected code1 = code + "\n\n\n# fmt: off\n" + code - expected = formatted.strip() + "\n# fmt: off\n" + code + expected = formatted.strip() + "\n\n\n# fmt: off\n" + code assert setup_formatter(code1).get_formatted() == expected def test_fmt_off_on(self): @@ -2134,9 +2137,6 @@ def test_fmt_off_on(self): code1 = code + "\n\n# fmt: on\n" + code expected = formatted + "\n\n# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected - code1 = code + "\n\n# fmt: on\n" + code - expected = formatted + "\n\n# fmt: on\n" + formatted - assert setup_formatter(code1).get_formatted() == expected code1 = "\n# fmt: off\n" + code + "\n# fmt: on\n" + code expected = "# fmt: off\n" + code + "\n# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected @@ -2147,19 +2147,13 @@ def test_fmt_off_not_on(self): TestSimpleParamFormatting.example_params_newline, TestSimpleParamFormatting.example_input_threads_newline, ): - code1 = ( - "\n# fmt: off\n" - + code - + "\nif 1:\n a=1\n # fmt: on\n b=2\n" - + code - ) expected = ( "# fmt: off\n" + code + "\nif 1:\n a=1\n # fmt: on\n b=2\n" + code ) - assert setup_formatter(code1).get_formatted() == expected + assert setup_formatter(expected).get_formatted() == expected def test_fmt_off_on_in_run(self): """# fmt: off inside Python code is handled by Black.""" @@ -2324,6 +2318,8 @@ def test_fmt_off_on_in_other(self): assert formatter.get_formatted() == expected def test_fmt_off_lagging_comments(self): + expected = "if 1:\n" " lagging_comments\n" "\n" " # fmtany\n" + assert setup_formatter(expected).get_formatted() == expected expected = ( "if 1:\n" " lagging_comments\n" @@ -2455,9 +2451,15 @@ def test_fmt_off_next(self): code1 = "\n\n# fmt: off[next]\n" + code + "\n" + code expected = "# fmt: off[next]\n" + code.strip("\n") + "\n\n\n" + formatted assert setup_formatter(code1).get_formatted() == expected - code1 = code + "\n# fmt: off[next]\n" + code + "\n\n\n" + code + code1 = code.rstrip() + "\n\n# fmtnext\n" + "\n\n\n" + code expected = ( - formatted + formatted.rstrip() + "\n\n\n" + "# fmtnext\n" + "\n\n" + formatted + ) + assert setup_formatter(code1).get_formatted() == expected + code1 = code.rstrip() + "\n\n# fmt: off[next]\n" + code + "\n\n\n" + code + expected = ( + formatted.rstrip() + + "\n\n\n" + "# fmt: off[next]\n" + code.strip("\n") + "\n\n\n" @@ -2465,10 +2467,10 @@ def test_fmt_off_next(self): ) assert setup_formatter(code1).get_formatted() == expected code1 = code + "\n# fmt: off[next]\n" + code - expected = formatted + "# fmt: off[next]\n" + code + expected = formatted + "\n\n# fmt: off[next]\n" + code assert setup_formatter(code1).get_formatted() == expected code1 = code + "\n# fmt: off[next]\n" + code + "\n\n" - expected = formatted + "# fmt: off[next]\n" + code.rstrip("\n") + "\n" + expected = formatted + "\n\n# fmt: off[next]\n" + code.rstrip("\n") + "\n" assert setup_formatter(code1).get_formatted() == expected def test_rule_if_rule(self): @@ -2543,6 +2545,27 @@ def test_rule_if2_rule(self): + format3 ) assert formatter.get_formatted() == expected + formatter = setup_formatter( + code1 + "\n" + "if 1:\n" + " if 2:\n" + " sth\n" + + "\n" + + "".join(" " + i for i in code2.splitlines(keepends=True)).rstrip("\n") + + "\n" + f"{code3}" + ) + expected = ( + format1 + "\n\n" + "if 1:\n" + f"{TAB * 1}if 2:\n" + f"{TAB * 2}sth\n" + + "\n" + + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + + "\n\n" + + format3 + ) + assert formatter.get_formatted() == expected formatter = setup_formatter( code1 + "\n" "if 1:\n" @@ -2554,20 +2577,20 @@ def test_rule_if2_rule(self): + "\n" f"{code3}" ) - expected1 = format1 + "\n\n" "if 1:\n" f"{TAB * 1}if 2:\n" - expected2 = ( + expected = ( + format1 + "\n\n" + "if 1:\n" + f"{TAB * 1}if 2:\n" f"{TAB * 2}# fmt: off[next]\n" + "".join(f"{TAB * 2}" + i for i in code2.splitlines(keepends=True)).rstrip( "\n" ) - + "\n" - + "\n" + + "\n\n" + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + "\n\n" + format3 ) - formatted = formatter.get_formatted() - assert formatted.startswith(expected1) and formatted.endswith(expected2) + assert formatter.get_formatted() == expected def test_fmt_off_next_in_if(self): code1, format1 = TestSimpleParamFormatting.example_shell_newline @@ -2598,9 +2621,9 @@ def test_fmt_off_next_in_if(self): + code3 ) expected = ( - format1.rstrip("\n") + "\n# fmt: off[next]\n" + format1 + "\n\n# fmt: off[next]\n" "if 1:\n" - + "".join(" " + i for i in code2.splitlines(keepends=True)) + + "".join(" " + i for i in code2.splitlines(keepends=True)).rstrip("\n") + "\n\n\n" + format3 ) @@ -2636,8 +2659,10 @@ def test_fmt_off_next_in_2if(self): + "\n" + "".join(" " + i for i in code3.splitlines(keepends=True)) ) - expected1 = format1.rstrip("\n") + "\n" "\n\n" "if 1:\n" - expected2 = ( + expected = ( + format1.rstrip("\n") + "\n" + "\n\n" + "if 1:\n" f"{TAB * 1}# fmt: off[next]\n" + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)).strip( "\n" @@ -2645,8 +2670,7 @@ def test_fmt_off_next_in_2if(self): + "\n\n" + "".join(f"{TAB * 1}" + i for i in format3.splitlines(keepends=True)) ) - formatted = formatter.get_formatted() - assert formatted.startswith(expected1) and formatted.endswith(expected2) + assert formatter.get_formatted() == expected def test_fmt_off_2(self): formatter = setup_formatter( From bfa6158f10ec8cf468528bec2eeba8f2efda4fb6 Mon Sep 17 00:00:00 2001 From: hwrn Date: Tue, 31 Mar 2026 00:12:13 +0800 Subject: [PATCH 18/53] fix: address review --- snakefmt/formatter.py | 25 ++++ snakefmt/parser/parser.py | 43 ++++++- tests/test_formatter.py | 234 +++++++++++++++++++++++++++++--------- 3 files changed, 242 insertions(+), 60 deletions(-) diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index 188133a..078c47e 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -260,6 +260,31 @@ def flush_fmt_off_region(self, verbatim: str): self.result += verbatim self.last_recognised_keyword = "" + def flush_sort_signal(self, verbatim): + """ + If "fmt: on sort" directive is in the keyword syntax, e.g.: + + rule: + directive1: ... + # fmt: off[sort] + directive2: ... + # fmt: on[sort] <- + # other comments + directive3: ... + + the "other comments" should be kept with directive3. + This function is called when "fmt: on[sort]" reached, + and it flushes the pending comments into self.result. + """ + if self.keywords: + pending = "" + for keyword in self.keyword_spec: + pending += self.keywords.pop(keyword, "") + self.previous_result += pending + self.previous_result += self.result + verbatim + self.result = "" + self.last_recognised_keyword = "" + def run_black_format_str( self, string: str, diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 866661f..3566865 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -213,6 +213,15 @@ def __init__(self, snakefile: Snakefile, sort_directives=False): elif "sort" in fmt_label.modifiers: self.sort_off_indent = status.cur_indent elif self._check_fmt_on(fmt_label, status.token) == "sort": + if not self.from_python and self.keyword_indent: + # multiline string is impossible here + # and we assume that origin_indent is the same indent + # of this comment + token_indent = status.cur_indent + sort_on = token_indent * TAB + status.token.line.strip() + "\n" + self.flush_sort_signal(sort_on) + status = self.get_next_queriable() + self.buffer = status.buffer continue elif self.fmt_off and status.cur_indent <= self.fmt_off[0]: self.fmt_off = None @@ -344,7 +353,7 @@ def _consume_python( or dedent below min_indent, or EOF. Returns (source_text, next_status) where next_status carries the stopping token. """ - origin_indent = start_token.start[1] + origin_indent = col_nb(start_token) lines: dict[int, str] = {start_token.start[0]: start_token.line} # Lines that are interior to a multiline token (string / f-string body). @@ -496,16 +505,36 @@ def _consume_fmt_off_in_python(self, token: Token, lines: dict[int, str]): self.fmt_off_preceded_by_blank_line = not last_line.strip() self.snakefile.denext(token) return True - elif fmt_on := self._check_fmt_on(fmt_label, token): - if fmt_on == "region": - lines.update(split_token_lines(token)) - return True + else: + sort_off_indent = self.sort_off_indent + if fmt_on := self._check_fmt_on(fmt_label, token): + if fmt_on == "region": + lines.update(split_token_lines(token)) + elif fmt_on == "sort": + if not self.from_python and self.keyword_indent: + # multiline string is impossible here + # and we assume that origin_indent is the same indent + # of this comment + token_indent = sort_off_indent or 0 + lines.update(split_token_lines(token)) + verbatim = self._reindent( + lines, set(), col_nb(token), token_indent * TAB + ) + self.flush_sort_signal(verbatim) + lines.clear() + else: + self.snakefile.denext(token) + return True return False @abstractmethod def flush_fmt_off_region(self, verbatim: str) -> None: """Flush unformatted region introduced by a fmt: off directive into result""" + @abstractmethod + def flush_sort_signal(self, verbatim: str) -> None: + """Commit fmt:on sort signal directly.""" + def _consume_fmt_off(self, start_token: Token, min_indent: int): verbatim, next_status = self._consume_python( start_token, vocab_recognises=False, added_indent=TAB * min_indent @@ -689,8 +718,12 @@ def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): self.fmt_off = None return "region" elif self.sort_off_indent is not None: + # `fmt: on[sort]` will turn on sorting + # `fmt: on` will also turn on sorting if no `fmt: off` set if not fmt_label.modifiers or "sort" in fmt_label.modifiers: token_indent = self._determine_comment_indent(token) + # but if sort is globally off, only `# fmt: on[sort]` + # can turn it on (self.sort_off_indent := -1) if token_indent == self.sort_off_indent: self.sort_off_indent = None return "sort" diff --git a/tests/test_formatter.py b/tests/test_formatter.py index f92b61f..274885d 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1677,24 +1677,17 @@ def test_storage(self): class TestRunBlockFormatting: - def test_comment_indentation_in_run_block(self): + def test_issue_267_comment_indentation_in_run_block(self): """https://github.com/snakemake/snakefmt/issues/267""" - expected = ( + snakecode = ( "rule fmt_bug_repro:\n" f"{TAB * 1}run:\n" f'{TAB * 2}if "something nested":\n' f"{TAB * 3}pass\n" f"{TAB * 2}# Comment gets indented\n" ) - assert setup_formatter(expected).get_formatted() == expected - snakecode = ( - "rule fmt_bug_repro:\n" - " run:\n" - ' if "something nested":\n' - " pass\n" - " # Comment gets indented\n" - ) - assert setup_formatter(snakecode).get_formatted() == expected + formatter = setup_formatter(snakecode) + assert formatter.get_formatted() == snakecode def test_double_block_comment(self): """https://github.com/snakemake/snakefmt/issues/196""" @@ -2115,9 +2108,12 @@ def test_fmt_off_at_middle(self): TestSimpleParamFormatting.example_params_newline, TestSimpleParamFormatting.example_input_threads_newline, ): + # baseline code1 = code + "\n\n\n# fmtoff\n" + code expected = formatted.strip() + "\n\n\n# fmtoff\n" + formatted assert setup_formatter(code1).get_formatted() == expected + + # before `# fmt: off`, new lines are added as usual code1 = code + "\n\n\n# fmt: off\n" + code expected = formatted.strip() + "\n\n\n# fmt: off\n" + code assert setup_formatter(code1).get_formatted() == expected @@ -2128,20 +2124,32 @@ def test_fmt_off_on(self): TestSimpleParamFormatting.example_params_newline, TestSimpleParamFormatting.example_input_threads_newline, ): + # baseline code1 = "\n# fmton\n" + code expected = "# fmton\n" + formatted assert setup_formatter(code1).get_formatted() == expected + + # before `# fmt: on`, empty lines are removed as usual code1 = "\n\n# fmt: on\n" + code expected = "# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected + + # also assert in `test_fmt_off_sort` code1 = code + "\n\n# fmt: on\n" + code expected = formatted + "\n\n# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected + + # fmt on can enable formatting after fmt off code1 = "\n# fmt: off\n" + code + "\n# fmt: on\n" + code expected = "# fmt: off\n" + code + "\n# fmt: on\n" + formatted assert setup_formatter(code1).get_formatted() == expected def test_fmt_off_not_on(self): + """` + - `# fmt: on` at a deeper indentation level than `# fmt: off` has no effect + - `# fmt: off` keeps the rest of the code unformatted until a same-indent + `# fmt: on` found + """ for code, formatted in ( TestSimpleParamFormatting.example_shell_newline, TestSimpleParamFormatting.example_params_newline, @@ -2200,10 +2208,10 @@ def test_fmt_off_on_in_run_complex(self): code, formatted = TestSimpleParamFormatting.example_shell_newline formatter = setup_formatter( f"rule:\n" - f" run:\n" - f" # fmt: off\n" - f" x = [ 1,2,3]\n" - f" # fmt: on\n\n" + f"{TAB * 1}run:\n" + f"{TAB * 2}# fmt: off\n" + f"{TAB * 2}x = [ 1,2,3]\n" + f"{TAB * 2}# fmt: on\n\n" f"sth=1\n" f"{code}" ) @@ -2219,9 +2227,9 @@ def test_fmt_off_on_in_run_complex(self): assert formatter.get_formatted() == expected formatter = setup_formatter( f"rule:\n" - f" run:\n" - f" # fmt: off\n" - f" x = [ 1,2,3]\n\n" + f"{TAB * 1}run:\n" + f"{TAB * 2}# fmt: off\n" + f"{TAB * 2}x = [ 1,2,3]\n\n" f"sth=1\n" f"{code}" ) @@ -2239,9 +2247,9 @@ def test_fmt_off_on_in_rule(self): code, formatted = TestSimpleParamFormatting.example_shell_newline formatter = setup_formatter( f"rule:\n" - f" # fmt: off\n" - f" run:\n" - f" x = [ 1,2,3]\n" + f"{TAB * 1}# fmt: off\n" + f"{TAB * 1}run:\n" + f"{TAB * 2}x = [ 1,2,3]\n" f"sth=1\n" f"{code}" ) @@ -2256,34 +2264,34 @@ def test_fmt_off_on_in_rule(self): assert formatter.get_formatted() == expected formatter = setup_formatter( f"rule:\n" - f" message: 'finishing'\n" - f" # Important input\n" - f" input: 'in.txt'\n" - f" # fmt: off\n" - f" log: 'log.txt'\n" - f" name: 'myrule'\n" - f" # fmt: on\n" - f" output: 'out.txt'\n" - f" run:\n" - f" # fmt: off\n" - f" x = [ 1,2,3]\n\n" + f"{TAB * 1}message: 'finishing'\n" + f"{TAB * 1}# Important input\n" + f"{TAB * 1}input: 'in.txt'\n" + f"{TAB * 1}# fmt: off\n" + f"{TAB * 1}log: 'log.txt'\n" + f"{TAB * 1}name: 'myrule'\n" + f"{TAB * 1}# fmt: on\n" + f"{TAB * 1}output: 'out.txt'\n" + f"{TAB * 1}run:\n" + f"{TAB * 2}# fmt: off\n" + f"{TAB * 2}x = [ 1,2,3]\n\n" f"sth=1\n" f"{code}" ) expected = ( "rule:\n" f"{TAB}message:\n" - f'{TAB}{TAB}"finishing"\n' + f'{TAB * 2}"finishing"\n' f"{TAB}# Important input\n" f"{TAB}input:\n" - f'{TAB}{TAB}"in.txt",\n' + f'{TAB * 2}"in.txt",\n' f"{TAB}# fmt: off\n" f"{TAB}log: 'log.txt'\n" f"{TAB}name: 'myrule'\n" f"{TAB}# fmt: on\n" f"{TAB}output:\n" - f'{TAB}{TAB}"out.txt",\n' - f"{TAB * 1}run:\n" + f'{TAB * 2}"out.txt",\n' + f"{TAB}run:\n" f"{TAB * 2}# fmt: off\n" f"{TAB * 2}x = [ 1,2,3]\n\n\n" f"sth = 1\n\n\n" @@ -2318,30 +2326,45 @@ def test_fmt_off_on_in_other(self): assert formatter.get_formatted() == expected def test_fmt_off_lagging_comments(self): - expected = "if 1:\n" " lagging_comments\n" "\n" " # fmtany\n" + expected = "if 1:\n" f"{TAB * 1}lagging_comments\n" "\n" f"{TAB * 1}# fmtany\n" assert setup_formatter(expected).get_formatted() == expected expected = ( "if 1:\n" - " lagging_comments\n" + f"{TAB * 1}lagging_comments\n" "\n" - " # fmt: off\n" - " rule a:\n" - ' input: "sth"\n' - ' name: "sth"\n' - " # fmt: on\n" + f"{TAB * 1}# fmt: off\n" + f"{TAB * 1}rule a:\n" + f'{TAB * 2}input: "sth"\n' + f'{TAB * 2}name: "sth"\n' + f"{TAB * 1}# fmt: on\n" ) assert setup_formatter(expected).get_formatted() == expected expected = ( "if 1:\n" - " # lagging_comments\n" - " # fmt: off\n" - " rule a:\n" - ' input: "sth"\n' - ' name: "sth"\n' - " # fmt: on\n" + f"{TAB * 1}# lagging_comments\n" + f"{TAB * 1}# fmt: off\n" + f"{TAB * 1}rule a:\n" + f'{TAB * 2}input: "sth"\n' + f'{TAB * 2}name: "sth"\n' + f"{TAB * 1}# fmt: on\n" ) assert setup_formatter(expected).get_formatted() == expected + def test_fmt_skip_in_python(self): + code = ( + "if 1:\n" + f"{TAB}x = [ 1,2,3] # fmt: skip\n" + f"{TAB}sth=1 # comment no skip\n" + f"{TAB}y = [4,5,6]" + ) + expected = ( + "if 1:\n" + f"{TAB}x = [ 1,2,3] # fmt: skip\n" + f"{TAB}sth = 1 # comment no skip\n" + f"{TAB}y = [4, 5, 6]\n" + ) + assert setup_formatter(code).get_formatted() == expected + class TestFmtOffSort: def test_fmt_off_sort(self): @@ -2351,15 +2374,23 @@ def test_fmt_off_sort(self): TestSortFormatting.sort_inline_comments, TestSortFormatting.sort_module, ): + # baseline: `# fmt: on` without a preceding `# fmt: off*` is a no-op + # and act as a normal comment code1 = code + "\n\n# fmt: on\n" + code expected = formatted + "\n\n# fmt: on\n" + formatted assert setup_formatter(code1, sort_params=True).get_formatted() == expected + + # `# fmt: off[sort]` disables sorting for the rest of the rule code1 = "# fmt: off[sort]\n" + code expected = "# fmt: off[sort]\n" + setup_formatter(code).get_formatted() assert setup_formatter(code1, sort_params=True).get_formatted() == expected + + # `# fmt: on[sort]` re-enables sorting after `# fmt: off[sort]` code2 = code1 + "\n\n# fmt: on[sort]\n" + code expected2 = expected + "\n\n# fmt: on[sort]\n" + formatted assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 + + # plain `# fmt: on` also re-enables sorting after `# fmt: off[sort]` code2 = code1 + "\n\n# fmt: on\n" + code expected2 = expected + "\n\n# fmt: on\n" + formatted assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 @@ -2397,6 +2428,12 @@ def test_fmt_off_sort_nothing(self): assert setup_formatter(code, sort_params=True).get_formatted() == expected def test_fmt_off_sort_between_directive(self): + """ + if you turn off sorting around one directive half way through the rule, + you would sort the half above it and the half below it, + the directive(s) that is surrounded by `# fmt: off` remain + at the same index within the rule. + """ code = ( "rule all:\n" f"{TAB}params: p=1\n" @@ -2404,9 +2441,10 @@ def test_fmt_off_sort_between_directive(self): f"{TAB}threads: 4\n" f"{TAB}conda: 'env.yaml'\n" f"{TAB}message: 'finishing'\n" - f"{TAB}log: 'log.txt'\n" f"{TAB}# fmt: off[sort]\n" + f"{TAB}log: 'log.txt'\n" f"{TAB}output: 'out.txt'\n" + f"{TAB}# before fmt\n" f"{TAB}# fmt: on[sort]\n" f"{TAB}# Important input\n" f"{TAB}input: 'in.txt'\n" @@ -2415,17 +2453,96 @@ def test_fmt_off_sort_between_directive(self): ) expected = ( "rule all:\n" - f"{TAB}name:\n" - f'{TAB*2}"myrule"\n' + f"{TAB}conda:\n" + f'{TAB*2}"env.yaml"\n' + f"{TAB}threads: 4\n" + f"{TAB}resources:\n" + f"{TAB*2}mem_mb=100,\n" + f"{TAB}params:\n" + f"{TAB*2}p=1,\n" + f"{TAB}message:\n" + f'{TAB*2}"finishing"\n' f"{TAB}# fmt: off[sort]\n" + f"{TAB}log:\n" + f'{TAB*2}"log.txt",\n' f"{TAB}output:\n" f'{TAB*2}"out.txt",\n' + f"{TAB}# before fmt\n" f"{TAB}# fmt: on[sort]\n" + f"{TAB}name:\n" + f'{TAB*2}"myrule"\n' f"{TAB}# Important input\n" f"{TAB}input:\n" f'{TAB*2}"in.txt",\n' + f"{TAB}shell:\n" + f'{TAB*2}"echo done"\n' + ) + assert setup_formatter(code, sort_params=True).get_formatted() == expected + + def test_fmt_off_sort_between_directive2(self): + """ + In this case, the `# fmt: on` is directly parsed from `Parser.process_keyword` + """ + code = ( + "rule all:\n" + " params: p=1\n" + " resources: mem_mb=100\n" + " threads: 4\n" + " conda: 'env.yaml'\n" + " message: 'finishing'\n" + " # fmt: off[sort]\n" + " log: 'log.txt'\n" + " output: 'out.txt'\n" + " # fmt: on[sort]\n" + " # Important input\n" + " input: 'in.txt'\n" + " name: 'myrule'\n" + " shell: 'echo done'\n" + ) + expected = ( + "rule all:\n" + f"{TAB}conda:\n" + f'{TAB*2}"env.yaml"\n' + f"{TAB}threads: 4\n" + f"{TAB}resources:\n" + f"{TAB*2}mem_mb=100,\n" + f"{TAB}params:\n" + f"{TAB*2}p=1,\n" + f"{TAB}message:\n" + f'{TAB*2}"finishing"\n' + f"{TAB}# fmt: off[sort]\n" f"{TAB}log:\n" f'{TAB*2}"log.txt",\n' + f"{TAB}output:\n" + f'{TAB*2}"out.txt",\n' + f"{TAB}# fmt: on[sort]\n" + f"{TAB}name:\n" + f'{TAB*2}"myrule"\n' + f"{TAB}# Important input\n" + f"{TAB}input:\n" + f'{TAB*2}"in.txt",\n' + f"{TAB}shell:\n" + f'{TAB*2}"echo done"\n' + ) + assert setup_formatter(code, sort_params=True).get_formatted() == expected + + def test_fmt_off_sort_between_directive_empty(self): + code = ( + "rule all:\n" + f"{TAB}params: p=1\n" + f"{TAB}resources: mem_mb=100\n" + f"{TAB}threads: 4\n" + f"{TAB}conda: 'env.yaml'\n" + f"{TAB}message: 'finishing'\n" + f"{TAB}# fmt: off[sort]\n" + f"{TAB}# fmt: on\n" + f"{TAB}# Important input\n" + f"{TAB}input: 'in.txt'\n" + f"{TAB}name: 'myrule'\n" + f"{TAB}shell: 'echo done'\n" + ) + expected = ( + "rule all:\n" f"{TAB}conda:\n" f'{TAB*2}"env.yaml"\n' f"{TAB}threads: 4\n" @@ -2435,6 +2552,13 @@ def test_fmt_off_sort_between_directive(self): f"{TAB*2}p=1,\n" f"{TAB}message:\n" f'{TAB*2}"finishing"\n' + f"{TAB}# fmt: off[sort]\n" + f"{TAB}# fmt: on\n" + f"{TAB}name:\n" + f'{TAB*2}"myrule"\n' + f"{TAB}# Important input\n" + f"{TAB}input:\n" + f'{TAB*2}"in.txt",\n' f"{TAB}shell:\n" f'{TAB*2}"echo done"\n' ) @@ -2691,8 +2815,8 @@ def test_fmt_off_2(self): f"if 1:\n" f"\n" f"{TAB}rule a:\n" - f"{TAB}{TAB}input:\n" - f'{TAB}{TAB}{TAB}"foo",\n' + f"{TAB * 2}input:\n" + f'{TAB * 3}"foo",\n' f"{TAB}# fmt: off[next]\n" f"{TAB}rule b:\n" f'{TAB} input: "bar"\n' @@ -2703,5 +2827,5 @@ def test_fmt_off_2(self): f"\n" f"rule d:\n" f"{TAB}input:\n" - f'{TAB}{TAB}"qux",\n' + f'{TAB * 2}"qux",\n' ) From 04bd28a54db849bdb00702a72262530bd2c567c9 Mon Sep 17 00:00:00 2001 From: hwrn Date: Tue, 31 Mar 2026 11:00:20 +0800 Subject: [PATCH 19/53] fix: docs --- README.md | 61 +++++++++------- snakefmt/formatter.py | 4 +- tests/test_formatter.py | 156 +++++++++++++++++++++++++--------------- 3 files changed, 137 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index bafa4d2..83c3673 100644 --- a/README.md +++ b/README.md @@ -56,14 +56,16 @@ design and specifications of [Black][black]. - [Usage](#usage) - [Basic Usage](#basic-usage) - [Full Usage](#full-usage) -- [Configuration](#configuration) - [Directive Sorting](#directive-sorting) - [Format Directives](#format-directives) + - [Configuration](#configuration) - [Integration](#integration) - [Editor Integration](#editor-integration) - [Version Control Integration](#version-control-integration) - [GitHub Actions](#github-actions) - [Plug Us](#plug-us) + - [Markdown](#markdown) + - [ReStructuredText](#restructuredtext) - [Changes](#changes) - [Contributing](#contributing) - [Cite](#cite) @@ -281,20 +283,6 @@ Options: -v, --verbose Turns on debug-level logger. ``` -## Configuration - -`snakefmt` is able to read project-specific default values for its command line options -from a `pyproject.toml` file. In addition, it will also load any [`black` -configurations][black-config] you have in the same file. - -By default, `snakefmt` will search in the parent directories of the formatted file(s) -for a file called `pyproject.toml` and use any configuration there. -If your configuration file is located somewhere else or called something different, -specify it using `--config`. - -Any options you pass on the command line will take precedence over default values in the -configuration file. - ### Directive Sorting By default, `snakefmt` sorts rule and module directives (like `input`, `output`, `shell`, etc.) into a consistent order. This makes rules easier to read and allows for quicker cross-referencing between inputs, outputs, and the resources used by the execution command. @@ -317,14 +305,12 @@ You can disable this feature using the `--no-sort` flag. ### Format Directives `snakefmt` supports inline comment directives to control formatting behaviour for specific regions of code. -Format directives are scope-local. -The design principle is: -- Only the region selected by `# fmt: off`/`# fmt: on` (or the single block selected by `# fmt: off[next]`) is left untouched. -- Code before and after that region follows normal `snakefmt` formatting and spacing behavior, equivalent to replacing the directive with a regular comment line. +All directives are scope-local: only the region they select is affected, while code before and after follows normal `snakefmt` formatting and spacing rules (equivalent to replacing the directive with a plain comment line). #### `# fmt: off` / `# fmt: on` -Disables all formatting for the region between the two directives. The directives must appear at the same indentation level. A `# fmt: on` at a deeper indent than the matching `# fmt: off` has no effect. +Disables all formatting for the region between the two directives. +Both directives *must* appear at the same indentation level; a `# fmt: on` at a deeper indent than the matching `# fmt: off` has no effect. ```python rule a: @@ -345,11 +331,13 @@ rule c: "d.txt", ``` -Note: inside `run:` blocks and other Python code, `# fmt: off` / `# fmt: on` is passed through to [Black][black] which handles it natively. +> **Note:** inside `run:` blocks and other Python contexts, `# fmt: off` / `# fmt: on` is passed through to [Black][black], which handles it natively. #### `# fmt: off[sort]` -Disables only directive sorting for the region, while still applying all other formatting. Useful when you want to preserve a custom directive order for a specific rule. +Disables directive sorting for the enclosed region while still applying all other formatting. +Directives between `# fmt: off[sort]` and `# fmt: on[sort]` are kept in their original order. +A plain `# fmt: on` also closes a `# fmt: off[sort]` region. ```python # fmt: off[sort] @@ -363,11 +351,10 @@ rule keep_my_order: # fmt: on[sort] ``` -A plain `# fmt: on` (without `[sort]`) also ends a `# fmt: off[sort]` region. - #### `# fmt: off[next]` -Disables formatting for the single next Snakemake keyword block (e.g. `rule`, `checkpoint`, `use rule`). Only that one block is left unformatted; subsequent blocks are formatted normally. +Disables formatting for the single next Snakemake keyword block (e.g. `rule`, `checkpoint`, `use rule`). +Only that block is left unformatted; all subsequent blocks are formatted normally. ```python rule formatted: @@ -388,9 +375,30 @@ rule also_formatted: "a.txt", ``` +#### `# fmt: skip` + +`# fmt: skip` preserves a single line exactly as written, without any formatting (see [Black's documentation][black-skip] for details). + +> **Note:** `# fmt: skip` is not yet supported for lines containing Snakemake directives (e.g. `input:`, `output:`). +> It currently applies only to plain Python lines. + +### Configuration + +`snakefmt` is able to read project-specific default values for its command line options +from a `pyproject.toml` file. In addition, it will also load any [`black` +configurations][black-config] you have in the same file. + +By default, `snakefmt` will search in the parent directories of the formatted file(s) +for a file called `pyproject.toml` and use any configuration there. +If your configuration file is located somewhere else or called something different, +specify it using `--config`. + +Any options you pass on the command line will take precedence over default values in the +configuration file. + #### Example -`pyproject.toml` +[`pyproject.toml`][pyproject] ```toml [tool.snakefmt] @@ -534,6 +542,7 @@ See [CONTRIBUTING.md][contributing]. [snakemake]: https://snakemake.readthedocs.io/ [black]: https://black.readthedocs.io/en/stable/ [black-config]: https://github.com/psf/black#pyprojecttoml +[black-skip]: https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#ignoring-sections [pyproject]: https://github.com/snakemake/snakefmt/blob/master/pyproject.toml [contributing]: CONTRIBUTING.md [changes]: CHANGELOG.md diff --git a/snakefmt/formatter.py b/snakefmt/formatter.py index 078c47e..3ea3d20 100644 --- a/snakefmt/formatter.py +++ b/snakefmt/formatter.py @@ -237,8 +237,8 @@ def flush_fmt_off_region(self, verbatim: str): self.buffer = "" if self.fmt_off: if self.fmt_off[0] == 0 and not self.no_formatting_yet: - if self.fmt_off and not self.result.endswith("\n\n\n"): - self.result += "\n\n" + while not self.result.endswith("\n\n\n"): + self.result += "\n" # When fmt:off[next] is inside a Python block (e.g. `if 1:`), the # directive ends up as a lagging_comment after flushing that block. is_nested_next = self.fmt_off[1] == "next" diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 274885d..37d0bc1 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1793,25 +1793,25 @@ def test_sorting_of_params(self): f"{TAB}shell: 'echo done'\n", "rule all:\n" f"{TAB}name:\n" - f'{TAB*2}"myrule"\n' + f'{TAB * 2}"myrule"\n' f"{TAB}# Important input\n" f"{TAB}input:\n" - f'{TAB*2}"in.txt",\n' + f'{TAB * 2}"in.txt",\n' f"{TAB}output:\n" - f'{TAB*2}"out.txt",\n' + f'{TAB * 2}"out.txt",\n' f"{TAB}log:\n" - f'{TAB*2}"log.txt",\n' + f'{TAB * 2}"log.txt",\n' f"{TAB}conda:\n" - f'{TAB*2}"env.yaml"\n' + f'{TAB * 2}"env.yaml"\n' f"{TAB}threads: 4\n" f"{TAB}resources:\n" - f"{TAB*2}mem_mb=100,\n" + f"{TAB * 2}mem_mb=100,\n" f"{TAB}params:\n" - f"{TAB*2}p=1,\n" + f"{TAB * 2}p=1,\n" f"{TAB}message:\n" - f'{TAB*2}"finishing"\n' + f'{TAB * 2}"finishing"\n' f"{TAB}shell:\n" - f'{TAB*2}"echo done"\n', + f'{TAB * 2}"echo done"\n', ) def test_sorting_comprehensive(self): @@ -1829,13 +1829,13 @@ def test_sorting_comprehensive(self): "rule complex:\n" f"{TAB}# Input comment\n" f"{TAB}input:\n" - f'{TAB*2}"i",\n' + f'{TAB * 2}"i",\n' f"{TAB}# Resource comment\n" f"{TAB}resources:\n" - f"{TAB*2}res=1,\n" + f"{TAB * 2}res=1,\n" f"{TAB}# Action comment\n" f"{TAB}shell:\n" - f'{TAB*2}"do something"\n', + f'{TAB * 2}"do something"\n', ) def test_sorting_with_comments_preservation(self): @@ -1847,15 +1847,15 @@ def test_sorting_with_comments_preservation(self): "rule inline_comments:\n" f"{TAB}shell: 'echo'\n" f"{TAB}params:\n" - f"{TAB*2}p=1, # parameter comment\n" + f"{TAB * 2}p=1, # parameter comment\n" f"{TAB}input: 'i'\n", "rule inline_comments:\n" f"{TAB}input:\n" - f'{TAB*2}"i",\n' + f'{TAB * 2}"i",\n' f"{TAB}params:\n" - f"{TAB*2}p=1, # parameter comment\n" + f"{TAB * 2}p=1, # parameter comment\n" f"{TAB}shell:\n" - f'{TAB*2}"echo"\n', + f'{TAB * 2}"echo"\n', ) def test_sorting_with_inline_parameter_comments(self): @@ -1875,19 +1875,19 @@ def test_sorting_with_inline_parameter_comments(self): "module other:\n" f'{TAB}name: "n"\n' f"{TAB}pathvars:\n" - f'{TAB*2}["pv"],\n' + f'{TAB * 2}["pv"],\n' f"{TAB}snakefile:\n" - f'{TAB*2}"s"\n' + f'{TAB * 2}"s"\n' f"{TAB}config:\n" - f'{TAB*2}"c"\n' + f'{TAB * 2}"c"\n' f"{TAB}skip_validation:\n" - f"{TAB*2}True\n" + f"{TAB * 2}True\n" f"{TAB}prefix:\n" - f'{TAB*2}"p"\n' + f'{TAB * 2}"p"\n' f"{TAB}replace_prefix:\n" - f'{TAB*2}"rp"\n' + f'{TAB * 2}"rp"\n' f"{TAB}meta_wrapper:\n" - f'{TAB*2}"wrapper"\n', + f'{TAB * 2}"wrapper"\n', ) def test_sorting_module(self): @@ -1905,11 +1905,11 @@ def test_sorting_checkpoint(self): expected = ( "checkpoint map_reads:\n" f"{TAB}input:\n" - f'{TAB*2}"in.txt",\n' + f'{TAB * 2}"in.txt",\n' f"{TAB}output:\n" - f'{TAB*2}"out.txt",\n' + f'{TAB * 2}"out.txt",\n' f"{TAB}shell:\n" - f'{TAB*2}"echo"\n' + f'{TAB * 2}"echo"\n' ) assert formatter.get_formatted() == expected @@ -2188,8 +2188,9 @@ def test_fmt_off_on_in_run(self): "z = [4, 5, 6]\n" ) assert setup_formatter(code).get_formatted() == expected + bad_indent = " " snakecode = "rule:\n" " run:\n" + ( - "".join(f" {i}\n" for i in code.splitlines()) + "".join(f"{bad_indent}{i}\n" for i in code.splitlines()) ) snakexpected = "rule:\n" f"{TAB * 1}run:\n" + ( f"{TAB * 2}# ?\n" @@ -2197,8 +2198,8 @@ def test_fmt_off_on_in_run(self): f"{TAB * 2}# fmt: off\n" f"{TAB * 2}y = [ 1, 2]\n" f"{TAB * 2}s = f'''\n" - f"{' '} {{y}} \n" - f"{' '} '''\n" + f"{bad_indent} {{y}} \n" + f"{bad_indent} '''\n" f"{TAB * 2}# fmt: on\n" f"{TAB * 2}z = [4, 5, 6]\n" ) @@ -2351,7 +2352,7 @@ def test_fmt_off_lagging_comments(self): assert setup_formatter(expected).get_formatted() == expected def test_fmt_skip_in_python(self): - code = ( + formatter = setup_formatter( "if 1:\n" f"{TAB}x = [ 1,2,3] # fmt: skip\n" f"{TAB}sth=1 # comment no skip\n" @@ -2363,7 +2364,24 @@ def test_fmt_skip_in_python(self): f"{TAB}sth = 1 # comment no skip\n" f"{TAB}y = [4, 5, 6]\n" ) - assert setup_formatter(code).get_formatted() == expected + assert formatter.get_formatted() == expected + + def test_fmt_skip_in_directive(self): + formatter = setup_formatter( + "rule a:\n" + f" params:\n" + f" x = [ 1,2,3] # fmt: skip\n" + f" input: a= 'sth' # fmt: skip\n" + ) + expected = ( + "rule a:\n" + f"{TAB}params:\n" + f"{TAB * 2}x=[1, 2, 3], # fmt: skip\n" + f"{TAB}input:\n" + f'{TAB * 2}a="sth", # fmt: skip\n' + ) + # TODO: currently `# fmt: skip` in directives is not supported + # assert formatter.get_formatted() == expected class TestFmtOffSort: @@ -2396,6 +2414,32 @@ def test_fmt_off_sort(self): assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 def test_fmt_off_sort_dedent(self): + """`# fmt: on` or `on[sort]` at a deeper indentation level than `off[sort]` + has no effect""" + code1, formatted1 = TestSortFormatting.sorting_comprehensive + formatted1 = setup_formatter(code1).get_formatted() + code2, formatted2 = TestSortFormatting.sort_with_comments + formatted2 = setup_formatter(code2).get_formatted() + code = ( + "# fmt: off[sort]\n" + "if 1:\n" + " # fmt: on\n" + + "".join(" " + i for i in code1.splitlines(keepends=True)).rstrip() + + "\n" + + code2.rstrip() + ) + expected = ( + "# fmt: off[sort]\n" + "if 1:\n" + "\n" + f"{TAB}# fmt: on\n" + + "".join(TAB + i for i in formatted1.splitlines(keepends=True)).rstrip() + + "\n" + "\n\n" + formatted2 + ) + assert setup_formatter(code, sort_params=True).get_formatted() == expected + + def test_fmt_off_sort_on_noeffect(self): code1, formatted1 = TestSortFormatting.sorting_comprehensive code2, formatted2 = TestSortFormatting.sort_with_comments formatted2 = setup_formatter(code2).get_formatted() @@ -2454,28 +2498,28 @@ def test_fmt_off_sort_between_directive(self): expected = ( "rule all:\n" f"{TAB}conda:\n" - f'{TAB*2}"env.yaml"\n' + f'{TAB * 2}"env.yaml"\n' f"{TAB}threads: 4\n" f"{TAB}resources:\n" - f"{TAB*2}mem_mb=100,\n" + f"{TAB * 2}mem_mb=100,\n" f"{TAB}params:\n" - f"{TAB*2}p=1,\n" + f"{TAB * 2}p=1,\n" f"{TAB}message:\n" - f'{TAB*2}"finishing"\n' + f'{TAB * 2}"finishing"\n' f"{TAB}# fmt: off[sort]\n" f"{TAB}log:\n" - f'{TAB*2}"log.txt",\n' + f'{TAB * 2}"log.txt",\n' f"{TAB}output:\n" - f'{TAB*2}"out.txt",\n' + f'{TAB * 2}"out.txt",\n' f"{TAB}# before fmt\n" f"{TAB}# fmt: on[sort]\n" f"{TAB}name:\n" - f'{TAB*2}"myrule"\n' + f'{TAB * 2}"myrule"\n' f"{TAB}# Important input\n" f"{TAB}input:\n" - f'{TAB*2}"in.txt",\n' + f'{TAB * 2}"in.txt",\n' f"{TAB}shell:\n" - f'{TAB*2}"echo done"\n' + f'{TAB * 2}"echo done"\n' ) assert setup_formatter(code, sort_params=True).get_formatted() == expected @@ -2502,27 +2546,27 @@ def test_fmt_off_sort_between_directive2(self): expected = ( "rule all:\n" f"{TAB}conda:\n" - f'{TAB*2}"env.yaml"\n' + f'{TAB * 2}"env.yaml"\n' f"{TAB}threads: 4\n" f"{TAB}resources:\n" - f"{TAB*2}mem_mb=100,\n" + f"{TAB * 2}mem_mb=100,\n" f"{TAB}params:\n" - f"{TAB*2}p=1,\n" + f"{TAB * 2}p=1,\n" f"{TAB}message:\n" - f'{TAB*2}"finishing"\n' + f'{TAB * 2}"finishing"\n' f"{TAB}# fmt: off[sort]\n" f"{TAB}log:\n" - f'{TAB*2}"log.txt",\n' + f'{TAB * 2}"log.txt",\n' f"{TAB}output:\n" - f'{TAB*2}"out.txt",\n' + f'{TAB * 2}"out.txt",\n' f"{TAB}# fmt: on[sort]\n" f"{TAB}name:\n" - f'{TAB*2}"myrule"\n' + f'{TAB * 2}"myrule"\n' f"{TAB}# Important input\n" f"{TAB}input:\n" - f'{TAB*2}"in.txt",\n' + f'{TAB * 2}"in.txt",\n' f"{TAB}shell:\n" - f'{TAB*2}"echo done"\n' + f'{TAB * 2}"echo done"\n' ) assert setup_formatter(code, sort_params=True).get_formatted() == expected @@ -2544,23 +2588,23 @@ def test_fmt_off_sort_between_directive_empty(self): expected = ( "rule all:\n" f"{TAB}conda:\n" - f'{TAB*2}"env.yaml"\n' + f'{TAB * 2}"env.yaml"\n' f"{TAB}threads: 4\n" f"{TAB}resources:\n" - f"{TAB*2}mem_mb=100,\n" + f"{TAB * 2}mem_mb=100,\n" f"{TAB}params:\n" - f"{TAB*2}p=1,\n" + f"{TAB * 2}p=1,\n" f"{TAB}message:\n" - f'{TAB*2}"finishing"\n' + f'{TAB * 2}"finishing"\n' f"{TAB}# fmt: off[sort]\n" f"{TAB}# fmt: on\n" f"{TAB}name:\n" - f'{TAB*2}"myrule"\n' + f'{TAB * 2}"myrule"\n' f"{TAB}# Important input\n" f"{TAB}input:\n" - f'{TAB*2}"in.txt",\n' + f'{TAB * 2}"in.txt",\n' f"{TAB}shell:\n" - f'{TAB*2}"echo done"\n' + f'{TAB * 2}"echo done"\n' ) assert setup_formatter(code, sort_params=True).get_formatted() == expected From bddeae10146ba35a5d9c8d93b5351a804f8e7729 Mon Sep 17 00:00:00 2001 From: hwrn Date: Tue, 31 Mar 2026 11:03:52 +0800 Subject: [PATCH 20/53] make flask8 happy --- tests/test_formatter.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 37d0bc1..99f9685 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2369,9 +2369,9 @@ def test_fmt_skip_in_python(self): def test_fmt_skip_in_directive(self): formatter = setup_formatter( "rule a:\n" - f" params:\n" - f" x = [ 1,2,3] # fmt: skip\n" - f" input: a= 'sth' # fmt: skip\n" + " params:\n" + " x = [ 1,2,3] # fmt: skip\n" + " input: a= 'sth' # fmt: skip\n" ) expected = ( "rule a:\n" @@ -2381,7 +2381,8 @@ def test_fmt_skip_in_directive(self): f'{TAB * 2}a="sth", # fmt: skip\n' ) # TODO: currently `# fmt: skip` in directives is not supported - # assert formatter.get_formatted() == expected + assert formatter.get_formatted() # == expected + assert expected class TestFmtOffSort: From 972860c9a24d9bdbe4e0164c8f63c48ce340cf62 Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 1 Apr 2026 18:16:44 +0800 Subject: [PATCH 21/53] docs: clear --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 83c3673..222ad23 100644 --- a/README.md +++ b/README.md @@ -304,7 +304,8 @@ You can disable this feature using the `--no-sort` flag. ### Format Directives -`snakefmt` supports inline comment directives to control formatting behaviour for specific regions of code. +`snakefmt` supports comment directives to control formatting behaviour for specific regions of code. +Directives should appear as standalone comment lines, an inline occurrence (e.g. `input: # fmt: off`) is treated as a plain comment and has no effect. All directives are scope-local: only the region they select is affected, while code before and after follows normal `snakefmt` formatting and spacing rules (equivalent to replacing the directive with a plain comment line). #### `# fmt: off` / `# fmt: on` @@ -379,8 +380,8 @@ rule also_formatted: `# fmt: skip` preserves a single line exactly as written, without any formatting (see [Black's documentation][black-skip] for details). -> **Note:** `# fmt: skip` is not yet supported for lines containing Snakemake directives (e.g. `input:`, `output:`). -> It currently applies only to plain Python lines. +> **Note:** `# fmt: skip` is not yet supported within Snakemake rule blocks. +> It currently applies only to plain Python lines outside of rules, checkpoints, and similar Snakemake constructs. ### Configuration From 06d090d317120b453ea9962a753952f4067ebc98 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 2 Apr 2026 11:11:20 +0800 Subject: [PATCH 22/53] chore: donot rename --- snakefmt/parser/parser.py | 43 +++++++++++++++++++---------------- snakefmt/parser/syntax.py | 47 ++++++++++++++++++--------------------- snakefmt/types.py | 11 ++++----- 3 files changed, 50 insertions(+), 51 deletions(-) diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 3566865..7445034 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -1,6 +1,7 @@ import re import tokenize from abc import ABC, abstractmethod +from tokenize import TokenInfo from typing import Literal, NamedTuple, Optional from snakefmt.exceptions import UnsupportedSyntax @@ -14,7 +15,7 @@ is_newline, re_add_curly_bracket_if_needed, ) -from snakefmt.types import TAB, Token, TokenIterator, col_nb +from snakefmt.types import TAB, TokenIterator, col_nb _FMT_DIRECTIVE_RE = re.compile( r"^# fmt: (off|on)(?:\[(\w+(?:,\s*\w+)*)\])?(?=$|\s{2}|\s#)" @@ -26,7 +27,7 @@ class FMT_DIRECTIVE(NamedTuple): modifiers: list[str] @classmethod - def from_token(cls, token: Token): + def from_token(cls, token: TokenInfo): if token.type != tokenize.COMMENT: return None return cls.from_str(token.string) @@ -46,7 +47,7 @@ def from_str(cls, token_string: str): return cls(disable, mods) # type: ignore[arg-type] -def split_token_lines(token: tokenize.TokenInfo): +def split_token_lines(token: TokenInfo): """Token can be multiline. e.g., `f'''\\nplaintext\\n'''` has these tokens: @@ -64,7 +65,7 @@ def split_token_lines(token: tokenize.TokenInfo): ) -def not_a_comment_related_token(token: Token): +def not_a_comment_related_token(token: TokenInfo): return token.type not in { tokenize.COMMENT, tokenize.NEWLINE, @@ -82,7 +83,7 @@ def check_indent(line: str, indents: list[str]) -> int: raise SyntaxError("Unexpected indent") -def token_indents_updated(token: Token, indents: list[str]) -> bool: +def token_indents_updated(token: TokenInfo, indents: list[str]) -> bool: if token.type == tokenize.INDENT: line = token.line indent = line[: len(line) - len(line.lstrip())] @@ -116,12 +117,12 @@ def __init__(self, fpath_or_stream, rulecount=0): self.rulecount = rulecount self.lines = 0 - def __next__(self) -> Token: + def __next__(self) -> TokenInfo: if self._buffered_tokens: return self._buffered_tokens.pop() return next(self._live_tokens) - def denext(self, token: Token) -> None: + def denext(self, token: TokenInfo) -> None: self._buffered_tokens.append(token) @@ -132,7 +133,7 @@ def comment_start(string: str) -> bool: class Status(NamedTuple): """Communicates the result of parsing a chunk of code""" - token: Token + token: TokenInfo block_indent: int # indent of the start of the parsed block cur_indent: int # indent of the end of the parsed block buffer: str @@ -170,7 +171,7 @@ def __init__(self, snakefile: Snakefile, sort_directives=False): self.block_indent = 0 self.queriable = True self.in_fstring = False - self.last_token: Optional[Token] = None + self.last_token: Optional[TokenInfo] = None # for `# fmt: off`, (indent, kind) # kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" self.fmt_off: Optional[tuple[int, Literal["next", "region"]]] = None @@ -347,7 +348,7 @@ def post_process_keyword(self) -> None: eg after finishing parsing a 'rule:'""" def _consume_python( - self, start_token: Token, vocab_recognises=True, added_indent: str = "" + self, start_token: TokenInfo, vocab_recognises=True, added_indent: str = "" ) -> tuple[str, Status]: """Collect Python source lines until a snakemake keyword at correct indent, or dedent below min_indent, or EOF. @@ -369,7 +370,7 @@ def _consume_python( consuming_next = False # used with stop_at_min seen_next_block_keyword = False - def _init_min_indent(token: Token): + def _init_min_indent(token: TokenInfo): nonlocal min_indent if not comment_start(token.string): while not token.line.startswith(self.indents[-1]): @@ -381,7 +382,7 @@ def _init_min_indent(token: Token): try: token = next(self.snakefile) except StopIteration: - eof_token = Token(tokenize.ENDMARKER, "", (0, 0), (0, 0), "") + eof_token = TokenInfo(tokenize.ENDMARKER, "", (0, 0), (0, 0), "") self.snakefile.denext(eof_token) break if min_indent == -1: @@ -463,7 +464,9 @@ def _init_min_indent(token: Token): pythonable=next_status.pythonable or bool(verbatim.strip()) ) - def _detent_last_indent(self, token: Token, last_indent_token: Optional[Token]): + def _detent_last_indent( + self, token: TokenInfo, last_indent_token: Optional[TokenInfo] + ): """ A whole keyword block consumed, hand the next same-level block back to main loop. @@ -474,7 +477,7 @@ def _detent_last_indent(self, token: Token, last_indent_token: Optional[Token]): self.indents.pop() self.syntax.cur_indent = len(self.indents) - 1 - def _consume_fmt_off_in_python(self, token: Token, lines: dict[int, str]): + def _consume_fmt_off_in_python(self, token: TokenInfo, lines: dict[int, str]): """ Consume `# fmt: off/on` directives within Python code. lines is needed to: @@ -535,7 +538,7 @@ def flush_fmt_off_region(self, verbatim: str) -> None: def flush_sort_signal(self, verbatim: str) -> None: """Commit fmt:on sort signal directly.""" - def _consume_fmt_off(self, start_token: Token, min_indent: int): + def _consume_fmt_off(self, start_token: TokenInfo, min_indent: int): verbatim, next_status = self._consume_python( start_token, vocab_recognises=False, added_indent=TAB * min_indent ) @@ -655,7 +658,7 @@ def context_exit(self, status: Status) -> None: while len(self.indents) - 1 > status.cur_indent: self.indents.pop() - def _determine_comment_indent(self, comment_token: Token) -> int: + def _determine_comment_indent(self, comment_token: TokenInfo) -> int: """ This function returns the real indent level of a comment token and update self.indents if needed, @@ -679,7 +682,7 @@ def _determine_comment_indent(self, comment_token: Token) -> int: then put all peeked tokens back. """ # ── Step 1: peek ahead to find follow_indent ──────────────────────── - peeked: list[Token] = [] + peeked: list[TokenInfo] = [] saved_indents = list(self.indents) follow_indent = len(self.indents) - 1 try: @@ -708,7 +711,7 @@ def _determine_comment_indent(self, comment_token: Token) -> int: # highest indent level fitting within the comment's column. return max(check_indent(comment_token.line, self.indents), follow_indent) - def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): + def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: TokenInfo): """Determine which fmt: on can turn on formatting""" if self.fmt_off: # `# fmt: on[sort]` no effect @@ -739,7 +742,9 @@ def get_next_queriable(self) -> Status: newline = False pythonable = False block_indent = -1 - prev_token: Optional[Token] = Token(tokenize.NAME, "", (-1, -1), (-1, -1), "") + prev_token: Optional[TokenInfo] = TokenInfo( + tokenize.NAME, "", (-1, -1), (-1, -1), "" + ) while True: token = next(self.snakefile) self.last_token = token diff --git a/snakefmt/parser/syntax.py b/snakefmt/parser/syntax.py index ae90fc5..f42a75e 100644 --- a/snakefmt/parser/syntax.py +++ b/snakefmt/parser/syntax.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from re import match as re_match +from tokenize import TokenInfo from typing import ClassVar, NamedTuple, Optional, Type from snakefmt import fstring_tokeniser_in_use @@ -20,14 +21,7 @@ SyntaxFormError, TooManyParameters, ) -from snakefmt.types import ( - COMMENT_SPACING, - Token, - TokenIterator, - col_nb, - line_nb, - not_empty, -) +from snakefmt.types import COMMENT_SPACING, TokenIterator, col_nb, line_nb, not_empty # ___Token parsing___# BRACKETS_OPEN = {"(", "[", "{"} @@ -110,7 +104,7 @@ def _extract_line_mid( return t -def re_add_curly_bracket_if_needed(token: Token) -> str: +def re_add_curly_bracket_if_needed(token: TokenInfo) -> str: result = "" if ( fstring_tokeniser_in_use @@ -125,7 +119,7 @@ def re_add_curly_bracket_if_needed(token: Token) -> str: def fstring_processing( - token: Token, prev_token: Optional[Token], in_fstring: bool + token: TokenInfo, prev_token: Optional[TokenInfo], in_fstring: bool ) -> bool: """ Returns True if we are entering, or have already entered and not exited, @@ -140,7 +134,7 @@ def fstring_processing( def operator_skip_spacing( - prev_token: Token, token: Token, in_fstring: bool = False + prev_token: TokenInfo, token: TokenInfo, in_fstring: bool = False ) -> bool: # Check for f-string conversion specifiers: ! followed by r, s, or a if ( @@ -170,7 +164,7 @@ def operator_skip_spacing( def add_token_space( - prev_token: Optional[Token], token: Token, in_fstring: bool = False + prev_token: Optional[TokenInfo], token: TokenInfo, in_fstring: bool = False ) -> bool: result = False if prev_token is not None: @@ -183,27 +177,27 @@ def add_token_space( return result -def is_colon(token: Token): +def is_colon(token: TokenInfo): return token.type == tokenize.OP and token.string == ":" -def is_newline(token: Token): +def is_newline(token: TokenInfo): return token.type == tokenize.NEWLINE or token.type == tokenize.NL -def brack_open(token: Token): +def brack_open(token: TokenInfo): return token.type == tokenize.OP and token.string in BRACKETS_OPEN -def brack_close(token: Token): +def brack_close(token: TokenInfo): return token.type == tokenize.OP and token.string in BRACKETS_CLOSE -def is_equal_sign(token: Token): +def is_equal_sign(token: TokenInfo): return token.type == tokenize.OP and token.string == "=" -def is_comma_sign(token: Token): +def is_comma_sign(token: TokenInfo): return token.type == tokenize.OP and token.string == "," @@ -212,7 +206,7 @@ class Parameter: Holds the value of a parameter-accepting keyword """ - def __init__(self, token: Token): + def __init__(self, token: TokenInfo): self.line_nb = line_nb(token) self.col_nb = col_nb(token) self.key = "" @@ -247,7 +241,10 @@ def has_value(self) -> bool: return len(self.value) > 0 def add_elem( - self, prev_token: Optional[Token], token: Token, in_fstring: bool = False + self, + prev_token: Optional[TokenInfo], + token: TokenInfo, + in_fstring: bool = False, ): if add_token_space(prev_token, token, in_fstring) and len(self.value) > 0: self.value += " " @@ -257,7 +254,7 @@ def add_elem( self.value += token.string - def to_key_val_mode(self, token: Token): + def to_key_val_mode(self, token: TokenInfo): if not self.has_value(): raise InvalidParameterSyntax( f"L{token.start[0]}:Operator = used with no preceding key" @@ -309,7 +306,7 @@ def __init__( self.keyword_indent = keyword_indent self.cur_indent = max(self.keyword_indent - 1, 0) self.comment = "" - self.token: Token + self.token: TokenInfo if snakefile is not None: self.validate_keyword_line(snakefile) @@ -412,7 +409,7 @@ def validate_rulelike_syntax(self, snakefile: TokenIterator): ColonError(self.line_nb, self.token.string, self.keyword_line) self.token = next(snakefile) - def add_processed_keyword(self, token: Token, keyword: str): + def add_processed_keyword(self, token: TokenInfo, keyword: str): self.processed_keywords.add(keyword) def check_empty(self): @@ -534,7 +531,7 @@ def check_exit(self, cur_param: Parameter, snakefile: TokenIterator): # untouched — the real processing will update it once tokens # are put back. temp_indent = self.cur_indent - cached_tokens: list[Token] = [] + cached_tokens: list[TokenInfo] = [] try: while True: t = next(snakefile) @@ -563,7 +560,7 @@ def check_exit(self, cur_param: Parameter, snakefile: TokenIterator): return exit def process_token( - self, cur_param: Parameter, prev_token: Optional[Token] + self, cur_param: Parameter, prev_token: Optional[TokenInfo] ) -> Parameter: token_type = self.token.type # f-string treatment (since python 3.12) diff --git a/snakefmt/types.py b/snakefmt/types.py index f39b92a..f0cb34a 100644 --- a/snakefmt/types.py +++ b/snakefmt/types.py @@ -5,19 +5,16 @@ COMMENT_SPACING = " " # PEP8, minimum of two spaces for inline comments -Token = TokenInfo - - -def line_nb(token: Token) -> int: +def line_nb(token: TokenInfo) -> int: return token.start[0] -def col_nb(token: Token) -> int: +def col_nb(token: TokenInfo) -> int: return token.start[1] -def not_empty(token: Token): +def not_empty(token: TokenInfo): return len(token.string) > 0 and not token.string.isspace() -TokenIterator = Iterator[Token] +TokenIterator = Iterator[TokenInfo] From dc919b7920c51d6e3b7b6aa9d0b1477db739d3a8 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 3 Apr 2026 23:59:15 +0800 Subject: [PATCH 23/53] feat: powerful TokenIterator --- snakefmt/blocken.py | 642 ++++++++++++++++++++++++++++++++++++++++++ tests/test_blocken.py | 229 +++++++++++++++ 2 files changed, 871 insertions(+) create mode 100644 snakefmt/blocken.py create mode 100644 tests/test_blocken.py diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py new file mode 100644 index 0000000..929694d --- /dev/null +++ b/snakefmt/blocken.py @@ -0,0 +1,642 @@ +import sys +import tokenize +from abc import ABC, abstractmethod +from typing import Callable, Iterator, NamedTuple, Optional +from tokenize import TokenInfo + + +from snakefmt.exceptions import UnsupportedSyntax + +if sys.version_info < (3, 12): + is_fstring_start = lambda token: False +else: + is_fstring_start = lambda token: token.type == tokenize.FSTRING_START + + def consume_fstring(tokens: TokenIterator): + finished: list[TokenInfo] = [] + isin_fstring = 1 + while True: + token = next(tokens) + finished.append(token) + if token.type == tokenize.FSTRING_START: + isin_fstring += 1 + elif token.type == tokenize.FSTRING_END: + isin_fstring -= 1 + if isin_fstring == 0: + break + return finished + + +def extract_indent(token: TokenInfo) -> str: + line = token.line + return line[: len(line) - len(line.lstrip())] + + +class TokenIterator: + def __init__(self, name, tokens: Iterator[TokenInfo]): + self.name = name + self._live_tokens = tokens + self._buffered_tokens: list[TokenInfo] = list() + self.tokens = tokens + self.lines = 0 + self.rulecount = 0 + self._overwrite_cmd: Optional[str] = None + self._last_token: Optional[TokenInfo] = None + + def __iter__(self): + return self + + def next_new_line(self): + """Returns contents of a entire logical lines (including continued lines), + also include indent tokens before it. + + the tokens yield like: + + [NL/COMMENT_LINE] -> [INDENT] -> (real content tokens) -> NEWLINE -> (repeat) + """ + head_empty_lines: list[TokenInfo] = [] + indents: list[TokenInfo] = [] + contents: list[TokenInfo] = [] + while True: + token = next(self) + if token.type == tokenize.NEWLINE or token.type == tokenize.ENDMARKER: + return head_empty_lines, indents, contents, token + elif not (contents or indents) and ( + token.type == tokenize.NL or token.type == tokenize.COMMENT + ): + head_empty_lines.append(token) + elif token.type == tokenize.INDENT or token.type == tokenize.DEDENT: + assert not contents, "Never expect indent after any content" + indents.append(token) + else: + contents.append(token) + + def next_component(self): + """Returns the next component, should not break string/bracket pairs""" + contents: list[TokenInfo] = [] + expect_brackets: list = [] + paired_brackets = {"(": ")", "[": "]", "{": "}"} + while expect_brackets or not contents: + token = next(self) + contents.append(token) + if token.type == tokenize.OP: + if token.string in paired_brackets: + expect_brackets.append(paired_brackets[token.string]) + elif token.string in ")]}": + if not expect_brackets or expect_brackets[-1] != token.string: + raise UnsupportedSyntax( + f"Unexpected closing bracket {token.string!r} at line {token.start[0]}" + ) + expect_brackets.pop() + elif is_fstring_start(token): + contents.extend(consume_fstring(self)) + return contents + + def next_block(self): + """Returns a entire block, just consume until the end of the block. + Donot care if there are nested blocks inside or snakemake keywords inside. + + it could be INDEDT -> [any content] -> DEDENT, or [any content] -> DEDENT + """ + block_contents: list[TokenInfo] = [] + head_empty_lines, indents, contents, token = self.next_new_line() + assert not indents or ( + [i.type for i in indents] == [tokenize.INDENT] + ), f"Unexpected indents {indents!r}" + assert contents, "Unexpected empty line" + block_contents.extend(head_empty_lines + indents + contents + [token]) + indent_level = 1 + while True: + # read entire line, dedent if needed + head_empty_lines, indents, contents, token = self.next_new_line() + if indents: + if [i.type for i in indents] == [tokenize.INDENT]: + indent_level += 1 + else: + assert {i.type for i in indents} == { + tokenize.DEDENT + }, f"Unexpected indents {indents!r}" + indent_level -= len(indents) + if indent_level <= 0: + # now it is used to represent `DEDENTs to keep` + # e.g. indent_level=1, 2 DEDENTs -> went 1 too deep -> keep 1 + indent_level += len(indents) + self.denext( + token, + *reversed(contents), + *reversed(indents[indent_level:]), + ) + break + elif token.type == tokenize.ENDMARKER and indent_level == 1: + # no indents (guaranteed) and no content (should be) + self.denext(token, *reversed(contents)) + break + block_contents.extend(head_empty_lines + indents + contents + [token]) + # there must be somewhere a DEDENT token to end the block, otherwise raise from __next__ + # now check comments + indent = extract_indent(block_contents[0]) + block_contents.extend(self.dedent_tail_noncoding(head_empty_lines, indent)) + block_contents.extend(indents[:indent_level]) + return block_contents + + def dedent_tail_noncoding(self, tokens: list[TokenInfo], block_indent: str): + """Call at the end of a block, + split comments belong to this block from those belong to parent blocks, + and reorder . + Dedent the tail_noncoding tokens of a block, and return the dedented tokens. + The indent level of the tail_noncoding tokens should be the same as the block_indent. + + Should control tail_noncoding of the block: + - all NL belongs to this block + - if block_indent <= extract_indent(comments): + - this COMMENT belongs to this block + - else: afterwards, all COMMENT belongs to parent (or grand-parents) block + """ + for i, token in enumerate(tokens): + if token.type == tokenize.COMMENT: + if not extract_indent(token).startswith(block_indent): + break + else: + assert token.type == tokenize.NL, f"Unexpected token {token!r}" + self.denext(*reversed(tokens[i:])) + return tokens[:i] + + def __next__(self) -> TokenInfo: + if self._buffered_tokens: + token = self._buffered_tokens.pop() + else: + try: + token = next(self._live_tokens) + except StopIteration as e: + if self._last_token is None: + raise UnsupportedSyntax( + f"Unexpected content of '{self.name}'" + ) from e + else: + raise UnsupportedSyntax( + f"Unexpected end of file after symbol[{self._last_token}] while parsing '{self.name}'" + ) from e + self._last_token = token + return token + + def denext(self, *tokens: TokenInfo) -> None: + """.denext(a, b, c): next(token) will return c, then b, then a. + pull back tokens so they can be pushed in the correct order when .next() + + .denext(token, previous_token, ...) + == .denext(token); .denext(previous_token); ; .denext(...) + => list(zip(self, range(3))) == [(..., 0), (previous_token, 1), (token, 2)] + """ + self._buffered_tokens.extend(tokens) + + +PYTHON_INDENT_KEYWORDS = { + i + for j in ("if elif else", "for while", "try except finally", "with") + for i in j.split() +} + + +def split_token_lines(token: TokenInfo): + """Token can be multiline. + e.g., `f'''\\nplaintext\\n'''` has these tokens: + + TokenInfo(type=61 (FSTRING_START), string="f'''", + start=(21, 0), end=(21, 4), line="f'''\\n") + TokenInfo(type=62 (FSTRING_MIDDLE), string='\\ncccccccc\\n', + start=(21, 4), end=(23, 0), line="f'''\\ncccccccc\\n'''\\n") + TokenInfo(type=63 (FSTRING_END), string="'''", + start=(23, 0), end=(23, 3), line="'''\\n") + + lines should be split to drop overlapping lines and keep unique ones. + """ + return zip( + range(token.start[0], token.end[0] + 1), token.line.splitlines(keepends=True) + ) + + +class Block(ABC): + """ + A block can be: + a continuous python code of lines with the same indentation level. + Also include functions, classes and decoraters (`@` lines) + a single block identifed by keywords in `{PYTHON_INDENT_KEYWORDS}` + and all the code under it, until the next block of the same or lower indent level. + a snakemake keyword block (rule, module, config, etc.) + and all the code under it, until the next block of the same or lower indent level. + snakemake keywords should NEVER in functions or classes + comments between blocks + (exclude the comment right before the indenting keyword, which is considered part of the block) + + Starting of blocks (file or new indent): + the space and comments until the first indenting keyword are considered a block of their own. + All other spaces are considered part of the previous block's trailing empty lines. + + Comment belongness: + Only comments with neither empty lines between/after the next block nor different indent levels + are considered part of the same block. + e.g.: + sth # block 1 + # comment 1 -> block 1 + + # comment 2 -> block 1 + + # comment 3 -> block 2 + def func(): # block 2 + pass # block 2.1 + # comment 4 -> block 2.1 + # comment 5 -> block 2 + + rule example: # block 3 + input: "data.txt" # block 3.1 and 3.1.1 + # comment 6 -> block 3.1 + output: # block 3.2 + "result.txt" # block 3.2.1 + # comment 7 -> block 3.2.1 + # comment 8 -> block 3.3 + + Indent of comments: + determined by the following real code line and previous indents. + + Durning parsing tokens, when a comment token is encountered, + its effective indent level is not directly knowable. + + principles: + follow_indent = indent of the following real code line + if EOF: + follow_indent = 0 + rule 1 (always): + indent of comments >= follow_indent + rule 2 (if follow_indent < self.indents[-1]): + indent of comments = epsilon + max( + i for i in self.indents if i <= comment_indent + ) + """ + + __slots__ = ( + "indent_level", + "head_noncoding", + "head_tokens", + "sub_blocks", + "tail_noncoding", + ) + + def __init__( + self, tokens: TokenIterator, indent_level: int, head_tokens: list[TokenInfo] + ) -> None: + self.sub_blocks: list["Block"] = [] + self.head_noncoding: list[TokenInfo] = [] + self.tail_noncoding: list[TokenInfo] = [] + self.head_tokens = head_tokens + self.indent_level = indent_level + self.consume(tokens) + + def extend_tail_noncoding(self, tokens: list[TokenInfo]): + self.tail_noncoding.extend(tokens) + del self.extend_tail_noncoding # should never be called again + return [] + + def extend_head_noncoding(self, tokens: list[TokenInfo]): + """Test if the tokens are all non-coding, and if so, extend head_noncoding with them and return True. + Otherwise, return False and do not modify head_noncoding. + """ + if {i.type for i in tokens} <= {tokenize.NL, tokenize.COMMENT}: + self.head_noncoding = tokens + return True + return False + + @abstractmethod + def consume(self, tokens: TokenIterator) -> None: ... + + @property + def start_token(self): + if not self.head_tokens: + raise UnsupportedSyntax("Unexpected empty block") + return self.head_tokens[0] + + @property + def raw_indent(self) -> str: + "tell the raw indent of the block" + assert self.start_token is not None, "start_token should be set after consume()" + return self.start_token.line[: self.start_token.start[1]] + + def block_lines(self) -> list[str]: + lines: dict[int, str] = {} + # Lines that are interior to a multiline token (string / f-string body). + # Their content must not be reindented. + string_interior_lines: set[int] = set() + for token in self.head_tokens: + if token.end[0] not in lines: + lines.update(split_token_lines(token)) + if token.start[0] != token.end[0]: + string_interior_lines.update( + range(token.start[0] + 1, token.end[0] + 1) + ) + newlines: list[str] = [] + for i in sorted(lines): + line = lines[i] + if i in string_interior_lines: + assert newlines, "block cannot start inner a multiline-string" + newlines[-1] += line + else: + newlines.append(line) + return newlines + + def raw(self) -> list[str]: + """return the code splited by lines, but should keep multiline-string or multiline-f-string complete, + to make trimming and reformatting easier. + + Should and Only should be rewrite for pure python blocks. + """ + + lines = ( + [comment.line for comment in self.head_noncoding] + + self.block_lines() + + [line for block in self.sub_blocks for line in block.raw()] + + [token.line for token in self.tail_noncoding] + ) + return lines + + def components(self) -> "Iterator[DocumentSymbol]": + """ + - position := (file, line number, column number) + - type := name / rule, input, output / function, class / etc. + if not a name, then that's the definition of the name (should link blank names to here) + - identifier := the identifier of the block, e.g. rule `a`, `input`, input `b`, etc. + when iterating sub-blocks in rule, identifier should modified to reflect the parent block, e.g. `rules.a.input.b` + (`b` may be difficult to identify, but at least we know the content of `input` block) + - content := "self.raw()", e.g. `"data.txt"` for input `b` in rule `a`, + and the whole content of the block for rule `a` + + Idealy, it should recognize sth like: + rules.a.input.b + - enable `rules.a` to the position of `rule a:` + - enable `~~~~~~~.input` to the position of `input:` of `rule a` + - enable `~~~~~~~~~~~~~.b` to the position of `b=` in `input:` of `rule a` + """ + for block in self.sub_blocks: + yield from block.components() + + @abstractmethod + def formatted(self) -> str: + """return formatted code of the block""" + + @abstractmethod + def compilation(self) -> str: + """return pure python code compiled from the block, without snakemake keywords and comments""" + + +class DocumentSymbol(NamedTuple): + name: str + detail: str + symbol_kind: str + position_start: tuple[int, int] + position_end: tuple[int, int] + block: "Block" + + +class PythonBlock(Block): + def consume(self, tokens): + "Do nothing, win" + + def formatted(self) -> str: + raise NotImplementedError + + def compilation(self) -> str: + raise NotImplementedError + + def components(self): + yield from [] + + +class ColonBlock(Block): + @classmethod + def _keyword(cls): + return cls.__name__.lower() + + @property + def keyword(self) -> str: + return self._keyword() + + __slots__ = ("post_colon_coding",) + + def __init__(self, tokens, indent_level, head_tokens): + self.post_colon_coding: list[TokenInfo] = [] + super().__init__(tokens, indent_level, head_tokens) + + def consume(self, tokens): + """Consume tokens until the end of the block head line (the line with `:`)""" + token = next(tokens) + if token.type != tokenize.INDENT: + tokens.denext(token) + token_iter = TokenIterator("", iter(self.head_tokens)) + colon_index = 0 + while True: + token, *rest = token_iter.next_component() + if not rest and token.type == tokenize.OP and token.string == ":": + break + colon_index += 1 + len(rest) + self.post_colon_coding = self.head_tokens[colon_index + 1 :] + else: + self.consume_body(tokens) + + @abstractmethod + def consume_body(self, tokens) -> None: ... + + def recognises(self, token: TokenInfo) -> bool: + return token.type == tokenize.NAME and token.string == self.keyword + + +class FunctionClassBlock(ColonBlock): ... + + +function_class_blocks: dict[str, type[FunctionClassBlock]] = { + i.lower(): type(i.capitalize(), (FunctionClassBlock,), {}) for i in ("def", "class") +} + + +class IfForTryWithBlock(ColonBlock): + def consume_body(self, tokens): + """Consume tokens until the end of the block head line (the line with `:`)""" + global_block = GlobalBlock(tokens, self.indent_level + 1, []) + self.sub_blocks.append(global_block) + + +if_for_try_with_blocks: dict[str, type[IfForTryWithBlock]] = { + i.lower(): type(i.capitalize(), (IfForTryWithBlock,), {}) + for i in PYTHON_INDENT_KEYWORDS +} + + +class SnakemakeBlock(ColonBlock): + __slots__ = ("name",) + name: str + + subautomata: dict[str, Block] = {} + deprecated: dict[str, str] = {} + + def components(self): + this_symbol = DocumentSymbol( + name=self.name, + detail="\n".join(i.rstrip() for i in self.block_lines()).strip("\n"), + symbol_kind=self._keyword(), + position_start=self.start_token.start, + position_end=self.head_tokens[-1].end, + block=self, + ) + yield this_symbol + + +global_snakemake_blocks: dict[str, type[SnakemakeBlock]] = {} + + +class CommentBlock(Block): ... + + +class GlobalBlock(Block): + subautomata = ( + function_class_blocks | if_for_try_with_blocks | global_snakemake_blocks + ) + + def consume(self, tokens): + """pass through all tokens until the next indenting keyword, + and check if there is any non-comment content. + """ + plain_python_tokens: list[TokenInfo] = [] + end_token: Optional[TokenInfo] = None + block_depth = 0 + indent_str = "[TBD]" + while not end_token or end_token.type != tokenize.ENDMARKER: + head_empty_lines, indents_, contents_, end_token = tokens.next_new_line() + if indents_: + if indents_[0].type == tokenize.INDENT: + assert len(indents_) == 1, f"Unexpected INDENTs {indents_!r}" + # there should be only one INDENT token at the beginning of the block + if block_depth == 0 and indent_str == "[TBD]": + indent_str = extract_indent(indents_[0]) + else: + block_depth += 1 + else: + assert {t.type for t in indents_} == { + tokenize.DEDENT + }, f"Unexpected DEDENTs {indents_!r}" + if block_depth: + block_depth -= 1 + else: + # get out of the block + tokens.denext( + end_token, + *reversed(contents_), + *reversed(indents_[1:]), + ) + head_empty_ = iter(head_empty_lines) + for token in head_empty_: + if token.type == tokenize.COMMENT: + if extract_indent(token).startswith(indent_str): + self.tail_noncoding.append(token) + else: + break + else: + self.tail_noncoding.append(token) + head_empty_lines1 = list(head_empty_) + tokens.denext(*reversed(list(head_empty_))) + head_empty_lines = head_empty_lines1 + break + if head_empty_lines: + if self.sub_blocks and not plain_python_tokens: + plain_python_tokens = self.sub_blocks[-1].extend_tail_noncoding( + head_empty_lines + ) + else: + plain_python_tokens.extend(head_empty_lines) + if contents_: + token = contents_[0] + if token.type == tokenize.NAME and token.string in self.subautomata: + indent_level = self.indent_level + block_depth + 1 + colon_block = self.subautomata[token.string]( + tokens, indent_level, [*contents_, end_token] + ) + if not colon_block.extend_head_noncoding(head_empty_lines): + self.sub_blocks.append( + PythonBlock(tokens, indent_level, plain_python_tokens) + ) + self.sub_blocks.append(colon_block) + plain_python_tokens = [] + else: + plain_python_tokens.extend((*contents_, end_token)) + else: + plain_python_tokens.append(end_token) + if plain_python_tokens: + self.sub_blocks.append( + PythonBlock(tokens, self.indent_level, plain_python_tokens) + ) + + def formatted(self) -> str: + raise NotImplementedError + + def compilation(self) -> str: + raise NotImplementedError + + +def parse(input: str | Callable[[], str], name: str = "") -> GlobalBlock: + if isinstance(input, str): + tokens = tokenize.generate_tokens( + iter(input.splitlines(keepends=True)).__next__ + ) + else: + tokens = tokenize.generate_tokens(input) + return GlobalBlock(TokenIterator(name, tokens), 0, []) + + +def _determine_comment_indent( + comment_token: TokenInfo, previous_indents: list[str], current_len: int +) -> int: + """ + This function returns the real indent level of a comment token and + update self.indents if needed, + which is determined by the following real code line and previous indents. + + Durning parsing self.snakefile, when a comment token is encountered, + its effective indent level is not directly knowable. + + principles: + follow_indent = indent of the following real code line + if EOF: + follow_indent = 0 + rule 1 (always): + indent of comments >= follow_indent + rule 2 (if follow_indent < self.indents[-1]): + indent of comments = epsilon + max( + i for i in self.indents if i <= comment_indent + ) + + next(self.snakefile) until follow_indent is determined, + then put all peeked tokens back. + """ + previous_len = len(previous_indents) - 1 + if previous_len <= current_len: + return current_len + # Rule 2 (dedent is happening, standalone only): snap comment to the + # highest indent level fitting within the comment's column. + for i, indent in enumerate(reversed(previous_indents)): + if comment_token.line.startswith(indent): + break + else: + raise SyntaxError("Unexpected indent") + return max(previous_len - i, current_len) + + +def token_indents_updated(token: TokenInfo, indents: list[str]) -> bool: + if token.type == tokenize.INDENT: + line = token.line + indent = line[: len(line) - len(line.lstrip())] + if indent not in indents: + indents.append(indent) + elif token.type == tokenize.DEDENT: + line = token.line + indent = line[: len(line) - len(line.lstrip())] + while indents and indents[-1] != indent: + indents.pop() + else: + return False + return True diff --git a/tests/test_blocken.py b/tests/test_blocken.py new file mode 100644 index 0000000..f2a73b0 --- /dev/null +++ b/tests/test_blocken.py @@ -0,0 +1,229 @@ +import pytest + +from snakefmt.blocken import ( + consume_fstring, + TokenIterator, + tokenize, + is_fstring_start, + UnsupportedSyntax, +) + + +def generate_tokens(input: str): + return list( + tokenize.generate_tokens(iter(input.splitlines(keepends=True)).__next__) + ) + + +class TestTokenIterator: + def test_fstring1(self): + input = 'f"hello world"' + tokens = generate_tokens(input) + token_iter = TokenIterator("", iter(tokens)) + # region test the classic useage of `consume_fstring`, + # togather with `is_fstring_start` + for t in token_iter: + if is_fstring_start(t): + contents = consume_fstring(token_iter) + break + # endregion test + assert t == tokens[0] + assert contents == tokens[1:-2] + assert [i.type for i in contents] == [ + tokenize.FSTRING_MIDDLE, + tokenize.FSTRING_END, + ] + + def test_fstring_with_bracket(self): + input = 'a = f"hello {world}"' + tokens = generate_tokens(input) + token_iter = TokenIterator("", iter(tokens)) + for t in token_iter: + if is_fstring_start(t): + contents = consume_fstring(token_iter) + assert t == tokens[2] + assert contents == tokens[3:-2] + assert [i.type for i in contents] == [ + tokenize.FSTRING_MIDDLE, + tokenize.OP, + tokenize.NAME, + tokenize.OP, + tokenize.FSTRING_END, + ] + break + + def test_consum_all(self): + input = "sth" + tokens = generate_tokens(input) + token_iter = TokenIterator("", iter(tokens)) + with pytest.raises(UnsupportedSyntax): + for t in token_iter: + pass + assert t.type == tokenize.ENDMARKER + + example1 = ( + "def f():\n" # + " return 1\n" + "\n" + "\n" + "b = f'''\n" + "{b =} f'''\n" + "# comment\n" + "with d: # comment\n" + " pass" + ) + + def test_next_new_line(self): + tokens = generate_tokens(self.example1) + token_iter = TokenIterator("", iter(tokens)) + # return: `def f():` + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == indents == [] + assert contents == tokens[:5] + assert [i.string for i in contents] == ["def", "f", "(", ")", ":"] + assert {token.line} == {t.line for t in contents} + # return: `return 1` with indent + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == [] + assert indents == [tokens[6]] + assert contents == tokens[7:9] + assert [i.string for i in contents] == ["return", "1"] + assert {token.line} == {t.line for t in contents} + # return: the full `b = f'''\n...` f-string, with dedent and empty lines + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == tokens[10:12] + assert indents == [tokens[12]] + assert contents == tokens[13:23] + assert [i.string for i in contents] == [ + *("b", "=", "f'''", "\n", "{", "b", "=", "}", " f", "'''") + ] + assert token.line == contents[-1].line + # return: `with d:`, with empty lines and inline comment + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == tokens[24:26] + assert indents == [] + assert contents == tokens[26:30] + assert [i.string for i in contents] == ["with", "d", ":", "# comment"] + assert {token.line} == {t.line for t in contents} + # return: `pass`, with indent but no `\n` at the end + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == [] + assert indents == [tokens[31]] + assert contents == tokens[32:33] + assert [i.string for i in contents] == ["pass"] + assert {token.line} == {t.line for t in contents} + assert token.string == "" and token.type == tokenize.NEWLINE + # return: the ENDMARKER, with dedent and no content + head_empty_lines, indents, contents, token = token_iter.next_new_line() + assert head_empty_lines == contents == [] + assert indents == [tokens[34]] + assert token == tokens[35] == tokens[-1] + assert token.type == tokenize.ENDMARKER + + example2 = ( + "def components(self):\n" # + " this_symbol: DocumentSymbol = DocumentSymbol(\n" + " name=self.name,\n" + " detail='\\n'.join(i.rstrip() for i in self.block_lines()).strip('\\n'),\n" + " symbol_kind=self._keyword(),\n" + " position_start=self.start_token.start,\n" + " position_end=self.head_tokens[-1].end,\n" + " block=self,\n" + " )\n" + " yield this_symbol\n" + ) + + def test_next_component(self): + tokens = generate_tokens(self.example2) + token_iter = TokenIterator("", iter(tokens)) + index = 0 + + def _check_single_component(*components: str): + nonlocal index + for string in components: + contents = token_iter.next_component() + assert contents == tokens[index : index + 1] + assert [i.string for i in contents] == [string] + index += 1 + + _check_single_component("def", "components") + contents = token_iter.next_component() + assert contents == tokens[2:5] + assert [i.string for i in contents] == ["(", "self", ")"] + index = 5 + _check_single_component( + *(":", "\n"), + *(" ", "this_symbol", ":", "DocumentSymbol", "=", "DocumentSymbol"), + ) + contents = token_iter.next_component() + assert contents == tokens[13:][:73] + assert [i.string for i in contents] == [ + *("(", "\n"), + *("name", "=", "self", ".", "name", ",", "\n"), + *("detail", "=", "'\\n'", ".", "join", "("), + *("i", ".", "rstrip", "(", ")"), + *("for", "i", "in", "self", ".", "block_lines", "(", ")"), + *(")", ".", "strip", "(", "'\\n'", ")", ",", "\n"), + *("symbol_kind", "=", "self", ".", "_keyword", "(", ")", ",", "\n"), + "position_start", + *("=", "self", ".", "start_token", ".", "start", ",", "\n"), + *("position_end", "=", "self", ".", "head_tokens", "["), + *("-", "1", "]", ".", "end", ",", "\n"), + *("block", "=", "self", ",", "\n"), + ")", + ] + index = 86 + _check_single_component("\n", "yield", "this_symbol", "\n", "") + contents = token_iter.next_component() + assert contents == tokens[91:][:1] == tokens[-1:] + + example3 = ( + "with a as b:\n" # + " b\n" + " # 0\n" + " while c:\n" + " d\n" + " # 1\n" + " # 2\n" + "\n" + " # 3\n" + " # 4\n" + " \n" + " # 5\n" + " # 6\n" + "7# 7\n" + "\n" + ) + + def test_next_block(self): + tokens = generate_tokens(self.example3) + assert [i for i, t in enumerate(tokens) if t.type == tokenize.INDENT] == [6, 15] + # from the first line to the last content line + contents = TokenIterator("", iter(tokens[3:])).next_block() + assert contents[0].line == "with a as b:\n" + assert contents == tokens[3:][:34] + assert contents[-1].type == tokenize.NEWLINE + assert contents[-1].line == "7# 7\n" + # from the second line, to the last line before + # ` # 5\n`, whose indent out of the block + contents = contents_ = TokenIterator("", iter(tokens[6:])).next_block() + assert contents[0].line == " b\n" and contents[0].type == tokenize.INDENT + assert contents == tokens[6:][:22] + tokens[32:][:2] + assert {t.type for t in contents[-2:]} == {tokenize.DEDENT} + assert contents[:-2][-1].line == " \n" + # even skip the heading indent, block ends at the same line + contents = TokenIterator("", iter(tokens[7:])).next_block() + assert contents == contents_[1:] + # so does the COMMENT line + contents = TokenIterator("", iter(tokens[9:])).next_block() + assert contents[0].line == " # 0\n" and contents[0].type == tokenize.COMMENT + assert contents == contents_[3:] + # enter the third block: exit before ` # 3\n` with 1 DEDENT only + contents = TokenIterator("", iter(tokens[15:])).next_block() + assert contents[0].line == " d\n" and tokens[14].type == tokenize.NEWLINE + assert contents == tokens[15:][:8] + tokens[32:][:1] + assert [t.type for t in contents[-4:]] == [ + *(tokenize.COMMENT, tokenize.NL, tokenize.NL, tokenize.DEDENT) + ] + assert contents[-4].line == contents[-3].line == " # 2\n" + assert contents[-2].line == "\n" From ee1c298a7c384a1f7df7d6489944e6fab7edc8ad Mon Sep 17 00:00:00 2001 From: hwrn Date: Sat, 4 Apr 2026 01:54:00 +0800 Subject: [PATCH 24/53] feat: parse_python_block --- snakefmt/blocken.py | 142 ++++++++++++++++++++---------------------- tests/test_blocken.py | 19 ++++++ 2 files changed, 85 insertions(+), 76 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 929694d..0bcc57c 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -1,7 +1,7 @@ import sys import tokenize from abc import ABC, abstractmethod -from typing import Callable, Iterator, NamedTuple, Optional +from typing import Callable, Iterable, Iterator, NamedTuple, Optional from tokenize import TokenInfo @@ -127,8 +127,7 @@ def next_block(self): *reversed(indents[indent_level:]), ) break - elif token.type == tokenize.ENDMARKER and indent_level == 1: - # no indents (guaranteed) and no content (should be) + if token.type == tokenize.ENDMARKER and indent_level == 1: self.denext(token, *reversed(contents)) break block_contents.extend(head_empty_lines + indents + contents + [token]) @@ -152,6 +151,8 @@ def dedent_tail_noncoding(self, tokens: list[TokenInfo], block_indent: str): - this COMMENT belongs to this block - else: afterwards, all COMMENT belongs to parent (or grand-parents) block """ + if not tokens: + return [] for i, token in enumerate(tokens): if token.type == tokenize.COMMENT: if not extract_indent(token).startswith(block_indent): @@ -215,6 +216,33 @@ def split_token_lines(token: TokenInfo): ) +def block_lines(tokens: Iterator[TokenInfo]): + lines: dict[int, str] = {} + # Lines that are interior to a multiline token (string / f-string body). + # Their content must not be reindented. + string_interior_lines: set[int] = set() + for token in tokens: + if not_indent(token) and token.end[0] not in lines: + lines.update(split_token_lines(token)) + if token.start[0] != token.end[0]: + string_interior_lines.update( + range(token.start[0] + 1, token.end[0] + 1) + ) + newlines: list[str] = [] + for i in sorted(lines): + line = lines[i] + if i in string_interior_lines: + assert newlines, "block cannot start inner a multiline-string" + newlines[-1] += line + else: + newlines.append(line) + return newlines + + +def not_indent(token: TokenInfo) -> bool: + return token.type != tokenize.INDENT and token.type != tokenize.DEDENT + + class Block(ABC): """ A block can be: @@ -293,13 +321,13 @@ def __init__( def extend_tail_noncoding(self, tokens: list[TokenInfo]): self.tail_noncoding.extend(tokens) - del self.extend_tail_noncoding # should never be called again return [] def extend_head_noncoding(self, tokens: list[TokenInfo]): """Test if the tokens are all non-coding, and if so, extend head_noncoding with them and return True. Otherwise, return False and do not modify head_noncoding. """ + assert not self.head_noncoding, "head_noncoding should be empty before extend" if {i.type for i in tokens} <= {tokenize.NL, tokenize.COMMENT}: self.head_noncoding = tokens return True @@ -320,40 +348,20 @@ def raw_indent(self) -> str: assert self.start_token is not None, "start_token should be set after consume()" return self.start_token.line[: self.start_token.start[1]] - def block_lines(self) -> list[str]: - lines: dict[int, str] = {} - # Lines that are interior to a multiline token (string / f-string body). - # Their content must not be reindented. - string_interior_lines: set[int] = set() - for token in self.head_tokens: - if token.end[0] not in lines: - lines.update(split_token_lines(token)) - if token.start[0] != token.end[0]: - string_interior_lines.update( - range(token.start[0] + 1, token.end[0] + 1) - ) - newlines: list[str] = [] - for i in sorted(lines): - line = lines[i] - if i in string_interior_lines: - assert newlines, "block cannot start inner a multiline-string" - newlines[-1] += line - else: - newlines.append(line) - return newlines + def block_lines(self): + return block_lines(iter(self.head_tokens)) - def raw(self) -> list[str]: + def raw(self): """return the code splited by lines, but should keep multiline-string or multiline-f-string complete, to make trimming and reformatting easier. Should and Only should be rewrite for pure python blocks. """ - lines = ( - [comment.line for comment in self.head_noncoding] + block_lines(filter(not_indent, self.head_noncoding)) + self.block_lines() + [line for block in self.sub_blocks for line in block.raw()] - + [token.line for token in self.tail_noncoding] + + block_lines(filter(not_indent, self.tail_noncoding)) ) return lines @@ -396,13 +404,15 @@ class DocumentSymbol(NamedTuple): class PythonBlock(Block): + """Hold `head_tokens` only, no tokens comments, no sub-blocks""" + def consume(self, tokens): "Do nothing, win" - def formatted(self) -> str: + def formatted(self): raise NotImplementedError - def compilation(self) -> str: + def compilation(self): raise NotImplementedError def components(self): @@ -441,13 +451,22 @@ def consume(self, tokens): self.consume_body(tokens) @abstractmethod - def consume_body(self, tokens) -> None: ... + def consume_body(self, tokens: TokenIterator) -> None: ... - def recognises(self, token: TokenInfo) -> bool: + def recognises(self, token: TokenInfo): return token.type == tokenize.NAME and token.string == self.keyword -class FunctionClassBlock(ColonBlock): ... +class FunctionClassBlock(ColonBlock): + def consume_body(self, tokens): + contents = tokens.next_block() + self.sub_blocks.append(PythonBlock(tokens, self.indent_level + 1, contents)) + + def formatted(self): + raise NotImplementedError + + def compilation(self): + raise NotImplementedError function_class_blocks: dict[str, type[FunctionClassBlock]] = { @@ -461,6 +480,12 @@ def consume_body(self, tokens): global_block = GlobalBlock(tokens, self.indent_level + 1, []) self.sub_blocks.append(global_block) + def formatted(self): + raise NotImplementedError + + def compilation(self): + raise NotImplementedError + if_for_try_with_blocks: dict[str, type[IfForTryWithBlock]] = { i.lower(): type(i.capitalize(), (IfForTryWithBlock,), {}) @@ -542,6 +567,7 @@ def consume(self, tokens): tokens.denext(*reversed(list(head_empty_))) head_empty_lines = head_empty_lines1 break + had_plain_python = len(plain_python_tokens) if head_empty_lines: if self.sub_blocks and not plain_python_tokens: plain_python_tokens = self.sub_blocks[-1].extend_tail_noncoding( @@ -556,9 +582,11 @@ def consume(self, tokens): colon_block = self.subautomata[token.string]( tokens, indent_level, [*contents_, end_token] ) - if not colon_block.extend_head_noncoding(head_empty_lines): + if colon_block.extend_head_noncoding(head_empty_lines): + plain_python_tokens = plain_python_tokens[:had_plain_python] + if plain_python_tokens: self.sub_blocks.append( - PythonBlock(tokens, indent_level, plain_python_tokens) + PythonBlock(tokens, self.indent_level, plain_python_tokens) ) self.sub_blocks.append(colon_block) plain_python_tokens = [] @@ -571,10 +599,10 @@ def consume(self, tokens): PythonBlock(tokens, self.indent_level, plain_python_tokens) ) - def formatted(self) -> str: + def formatted(self): raise NotImplementedError - def compilation(self) -> str: + def compilation(self): raise NotImplementedError @@ -588,44 +616,6 @@ def parse(input: str | Callable[[], str], name: str = "") -> GlobalBlock return GlobalBlock(TokenIterator(name, tokens), 0, []) -def _determine_comment_indent( - comment_token: TokenInfo, previous_indents: list[str], current_len: int -) -> int: - """ - This function returns the real indent level of a comment token and - update self.indents if needed, - which is determined by the following real code line and previous indents. - - Durning parsing self.snakefile, when a comment token is encountered, - its effective indent level is not directly knowable. - - principles: - follow_indent = indent of the following real code line - if EOF: - follow_indent = 0 - rule 1 (always): - indent of comments >= follow_indent - rule 2 (if follow_indent < self.indents[-1]): - indent of comments = epsilon + max( - i for i in self.indents if i <= comment_indent - ) - - next(self.snakefile) until follow_indent is determined, - then put all peeked tokens back. - """ - previous_len = len(previous_indents) - 1 - if previous_len <= current_len: - return current_len - # Rule 2 (dedent is happening, standalone only): snap comment to the - # highest indent level fitting within the comment's column. - for i, indent in enumerate(reversed(previous_indents)): - if comment_token.line.startswith(indent): - break - else: - raise SyntaxError("Unexpected indent") - return max(previous_len - i, current_len) - - def token_indents_updated(token: TokenInfo, indents: list[str]) -> bool: if token.type == tokenize.INDENT: line = token.line diff --git a/tests/test_blocken.py b/tests/test_blocken.py index f2a73b0..dd57d40 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -6,6 +6,7 @@ tokenize, is_fstring_start, UnsupportedSyntax, + parse, ) @@ -227,3 +228,21 @@ def test_next_block(self): ] assert contents[-4].line == contents[-3].line == " # 2\n" assert contents[-2].line == "\n" + + +class TestBlock: + example1 = ( + "def f():\n" # + " return 1\n" + "\n" + "\n" + "b = f'''\n" + "{b =} f'''\n" + "# comment\n" + "with d: # comment\n" + " pass" + ) + + def test_parse_python_block(self): + block = parse(self.example1) + assert "".join(block.raw()) == self.example1 From c62059d306e2e953cb5ca783d2e626dfb706f851 Mon Sep 17 00:00:00 2001 From: hwrn Date: Sun, 5 Apr 2026 16:30:39 +0800 Subject: [PATCH 25/53] refactor: update TokenIterator and LogicalLine for improved line handling --- snakefmt/blocken.py | 511 ++++++++++++++++++++++-------------------- tests/test_blocken.py | 72 +++++- 2 files changed, 335 insertions(+), 248 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 0bcc57c..dcd7a8e 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -1,7 +1,7 @@ import sys import tokenize from abc import ABC, abstractmethod -from typing import Callable, Iterable, Iterator, NamedTuple, Optional +from typing import Callable, Iterator, NamedTuple, Optional from tokenize import TokenInfo @@ -12,11 +12,10 @@ else: is_fstring_start = lambda token: token.type == tokenize.FSTRING_START - def consume_fstring(tokens: TokenIterator): + def consume_fstring(tokens: Iterator[TokenInfo]): finished: list[TokenInfo] = [] isin_fstring = 1 - while True: - token = next(tokens) + for token in tokens: finished.append(token) if token.type == tokenize.FSTRING_START: isin_fstring += 1 @@ -27,7 +26,7 @@ def consume_fstring(tokens: TokenIterator): return finished -def extract_indent(token: TokenInfo) -> str: +def extract_deindents(token: TokenInfo) -> str: line = token.line return line[: len(line) - len(line.lstrip())] @@ -47,29 +46,7 @@ def __iter__(self): return self def next_new_line(self): - """Returns contents of a entire logical lines (including continued lines), - also include indent tokens before it. - - the tokens yield like: - - [NL/COMMENT_LINE] -> [INDENT] -> (real content tokens) -> NEWLINE -> (repeat) - """ - head_empty_lines: list[TokenInfo] = [] - indents: list[TokenInfo] = [] - contents: list[TokenInfo] = [] - while True: - token = next(self) - if token.type == tokenize.NEWLINE or token.type == tokenize.ENDMARKER: - return head_empty_lines, indents, contents, token - elif not (contents or indents) and ( - token.type == tokenize.NL or token.type == tokenize.COMMENT - ): - head_empty_lines.append(token) - elif token.type == tokenize.INDENT or token.type == tokenize.DEDENT: - assert not contents, "Never expect indent after any content" - indents.append(token) - else: - contents.append(token) + return LogicalLine.from_token(self) def next_component(self): """Returns the next component, should not break string/bracket pairs""" @@ -98,69 +75,70 @@ def next_block(self): it could be INDEDT -> [any content] -> DEDENT, or [any content] -> DEDENT """ - block_contents: list[TokenInfo] = [] - head_empty_lines, indents, contents, token = self.next_new_line() - assert not indents or ( - [i.type for i in indents] == [tokenize.INDENT] - ), f"Unexpected indents {indents!r}" - assert contents, "Unexpected empty line" - block_contents.extend(head_empty_lines + indents + contents + [token]) - indent_level = 1 + line = self.next_new_line() + assert line.deindelta >= 0, "Unexpected DEDENT at the beginning of a block" + assert line.body, "Unexpected empty line at the beginning of a block" + lines = [line] + deindelta = 1 while True: # read entire line, dedent if needed - head_empty_lines, indents, contents, token = self.next_new_line() - if indents: - if [i.type for i in indents] == [tokenize.INDENT]: - indent_level += 1 - else: - assert {i.type for i in indents} == { - tokenize.DEDENT - }, f"Unexpected indents {indents!r}" - indent_level -= len(indents) - if indent_level <= 0: - # now it is used to represent `DEDENTs to keep` - # e.g. indent_level=1, 2 DEDENTs -> went 1 too deep -> keep 1 - indent_level += len(indents) - self.denext( - token, - *reversed(contents), - *reversed(indents[indent_level:]), - ) - break - if token.type == tokenize.ENDMARKER and indent_level == 1: - self.denext(token, *reversed(contents)) + line = self.next_new_line() + deindelta += line.deindelta + if deindelta <= 0: + deindelta -= line.deindelta + break + elif line.end.type == tokenize.ENDMARKER: + assert deindelta == 1 break - block_contents.extend(head_empty_lines + indents + contents + [token]) + lines.append(line) # there must be somewhere a DEDENT token to end the block, otherwise raise from __next__ # now check comments - indent = extract_indent(block_contents[0]) - block_contents.extend(self.dedent_tail_noncoding(head_empty_lines, indent)) - block_contents.extend(indents[:indent_level]) - return block_contents - - def dedent_tail_noncoding(self, tokens: list[TokenInfo], block_indent: str): - """Call at the end of a block, - split comments belong to this block from those belong to parent blocks, - and reorder . - Dedent the tail_noncoding tokens of a block, and return the dedented tokens. - The indent level of the tail_noncoding tokens should be the same as the block_indent. + indent = extract_deindents(lines[0].body[0]) + tail_noncoding = self.denext_by_indent(line, indent, deindelta) + return lines, tail_noncoding + + def denext_by_indent(self, line: LogicalLine, indent: str, deindelta=1): + """Call when a block is ended by a DEDENT token, + to split comments belong to this block from those belong to parent blocks, + and reorder tokens so that the next block can be parsed correctly. + + Parameters: + - line: the line after the block, with DEDENT out of the block + - indent: the indent string of the ending block, + used to determine the belongness of comments + - deindelta: the number of DEDENT tokens to pop, + should be >1 if the block ends at deeper indent levels + + Return: the head_noncoding tokens belongs to the ending block + according to indents: + - if block_indent <= extract_deindents(comments): + - this COMMENT belongs to this block + - else: afterwards, all COMMENT belongs to parent (or grand-parents) block + - all NL before this COMMENT belongs to this block - Should control tail_noncoding of the block: - - all NL belongs to this block - - if block_indent <= extract_indent(comments): - - this COMMENT belongs to this block - - else: afterwards, all COMMENT belongs to parent (or grand-parents) block + Dedent the tail_noncoding tokens of a block, and return the dedented tokens. + The indent level of the tail_noncoding tokens should be the same (or deeper) + as the block_indent. """ - if not tokens: - return [] - for i, token in enumerate(tokens): + head, dedents, body, end = line + self.denext(end, *reversed(body), *reversed(dedents[1:])) + if body: + assert not body[0].line.startswith(indent), ( + f"indent of ending block(`{indent!r}`) should longer " + f"than the next line(`{body[0].line!r}`)" + ) + if not head: + return dedents[:deindelta] + for i, token in enumerate(head): if token.type == tokenize.COMMENT: - if not extract_indent(token).startswith(block_indent): + if not extract_deindents(token).startswith(indent): break else: assert token.type == tokenize.NL, f"Unexpected token {token!r}" - self.denext(*reversed(tokens[i:])) - return tokens[:i] + else: + i += 1 # == len(head), push all head tokens back + self.denext(*reversed(head[i:])) + return head[:i] + dedents[:deindelta] def __next__(self) -> TokenInfo: if self._buffered_tokens: @@ -191,11 +169,75 @@ def denext(self, *tokens: TokenInfo) -> None: self._buffered_tokens.extend(tokens) -PYTHON_INDENT_KEYWORDS = { - i - for j in ("if elif else", "for while", "try except finally", "with") - for i in j.split() -} +class LogicalLine(NamedTuple): + head_noncoding: list[TokenInfo] + deindents: list[TokenInfo] + body: list[TokenInfo] + end: TokenInfo + + @property + def end_op(self): + body_size = len(self.body) + if body_size < 2: # single op line make no sense + return None + last_token = self.body[-1] + if last_token.type == tokenize.COMMENT: + last_token = self.body[-2] + if last_token.type != tokenize.OP: + return None + return last_token.string + + @property + def deindelta(self): + if not self.deindents: + return 0 + if [i.type for i in self.deindents] == [tokenize.INDENT]: + return 1 + assert {i.type for i in self.deindents} == {tokenize.DEDENT} + return -len(self.deindents) + + @property + def linestrs(self): + if not self.head_noncoding and self.body: + if self.body[0].start[0] == self.end.end[0]: + return [self.body[0].line] + return tokens2linestrs(iter(self.iter)) + + @property + def iter(self): + yield from self.head_noncoding + yield from self.deindents + yield from self.body + yield self.end + + @classmethod + def from_token(cls, tokens: Iterator[TokenInfo]): + """Returns contents of a entire logical lines (including continued lines), + also include deindent tokens before it. + + the tokens yield like: + + [NL/COMMENT_LINE] -> [indeents] -> (real content tokens) -> NEWLINE -> (repeat) + or + [NL/COMMENT_LINE] -> [DEDENT] -> () -> ENDMARKER + """ + + head_empty_lines: list[TokenInfo] = [] + deindents: list[TokenInfo] = [] + contents: list[TokenInfo] = [] + for token in tokens: + if token.type == tokenize.NEWLINE or token.type == tokenize.ENDMARKER: + break + elif not (contents or deindents) and ( + token.type == tokenize.NL or token.type == tokenize.COMMENT + ): + head_empty_lines.append(token) + elif token.type == tokenize.INDENT or token.type == tokenize.DEDENT: + assert not contents, "Never expect deindents after any content" + deindents.append(token) + else: + contents.append(token) + return cls(head_empty_lines, deindents, contents, token) def split_token_lines(token: TokenInfo): @@ -216,13 +258,17 @@ def split_token_lines(token: TokenInfo): ) -def block_lines(tokens: Iterator[TokenInfo]): +def tokens2linestrs(tokens: Iterator[TokenInfo]): + """Convert a sequence of tokens into a list of strings, one for each line. + ignore deindents (may be reorganized from next few lines) + """ + lines: dict[int, str] = {} # Lines that are interior to a multiline token (string / f-string body). # Their content must not be reindented. string_interior_lines: set[int] = set() for token in tokens: - if not_indent(token) and token.end[0] not in lines: + if not_deindent(token) and token.end[0] not in lines: lines.update(split_token_lines(token)) if token.start[0] != token.end[0]: string_interior_lines.update( @@ -239,7 +285,7 @@ def block_lines(tokens: Iterator[TokenInfo]): return newlines -def not_indent(token: TokenInfo) -> bool: +def not_deindent(token: TokenInfo) -> bool: return token.type != tokenize.INDENT and token.type != tokenize.DEDENT @@ -301,67 +347,59 @@ def func(): # block 2 ) """ - __slots__ = ( - "indent_level", - "head_noncoding", - "head_tokens", - "sub_blocks", - "tail_noncoding", - ) + __slots__ = ("deindent_level", "head_lines", "body_blocks", "tail_noncoding") def __init__( - self, tokens: TokenIterator, indent_level: int, head_tokens: list[TokenInfo] - ) -> None: - self.sub_blocks: list["Block"] = [] - self.head_noncoding: list[TokenInfo] = [] + self, + deindent_level: int, + tokens: TokenIterator, + lines: list[LogicalLine] | None = None, + ): + self.deindent_level = deindent_level + self.head_lines = [] if lines is None else lines + self.body_blocks: list[Block] = [] self.tail_noncoding: list[TokenInfo] = [] - self.head_tokens = head_tokens - self.indent_level = indent_level self.consume(tokens) def extend_tail_noncoding(self, tokens: list[TokenInfo]): self.tail_noncoding.extend(tokens) return [] - def extend_head_noncoding(self, tokens: list[TokenInfo]): - """Test if the tokens are all non-coding, and if so, extend head_noncoding with them and return True. - Otherwise, return False and do not modify head_noncoding. - """ - assert not self.head_noncoding, "head_noncoding should be empty before extend" - if {i.type for i in tokens} <= {tokenize.NL, tokenize.COMMENT}: - self.head_noncoding = tokens - return True - return False - @abstractmethod def consume(self, tokens: TokenIterator) -> None: ... @property - def start_token(self): - if not self.head_tokens: - raise UnsupportedSyntax("Unexpected empty block") - return self.head_tokens[0] + def start_token(self) -> TokenInfo | None: + for line in self.head_lines: + if line.body: + return line.body[0] + for block in self.body_blocks: + token = block.start_token + if token: + return token + return None @property - def raw_indent(self) -> str: + def indent_str(self) -> str: "tell the raw indent of the block" assert self.start_token is not None, "start_token should be set after consume()" return self.start_token.line[: self.start_token.start[1]] - def block_lines(self): - return block_lines(iter(self.head_tokens)) + @property + def head_linestrs(self): + return [i for line in self.head_lines for i in line.linestrs] - def raw(self): + @property + def full_linestrs(self) -> list[str]: """return the code splited by lines, but should keep multiline-string or multiline-f-string complete, to make trimming and reformatting easier. Should and Only should be rewrite for pure python blocks. """ lines = ( - block_lines(filter(not_indent, self.head_noncoding)) - + self.block_lines() - + [line for block in self.sub_blocks for line in block.raw()] - + block_lines(filter(not_indent, self.tail_noncoding)) + self.head_linestrs + + [line for block in self.body_blocks for line in block.full_linestrs] + + tokens2linestrs(filter(not_deindent, self.tail_noncoding)) ) return lines @@ -382,7 +420,7 @@ def components(self) -> "Iterator[DocumentSymbol]": - enable `~~~~~~~.input` to the position of `input:` of `rule a` - enable `~~~~~~~~~~~~~.b` to the position of `b=` in `input:` of `rule a` """ - for block in self.sub_blocks: + for block in self.body_blocks: yield from block.components() @abstractmethod @@ -404,7 +442,7 @@ class DocumentSymbol(NamedTuple): class PythonBlock(Block): - """Hold `head_tokens` only, no tokens comments, no sub-blocks""" + """Hold `head_lines` and `tail_noncoding`, no `body_blocks`""" def consume(self, tokens): "Do nothing, win" @@ -420,6 +458,16 @@ def components(self): class ColonBlock(Block): + """ + Hold `head_lines`, `body_blocks`, `tail_noncoding` for: + "`subautomata` ...`:` [COMMENT]" <- headlines + `line` <- body_blocks[0] + [...] <- body_blocks[1:] + or + "`subautomata` ...`:` `inline`" <- headlines + body_blocks is empty + """ + @classmethod def _keyword(cls): return cls.__name__.lower() @@ -428,27 +476,17 @@ def _keyword(cls): def keyword(self) -> str: return self._keyword() - __slots__ = ("post_colon_coding",) - - def __init__(self, tokens, indent_level, head_tokens): - self.post_colon_coding: list[TokenInfo] = [] - super().__init__(tokens, indent_level, head_tokens) + @property + def colon_line(self): + assert self.head_lines, "ColonBlock should have head lines" + return self.head_lines[-1] def consume(self, tokens): """Consume tokens until the end of the block head line (the line with `:`)""" - token = next(tokens) - if token.type != tokenize.INDENT: - tokens.denext(token) - token_iter = TokenIterator("", iter(self.head_tokens)) - colon_index = 0 - while True: - token, *rest = token_iter.next_component() - if not rest and token.type == tokenize.OP and token.string == ":": - break - colon_index += 1 + len(rest) - self.post_colon_coding = self.head_tokens[colon_index + 1 :] - else: - self.consume_body(tokens) + if self.colon_line.end_op != ":": + # single line indent such as `else: pass` or `except: pass` + return + self.consume_body(tokens) @abstractmethod def consume_body(self, tokens: TokenIterator) -> None: ... @@ -458,9 +496,14 @@ def recognises(self, token: TokenInfo): class FunctionClassBlock(ColonBlock): + """A block starting with `def` or `class`, and only has a single body PythonBlock + Also contain heading decorators (`@` lines) + """ + def consume_body(self, tokens): - contents = tokens.next_block() - self.sub_blocks.append(PythonBlock(tokens, self.indent_level + 1, contents)) + lines, tail_noncoding = tokens.next_block() + self.body_blocks.append(PythonBlock(self.deindent_level + 1, tokens, lines)) + self.extend_tail_noncoding(tail_noncoding) def formatted(self): raise NotImplementedError @@ -477,8 +520,9 @@ def compilation(self): class IfForTryWithBlock(ColonBlock): def consume_body(self, tokens): """Consume tokens until the end of the block head line (the line with `:`)""" - global_block = GlobalBlock(tokens, self.indent_level + 1, []) - self.sub_blocks.append(global_block) + global_block = GlobalBlock(self.deindent_level + 1, tokens, []) + self.body_blocks.extend(global_block.body_blocks) + self.extend_tail_noncoding(global_block.tail_noncoding) def formatted(self): raise NotImplementedError @@ -487,6 +531,19 @@ def compilation(self): raise NotImplementedError +class UnknownIndentBlock(IfForTryWithBlock): + """Although I cannot imadge why an INDENT occurs + without the control of existing colon keywords, but just in case, + I will treat the contents as a global block + """ + + +PYTHON_INDENT_KEYWORDS = { + i + for j in ("if elif else", "for while", "try except finally", "with") + for i in j.split() +} + if_for_try_with_blocks: dict[str, type[IfForTryWithBlock]] = { i.lower(): type(i.capitalize(), (IfForTryWithBlock,), {}) for i in PYTHON_INDENT_KEYWORDS @@ -503,10 +560,10 @@ class SnakemakeBlock(ColonBlock): def components(self): this_symbol = DocumentSymbol( name=self.name, - detail="\n".join(i.rstrip() for i in self.block_lines()).strip("\n"), + detail="\n".join(i.rstrip() for i in self.head_linestrs).strip("\n"), symbol_kind=self._keyword(), - position_start=self.start_token.start, - position_end=self.head_tokens[-1].end, + position_start=self.colon_line.body[0].start, + position_end=self.colon_line.body[-1].end, block=self, ) yield this_symbol @@ -515,89 +572,81 @@ def components(self): global_snakemake_blocks: dict[str, type[SnakemakeBlock]] = {} -class CommentBlock(Block): ... +class GlobalBlock(Block): + """Hold `body_blocks` only, no `head_lines` nor `tail_noncoding` + all blocks in `body_blocks` should in the + same deindent level as GlobalBlock itself + so tail_noncoding always updated to the last body_block + """ -class GlobalBlock(Block): subautomata = ( function_class_blocks | if_for_try_with_blocks | global_snakemake_blocks ) def consume(self, tokens): - """pass through all tokens until the next indenting keyword, - and check if there is any non-comment content. + """Split all lines of same indent into plain Python blocks and indent blocks, + until the end of file or DEDENT out. + + - select subautomata to consume indent blocks + - denext_by_indent when DEDENT out """ - plain_python_tokens: list[TokenInfo] = [] - end_token: Optional[TokenInfo] = None - block_depth = 0 + + plain_python_lines: list[LogicalLine] = [] + tail_noncoding: list[TokenInfo] = [] indent_str = "[TBD]" - while not end_token or end_token.type != tokenize.ENDMARKER: - head_empty_lines, indents_, contents_, end_token = tokens.next_new_line() - if indents_: - if indents_[0].type == tokenize.INDENT: - assert len(indents_) == 1, f"Unexpected INDENTs {indents_!r}" - # there should be only one INDENT token at the beginning of the block - if block_depth == 0 and indent_str == "[TBD]": - indent_str = extract_indent(indents_[0]) - else: - block_depth += 1 - else: - assert {t.type for t in indents_} == { - tokenize.DEDENT - }, f"Unexpected DEDENTs {indents_!r}" - if block_depth: - block_depth -= 1 - else: - # get out of the block - tokens.denext( - end_token, - *reversed(contents_), - *reversed(indents_[1:]), - ) - head_empty_ = iter(head_empty_lines) - for token in head_empty_: - if token.type == tokenize.COMMENT: - if extract_indent(token).startswith(indent_str): - self.tail_noncoding.append(token) - else: - break - else: - self.tail_noncoding.append(token) - head_empty_lines1 = list(head_empty_) - tokens.denext(*reversed(list(head_empty_))) - head_empty_lines = head_empty_lines1 - break - had_plain_python = len(plain_python_tokens) - if head_empty_lines: - if self.sub_blocks and not plain_python_tokens: - plain_python_tokens = self.sub_blocks[-1].extend_tail_noncoding( - head_empty_lines - ) - else: - plain_python_tokens.extend(head_empty_lines) - if contents_: - token = contents_[0] - if token.type == tokenize.NAME and token.string in self.subautomata: - indent_level = self.indent_level + block_depth + 1 - colon_block = self.subautomata[token.string]( - tokens, indent_level, [*contents_, end_token] - ) - if colon_block.extend_head_noncoding(head_empty_lines): - plain_python_tokens = plain_python_tokens[:had_plain_python] - if plain_python_tokens: - self.sub_blocks.append( - PythonBlock(tokens, self.indent_level, plain_python_tokens) - ) - self.sub_blocks.append(colon_block) - plain_python_tokens = [] - else: - plain_python_tokens.extend((*contents_, end_token)) + + def append_sub(block_type: type[ColonBlock], header_lines: list[LogicalLine]): + if plain_python_lines: + self.body_blocks.append( + PythonBlock(self.deindent_level, tokens, list(plain_python_lines)) + ) + plain_python_lines.clear() + self.body_blocks.append( + block_type(self.deindent_level, tokens, header_lines) + ) + + while True: + line = tokens.next_new_line() + if line.deindelta > 0 and indent_str != "[TBD]": + tokens.denext(*reversed(list(line.iter))) + assert plain_python_lines, "Unexpected INDENT without any content" + header_line = plain_python_lines.pop() + append_sub(UnknownIndentBlock, [header_line]) + continue + elif line.deindelta < 0: + assert indent_str != "[TBD]" + tail_noncoding = tokens.denext_by_indent(line, indent_str, 1) + break + elif line.end.type == tokenize.ENDMARKER: + plain_python_lines.append( + LogicalLine(line.head_noncoding, [], [], line.end) + ) + self.body_blocks.append( + PythonBlock(self.deindent_level, tokens, plain_python_lines) + ) + plain_python_lines = [] + break else: - plain_python_tokens.append(end_token) - if plain_python_tokens: - self.sub_blocks.append( - PythonBlock(tokens, self.indent_level, plain_python_tokens) + if indent_str == "[TBD]": + assert ( + line.body + ), "Unexpected empty line at the beginning of a block" + indent_str = extract_deindents(line.body[0]) + if ( + line.body[0].type == tokenize.NAME + and line.body[0].string in self.subautomata + ): + append_sub(self.subautomata[line.body[0].string], [line]) + else: + plain_python_lines.append(line) + if plain_python_lines: + self.body_blocks.append( + PythonBlock(self.deindent_level, tokens, plain_python_lines) ) + if tail_noncoding: + assert self.body_blocks + self.body_blocks[-1].extend_tail_noncoding(tail_noncoding) def formatted(self): raise NotImplementedError @@ -613,20 +662,4 @@ def parse(input: str | Callable[[], str], name: str = "") -> GlobalBlock ) else: tokens = tokenize.generate_tokens(input) - return GlobalBlock(TokenIterator(name, tokens), 0, []) - - -def token_indents_updated(token: TokenInfo, indents: list[str]) -> bool: - if token.type == tokenize.INDENT: - line = token.line - indent = line[: len(line) - len(line.lstrip())] - if indent not in indents: - indents.append(indent) - elif token.type == tokenize.DEDENT: - line = token.line - indent = line[: len(line) - len(line.lstrip())] - while indents and indents[-1] != indent: - indents.pop() - else: - return False - return True + return GlobalBlock(0, TokenIterator(name, tokens), []) diff --git a/tests/test_blocken.py b/tests/test_blocken.py index dd57d40..6f15873 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -1,6 +1,10 @@ import pytest from snakefmt.blocken import ( + FunctionClassBlock, + GlobalBlock, + IfForTryWithBlock, + PythonBlock, consume_fstring, TokenIterator, tokenize, @@ -200,27 +204,33 @@ def test_next_block(self): tokens = generate_tokens(self.example3) assert [i for i, t in enumerate(tokens) if t.type == tokenize.INDENT] == [6, 15] # from the first line to the last content line - contents = TokenIterator("", iter(tokens[3:])).next_block() + lines, tail_noncoding = TokenIterator("", iter(tokens[3:])).next_block() + contents = [t for line in lines for t in line.iter] + tail_noncoding assert contents[0].line == "with a as b:\n" - assert contents == tokens[3:][:34] - assert contents[-1].type == tokenize.NEWLINE - assert contents[-1].line == "7# 7\n" + assert contents == tokens[3:][:35] + assert contents[-1].type == tokenize.NL + assert contents[-2].type == tokenize.NEWLINE + assert contents[-2].line == "7# 7\n" # from the second line, to the last line before # ` # 5\n`, whose indent out of the block - contents = contents_ = TokenIterator("", iter(tokens[6:])).next_block() + lines, tail_noncoding = TokenIterator("", iter(tokens[6:])).next_block() + contents = contents_ = [t for line in lines for t in line.iter] + tail_noncoding assert contents[0].line == " b\n" and contents[0].type == tokenize.INDENT assert contents == tokens[6:][:22] + tokens[32:][:2] assert {t.type for t in contents[-2:]} == {tokenize.DEDENT} assert contents[:-2][-1].line == " \n" # even skip the heading indent, block ends at the same line - contents = TokenIterator("", iter(tokens[7:])).next_block() + lines, tail_noncoding = TokenIterator("", iter(tokens[7:])).next_block() + contents = [t for line in lines for t in line.iter] + tail_noncoding assert contents == contents_[1:] # so does the COMMENT line - contents = TokenIterator("", iter(tokens[9:])).next_block() + lines, tail_noncoding = TokenIterator("", iter(tokens[9:])).next_block() + contents = [t for line in lines for t in line.iter] + tail_noncoding assert contents[0].line == " # 0\n" and contents[0].type == tokenize.COMMENT assert contents == contents_[3:] # enter the third block: exit before ` # 3\n` with 1 DEDENT only - contents = TokenIterator("", iter(tokens[15:])).next_block() + lines, tail_noncoding = TokenIterator("", iter(tokens[15:])).next_block() + contents = [t for line in lines for t in line.iter] + tail_noncoding assert contents[0].line == " d\n" and tokens[14].type == tokenize.NEWLINE assert contents == tokens[15:][:8] + tokens[32:][:1] assert [t.type for t in contents[-4:]] == [ @@ -245,4 +255,48 @@ class TestBlock: def test_parse_python_block(self): block = parse(self.example1) - assert "".join(block.raw()) == self.example1 + assert "".join(block.full_linestrs) == self.example1 + assert isinstance(block, GlobalBlock) + assert not block.head_lines + assert not block.tail_noncoding + assert ( + {block.deindent_level} + == {i.deindent_level for i in block.body_blocks} + == {0} + ) + assert ["".join(i.full_linestrs) for i in block.body_blocks] == [ + "def f():\n return 1\n\n\n", + "b = f'''\n{b =} f'''\n", + "# comment\nwith d: # comment\n pass", + "", + ] + fun1 = block.body_blocks[0] + assert isinstance(fun1, FunctionClassBlock) + assert [i.string for i in fun1.colon_line.body] == ["def", "f", "(", ")", ":"] + assert [tuple(i) for i in fun1.tail_noncoding] == [ + (tokenize.NL, "\n", (3, 0), (3, 1), "\n"), + (tokenize.NL, "\n", (4, 0), (4, 1), "\n"), + (tokenize.DEDENT, "", (5, 0), (5, 0), "b = f'''\n"), + ] + assert ["".join(i.full_linestrs) for i in fun1.body_blocks] == [ + " return 1\n" + ] + fun11 = fun1.body_blocks[0] + assert isinstance(fun11, PythonBlock) + assert [line.linestrs for line in fun11.head_lines] == [[" return 1\n"]] + assert not fun11.body_blocks + assert not fun11.tail_noncoding + if3 = block.body_blocks[2] + assert isinstance(if3, IfForTryWithBlock) + assert [i.string for i in if3.colon_line.body] == [ + *("with", "d", ":", "# comment"), + ] + assert not if3.tail_noncoding + assert ["".join(i.full_linestrs) for i in if3.body_blocks] == [" pass"] + if31 = if3.body_blocks[0] + assert isinstance(if31, PythonBlock) + assert [line.linestrs for line in if31.head_lines] == [[" pass"]] + assert not if31.body_blocks + assert [tuple(i) for i in if31.tail_noncoding] == [ + (tokenize.DEDENT, "", (10, 0), (10, 0), "") + ] From fa0b253ab4c9d826be6a1cdc42de389fb93ed363 Mon Sep 17 00:00:00 2001 From: hwrn Date: Mon, 6 Apr 2026 01:02:59 +0800 Subject: [PATCH 26/53] feat: SnakemakeBlock --- snakefmt/blocken.py | 555 ++++++++++++++++++++++++++++++++++++------ tests/test_blocken.py | 47 ++++ 2 files changed, 521 insertions(+), 81 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index dcd7a8e..837cd30 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -1,8 +1,9 @@ import sys import tokenize from abc import ABC, abstractmethod -from typing import Callable, Iterator, NamedTuple, Optional +from typing import Callable, Iterator, NamedTuple, Optional, Mapping from tokenize import TokenInfo +from collections import OrderedDict from snakefmt.exceptions import UnsupportedSyntax @@ -121,8 +122,8 @@ def denext_by_indent(self, line: LogicalLine, indent: str, deindelta=1): as the block_indent. """ head, dedents, body, end = line - self.denext(end, *reversed(body), *reversed(dedents[1:])) - if body: + self.denext(end, *reversed(body), *reversed(dedents[deindelta:])) + if body and indent: assert not body[0].line.startswith(indent), ( f"indent of ending block(`{indent!r}`) should longer " f"than the next line(`{body[0].line!r}`)" @@ -348,6 +349,8 @@ def func(): # block 2 """ __slots__ = ("deindent_level", "head_lines", "body_blocks", "tail_noncoding") + subautomata: Mapping[str, "type[ColonBlock]"] = {} + deprecated: Mapping[str, str] = {} def __init__( self, @@ -368,6 +371,84 @@ def extend_tail_noncoding(self, tokens: list[TokenInfo]): @abstractmethod def consume(self, tokens: TokenIterator) -> None: ... + def recognize(self, token: TokenInfo): + """Whether the block can be recognized by the first token of its head lines""" + if token.type == tokenize.NAME: + if token.string in self.subautomata: + return self.subautomata[token.string] + if token.string in self.deprecated: + raise UnsupportedSyntax( + f"Keyword {token.string!r} is deprecated, " + f"{self.deprecated[token.string]!r}." + ) + + def consume_subblocks(self, tokens: TokenIterator, ender_subblock=False): + """Split all lines of same indent into plain Python blocks and indent blocks, + until the end of file or DEDENT out. + + - select subautomata to consume indent blocks + - denext_by_indent when DEDENT out + + Used in GlobalBlock and SnakemakeKeywordBlock, to consume their body blocks. + """ + deindent_level = self.deindent_level + int(ender_subblock) + blocks: list[Block] = [] + + plain_python_lines: list[LogicalLine] = [] + tail_noncoding: list[TokenInfo] = [] + indent_str = "[TBD]" + + def append_sub(block_type: type[ColonBlock], header_lines: list[LogicalLine]): + if plain_python_lines: + blocks.append( + PythonBlock(deindent_level, tokens, list(plain_python_lines)) + ) + plain_python_lines.clear() + blocks.append(block_type(deindent_level, tokens, header_lines)) + + while True: + line = tokens.next_new_line() + if line.deindelta > 0 and indent_str != "[TBD]": + tokens.denext(*reversed(list(line.iter))) + assert plain_python_lines, "Unexpected INDENT without any content" + header_line = plain_python_lines.pop() + append_sub(UnknownIndentBlock, [header_line]) + continue + elif line.deindelta < 0: + assert indent_str and indent_str != "[TBD]" + tail_noncoding = tokens.denext_by_indent(line, indent_str, 1) + break + elif line.end.type == tokenize.ENDMARKER: + plain_python_lines.append( + LogicalLine(line.head_noncoding, [], [], line.end) + ) + blocks.append(PythonBlock(deindent_level, tokens, plain_python_lines)) + plain_python_lines = [] + break + else: + if indent_str == "[TBD]": + assert ( + line.body + ), "Unexpected empty line at the beginning of a block" + indent_str = extract_deindents(line.body[0]) + if block := self.recognize(line.body[0]): + append_sub(block, [line]) + elif line.body[0].string == "@": + headers = [line] + while True: + headers.append(tokens.next_new_line()) + if block := self.recognize(headers[-1].body[0]): + break + append_sub(block, headers) + else: + plain_python_lines.append(line) + if plain_python_lines: + blocks.append(PythonBlock(deindent_level, tokens, plain_python_lines)) + if tail_noncoding: + assert blocks + blocks[-1].extend_tail_noncoding(tail_noncoding) + return blocks + @property def start_token(self) -> TokenInfo | None: for line in self.head_lines: @@ -474,8 +555,15 @@ def _keyword(cls): @property def keyword(self) -> str: + """Used such as `yield f"workflow.{self.keyword}("`""" return self._keyword() + @property + def prior_colon(self): ... + + @property + def post_colon(self): ... + @property def colon_line(self): assert self.head_lines, "ColonBlock should have head lines" @@ -483,10 +571,9 @@ def colon_line(self): def consume(self, tokens): """Consume tokens until the end of the block head line (the line with `:`)""" - if self.colon_line.end_op != ":": - # single line indent such as `else: pass` or `except: pass` - return - self.consume_body(tokens) + if self.colon_line.end_op == ":": + self.consume_body(tokens) + # else: single line indent such as `else: pass` or `except: pass` @abstractmethod def consume_body(self, tokens: TokenIterator) -> None: ... @@ -519,10 +606,8 @@ def compilation(self): class IfForTryWithBlock(ColonBlock): def consume_body(self, tokens): - """Consume tokens until the end of the block head line (the line with `:`)""" - global_block = GlobalBlock(self.deindent_level + 1, tokens, []) - self.body_blocks.extend(global_block.body_blocks) - self.extend_tail_noncoding(global_block.tail_noncoding) + blocks = GlobalBlock(self.deindent_level + 1, tokens, []).body_blocks + self.body_blocks.extend(blocks) def formatted(self): raise NotImplementedError @@ -550,13 +635,10 @@ class UnknownIndentBlock(IfForTryWithBlock): } -class SnakemakeBlock(ColonBlock): +class NamedBlock(ColonBlock): __slots__ = ("name",) name: str - subautomata: dict[str, Block] = {} - deprecated: dict[str, str] = {} - def components(self): this_symbol = DocumentSymbol( name=self.name, @@ -569,7 +651,377 @@ def components(self): yield this_symbol -global_snakemake_blocks: dict[str, type[SnakemakeBlock]] = {} +class SnakemakeBlock(ColonBlock): + subautomata = {} + deprecated = {} + + def components(self) -> Iterator[DocumentSymbol]: + yield from [] + + def formatted(self): + raise NotImplementedError + + def compilation(self): + raise NotImplementedError + + +class PythonArgumentsBlock(PythonBlock): + """Block inside snakemake directives, + such as `data.txt` in `input: \n "data.txt"` + + Only allow: + - simple expressions on the right, e.g. `"data.txt",` + - assignment with simple names on the left, e.g. `a = 1,` + - Specally, allow `*args` and `**kwargs` as normal function + + Enhancement could be done: accepth expressions without trailing comma, + because each expression is already splitted by lines, + and we can add a trailing comma only if needed. + If we want to support expressions without trailing comma, + cases where two lines can makesense without a comma between them + should be carefully considered, + e.g.: + input: + "data.txt" + "data2.txt" + params: + sth + (a, b) + Although in our view this is naturally two expressions, + the action do change with the proposed enhancement. + """ + + +class PythonOneLineArgument(PythonArgumentsBlock): + """Only allow simple expressions on the right""" + + +class PythonListArguments(PythonArgumentsBlock): + """Only allow simple expressions on the right, and the whole block should be a list""" + + +class PythonListDictArguments(PythonArgumentsBlock): + """Parsed as *args, **kwargs""" + + +class SnakemakeOneLineArgumentsBlock(SnakemakeBlock): + def consume_body(self, tokens): + lines, tail_noncoding = tokens.next_block() + self.body_blocks.append( + PythonOneLineArgument(self.deindent_level + 1, tokens, lines) + ) + self.extend_tail_noncoding(tail_noncoding) + + def formatted(self): + raise NotImplementedError + + def compilation(self): + raise NotImplementedError + + +def init_block_register(): + def register_block(name: Optional[str] = None): + def decorator(type_: type[SnakemakeBlock]): + keyword = name or type_._keyword() + namespace[keyword] = type_ + return type_ + + return decorator + + namespace: OrderedDict[str, type[SnakemakeBlock]] = OrderedDict() + return namespace, register_block + + +global_snakemake_subautomata, _register = init_block_register() + + +@_register() +class Include(SnakemakeOneLineArgumentsBlock): ... + + +@_register() +class Workdir(SnakemakeOneLineArgumentsBlock): ... + + +@_register() +class Configfile(SnakemakeOneLineArgumentsBlock): ... + + +@_register("pepfile") +class Set_Pepfile(SnakemakeOneLineArgumentsBlock): ... + + +@_register() +class Pepschema(SnakemakeOneLineArgumentsBlock): ... + + +@_register() +class Report(SnakemakeOneLineArgumentsBlock): ... + + +@_register() +class Ruleorder(SnakemakeOneLineArgumentsBlock): ... + + +@_register("singularity") +@_register("container") +class Global_Container(SnakemakeOneLineArgumentsBlock): ... + + +@_register("containerized") +class Global_Containerized(SnakemakeOneLineArgumentsBlock): ... + + +@_register("conda") +class Global_Conda(SnakemakeOneLineArgumentsBlock): ... + + +class SnakemakeListArgumentsBlock(SnakemakeBlock): + def consume_body(self, tokens): + lines, tail_noncoding = tokens.next_block() + self.body_blocks.append( + PythonListArguments(self.deindent_level + 1, tokens, lines) + ) + self.extend_tail_noncoding(tail_noncoding) + + def formatted(self): + raise NotImplementedError + + def compilation(self): + raise NotImplementedError + + +@_register("envvars") +class Register_Envvars(SnakemakeListArgumentsBlock): ... + + +@_register() +class Localrules(SnakemakeListArgumentsBlock): ... + + +@_register() +class InputFlags(SnakemakeListArgumentsBlock): ... + + +@_register() +class OutputFlags(SnakemakeListArgumentsBlock): ... + + +class SnakemakeListDictArgumentsBlock(SnakemakeBlock): + """Block of snakemake directives, such as `input:`, `output:`, etc. + The content is pure python. + """ + + def consume_body(self, tokens): + lines, tail_noncoding = tokens.next_block() + self.body_blocks.append( + PythonListDictArguments(self.deindent_level + 1, tokens, lines) + ) + self.extend_tail_noncoding(tail_noncoding) + + def formatted(self): + raise NotImplementedError + + def compilation(self): + raise NotImplementedError + + +@_register("wildcard_constraints") +class Global_Wildcard_Constraints(SnakemakeListDictArgumentsBlock): ... + + +@_register() +class Scattergather(SnakemakeListDictArgumentsBlock): ... + + +@_register("resource_scope") +class ResourceScope(SnakemakeListDictArgumentsBlock): ... + + +@_register("storage") +class Storage(SnakemakeListDictArgumentsBlock): ... + + +@_register("pathvars") +class Register_Pathvars(SnakemakeListDictArgumentsBlock): ... + + +class SnakemakeExecutableBlock(SnakemakeBlock): + """Block of snakemake directives, such as `run:`, `onstart:`, etc. + The content is pure python. + """ + + def consume_body(self, tokens): + lines, tail_noncoding = tokens.next_block() + self.body_blocks.append(PythonBlock(self.deindent_level + 1, tokens, lines)) + self.extend_tail_noncoding(tail_noncoding) + + +@_register() +class OnStart(SnakemakeExecutableBlock): ... + + +@_register() +class OnSuccess(SnakemakeExecutableBlock): ... + + +@_register() +class OnError(SnakemakeExecutableBlock): ... + + +class SnakemakeKeywordBlock(SnakemakeBlock): + """Block of snakemake directives, such as `rule:`, `module:`, etc. + The contents are other snakemake blocks. + """ + + def consume_body(self, tokens): + blocks = self.consume_subblocks(tokens, ender_subblock=True) + if any(not isinstance(i, SnakemakeBlock) for i in blocks): + raise UnsupportedSyntax( + f"Unexpected content in {self.keyword} block: " + f"only snakemake blocks are allowed, but got {blocks}" + ) + self.body_blocks = blocks + + +@_register() +class Module(NamedBlock, SnakemakeKeywordBlock): + subautomata, _register = init_block_register() + + @_register() + class Name(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Snakefile(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Meta_Wrapper(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Skip_Validation(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Config(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Pathvars(SnakemakeListDictArgumentsBlock): ... + + @_register() + class Prefix(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Replace_Prefix(SnakemakeOneLineArgumentsBlock): ... + + +@_register("use") +class UseRule(NamedBlock, SnakemakeKeywordBlock): + subautomata, _register = init_block_register() + + @_register() + class Name(SnakemakeOneLineArgumentsBlock): ... + + @_register("default_target") + class Default_Target_Rule(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Input(SnakemakeListDictArgumentsBlock): ... + + @_register() + class Output(SnakemakeListDictArgumentsBlock): ... + + @_register() + class Log(SnakemakeListDictArgumentsBlock): ... + + @_register() + class Benchmark(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class RulePathvars(SnakemakeListDictArgumentsBlock): ... + + @_register("wildcard_constraints") + class Register_Wildcard_Constraints(SnakemakeListDictArgumentsBlock): ... + + @_register("cache") + class Cache_Rule(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Priority(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Retries(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Group(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class LocalRule(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Handover(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Shadow(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Conda(SnakemakeOneLineArgumentsBlock): ... + + @_register("singularity") + @_register() + class Container(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Containerized(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class EnvModules(SnakemakeListArgumentsBlock): ... + + @_register() + class Threads(SnakemakeOneLineArgumentsBlock): ... + + @_register() + class Resources(SnakemakeListDictArgumentsBlock): ... + + @_register() + class Params(SnakemakeListDictArgumentsBlock): ... + + @_register() + class Message(SnakemakeOneLineArgumentsBlock): ... + + deprecated = {"version": "Use conda or container directive instead (see docs)."} + + +@_register() +class Rule(UseRule): + exec_subautomata, _register = init_block_register() + + @_register() + class Run(SnakemakeExecutableBlock): ... + + class AbstractCmd(SnakemakeOneLineArgumentsBlock, Run): ... + + @_register() + class Shell(AbstractCmd): ... + + @_register() + class Script(AbstractCmd): ... + + @_register() + class Notebook(Script): ... + + @_register() + class Wrapper(Script): ... + + @_register("template_engine") + class TemplateEngine(Script): ... + + @_register() + class CWL(Script): ... + + subautomata = {**UseRule.subautomata, **exec_subautomata} + + +@_register() +class Checkpoint(Rule): ... class GlobalBlock(Block): @@ -580,73 +1032,14 @@ class GlobalBlock(Block): so tail_noncoding always updated to the last body_block """ - subautomata = ( - function_class_blocks | if_for_try_with_blocks | global_snakemake_blocks - ) + subautomata = { + **function_class_blocks, + **if_for_try_with_blocks, + **global_snakemake_subautomata, + } def consume(self, tokens): - """Split all lines of same indent into plain Python blocks and indent blocks, - until the end of file or DEDENT out. - - - select subautomata to consume indent blocks - - denext_by_indent when DEDENT out - """ - - plain_python_lines: list[LogicalLine] = [] - tail_noncoding: list[TokenInfo] = [] - indent_str = "[TBD]" - - def append_sub(block_type: type[ColonBlock], header_lines: list[LogicalLine]): - if plain_python_lines: - self.body_blocks.append( - PythonBlock(self.deindent_level, tokens, list(plain_python_lines)) - ) - plain_python_lines.clear() - self.body_blocks.append( - block_type(self.deindent_level, tokens, header_lines) - ) - - while True: - line = tokens.next_new_line() - if line.deindelta > 0 and indent_str != "[TBD]": - tokens.denext(*reversed(list(line.iter))) - assert plain_python_lines, "Unexpected INDENT without any content" - header_line = plain_python_lines.pop() - append_sub(UnknownIndentBlock, [header_line]) - continue - elif line.deindelta < 0: - assert indent_str != "[TBD]" - tail_noncoding = tokens.denext_by_indent(line, indent_str, 1) - break - elif line.end.type == tokenize.ENDMARKER: - plain_python_lines.append( - LogicalLine(line.head_noncoding, [], [], line.end) - ) - self.body_blocks.append( - PythonBlock(self.deindent_level, tokens, plain_python_lines) - ) - plain_python_lines = [] - break - else: - if indent_str == "[TBD]": - assert ( - line.body - ), "Unexpected empty line at the beginning of a block" - indent_str = extract_deindents(line.body[0]) - if ( - line.body[0].type == tokenize.NAME - and line.body[0].string in self.subautomata - ): - append_sub(self.subautomata[line.body[0].string], [line]) - else: - plain_python_lines.append(line) - if plain_python_lines: - self.body_blocks.append( - PythonBlock(self.deindent_level, tokens, plain_python_lines) - ) - if tail_noncoding: - assert self.body_blocks - self.body_blocks[-1].extend_tail_noncoding(tail_noncoding) + self.body_blocks = self.consume_subblocks(tokens) def formatted(self): raise NotImplementedError diff --git a/tests/test_blocken.py b/tests/test_blocken.py index 6f15873..ce5fd7d 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -300,3 +300,50 @@ def test_parse_python_block(self): assert [tuple(i) for i in if31.tail_noncoding] == [ (tokenize.DEDENT, "", (10, 0), (10, 0), "") ] + + example2 = ( + "rule A:\n" # L1 + " input:\n" + " a = '1'\n" + " output:\n" + " 'b = 2'\n" + " run:\n" + " print(1)\n" + "\n" + "\n" + "checkpoint:\n" + " name: 'check'\n" # L11 + " params:\n" + " c = '''\n" + " c = '''\n" + " conda: 'conda.yaml'\n" + " shell: 'touch d'\n" + "\n" + "\n" + "onsuccess:\n" + " for i in range(10):\n" + " print(i)\n" # L21 + "\n" + "\n" + "wildcard_constraints:\n" + " sth = r'a|b|c',\n" + " sth2 = r'a|b|c',\n" + " sth3 = r'a|b|c'\n" + "\n" + "\n" + "Report:\n" + " 'report'\n" # L31 + ) + + def test_parse_snakefile(self): + block = parse(self.example2) + assert "".join(block.full_linestrs) == self.example2 + assert isinstance(block, GlobalBlock) + assert ["".join(i.full_linestrs) for i in block.body_blocks] == [ + "rule A:\n input:\n a = '1'\n output:\n 'b = 2'\n run:\n print(1)\n\n\n", + "checkpoint:\n name: 'check'\n params:\n c = '''\n c = '''\n conda: 'conda.yaml'\n shell: 'touch d'\n\n\n", + "onsuccess:\n for i in range(10):\n print(i)\n\n\n", + "wildcard_constraints:\n sth = r'a|b|c',\n sth2 = r'a|b|c',\n sth3 = r'a|b|c'\n\n\n", + "Report:\n 'report'\n", + "", + ] From 753c73382aa378bbdde226de185bea62f33f3a0b Mon Sep 17 00:00:00 2001 From: hwrn Date: Mon, 6 Apr 2026 15:00:00 +0800 Subject: [PATCH 27/53] feat: update block handling and formatting logic for pure python --- snakefmt/blocken.py | 186 +++++++++++++++++++++++++++++++++++------- tests/test_blocken.py | 63 +++++++++++++- 2 files changed, 217 insertions(+), 32 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 837cd30..b5f18f5 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -1,12 +1,16 @@ import sys import tokenize from abc import ABC, abstractmethod -from typing import Callable, Iterator, NamedTuple, Optional, Mapping +from typing import Callable, Iterator, Literal, NamedTuple, Optional, Mapping from tokenize import TokenInfo from collections import OrderedDict +import black.parsing + +from snakefmt.config import read_black_config, Mode from snakefmt.exceptions import UnsupportedSyntax +from snakefmt.types import TAB if sys.version_info < (3, 12): is_fstring_start = lambda token: False @@ -27,8 +31,7 @@ def consume_fstring(tokens: Iterator[TokenInfo]): return finished -def extract_deindents(token: TokenInfo) -> str: - line = token.line +def extract_line_indent(line: str) -> str: return line[: len(line) - len(line.lstrip())] @@ -94,7 +97,7 @@ def next_block(self): lines.append(line) # there must be somewhere a DEDENT token to end the block, otherwise raise from __next__ # now check comments - indent = extract_deindents(lines[0].body[0]) + indent = extract_line_indent(lines[0].body[0].line) tail_noncoding = self.denext_by_indent(line, indent, deindelta) return lines, tail_noncoding @@ -112,7 +115,7 @@ def denext_by_indent(self, line: LogicalLine, indent: str, deindelta=1): Return: the head_noncoding tokens belongs to the ending block according to indents: - - if block_indent <= extract_deindents(comments): + - if block_indent <= extract_line_indent(comments.line): - this COMMENT belongs to this block - else: afterwards, all COMMENT belongs to parent (or grand-parents) block - all NL before this COMMENT belongs to this block @@ -132,7 +135,7 @@ def denext_by_indent(self, line: LogicalLine, indent: str, deindelta=1): return dedents[:deindelta] for i, token in enumerate(head): if token.type == tokenize.COMMENT: - if not extract_deindents(token).startswith(indent): + if not extract_line_indent(token.line).startswith(indent): break else: assert token.type == tokenize.NL, f"Unexpected token {token!r}" @@ -290,6 +293,15 @@ def not_deindent(token: TokenInfo) -> bool: return token.type != tokenize.INDENT and token.type != tokenize.DEDENT +class FormatState(NamedTuple): + fmt_off: bool = False + sort_direcives: bool = False + + def update(self, *str): + # TODO: implement state update logic + return self._replace() + + class Block(ABC): """ A block can be: @@ -430,7 +442,7 @@ def append_sub(block_type: type[ColonBlock], header_lines: list[LogicalLine]): assert ( line.body ), "Unexpected empty line at the beginning of a block" - indent_str = extract_deindents(line.body[0]) + indent_str = extract_line_indent(line.body[0].line) if block := self.recognize(line.body[0]): append_sub(block, [line]) elif line.body[0].string == "@": @@ -480,7 +492,7 @@ def full_linestrs(self) -> list[str]: lines = ( self.head_linestrs + [line for block in self.body_blocks for line in block.full_linestrs] - + tokens2linestrs(filter(not_deindent, self.tail_noncoding)) + + tokens2linestrs(iter(self.tail_noncoding)) ) return lines @@ -505,11 +517,11 @@ def components(self) -> "Iterator[DocumentSymbol]": yield from block.components() @abstractmethod - def formatted(self) -> str: + def formatted(self, mode: Mode, state: FormatState) -> tuple[str, FormatState]: """return formatted code of the block""" @abstractmethod - def compilation(self) -> str: + def compilation(self): """return pure python code compiled from the block, without snakemake keywords and comments""" @@ -522,14 +534,63 @@ class DocumentSymbol(NamedTuple): block: "Block" +def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] = ""): + """Format a string using Black formatter. + + if indent: + prefix = make series of `{' ' * i}if 1:\\n` to increase indent level + format(prefix + string) + remove first `indent` lines + if partial == ":": + safe_indent = longest(prefix spacing) + format(string + f"\\n{safe_indent} pass") + remove the last line + if partial == "(": + format("f(" + string + ")") + if string.startswith("f(\\n"): + remove the first line and the last line + else: + remove first three characters and the last character + """ + prefix = "" + for i in range(indent): + prefix += " " * i + "if 1:\n" + if partial == ":": + # for block such as if/else/... + safe_indent = max(extract_line_indent(line) for line in prefix.splitlines()) + string = raw + f"\n{safe_indent} pass" + elif partial == "(": + string = "f(\n" + raw + "\n)" + else: + string = raw + try: + fmted = black.format_str(prefix + string, mode=mode) + except black.parsing.InvalidInput as e: + raise e + if indent: + fix = fmted.split("\n", indent)[-1] + else: + fix = fmted + if partial == ":": + fix = fix.rstrip().rsplit("\n", 1)[0] + "\n" + elif partial == "(": + if string.startswith("f(\n"): + fix = fix.split("\n", 1)[1].rsplit("\n", 1)[0] + "\n" + else: + fix = fix[2:-1] + return fix + + class PythonBlock(Block): """Hold `head_lines` and `tail_noncoding`, no `body_blocks`""" def consume(self, tokens): "Do nothing, win" - def formatted(self): - raise NotImplementedError + def formatted(self, mode, state): + raw = "".join(self.full_linestrs) + formatted = format_black(raw, mode, self.deindent_level) + return formatted, state def compilation(self): raise NotImplementedError @@ -582,9 +643,14 @@ def recognises(self, token: TokenInfo): return token.type == tokenize.NAME and token.string == self.keyword -class FunctionClassBlock(ColonBlock): +class NoSnakemakeBlock(ColonBlock): """A block starting with `def` or `class`, and only has a single body PythonBlock Also contain heading decorators (`@` lines) + + Also, snakemake keywords should not be used in `async` blocks + + TODO: although not recommended, snakemake keywords can be used in function/class body + Should handle that cases in the future """ def consume_body(self, tokens): @@ -592,15 +658,17 @@ def consume_body(self, tokens): self.body_blocks.append(PythonBlock(self.deindent_level + 1, tokens, lines)) self.extend_tail_noncoding(tail_noncoding) - def formatted(self): - raise NotImplementedError + def formatted(self, mode, state): + raw = "".join(self.full_linestrs) + formatted = format_black(raw, mode, self.deindent_level) + return formatted, state def compilation(self): raise NotImplementedError -function_class_blocks: dict[str, type[FunctionClassBlock]] = { - i.lower(): type(i.capitalize(), (FunctionClassBlock,), {}) for i in ("def", "class") +function_class_blocks: dict[str, type[NoSnakemakeBlock]] = { + i.lower(): type(i.capitalize(), (NoSnakemakeBlock,), {}) for i in ("def", "class") } @@ -609,8 +677,28 @@ def consume_body(self, tokens): blocks = GlobalBlock(self.deindent_level + 1, tokens, []).body_blocks self.body_blocks.extend(blocks) - def formatted(self): - raise NotImplementedError + def formatted(self, mode, state): + formatted = [] + if self.body_blocks: + raw = "".join(self.full_linestrs) + return format_black(raw, mode, self.deindent_level), state + raw_head = "".join(self.head_linestrs) + head = format_black(raw_head, mode, self.deindent_level, partial=":") + formatted.append(head) + state_ = state + if isinstance( + self.body_blocks[0], + (NoSnakemakeBlock, NamedBlock, SnakemakeExecutableBlock), + ): + formatted.append("\n") + for block in self.body_blocks: + block_formatted, state_ = block.formatted(mode, state_) + formatted.append(block_formatted) + formatted.append("\n") + formatted.pop() # remove the last "\n" + for comment in tokens2linestrs(iter(self.tail_noncoding)): + formatted.append(TAB * self.deindent_level + comment.lstrip()) + return "".join(formatted), state_ def compilation(self): raise NotImplementedError @@ -635,6 +723,39 @@ class UnknownIndentBlock(IfForTryWithBlock): } +class CaseBlock(IfForTryWithBlock): ... + + +class MatchBlock(ColonBlock): + subautomata = {"case": CaseBlock} + + def consume_body(self, tokens): + blocks = self.consume_subblocks(tokens, ender_subblock=True) + if any(not isinstance(i, CaseBlock) for i in blocks): + raise UnsupportedSyntax( + f"Unexpected content in {self.keyword} block: " + f"only `Case` keyword is allowed, but got {blocks}" + ) + self.body_blocks = blocks + + def formatted(self, mode, state): + raise NotImplementedError + + def compilation(self): + raise NotImplementedError + + +class AsyncBlock(NoSnakemakeBlock): ... + + +python_subautomata: dict[str, type[ColonBlock]] = { + **function_class_blocks, + **if_for_try_with_blocks, + "match": MatchBlock, + "async": AsyncBlock, +} + + class NamedBlock(ColonBlock): __slots__ = ("name",) name: str @@ -658,7 +779,7 @@ class SnakemakeBlock(ColonBlock): def components(self) -> Iterator[DocumentSymbol]: yield from [] - def formatted(self): + def formatted(self, mode, state): raise NotImplementedError def compilation(self): @@ -712,7 +833,7 @@ def consume_body(self, tokens): ) self.extend_tail_noncoding(tail_noncoding) - def formatted(self): + def formatted(self, mode, state): raise NotImplementedError def compilation(self): @@ -784,7 +905,7 @@ def consume_body(self, tokens): ) self.extend_tail_noncoding(tail_noncoding) - def formatted(self): + def formatted(self, mode, state): raise NotImplementedError def compilation(self): @@ -819,7 +940,7 @@ def consume_body(self, tokens): ) self.extend_tail_noncoding(tail_noncoding) - def formatted(self): + def formatted(self, mode, state): raise NotImplementedError def compilation(self): @@ -1032,17 +1153,22 @@ class GlobalBlock(Block): so tail_noncoding always updated to the last body_block """ - subautomata = { - **function_class_blocks, - **if_for_try_with_blocks, - **global_snakemake_subautomata, - } + subautomata = {**python_subautomata, **global_snakemake_subautomata} def consume(self, tokens): self.body_blocks = self.consume_subblocks(tokens) - def formatted(self): - raise NotImplementedError + def formatted(self, mode, state): + formatted = [] + state_ = state + linesep = "\n" if self.deindent_level else "\n\n" + for block in self.body_blocks: + block_formatted, state_ = block.formatted(mode, state_) + formatted.append(block_formatted) + formatted.append(linesep) + if formatted: + formatted.pop() # remove the last "\n" + return "".join(formatted), state_ def compilation(self): raise NotImplementedError diff --git a/tests/test_blocken.py b/tests/test_blocken.py index ce5fd7d..c8793aa 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -1,7 +1,8 @@ import pytest from snakefmt.blocken import ( - FunctionClassBlock, + FormatState, + NoSnakemakeBlock, GlobalBlock, IfForTryWithBlock, PythonBlock, @@ -11,7 +12,9 @@ is_fstring_start, UnsupportedSyntax, parse, + black, ) +from snakefmt.config import read_black_config def generate_tokens(input: str): @@ -271,7 +274,7 @@ def test_parse_python_block(self): "", ] fun1 = block.body_blocks[0] - assert isinstance(fun1, FunctionClassBlock) + assert isinstance(fun1, NoSnakemakeBlock) assert [i.string for i in fun1.colon_line.body] == ["def", "f", "(", ")", ":"] assert [tuple(i) for i in fun1.tail_noncoding] == [ (tokenize.NL, "\n", (3, 0), (3, 1), "\n"), @@ -347,3 +350,59 @@ def test_parse_snakefile(self): "Report:\n 'report'\n", "", ] + + +class TestBlockFormat: + + example1 = ( + "\n" + "@decorator\n" + "\n" + "#def f(\n" + "def f(\n" + " a, b:int\n" + "):\n" # + " return 1\n" + "b = f'''\n" + "{b =} f'''\n" + " # comment\n" + "c = [i for j in k] if m else (\n" + " lambda: None\n" + " )\n" + ) + + mode = read_black_config(None) + state = FormatState() + + def test_format_python_block(self): + block = parse(self.example1) + # fun11.formatted(self.mode, self.state) + assert "".join(block.full_linestrs) == self.example1 + assert [i.full_linestrs for i in block.body_blocks] == [ + [ + "\n", + "@decorator\n", + "\n", + "#def f(\n", + "def f(\n", + " a, b:int\n", + "):\n", + " return 1\n", + ], + [ + "b = f'''\n{b =} f'''\n", + " # comment\n", + "c = [i for j in k] if m else (\n", + " lambda: None\n", + " )\n", + ], + ] + py2 = block.body_blocks[1] + assert len(py2.head_lines) == 3 + assert ( + py2.formatted(self.mode, self.state)[0] + == 'b = f"""\n{b =} f"""\n# comment\nc = [i for j in k] if m else (lambda: None)\n' + ) + assert block.formatted(self.mode, self.state)[0] == black.format_str( + self.example1, mode=self.mode + ) From 1d17ffb01f120eab17e87dea0ec25a0a3f95302a Mon Sep 17 00:00:00 2001 From: hwrn Date: Tue, 7 Apr 2026 02:06:23 +0800 Subject: [PATCH 28/53] fix: support snakefile back --- snakefmt/blocken.py | 389 ++++++++++++++++++++++++++++++------------ tests/test_blocken.py | 110 +++++++++++- 2 files changed, 379 insertions(+), 120 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index b5f18f5..d08aa5c 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -1,7 +1,16 @@ import sys import tokenize from abc import ABC, abstractmethod -from typing import Callable, Iterator, Literal, NamedTuple, Optional, Mapping +from typing import ( + Any, + Callable, + Iterator, + Literal, + NamedTuple, + Optional, + Mapping, + TypeVar, +) from tokenize import TokenInfo from collections import OrderedDict import black.parsing @@ -162,6 +171,12 @@ def __next__(self) -> TokenInfo: self._last_token = token return token + @property + def rest(self): + while self._buffered_tokens: + yield self._buffered_tokens.pop() + yield from self._live_tokens + def denext(self, *tokens: TokenInfo) -> None: """.denext(a, b, c): next(token) will return c, then b, then a. pull back tokens so they can be pushed in the correct order when .next() @@ -557,15 +572,17 @@ def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] prefix += " " * i + "if 1:\n" if partial == ":": # for block such as if/else/... - safe_indent = max(extract_line_indent(line) for line in prefix.splitlines()) - string = raw + f"\n{safe_indent} pass" + safe_indent = max(extract_line_indent(line) for line in raw.splitlines()) + string = raw + f"{safe_indent} pass" elif partial == "(": - string = "f(\n" + raw + "\n)" + # Tb() effects equals to a entire new indent + string = " " * indent + "Tb(\n" + raw + "\n)" else: string = raw try: fmted = black.format_str(prefix + string, mode=mode) except black.parsing.InvalidInput as e: + breakpoint() raise e if indent: fix = fmted.split("\n", indent)[-1] @@ -574,10 +591,11 @@ def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] if partial == ":": fix = fix.rstrip().rsplit("\n", 1)[0] + "\n" elif partial == "(": - if string.startswith("f(\n"): + fix = fix.strip() + if fix.startswith("Tb(\n"): fix = fix.split("\n", 1)[1].rsplit("\n", 1)[0] + "\n" else: - fix = fix[2:-1] + fix = TAB * (indent + 1) + fix[3:-1] + "\n" return fix @@ -589,6 +607,8 @@ def consume(self, tokens): def formatted(self, mode, state): raw = "".join(self.full_linestrs) + if not raw.strip(): + return "", state formatted = format_black(raw, mode, self.deindent_level) return formatted, state @@ -619,11 +639,19 @@ def keyword(self) -> str: """Used such as `yield f"workflow.{self.keyword}("`""" return self._keyword() - @property - def prior_colon(self): ... - - @property - def post_colon(self): ... + def split_colon_line(self): + token_iter = TokenIterator("", iter(self.head_lines[-1].iter)) + last_line_tokens = [] + while True: + component = token_iter.next_component() + if [(i.type, i.string) for i in component] == [(tokenize.OP, ":")]: + break + last_line_tokens.extend(component) + (colon_token,) = component + prior = tokens2linestrs(iter(last_line_tokens)) + prior[-1] = prior[-1][: colon_token.start[1]] + token_iter.denext(colon_token) + return prior, token_iter @property def colon_line(self): @@ -679,7 +707,7 @@ def consume_body(self, tokens): def formatted(self, mode, state): formatted = [] - if self.body_blocks: + if not self.body_blocks: raw = "".join(self.full_linestrs) return format_black(raw, mode, self.deindent_level), state raw_head = "".join(self.head_linestrs) @@ -697,7 +725,8 @@ def formatted(self, mode, state): formatted.append("\n") formatted.pop() # remove the last "\n" for comment in tokens2linestrs(iter(self.tail_noncoding)): - formatted.append(TAB * self.deindent_level + comment.lstrip()) + if comment.strip(): + formatted.append(TAB * self.deindent_level + comment.lstrip()) return "".join(formatted), state_ def compilation(self): @@ -780,7 +809,37 @@ def components(self) -> Iterator[DocumentSymbol]: yield from [] def formatted(self, mode, state): - raise NotImplementedError + formatted_prior, post_colon = self.format_head(mode) + formatted_body = self.format_body(mode, state, post_colon) + formatted = [formatted_prior, formatted_body] + for comment in tokens2linestrs(iter(self.tail_noncoding)): + if comment.strip(): + formatted.append(TAB * self.deindent_level + comment.lstrip()) + return "".join(formatted), state + + def format_head(self, mode: Mode) -> tuple[str, list[TokenInfo]]: + assert ( + len(self.head_lines) == 1 + ), "Snakemake keywords should only have one head line" + prior_colon, post_colon = self.split_colon_line() + (head,) = prior_colon + assert len(prior_colon) == 1, "Snakemake keywords should be single line" + components = head.strip().split() + formatted_head = TAB * self.deindent_level + " ".join(components) + ":" + if self.colon_line.end_op == ":": + # only a single line comment or empty is possible here, add directly + colon_token = next(post_colon) + post = tokens2linestrs(post_colon.rest) + post[0] = post[0][colon_token.end[1] :] + fake_str = f"if 1:" + "".join(post) + " ..." + fake_fmt = format_black(fake_str, mode).strip() + formatted_head += fake_fmt.split(":", 1)[1].rsplit("\n", 1)[0] + "\n" + return formatted_head, [] + else: + return formatted_head + "\n", list(post_colon.rest) + + @abstractmethod + def format_body(self, mode, state, post_colon: list[TokenInfo]) -> str: ... def compilation(self): raise NotImplementedError @@ -813,36 +872,155 @@ class PythonArgumentsBlock(PythonBlock): """ -class PythonOneLineArgument(PythonArgumentsBlock): - """Only allow simple expressions on the right""" +class PythonArguments(PythonArgumentsBlock): + """Parsed as *args, **kwargs""" + def formatted(self, mode, state): + """PythonArguments and its subclasses always at the terminal + of the snakemake keyword tree, + so returned state never used anymore + """ + assert not self.body_blocks, "PythonArguments should not have body blocks" + raw = "".join(self.head_linestrs) + if not self.head_lines[-1].end_op == ",": + raw += "\n," + tail_noncoding = tokens2linestrs(iter(self.tail_noncoding)) + raw += "".join(i for i in tail_noncoding if i.strip()) + formatted = format_black(raw, mode, self.deindent_level - 1, partial="(") + return formatted, state -class PythonListArguments(PythonArgumentsBlock): + @classmethod + def format_post_colon( + cls, post_colon: list[TokenInfo], deindent_level: int, mode: Mode + ): + """If the params are in the same line as the keyword, + e.g. `input: "data.txt"`, + then self.body_blocks is empty, should take those from post_colon + """ + assert ( + post_colon and post_colon[-1].type == tokenize.NEWLINE + ), "Unexpected post_colon without a new line at the end" + colon_token = post_colon[0] + partial_line = LogicalLine([], [], post_colon[1:-1], post_colon[-1]) + post = tokens2linestrs(iter(partial_line.body)) + post[0] = post[0][colon_token.end[1] :] + raw = "".join(post) + if not partial_line.end_op == ",": + raw += "\n," + formatted = format_black(raw, mode, deindent_level, partial="(") + return formatted + + +class PythonUnnamedArguments(PythonArguments): """Only allow simple expressions on the right, and the whole block should be a list""" -class PythonListDictArguments(PythonArgumentsBlock): - """Parsed as *args, **kwargs""" +class PythonOneLineArgument(PythonUnnamedArguments): + """Only allow simple expressions on the right""" + def formatted(self, mode, state): + """Only a single expression, trim the trailing comma""" + assert not self.body_blocks, "PythonArguments should not have body blocks" + raw = "".join(self.head_linestrs) + if self.head_lines[-1].end_op == ",": + last_line = self.head_lines[-1] + comma_token = ( + last_line.body[-2] + if last_line.body[-1].type == tokenize.COMMENT + else last_line.body[-1] + ) + comma_start = comma_token.start[1] - len(comma_token.line) + raw = raw[:comma_start] + raw[comma_start + 1 :] + tail_noncoding = tokens2linestrs(iter(self.tail_noncoding)) + raw += "".join(i for i in tail_noncoding if i.strip()) + formatted = format_black(raw, mode, self.deindent_level - 1, partial="(") + return formatted, state + + @classmethod + def format_post_colon( + cls, post_colon: list[TokenInfo], deindent_level: int, mode: Mode + ): + assert ( + post_colon and post_colon[-1].type == tokenize.NEWLINE + ), "Unexpected post_colon without a new line at the end" + colon_token = post_colon[0] + partial_line = LogicalLine([], [], post_colon[1:-1], post_colon[-1]) + post = tokens2linestrs(iter(partial_line.body)) + post[0] = post[0][colon_token.end[1] :] + raw = "".join(post) + if partial_line.end_op == ",": + comma_token = ( + partial_line.body[-2] + if partial_line.body[-1].type == tokenize.COMMENT + else partial_line.body[-1] + ) + comma_start = comma_token.start[1] - len(comma_token.line) + raw = raw[:comma_start] + raw[comma_start + 1 :] + formatted = format_black(raw, mode, deindent_level, partial="(") + return formatted + + +class SnakemakeArgumentsBlock(SnakemakeBlock): + """Block of snakemake directives, such as `input:`, `output:`, etc. + The content is pure python. + """ + + Argument = PythonArguments -class SnakemakeOneLineArgumentsBlock(SnakemakeBlock): def consume_body(self, tokens): lines, tail_noncoding = tokens.next_block() - self.body_blocks.append( - PythonOneLineArgument(self.deindent_level + 1, tokens, lines) - ) + self.body_blocks.append(self.Argument(self.deindent_level + 1, tokens, lines)) self.extend_tail_noncoding(tail_noncoding) - def formatted(self, mode, state): - raise NotImplementedError + def format_body(self, mode, state, post_colon) -> str: + """Format body as in the function call, + e.g. `input: "data.txt",` -> `input("data.txt")` + """ + if post_colon: + return self.Argument.format_post_colon( + post_colon, self.deindent_level, mode + ) + else: + (param_space,) = self.body_blocks + return param_space.formatted(mode, state)[0] def compilation(self): raise NotImplementedError +class SnakemakeUnnamedArgumentsBlock(SnakemakeArgumentsBlock): + Argument = PythonUnnamedArguments + + +class SnakemakeUnnamedArgumentBlock(SnakemakeUnnamedArgumentsBlock): + Argument = PythonOneLineArgument + + +class SnakemakeInlineArgumentBlock(SnakemakeUnnamedArgumentBlock): + + def formatted(self, mode, state): + """Try to merge the inline argument into the head line. + If the line is too long after merging, then keep them separate. + """ + formatted_prior, post_colon = self.format_head(mode) + formatted_body = self.format_body(mode, state, post_colon) + formatted = [formatted_prior, formatted_body] + if formatted_body.count("\n") == 1 and formatted_body.endswith("\n"): + if formatted_prior.endswith(":\n") and "#" not in formatted_prior: + formatted_merge = formatted_prior[:-1] + " " + formatted_body.lstrip() + if len(formatted_merge) <= mode.line_length: + formatted = [formatted_merge] + for comment in tokens2linestrs(iter(self.tail_noncoding)): + if comment.strip(): + formatted.append(TAB * self.deindent_level + comment.lstrip()) + return "".join(formatted), state + + def init_block_register(): + T = TypeVar("T", bound=SnakemakeBlock) + def register_block(name: Optional[str] = None): - def decorator(type_: type[SnakemakeBlock]): + def decorator(type_: type[T]) -> type[T]: keyword = name or type_._keyword() namespace[keyword] = type_ return type_ @@ -857,114 +1035,80 @@ def decorator(type_: type[SnakemakeBlock]): @_register() -class Include(SnakemakeOneLineArgumentsBlock): ... +class Include(SnakemakeInlineArgumentBlock): ... @_register() -class Workdir(SnakemakeOneLineArgumentsBlock): ... +class Workdir(SnakemakeInlineArgumentBlock): ... @_register() -class Configfile(SnakemakeOneLineArgumentsBlock): ... +class Configfile(SnakemakeInlineArgumentBlock): ... @_register("pepfile") -class Set_Pepfile(SnakemakeOneLineArgumentsBlock): ... +class Set_Pepfile(SnakemakeInlineArgumentBlock): ... @_register() -class Pepschema(SnakemakeOneLineArgumentsBlock): ... +class Pepschema(SnakemakeInlineArgumentBlock): ... @_register() -class Report(SnakemakeOneLineArgumentsBlock): ... +class Report(SnakemakeInlineArgumentBlock): ... @_register() -class Ruleorder(SnakemakeOneLineArgumentsBlock): ... +class Ruleorder(SnakemakeInlineArgumentBlock): ... @_register("singularity") @_register("container") -class Global_Container(SnakemakeOneLineArgumentsBlock): ... +class Global_Container(SnakemakeInlineArgumentBlock): ... @_register("containerized") -class Global_Containerized(SnakemakeOneLineArgumentsBlock): ... +class Global_Containerized(SnakemakeInlineArgumentBlock): ... @_register("conda") -class Global_Conda(SnakemakeOneLineArgumentsBlock): ... - - -class SnakemakeListArgumentsBlock(SnakemakeBlock): - def consume_body(self, tokens): - lines, tail_noncoding = tokens.next_block() - self.body_blocks.append( - PythonListArguments(self.deindent_level + 1, tokens, lines) - ) - self.extend_tail_noncoding(tail_noncoding) - - def formatted(self, mode, state): - raise NotImplementedError - - def compilation(self): - raise NotImplementedError +class Global_Conda(SnakemakeInlineArgumentBlock): ... @_register("envvars") -class Register_Envvars(SnakemakeListArgumentsBlock): ... +class Register_Envvars(SnakemakeUnnamedArgumentsBlock): ... @_register() -class Localrules(SnakemakeListArgumentsBlock): ... +class Localrules(SnakemakeUnnamedArgumentsBlock): ... @_register() -class InputFlags(SnakemakeListArgumentsBlock): ... +class InputFlags(SnakemakeUnnamedArgumentsBlock): ... @_register() -class OutputFlags(SnakemakeListArgumentsBlock): ... - - -class SnakemakeListDictArgumentsBlock(SnakemakeBlock): - """Block of snakemake directives, such as `input:`, `output:`, etc. - The content is pure python. - """ - - def consume_body(self, tokens): - lines, tail_noncoding = tokens.next_block() - self.body_blocks.append( - PythonListDictArguments(self.deindent_level + 1, tokens, lines) - ) - self.extend_tail_noncoding(tail_noncoding) - - def formatted(self, mode, state): - raise NotImplementedError - - def compilation(self): - raise NotImplementedError +class OutputFlags(SnakemakeUnnamedArgumentsBlock): ... @_register("wildcard_constraints") -class Global_Wildcard_Constraints(SnakemakeListDictArgumentsBlock): ... +class Global_Wildcard_Constraints(SnakemakeArgumentsBlock): ... @_register() -class Scattergather(SnakemakeListDictArgumentsBlock): ... +class Scattergather(SnakemakeArgumentsBlock): ... @_register("resource_scope") -class ResourceScope(SnakemakeListDictArgumentsBlock): ... +class ResourceScope(SnakemakeArgumentsBlock): ... @_register("storage") -class Storage(SnakemakeListDictArgumentsBlock): ... +class Storage(SnakemakeArgumentsBlock): ... @_register("pathvars") -class Register_Pathvars(SnakemakeListDictArgumentsBlock): ... +class Register_Pathvars(SnakemakeArgumentsBlock): ... class SnakemakeExecutableBlock(SnakemakeBlock): @@ -977,6 +1121,15 @@ def consume_body(self, tokens): self.body_blocks.append(PythonBlock(self.deindent_level + 1, tokens, lines)) self.extend_tail_noncoding(tail_noncoding) + def format_body(self, mode, state, post_colon): + if post_colon: + return PythonOneLineArgument.format_post_colon( + post_colon, self.deindent_level, mode + ) + else: + (param_space,) = self.body_blocks + return param_space.formatted(mode, state)[0] + @_register() class OnStart(SnakemakeExecutableBlock): ... @@ -1004,34 +1157,43 @@ def consume_body(self, tokens): ) self.body_blocks = blocks + def format_body(self, mode, state, post_colon): + assert not post_colon, "Invalid inline contents" + formatted = [] + for block in self.body_blocks: + block_formatted, state_ = block.formatted(mode, state) + formatted.append(block_formatted) + # no `\n` between + return "".join(formatted) + @_register() class Module(NamedBlock, SnakemakeKeywordBlock): subautomata, _register = init_block_register() @_register() - class Name(SnakemakeOneLineArgumentsBlock): ... + class Name(SnakemakeInlineArgumentBlock): ... @_register() - class Snakefile(SnakemakeOneLineArgumentsBlock): ... + class Snakefile(SnakemakeUnnamedArgumentBlock): ... @_register() - class Meta_Wrapper(SnakemakeOneLineArgumentsBlock): ... + class Meta_Wrapper(SnakemakeUnnamedArgumentBlock): ... @_register() - class Skip_Validation(SnakemakeOneLineArgumentsBlock): ... + class Skip_Validation(SnakemakeUnnamedArgumentBlock): ... @_register() - class Config(SnakemakeOneLineArgumentsBlock): ... + class Config(SnakemakeUnnamedArgumentBlock): ... @_register() - class Pathvars(SnakemakeListDictArgumentsBlock): ... + class Pathvars(SnakemakeArgumentsBlock): ... @_register() - class Prefix(SnakemakeOneLineArgumentsBlock): ... + class Prefix(SnakemakeUnnamedArgumentBlock): ... @_register() - class Replace_Prefix(SnakemakeOneLineArgumentsBlock): ... + class Replace_Prefix(SnakemakeUnnamedArgumentBlock): ... @_register("use") @@ -1039,74 +1201,74 @@ class UseRule(NamedBlock, SnakemakeKeywordBlock): subautomata, _register = init_block_register() @_register() - class Name(SnakemakeOneLineArgumentsBlock): ... + class Name(SnakemakeUnnamedArgumentBlock): ... @_register("default_target") - class Default_Target_Rule(SnakemakeOneLineArgumentsBlock): ... + class Default_Target_Rule(SnakemakeInlineArgumentBlock): ... @_register() - class Input(SnakemakeListDictArgumentsBlock): ... + class Input(SnakemakeArgumentsBlock): ... @_register() - class Output(SnakemakeListDictArgumentsBlock): ... + class Output(SnakemakeArgumentsBlock): ... @_register() - class Log(SnakemakeListDictArgumentsBlock): ... + class Log(SnakemakeArgumentsBlock): ... @_register() - class Benchmark(SnakemakeOneLineArgumentsBlock): ... + class Benchmark(SnakemakeUnnamedArgumentBlock): ... @_register() - class RulePathvars(SnakemakeListDictArgumentsBlock): ... + class RulePathvars(SnakemakeArgumentsBlock): ... @_register("wildcard_constraints") - class Register_Wildcard_Constraints(SnakemakeListDictArgumentsBlock): ... + class Register_Wildcard_Constraints(SnakemakeArgumentsBlock): ... @_register("cache") - class Cache_Rule(SnakemakeOneLineArgumentsBlock): ... + class Cache_Rule(SnakemakeInlineArgumentBlock): ... @_register() - class Priority(SnakemakeOneLineArgumentsBlock): ... + class Priority(SnakemakeInlineArgumentBlock): ... @_register() - class Retries(SnakemakeOneLineArgumentsBlock): ... + class Retries(SnakemakeInlineArgumentBlock): ... @_register() - class Group(SnakemakeOneLineArgumentsBlock): ... + class Group(SnakemakeUnnamedArgumentBlock): ... @_register() - class LocalRule(SnakemakeOneLineArgumentsBlock): ... + class LocalRule(SnakemakeInlineArgumentBlock): ... @_register() - class Handover(SnakemakeOneLineArgumentsBlock): ... + class Handover(SnakemakeInlineArgumentBlock): ... @_register() - class Shadow(SnakemakeOneLineArgumentsBlock): ... + class Shadow(SnakemakeUnnamedArgumentBlock): ... @_register() - class Conda(SnakemakeOneLineArgumentsBlock): ... + class Conda(SnakemakeUnnamedArgumentBlock): ... @_register("singularity") @_register() - class Container(SnakemakeOneLineArgumentsBlock): ... + class Container(SnakemakeUnnamedArgumentBlock): ... @_register() - class Containerized(SnakemakeOneLineArgumentsBlock): ... + class Containerized(SnakemakeUnnamedArgumentBlock): ... @_register() - class EnvModules(SnakemakeListArgumentsBlock): ... + class EnvModules(SnakemakeUnnamedArgumentsBlock): ... @_register() - class Threads(SnakemakeOneLineArgumentsBlock): ... + class Threads(SnakemakeInlineArgumentBlock): ... @_register() - class Resources(SnakemakeListDictArgumentsBlock): ... + class Resources(SnakemakeArgumentsBlock): ... @_register() - class Params(SnakemakeListDictArgumentsBlock): ... + class Params(SnakemakeArgumentsBlock): ... @_register() - class Message(SnakemakeOneLineArgumentsBlock): ... + class Message(SnakemakeUnnamedArgumentBlock): ... deprecated = {"version": "Use conda or container directive instead (see docs)."} @@ -1118,7 +1280,7 @@ class Rule(UseRule): @_register() class Run(SnakemakeExecutableBlock): ... - class AbstractCmd(SnakemakeOneLineArgumentsBlock, Run): ... + class AbstractCmd(SnakemakeUnnamedArgumentBlock, Run): ... @_register() class Shell(AbstractCmd): ... @@ -1164,8 +1326,9 @@ def formatted(self, mode, state): linesep = "\n" if self.deindent_level else "\n\n" for block in self.body_blocks: block_formatted, state_ = block.formatted(mode, state_) - formatted.append(block_formatted) - formatted.append(linesep) + if block_formatted: # avoid adding extra blank lines for empty blocks + formatted.append(block_formatted) + formatted.append(linesep) if formatted: formatted.pop() # remove the last "\n" return "".join(formatted), state_ diff --git a/tests/test_blocken.py b/tests/test_blocken.py index c8793aa..c2c89d3 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -8,6 +8,7 @@ PythonBlock, consume_fstring, TokenIterator, + format_black, tokenize, is_fstring_start, UnsupportedSyntax, @@ -15,6 +16,7 @@ black, ) from snakefmt.config import read_black_config +from snakefmt.types import TAB def generate_tokens(input: str): @@ -352,6 +354,28 @@ def test_parse_snakefile(self): ] +mode = read_black_config(None) +state = FormatState() + + +class TestFormat: + def test_format_colon(self): + raw = "if 1: #comment\n" + fmted = format_black(raw, mode=mode, partial=":") + assert fmted == "if 1: # comment\n" + + def test_format_paren(self): + raw = " 'b', a=1\n," + fmted = format_black(raw, mode=mode, indent=2, partial="(") + assert fmted == ( + f'{TAB * 3}"b",\n' # + f"{TAB * 3}a=1,\n" + ) + raw = " 'b = 2'\n\n," + fmted = format_black(raw, mode=mode, indent=1, partial="(") + assert fmted == (f'{TAB * 2}"b = 2",\n') + + class TestBlockFormat: example1 = ( @@ -371,12 +395,9 @@ class TestBlockFormat: " )\n" ) - mode = read_black_config(None) - state = FormatState() - def test_format_python_block(self): block = parse(self.example1) - # fun11.formatted(self.mode, self.state) + # fun11.formatted(mode, state) assert "".join(block.full_linestrs) == self.example1 assert [i.full_linestrs for i in block.body_blocks] == [ [ @@ -400,9 +421,84 @@ def test_format_python_block(self): py2 = block.body_blocks[1] assert len(py2.head_lines) == 3 assert ( - py2.formatted(self.mode, self.state)[0] + py2.formatted(mode, state)[0] == 'b = f"""\n{b =} f"""\n# comment\nc = [i for j in k] if m else (lambda: None)\n' ) - assert block.formatted(self.mode, self.state)[0] == black.format_str( - self.example1, mode=self.mode + assert block.formatted(mode, state)[0] == black.format_str( + self.example1, mode=mode ) + + example2 = ( + "rule A:\n" # L1 + " input: a = '1'\n" + " output:\n" + " 'b = 2'\n" + " run:\n" + " print ( 1 \n )\n" + "\n" + "\n" + "checkpoint:\n" + " name: 'check'\n" # L11 + " params:\n" + " c = [i for \n" + " i in range(1) if 3],\n" + " conda = 'conda.yaml'\n" + " shell: 'touch d'\n" + "\n" + "\n" + "onsuccess:\n" + " for i in range(10):\n" + " print(i)\n" # L21 + "\n" + "\n" + "wildcard_constraints:\n" + " sth = r'a|b|c',\n" + " sth2 = r'a|b|c',\n" + " sth3 = r'a|b|c'\n" + "\n" + "\n" + "report:\n" + "\n" + " 'report'\n" # L31 + "\n" + "\n" + "\n", + "rule A:\n" + " input:\n" + ' a="1",\n' + " output:\n" + ' "b = 2",\n' + " run:\n" + " print(1)\n" + "\n" + "\n" + "checkpoint:\n" + " name:\n" + ' "check"\n' + " params:\n" + " c=[i for i in range(1) if 3],\n" + ' conda="conda.yaml",\n' + " shell:\n" + ' "touch d"\n' + "\n" + "\n" + "onsuccess:\n" + " for i in range(10):\n" + " print(i)\n" + "\n" + "\n" + "wildcard_constraints:\n" + ' sth=r"a|b|c",\n' + ' sth2=r"a|b|c",\n' + ' sth3=r"a|b|c",\n' + "\n" + "\n" + 'report: "report"\n', + ) + + def test_format_snakefile(self): + code, formatted = self.example2 + block = parse(code) + assert block.formatted(mode, state)[0].replace("\n", "<\n") == ( + formatted + ).replace("\n", "<\n") From 5f6c7939776301c06f876bf1ed0876fcdfbb4bc3 Mon Sep 17 00:00:00 2001 From: hwrn Date: Tue, 7 Apr 2026 22:33:40 +0800 Subject: [PATCH 29/53] fix: 60 failed left --- snakefmt/blocken.py | 266 +++++++++++++++++++++++++++++++------------- 1 file changed, 191 insertions(+), 75 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index d08aa5c..78d1dc5 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -89,6 +89,9 @@ def next_block(self): it could be INDEDT -> [any content] -> DEDENT, or [any content] -> DEDENT """ line = self.next_new_line() + if line.end.type == tokenize.ENDMARKER: + self.denext(*reversed(list(line.iter))) + return [], [] assert line.deindelta >= 0, "Unexpected DEDENT at the beginning of a block" assert line.body, "Unexpected empty line at the beginning of a block" lines = [line] @@ -582,7 +585,6 @@ def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] try: fmted = black.format_str(prefix + string, mode=mode) except black.parsing.InvalidInput as e: - breakpoint() raise e if indent: fix = fmted.split("\n", indent)[-1] @@ -640,7 +642,9 @@ def keyword(self) -> str: return self._keyword() def split_colon_line(self): - token_iter = TokenIterator("", iter(self.head_lines[-1].iter)) + token_iter = TokenIterator( + "", iter(self.colon_line.body + [self.colon_line.end]) + ) last_line_tokens = [] while True: component = token_iter.next_component() @@ -651,7 +655,7 @@ def split_colon_line(self): prior = tokens2linestrs(iter(last_line_tokens)) prior[-1] = prior[-1][: colon_token.start[1]] token_iter.denext(colon_token) - return prior, token_iter + return self.colon_line.head_noncoding, prior, token_iter @property def colon_line(self): @@ -821,11 +825,15 @@ def format_head(self, mode: Mode) -> tuple[str, list[TokenInfo]]: assert ( len(self.head_lines) == 1 ), "Snakemake keywords should only have one head line" - prior_colon, post_colon = self.split_colon_line() - (head,) = prior_colon + indent = TAB * self.deindent_level + noncoding, prior_colon, post_colon = self.split_colon_line() + formatted_comments = "".join( + indent + i.line.lstrip() for i in noncoding if i.type == tokenize.COMMENT + ) assert len(prior_colon) == 1, "Snakemake keywords should be single line" + (head,) = prior_colon components = head.strip().split() - formatted_head = TAB * self.deindent_level + " ".join(components) + ":" + formatted_head = formatted_comments + indent + " ".join(components) + ":" if self.colon_line.end_op == ":": # only a single line comment or empty is possible here, add directly colon_token = next(post_colon) @@ -853,12 +861,63 @@ class PythonArgumentsBlock(PythonBlock): - simple expressions on the right, e.g. `"data.txt",` - assignment with simple names on the left, e.g. `a = 1,` - Specally, allow `*args` and `**kwargs` as normal function + """ + + @classmethod + def format_post_colon( + cls, + mode: Mode, + deindent_level: int, + post_colon: list[TokenInfo], + body_blocks: list[Block], + ) -> str: + """If there is indent after the colon line, + even if expressions exist in that line, + indent body should be formatted as part of the cotent: + input: balabal, # <- expression after the colon + balabal2 # <- indent body, should be formatted as part of the content + to: + input: + balabal, + balabal2, + """ + if post_colon: + assert ( + post_colon[-1].type == tokenize.NEWLINE + ), "Unexpected post_colon without a new line at the end" + colon_token = post_colon[0] + partial_line = LogicalLine([], [], post_colon[1:-1], post_colon[-1]) + post = tokens2linestrs(iter(partial_line.body)) + post[0] = post[0][colon_token.end[1] :] + else: + post = [] + if body_blocks: + (param_space,) = body_blocks + assert ( + not param_space.body_blocks + ), "Argument block should not have body blocks" + for line in param_space.head_lines: + post.extend(line.linestrs) + post.extend(tokens2linestrs(iter(param_space.tail_noncoding))) + # here is used to check the end_op + partial_line = param_space.head_lines[-1] + raw = cls.handle_end_comma("".join(post), partial_line) + formatted = format_black(raw, mode, deindent_level, partial="(") + return formatted + + @staticmethod + @abstractmethod + def handle_end_comma(raw: str, last_line: LogicalLine) -> str: ... + + +class PythonArguments(PythonArgumentsBlock): + """Parsed as *args, **kwargs + + Enhancement: accepth expressions without trailing comma, + Since each expression is already splitted by lines, + we can automatically add trailing commas to avoid syntax errors - Enhancement could be done: accepth expressions without trailing comma, - because each expression is already splitted by lines, - and we can add a trailing comma only if needed. - If we want to support expressions without trailing comma, - cases where two lines can makesense without a comma between them + Cases where two lines can makesense without a comma between them should be carefully considered, e.g.: input: @@ -869,11 +928,10 @@ class PythonArgumentsBlock(PythonBlock): (a, b) Although in our view this is naturally two expressions, the action do change with the proposed enhancement. - """ - -class PythonArguments(PythonArgumentsBlock): - """Parsed as *args, **kwargs""" + Further enhancement: support expressions without trailing comma in syntax, + but that's not eazy, especially for unnamed arguments + """ def formatted(self, mode, state): """PythonArguments and its subclasses always at the terminal @@ -889,33 +947,18 @@ def formatted(self, mode, state): formatted = format_black(raw, mode, self.deindent_level - 1, partial="(") return formatted, state - @classmethod - def format_post_colon( - cls, post_colon: list[TokenInfo], deindent_level: int, mode: Mode - ): - """If the params are in the same line as the keyword, - e.g. `input: "data.txt"`, - then self.body_blocks is empty, should take those from post_colon - """ - assert ( - post_colon and post_colon[-1].type == tokenize.NEWLINE - ), "Unexpected post_colon without a new line at the end" - colon_token = post_colon[0] - partial_line = LogicalLine([], [], post_colon[1:-1], post_colon[-1]) - post = tokens2linestrs(iter(partial_line.body)) - post[0] = post[0][colon_token.end[1] :] - raw = "".join(post) - if not partial_line.end_op == ",": + @staticmethod + def handle_end_comma(raw, last_line): + if not last_line.end_op == ",": raw += "\n," - formatted = format_black(raw, mode, deindent_level, partial="(") - return formatted + return raw class PythonUnnamedArguments(PythonArguments): """Only allow simple expressions on the right, and the whole block should be a list""" -class PythonOneLineArgument(PythonUnnamedArguments): +class PythonOneLineArgument(PythonArgumentsBlock): """Only allow simple expressions on the right""" def formatted(self, mode, state): @@ -936,28 +979,17 @@ def formatted(self, mode, state): formatted = format_black(raw, mode, self.deindent_level - 1, partial="(") return formatted, state - @classmethod - def format_post_colon( - cls, post_colon: list[TokenInfo], deindent_level: int, mode: Mode - ): - assert ( - post_colon and post_colon[-1].type == tokenize.NEWLINE - ), "Unexpected post_colon without a new line at the end" - colon_token = post_colon[0] - partial_line = LogicalLine([], [], post_colon[1:-1], post_colon[-1]) - post = tokens2linestrs(iter(partial_line.body)) - post[0] = post[0][colon_token.end[1] :] - raw = "".join(post) - if partial_line.end_op == ",": + @staticmethod + def handle_end_comma(raw, last_line): + if last_line.end_op == ",": comma_token = ( - partial_line.body[-2] - if partial_line.body[-1].type == tokenize.COMMENT - else partial_line.body[-1] + last_line.body[-2] + if last_line.body[-1].type == tokenize.COMMENT + else last_line.body[-1] ) comma_start = comma_token.start[1] - len(comma_token.line) raw = raw[:comma_start] + raw[comma_start + 1 :] - formatted = format_black(raw, mode, deindent_level, partial="(") - return formatted + return raw class SnakemakeArgumentsBlock(SnakemakeBlock): @@ -965,24 +997,46 @@ class SnakemakeArgumentsBlock(SnakemakeBlock): The content is pure python. """ - Argument = PythonArguments + Argument: type[PythonArgumentsBlock] = PythonArguments + + def consume(self, tokens): + """Even if the colon line contains params after the colon, + we still expect an optional indent body + so: if self.colon_line.end_op == ":" or True: + """ + self.consume_body(tokens) def consume_body(self, tokens): + if self.colon_line.end_op != ":": + # See if the body is indented. + # NL and COMMENT can precede the INDENT; + # anything else means no body. + peeked: list[TokenInfo] = [] + for token in tokens: + peeked.append(token) + if token.type != tokenize.NL and token.type != tokenize.COMMENT: + break + tokens.denext(*reversed(peeked)) + if peeked[-1].type != tokenize.INDENT: + return lines, tail_noncoding = tokens.next_block() - self.body_blocks.append(self.Argument(self.deindent_level + 1, tokens, lines)) - self.extend_tail_noncoding(tail_noncoding) + if lines: + self.body_blocks.append( + self.Argument(self.deindent_level + 1, tokens, lines) + ) + self.extend_tail_noncoding(tail_noncoding) + else: + assert ( + self.colon_line.end_op != ":" + ), "Empty body after colon is not allowed" def format_body(self, mode, state, post_colon) -> str: """Format body as in the function call, e.g. `input: "data.txt",` -> `input("data.txt")` """ - if post_colon: - return self.Argument.format_post_colon( - post_colon, self.deindent_level, mode - ) - else: - (param_space,) = self.body_blocks - return param_space.formatted(mode, state)[0] + return self.Argument.format_post_colon( + mode, self.deindent_level, post_colon, self.body_blocks + ) def compilation(self): raise NotImplementedError @@ -992,7 +1046,7 @@ class SnakemakeUnnamedArgumentsBlock(SnakemakeArgumentsBlock): Argument = PythonUnnamedArguments -class SnakemakeUnnamedArgumentBlock(SnakemakeUnnamedArgumentsBlock): +class SnakemakeUnnamedArgumentBlock(SnakemakeArgumentsBlock): Argument = PythonOneLineArgument @@ -1122,13 +1176,9 @@ def consume_body(self, tokens): self.extend_tail_noncoding(tail_noncoding) def format_body(self, mode, state, post_colon): - if post_colon: - return PythonOneLineArgument.format_post_colon( - post_colon, self.deindent_level, mode - ) - else: - (param_space,) = self.body_blocks - return param_space.formatted(mode, state)[0] + return PythonOneLineArgument.format_post_colon( + mode, self.deindent_level, post_colon, self.body_blocks + ) @_register() @@ -1196,8 +1246,7 @@ class Prefix(SnakemakeUnnamedArgumentBlock): ... class Replace_Prefix(SnakemakeUnnamedArgumentBlock): ... -@_register("use") -class UseRule(NamedBlock, SnakemakeKeywordBlock): +class _Rule(NamedBlock, SnakemakeKeywordBlock): subautomata, _register = init_block_register() @_register() @@ -1273,8 +1322,47 @@ class Message(SnakemakeUnnamedArgumentBlock): ... deprecated = {"version": "Use conda or container directive instead (see docs)."} +@_register("use") +class UseRule(_Rule): + def formatted(self, mode, state): + """Allow: + use rule * from other_workflow exclude ruleC as other_* + use rule * from other_workflow exclude ruleC + use rule * from other_workflow as other_* + use rule * from other_workflow + """ + assert len(self.head_lines) == 1, "use directive should only have one head line" + head_line = tokens2linestrs(iter(self.head_lines[0].body)) + assert len(head_line) == 1, "use directive should be single line" + head_bulk_line = head_line[0].split("#", 1)[0] + if ":" not in head_bulk_line: + # return quickly (also no body block here) + indent = TAB * self.deindent_level + noncoding = self.head_lines[0].head_noncoding + # TODO: format comments using black + formatted_comments = "".join( + indent + format_black(i.line.lstrip(), mode) + for i in noncoding + if i.type == tokenize.COMMENT + ) + components = head_bulk_line.strip().split() + formatted_head = formatted_comments + indent + " ".join(components) + if "#" in head_line[0]: + formatted_head += " " + format_black( + "#" + head_line[0].split("#", 1)[1], mode=mode + ).rstrip("\n") + return formatted_head + "\n", state + formatted_prior, post_colon = self.format_head(mode) + formatted_body = self.format_body(mode, state, post_colon) + formatted = [formatted_prior, formatted_body] + for comment in tokens2linestrs(iter(self.tail_noncoding)): + if comment.strip(): + formatted.append(TAB * self.deindent_level + comment.lstrip()) + return "".join(formatted), state + + @_register() -class Rule(UseRule): +class Rule(_Rule): exec_subautomata, _register = init_block_register() @_register() @@ -1300,7 +1388,7 @@ class TemplateEngine(Script): ... @_register() class CWL(Script): ... - subautomata = {**UseRule.subautomata, **exec_subautomata} + subautomata = {**_Rule.subautomata, **exec_subautomata} @_register() @@ -1315,8 +1403,14 @@ class GlobalBlock(Block): so tail_noncoding always updated to the last body_block """ + __slots__ = ("mode",) + mode: Mode + subautomata = {**python_subautomata, **global_snakemake_subautomata} + def __init__(self, deindent_level, tokens, lines=None): + super().__init__(deindent_level, tokens, lines) + def consume(self, tokens): self.body_blocks = self.consume_subblocks(tokens) @@ -1333,6 +1427,13 @@ def formatted(self, mode, state): formatted.pop() # remove the last "\n" return "".join(formatted), state_ + def get_formatted(self, mode: Mode | None = None): + if mode is None: + mode = getattr(self, "mode", None) + if mode is None: + raise ValueError("Mode should be provided for formatting") + return self.formatted(mode, FormatState())[0] + def compilation(self): raise NotImplementedError @@ -1345,3 +1446,18 @@ def parse(input: str | Callable[[], str], name: str = "") -> GlobalBlock else: tokens = tokenize.generate_tokens(input) return GlobalBlock(0, TokenIterator(name, tokens), []) + + +def setup_formatter( + snake: str, + line_length: int | None = None, + sort_params: bool = False, + black_config_file=None, +): + formatter = parse(snake) + mode = read_black_config(black_config_file) or Mode() + if line_length is not None: + mode.line_length = line_length + + formatter.mode = mode + return formatter From 12ad8cedef606b9c8c2374f88f25a888d1d67137 Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 8 Apr 2026 00:21:57 +0800 Subject: [PATCH 30/53] feat: partial match-case block --- snakefmt/blocken.py | 88 ++++++++++++++++++++++++++++++++++++------- tests/test_blocken.py | 29 ++++++++++++++ 2 files changed, 104 insertions(+), 13 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 78d1dc5..17be526 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -2,7 +2,6 @@ import tokenize from abc import ABC, abstractmethod from typing import ( - Any, Callable, Iterator, Literal, @@ -552,6 +551,55 @@ class DocumentSymbol(NamedTuple): block: "Block" +def format_python_colon_head( + raw: str, mode: Mode, keyword: str, indent_str: str = "", indent=0, partial=False +): + """Continuation keywords (else/elif/except/finally) need a preceding fake block + because black cannot parse them in isolation. + """ + if keyword == "elif" or keyword == "else": + fake_head = indent_str + "if 1: pass\n" + fake_head_lines = 2 # black always expands "if 1: pass" to 2 lines + elif keyword == "except" or keyword == "finally": + fake_head = indent_str + "try: pass\n" + fake_head_lines = 2 # black always expands "try: pass" to 2 lines + elif keyword == "match": + # match needs at least one case, add a dummy case + dummy_case = indent_str + " case _: pass\n" + formatted = format_black(raw + dummy_case, mode, indent, "") + # Keep only the match line (first line before case) + match_line = formatted.split("\n")[0] + return match_line + "\n" if match_line else raw + elif keyword == "case": + # case needs to be inside a match, construct the full block + case_content = raw.lstrip() # Remove indentation + # Create match-case block with case inside + full_block = indent_str + "match 1:\n" + indent_str + " " + case_content + formatted = format_black(full_block, mode, 0, ":") + # Extract just the case line(s) and restore original indent + lines = formatted.split("\n") + result_lines = [] + case_started = False + for line in lines: + if line.strip().startswith("case"): + case_started = True + if case_started: + if line.strip() and line.strip() != "pass": + # Remove the extra indentation added by format_black and keep just the case + result_lines.append(indent_str + line.lstrip()) + elif line.strip() == "pass": + break + if result_lines: + return "\n".join(result_lines) + "\n" + return raw + else: + return format_black(raw, mode, indent, ":" if partial else "") + formatted = format_black(fake_head + raw, mode, indent, ":") + if not fake_head: + return formatted + return formatted.split("\n", fake_head_lines)[-1] + + def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] = ""): """Format a string using Black formatter. @@ -713,9 +761,14 @@ def formatted(self, mode, state): formatted = [] if not self.body_blocks: raw = "".join(self.full_linestrs) - return format_black(raw, mode, self.deindent_level), state + head = format_python_colon_head( + raw, mode, self.keyword, self.indent_str, self.deindent_level + ) + return head, state raw_head = "".join(self.head_linestrs) - head = format_black(raw_head, mode, self.deindent_level, partial=":") + head = format_python_colon_head( + raw_head, mode, self.keyword, self.indent_str, self.deindent_level, True + ) formatted.append(head) state_ = state if isinstance( @@ -772,7 +825,7 @@ def consume_body(self, tokens): self.body_blocks = blocks def formatted(self, mode, state): - raise NotImplementedError + raise NotImplementedError("Not supported to format match-case blocks yet") def compilation(self): raise NotImplementedError @@ -1176,9 +1229,13 @@ def consume_body(self, tokens): self.extend_tail_noncoding(tail_noncoding) def format_body(self, mode, state, post_colon): - return PythonOneLineArgument.format_post_colon( - mode, self.deindent_level, post_colon, self.body_blocks - ) + if post_colon: + return PythonOneLineArgument.format_post_colon( + mode, self.deindent_level, post_colon, self.body_blocks + ) + else: + (param_space,) = self.body_blocks + return param_space.formatted(mode, state)[0] @_register() @@ -1339,12 +1396,17 @@ def formatted(self, mode, state): # return quickly (also no body block here) indent = TAB * self.deindent_level noncoding = self.head_lines[0].head_noncoding - # TODO: format comments using black - formatted_comments = "".join( - indent + format_black(i.line.lstrip(), mode) - for i in noncoding - if i.type == tokenize.COMMENT - ) + if noncoding: + raw_noncoding = "".join(tokens2linestrs(iter(noncoding))) + # `1` make sure all comments dedent to no prefix, then we can remove it + foramtted_noindent = format_black(raw_noncoding + "1", mode).split( + "\n" + )[:-2] + formatted_comments = "".join( + indent + i + "\n" if i else "\n" for i in foramtted_noindent + ) + else: + formatted_comments = "" components = head_bulk_line.strip().split() formatted_head = formatted_comments + indent + " ".join(components) if "#" in head_line[0]: diff --git a/tests/test_blocken.py b/tests/test_blocken.py index c2c89d3..46cf6b9 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -9,6 +9,7 @@ consume_fstring, TokenIterator, format_black, + format_python_colon_head, tokenize, is_fstring_start, UnsupportedSyntax, @@ -375,6 +376,34 @@ def test_format_paren(self): fmted = format_black(raw, mode=mode, indent=1, partial="(") assert fmted == (f'{TAB * 2}"b = 2",\n') + def test_format_partial_colon(self): + for i in ( + "if cond:\n", + "else:\n", + "elif x > 0:\n", + "except ValueError:\n", + "finally:\n", + "match val:\n", + ): + fmted = format_python_colon_head( + i, mode, i.strip().split()[0].replace(":", "") + ) + assert fmted == i + for i in ( + f"{TAB}else:\n", + f"{TAB}elif x > 0:\n", + f"{TAB}except (ValueError, KeyError):\n", + f"{TAB}finally:\n", + f"{TAB}case Point(x, 0):\n", + ): + fmted = format_python_colon_head( + i, mode, i.strip().split()[0].replace(":", ""), indent_str=TAB, indent=1 + ) + assert fmted == i + i = " elif (\n x > 0\n ):\n" + fmted = format_python_colon_head(i, mode, "elif", indent_str=TAB, indent=1) + assert fmted == " elif x > 0:\n" + class TestBlockFormat: From 4a6ae69e4497e01355797f5b8e257028d534e246 Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 8 Apr 2026 02:02:31 +0800 Subject: [PATCH 31/53] fix: simplify --- snakefmt/blocken.py | 33 ++++++++------------------------- tests/test_blocken.py | 28 ++++++++++++++++++++++++---- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 17be526..7d9e0ae 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -564,34 +564,17 @@ def format_python_colon_head( fake_head = indent_str + "try: pass\n" fake_head_lines = 2 # black always expands "try: pass" to 2 lines elif keyword == "match": - # match needs at least one case, add a dummy case + # match needs at least one case dummy_case = indent_str + " case _: pass\n" formatted = format_black(raw + dummy_case, mode, indent, "") - # Keep only the match line (first line before case) - match_line = formatted.split("\n")[0] - return match_line + "\n" if match_line else raw + # Keep only the match line + return formatted.rsplit("\n", 3)[0] + "\n" elif keyword == "case": # case needs to be inside a match, construct the full block - case_content = raw.lstrip() # Remove indentation - # Create match-case block with case inside - full_block = indent_str + "match 1:\n" + indent_str + " " + case_content - formatted = format_black(full_block, mode, 0, ":") - # Extract just the case line(s) and restore original indent - lines = formatted.split("\n") - result_lines = [] - case_started = False - for line in lines: - if line.strip().startswith("case"): - case_started = True - if case_started: - if line.strip() and line.strip() != "pass": - # Remove the extra indentation added by format_black and keep just the case - result_lines.append(indent_str + line.lstrip()) - elif line.strip() == "pass": - break - if result_lines: - return "\n".join(result_lines) + "\n" - return raw + assert indent_str and indent, "`case` block must be indented" + dummy_match = indent_str[:-1] + "match 1:\n" + raw + formatted = format_black(dummy_match, mode, indent - 1, ":") + return formatted.split("\n", 1)[-1] else: return format_black(raw, mode, indent, ":" if partial else "") formatted = format_black(fake_head + raw, mode, indent, ":") @@ -620,7 +603,7 @@ def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] """ prefix = "" for i in range(indent): - prefix += " " * i + "if 1:\n" + prefix += " " * i + "def a():\n" if partial == ":": # for block such as if/else/... safe_indent = max(extract_line_indent(line) for line in raw.splitlines()) diff --git a/tests/test_blocken.py b/tests/test_blocken.py index 46cf6b9..c7aace4 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -365,6 +365,11 @@ def test_format_colon(self): fmted = format_black(raw, mode=mode, partial=":") assert fmted == "if 1: # comment\n" + def test_format_def(self): + raw = f"{TAB}def s(a):\n" f"{TAB*2}if a:\n" f'{TAB* 3}return "Hello World"\n' + fmted = format_black(raw, mode=mode, indent=1) + assert fmted == raw + def test_format_paren(self): raw = " 'b', a=1\n," fmted = format_black(raw, mode=mode, indent=2, partial="(") @@ -386,23 +391,38 @@ def test_format_partial_colon(self): "match val:\n", ): fmted = format_python_colon_head( - i, mode, i.strip().split()[0].replace(":", "") + i, mode, i.strip().split()[0].replace(":", ""), partial=True ) assert fmted == i + + def test_format_partial_colon_indent(self): for i in ( f"{TAB}else:\n", f"{TAB}elif x > 0:\n", f"{TAB}except (ValueError, KeyError):\n", f"{TAB}finally:\n", + f"{TAB}match val:\n", f"{TAB}case Point(x, 0):\n", ): fmted = format_python_colon_head( - i, mode, i.strip().split()[0].replace(":", ""), indent_str=TAB, indent=1 + i, + mode, + i.strip().split()[0].replace(":", ""), + indent_str=TAB, + indent=1, + partial=True, ) assert fmted == i - i = " elif (\n x > 0\n ):\n" - fmted = format_python_colon_head(i, mode, "elif", indent_str=TAB, indent=1) + i = f"{TAB}elif (\n x > 0\n ):\n" + fmted = format_python_colon_head( + i, mode, "elif", indent_str=TAB, indent=1, partial=True + ) assert fmted == " elif x > 0:\n" + i = f"{TAB*2}case Point(x, 0):\n" + fmted = format_python_colon_head( + i, mode, "case", indent_str=TAB * 2, indent=2, partial=True + ) + assert fmted == i class TestBlockFormat: From c29abf99ba90a58cf3edb62af4b8ae13c62e189f Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 8 Apr 2026 02:38:50 +0800 Subject: [PATCH 32/53] fix: blank lines --- snakefmt/blocken.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 7d9e0ae..7b7ad7b 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -754,16 +754,23 @@ def formatted(self, mode, state): ) formatted.append(head) state_ = state - if isinstance( - self.body_blocks[0], - (NoSnakemakeBlock, NamedBlock, SnakemakeExecutableBlock), - ): - formatted.append("\n") - for block in self.body_blocks: + prev_was_major = False + for i, block in enumerate(self.body_blocks): + this_is_major = isinstance( + block, (NoSnakemakeBlock, NamedBlock, SnakemakeExecutableBlock) + ) + if ( + this_is_major + or prev_was_major + or (i == 0 and isinstance(block, SnakemakeBlock)) + ): + # Always enforce blank line before/after + # def/class/rule/onstart/etc. blocks + # Always add blank line before the first inline snakemake block + formatted.append("\n") + prev_was_major = this_is_major block_formatted, state_ = block.formatted(mode, state_) formatted.append(block_formatted) - formatted.append("\n") - formatted.pop() # remove the last "\n" for comment in tokens2linestrs(iter(self.tail_noncoding)): if comment.strip(): formatted.append(TAB * self.deindent_level + comment.lstrip()) @@ -1096,7 +1103,8 @@ def formatted(self, mode, state): formatted_body = self.format_body(mode, state, post_colon) formatted = [formatted_prior, formatted_body] if formatted_body.count("\n") == 1 and formatted_body.endswith("\n"): - if formatted_prior.endswith(":\n") and "#" not in formatted_prior: + last_head_line = formatted_prior.rsplit("\n", 2)[-2] + if formatted_prior.endswith(":\n") and "#" not in last_head_line: formatted_merge = formatted_prior[:-1] + " " + formatted_body.lstrip() if len(formatted_merge) <= mode.line_length: formatted = [formatted_merge] From e6b74fae7bd58563199e2416e7a8a7a5f9eeaa9b Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 8 Apr 2026 09:29:40 +0800 Subject: [PATCH 33/53] fix: _continuation_kws --- snakefmt/blocken.py | 180 +++++++++++++++++++++++--------------------- 1 file changed, 96 insertions(+), 84 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 7b7ad7b..6eb61eb 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -319,6 +319,87 @@ def update(self, *str): return self._replace() +def format_python_colon_head( + raw: str, mode: Mode, keyword: str, indent_str: str = "", indent=0, partial=False +): + """Continuation keywords (else/elif/except/finally) need a preceding fake block + because black cannot parse them in isolation. + """ + if keyword == "elif" or keyword == "else": + fake_head = indent_str + "if 1: pass\n" + fake_head_lines = 2 # black always expands "if 1: pass" to 2 lines + elif keyword == "except" or keyword == "finally": + fake_head = indent_str + "try: pass\n" + fake_head_lines = 2 # black always expands "try: pass" to 2 lines + elif keyword == "match": + # match needs at least one case + dummy_case = indent_str + " case _: pass\n" + formatted = format_black(raw + dummy_case, mode, indent, "") + # Keep only the match line + return formatted.rsplit("\n", 3)[0] + "\n" + elif keyword == "case": + # case needs to be inside a match, construct the full block + assert indent_str and indent, "`case` block must be indented" + dummy_match = indent_str[:-1] + "match 1:\n" + raw + formatted = format_black(dummy_match, mode, indent - 1, ":") + return formatted.split("\n", 1)[-1] + else: + return format_black(raw, mode, indent, ":" if partial else "") + formatted = format_black(fake_head + raw, mode, indent, ":") + if not fake_head: + return formatted + return formatted.split("\n", fake_head_lines)[-1] + + +def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] = ""): + """Format a string using Black formatter. + + if indent: + prefix = make series of `{' ' * i}if 1:\\n` to increase indent level + format(prefix + string) + remove first `indent` lines + if partial == ":": + safe_indent = longest(prefix spacing) + format(string + f"\\n{safe_indent} pass") + remove the last line + if partial == "(": + format("f(" + string + ")") + if string.startswith("f(\\n"): + remove the first line and the last line + else: + remove first three characters and the last character + """ + prefix = "" + for i in range(indent): + prefix += " " * i + "def a():\n" + if partial == ":": + # for block such as if/else/... + safe_indent = max(extract_line_indent(line) for line in raw.splitlines()) + string = raw + f"{safe_indent} pass" + elif partial == "(": + # Tb() effects equals to a entire new indent + string = " " * indent + "Tb(\n" + raw + "\n)" + else: + string = raw + try: + fmted = black.format_str(prefix + string, mode=mode) + except black.parsing.InvalidInput as e: + raise e + if indent: + fix = fmted.split("\n", indent)[-1] + else: + fix = fmted + if partial == ":": + fix = fix.rstrip().rsplit("\n", 1)[0] + "\n" + elif partial == "(": + fix = fix.strip() + if fix.startswith("Tb(\n"): + fix = fix.split("\n", 1)[1].rsplit("\n", 1)[0] + "\n" + else: + fix = TAB * (indent + 1) + fix[3:-1] + "\n" + return fix + + class Block(ABC): """ A block can be: @@ -551,87 +632,6 @@ class DocumentSymbol(NamedTuple): block: "Block" -def format_python_colon_head( - raw: str, mode: Mode, keyword: str, indent_str: str = "", indent=0, partial=False -): - """Continuation keywords (else/elif/except/finally) need a preceding fake block - because black cannot parse them in isolation. - """ - if keyword == "elif" or keyword == "else": - fake_head = indent_str + "if 1: pass\n" - fake_head_lines = 2 # black always expands "if 1: pass" to 2 lines - elif keyword == "except" or keyword == "finally": - fake_head = indent_str + "try: pass\n" - fake_head_lines = 2 # black always expands "try: pass" to 2 lines - elif keyword == "match": - # match needs at least one case - dummy_case = indent_str + " case _: pass\n" - formatted = format_black(raw + dummy_case, mode, indent, "") - # Keep only the match line - return formatted.rsplit("\n", 3)[0] + "\n" - elif keyword == "case": - # case needs to be inside a match, construct the full block - assert indent_str and indent, "`case` block must be indented" - dummy_match = indent_str[:-1] + "match 1:\n" + raw - formatted = format_black(dummy_match, mode, indent - 1, ":") - return formatted.split("\n", 1)[-1] - else: - return format_black(raw, mode, indent, ":" if partial else "") - formatted = format_black(fake_head + raw, mode, indent, ":") - if not fake_head: - return formatted - return formatted.split("\n", fake_head_lines)[-1] - - -def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] = ""): - """Format a string using Black formatter. - - if indent: - prefix = make series of `{' ' * i}if 1:\\n` to increase indent level - format(prefix + string) - remove first `indent` lines - if partial == ":": - safe_indent = longest(prefix spacing) - format(string + f"\\n{safe_indent} pass") - remove the last line - if partial == "(": - format("f(" + string + ")") - if string.startswith("f(\\n"): - remove the first line and the last line - else: - remove first three characters and the last character - """ - prefix = "" - for i in range(indent): - prefix += " " * i + "def a():\n" - if partial == ":": - # for block such as if/else/... - safe_indent = max(extract_line_indent(line) for line in raw.splitlines()) - string = raw + f"{safe_indent} pass" - elif partial == "(": - # Tb() effects equals to a entire new indent - string = " " * indent + "Tb(\n" + raw + "\n)" - else: - string = raw - try: - fmted = black.format_str(prefix + string, mode=mode) - except black.parsing.InvalidInput as e: - raise e - if indent: - fix = fmted.split("\n", indent)[-1] - else: - fix = fmted - if partial == ":": - fix = fix.rstrip().rsplit("\n", 1)[0] + "\n" - elif partial == "(": - fix = fix.strip() - if fix.startswith("Tb(\n"): - fix = fix.split("\n", 1)[1].rsplit("\n", 1)[0] + "\n" - else: - fix = TAB * (indent + 1) + fix[3:-1] + "\n" - return fix - - class PythonBlock(Block): """Hold `head_lines` and `tail_noncoding`, no `body_blocks`""" @@ -1471,13 +1471,25 @@ def formatted(self, mode, state): formatted = [] state_ = state linesep = "\n" if self.deindent_level else "\n\n" - for block in self.body_blocks: + # TODO: better handling of blank lines between blocks + _continuation_kws = {"elif", "else", "except", "finally"} + blocks = self.body_blocks + for i, block in enumerate(blocks): block_formatted, state_ = block.formatted(mode, state_) if block_formatted: # avoid adding extra blank lines for empty blocks formatted.append(block_formatted) - formatted.append(linesep) + # continuation keywords (else/elif/except/finally) must not be + # separated from the preceding block by a full blank line + next_block = blocks[i + 1] if i + 1 < len(blocks) else None + if ( + isinstance(next_block, IfForTryWithBlock) + and next_block.keyword in _continuation_kws + ): + formatted.append("\n") + else: + formatted.append(linesep) if formatted: - formatted.pop() # remove the last "\n" + formatted.pop() # remove the last separator return "".join(formatted), state_ def get_formatted(self, mode: Mode | None = None): From a2323e9ab9f29150bc4a52b0354f61439772967b Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 8 Apr 2026 10:36:49 +0800 Subject: [PATCH 34/53] fix: oneline function --- snakefmt/blocken.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 6eb61eb..0f9abba 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -396,7 +396,13 @@ def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] if fix.startswith("Tb(\n"): fix = fix.split("\n", 1)[1].rsplit("\n", 1)[0] + "\n" else: - fix = TAB * (indent + 1) + fix[3:-1] + "\n" + if not "#" in fix: # safe to unpack function + fix = TAB * (indent + 1) + fix[3:-1] + "\n" + else: + fix = ( + format_black(raw + "\n#", mode, indent, partial).rsplit("\n", 2)[0] + + "\n" + ) return fix @@ -1103,11 +1109,15 @@ def formatted(self, mode, state): formatted_body = self.format_body(mode, state, post_colon) formatted = [formatted_prior, formatted_body] if formatted_body.count("\n") == 1 and formatted_body.endswith("\n"): - last_head_line = formatted_prior.rsplit("\n", 2)[-2] + if formatted_prior.count("\n") > 1: + prev, last_head_line = formatted_prior[:-1].rsplit("\n", 1) + prev += "\n" + else: + prev, last_head_line = "", formatted_prior[:-1] if formatted_prior.endswith(":\n") and "#" not in last_head_line: - formatted_merge = formatted_prior[:-1] + " " + formatted_body.lstrip() + formatted_merge = last_head_line + " " + formatted_body.lstrip() if len(formatted_merge) <= mode.line_length: - formatted = [formatted_merge] + formatted = [prev + formatted_merge] for comment in tokens2linestrs(iter(self.tail_noncoding)): if comment.strip(): formatted.append(TAB * self.deindent_level + comment.lstrip()) @@ -1485,7 +1495,19 @@ def formatted(self, mode, state): isinstance(next_block, IfForTryWithBlock) and next_block.keyword in _continuation_kws ): - formatted.append("\n") + formatted.append("\n") # continuation: elif/else/except/finally + elif isinstance(block, PythonBlock) and isinstance( + next_block, IfForTryWithBlock + ): + formatted.append("") # Python lead-in: no extra blank line + elif ( + isinstance(block, SnakemakeBlock) + and isinstance(next_block, SnakemakeBlock) + and not isinstance( + next_block, (NamedBlock, SnakemakeExecutableBlock) + ) + ): + formatted.append("") # Python lead-in: no extra blank line else: formatted.append(linesep) if formatted: From b0a7971bec6119c2f2ff445079d21d8ee35dfa43 Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 8 Apr 2026 18:23:37 +0800 Subject: [PATCH 35/53] fix: serious problems --- snakefmt/blocken.py | 220 +++++++++++++++++++------------------------- 1 file changed, 95 insertions(+), 125 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 0f9abba..266eac2 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -620,9 +620,15 @@ def components(self) -> "Iterator[DocumentSymbol]": for block in self.body_blocks: yield from block.components() - @abstractmethod - def formatted(self, mode: Mode, state: FormatState) -> tuple[str, FormatState]: - """return formatted code of the block""" + def segment2format(self, mode: Mode, state: FormatState): + """yield: + - [unformated_python_code, Literal[False]] + - [formated_snakemake_code, Literal[True]] + """ + yield "".join(self.head_linestrs), False + for block in self.body_blocks: + yield from block.segment2format(mode, state) + yield "".join(tokens2linestrs(iter(self.tail_noncoding))), False @abstractmethod def compilation(self): @@ -644,7 +650,7 @@ class PythonBlock(Block): def consume(self, tokens): "Do nothing, win" - def formatted(self, mode, state): + def formatted(self, mode, state) -> tuple[str, FormatState]: raw = "".join(self.full_linestrs) if not raw.strip(): return "", state @@ -727,11 +733,6 @@ def consume_body(self, tokens): self.body_blocks.append(PythonBlock(self.deindent_level + 1, tokens, lines)) self.extend_tail_noncoding(tail_noncoding) - def formatted(self, mode, state): - raw = "".join(self.full_linestrs) - formatted = format_black(raw, mode, self.deindent_level) - return formatted, state - def compilation(self): raise NotImplementedError @@ -746,42 +747,6 @@ def consume_body(self, tokens): blocks = GlobalBlock(self.deindent_level + 1, tokens, []).body_blocks self.body_blocks.extend(blocks) - def formatted(self, mode, state): - formatted = [] - if not self.body_blocks: - raw = "".join(self.full_linestrs) - head = format_python_colon_head( - raw, mode, self.keyword, self.indent_str, self.deindent_level - ) - return head, state - raw_head = "".join(self.head_linestrs) - head = format_python_colon_head( - raw_head, mode, self.keyword, self.indent_str, self.deindent_level, True - ) - formatted.append(head) - state_ = state - prev_was_major = False - for i, block in enumerate(self.body_blocks): - this_is_major = isinstance( - block, (NoSnakemakeBlock, NamedBlock, SnakemakeExecutableBlock) - ) - if ( - this_is_major - or prev_was_major - or (i == 0 and isinstance(block, SnakemakeBlock)) - ): - # Always enforce blank line before/after - # def/class/rule/onstart/etc. blocks - # Always add blank line before the first inline snakemake block - formatted.append("\n") - prev_was_major = this_is_major - block_formatted, state_ = block.formatted(mode, state_) - formatted.append(block_formatted) - for comment in tokens2linestrs(iter(self.tail_noncoding)): - if comment.strip(): - formatted.append(TAB * self.deindent_level + comment.lstrip()) - return "".join(formatted), state_ - def compilation(self): raise NotImplementedError @@ -861,28 +826,33 @@ class SnakemakeBlock(ColonBlock): def components(self) -> Iterator[DocumentSymbol]: yield from [] + def segment2format(self, mode, state): + """yield: + - [unformated_python_code, Literal[False]] + - [formated_snakemake_code, Literal[True]] + """ + head_noncding, body = self.formatted(mode, state) + yield "".join(head_noncding), False + yield body, True + yield "".join(tokens2linestrs(iter(self.tail_noncoding))), False + def formatted(self, mode, state): - formatted_prior, post_colon = self.format_head(mode) + noncoding_lines, formatted_prior, post_colon = self.format_head(mode) formatted_body = self.format_body(mode, state, post_colon) formatted = [formatted_prior, formatted_body] - for comment in tokens2linestrs(iter(self.tail_noncoding)): - if comment.strip(): - formatted.append(TAB * self.deindent_level + comment.lstrip()) - return "".join(formatted), state + return noncoding_lines, "".join(formatted) - def format_head(self, mode: Mode) -> tuple[str, list[TokenInfo]]: + def format_head(self, mode: Mode) -> tuple[list[str], str, list[TokenInfo]]: assert ( len(self.head_lines) == 1 ), "Snakemake keywords should only have one head line" indent = TAB * self.deindent_level noncoding, prior_colon, post_colon = self.split_colon_line() - formatted_comments = "".join( - indent + i.line.lstrip() for i in noncoding if i.type == tokenize.COMMENT - ) + noncoding_lines = tokens2linestrs(iter(noncoding)) assert len(prior_colon) == 1, "Snakemake keywords should be single line" (head,) = prior_colon components = head.strip().split() - formatted_head = formatted_comments + indent + " ".join(components) + ":" + formatted_head = indent + " ".join(components) + ":" if self.colon_line.end_op == ":": # only a single line comment or empty is possible here, add directly colon_token = next(post_colon) @@ -891,9 +861,9 @@ def format_head(self, mode: Mode) -> tuple[str, list[TokenInfo]]: fake_str = f"if 1:" + "".join(post) + " ..." fake_fmt = format_black(fake_str, mode).strip() formatted_head += fake_fmt.split(":", 1)[1].rsplit("\n", 1)[0] + "\n" - return formatted_head, [] + return noncoding_lines, formatted_head, [] else: - return formatted_head + "\n", list(post_colon.rest) + return noncoding_lines, formatted_head + "\n", list(post_colon.rest) @abstractmethod def format_body(self, mode, state, post_colon: list[TokenInfo]) -> str: ... @@ -929,6 +899,9 @@ def format_post_colon( input: balabal, balabal2, + + Morover, the original snakefmt allow sort positional arguments before keyword arguments. + Here need check, too """ if post_colon: assert ( @@ -1105,7 +1078,7 @@ def formatted(self, mode, state): """Try to merge the inline argument into the head line. If the line is too long after merging, then keep them separate. """ - formatted_prior, post_colon = self.format_head(mode) + noncoding_lines, formatted_prior, post_colon = self.format_head(mode) formatted_body = self.format_body(mode, state, post_colon) formatted = [formatted_prior, formatted_body] if formatted_body.count("\n") == 1 and formatted_body.endswith("\n"): @@ -1118,10 +1091,7 @@ def formatted(self, mode, state): formatted_merge = last_head_line + " " + formatted_body.lstrip() if len(formatted_merge) <= mode.line_length: formatted = [prev + formatted_merge] - for comment in tokens2linestrs(iter(self.tail_noncoding)): - if comment.strip(): - formatted.append(TAB * self.deindent_level + comment.lstrip()) - return "".join(formatted), state + return noncoding_lines, "".join(formatted) def init_block_register(): @@ -1236,6 +1206,7 @@ def format_body(self, mode, state, post_colon): ) else: (param_space,) = self.body_blocks + assert isinstance(param_space, PythonBlock), "Unexpected body block type" return param_space.formatted(mode, state)[0] @@ -1258,7 +1229,8 @@ class SnakemakeKeywordBlock(SnakemakeBlock): def consume_body(self, tokens): blocks = self.consume_subblocks(tokens, ender_subblock=True) - if any(not isinstance(i, SnakemakeBlock) for i in blocks): + if any(not isinstance(i, SnakemakeBlock) for i in blocks[1:]): + breakpoint() raise UnsupportedSyntax( f"Unexpected content in {self.keyword} block: " f"only snakemake blocks are allowed, but got {blocks}" @@ -1267,10 +1239,25 @@ def consume_body(self, tokens): def format_body(self, mode, state, post_colon): assert not post_colon, "Invalid inline contents" - formatted = [] - for block in self.body_blocks: - block_formatted, state_ = block.formatted(mode, state) - formatted.append(block_formatted) + formatted: list[str] = [] + tail_noncoding: list[str] = [] + indent = TAB * (self.deindent_level + 1) + for i, block in enumerate(self.body_blocks): + if tail_noncoding: + tail_noncoding = [i.lstrip().rstrip("\n") for i in tail_noncoding] + formatted.extend(f"{indent}{i}\n" for i in tail_noncoding if i) + tail_noncoding = [] + if i == 0 and isinstance(block, PythonBlock): + body, state = block.formatted(mode, state) + else: + assert isinstance( + block, SnakemakeBlock + ), "Unexpected block type in snakemake keyword block" + noncoding, body = block.formatted(mode, state) + formatted.extend(i for i in noncoding if i.strip()) + formatted.append(body) + if block.tail_noncoding: + tail_noncoding = tokens2linestrs(iter(block.tail_noncoding)) # no `\n` between return "".join(formatted) @@ -1326,7 +1313,7 @@ class Log(SnakemakeArgumentsBlock): ... class Benchmark(SnakemakeUnnamedArgumentBlock): ... @_register() - class RulePathvars(SnakemakeArgumentsBlock): ... + class Pathvars(SnakemakeArgumentsBlock): ... @_register("wildcard_constraints") class Register_Wildcard_Constraints(SnakemakeArgumentsBlock): ... @@ -1396,32 +1383,18 @@ def formatted(self, mode, state): if ":" not in head_bulk_line: # return quickly (also no body block here) indent = TAB * self.deindent_level - noncoding = self.head_lines[0].head_noncoding - if noncoding: - raw_noncoding = "".join(tokens2linestrs(iter(noncoding))) - # `1` make sure all comments dedent to no prefix, then we can remove it - foramtted_noindent = format_black(raw_noncoding + "1", mode).split( - "\n" - )[:-2] - formatted_comments = "".join( - indent + i + "\n" if i else "\n" for i in foramtted_noindent - ) - else: - formatted_comments = "" + noncoding_lines = tokens2linestrs(iter(self.head_lines[0].head_noncoding)) components = head_bulk_line.strip().split() - formatted_head = formatted_comments + indent + " ".join(components) + formatted_head = indent + " ".join(components) if "#" in head_line[0]: formatted_head += " " + format_black( "#" + head_line[0].split("#", 1)[1], mode=mode ).rstrip("\n") - return formatted_head + "\n", state - formatted_prior, post_colon = self.format_head(mode) + return noncoding_lines, formatted_head + "\n" + noncoding_lines, formatted_prior, post_colon = self.format_head(mode) formatted_body = self.format_body(mode, state, post_colon) formatted = [formatted_prior, formatted_body] - for comment in tokens2linestrs(iter(self.tail_noncoding)): - if comment.strip(): - formatted.append(TAB * self.deindent_level + comment.lstrip()) - return "".join(formatted), state + return noncoding_lines, "".join(formatted) @_register() @@ -1477,49 +1450,46 @@ def __init__(self, deindent_level, tokens, lines=None): def consume(self, tokens): self.body_blocks = self.consume_subblocks(tokens) - def formatted(self, mode, state): - formatted = [] - state_ = state - linesep = "\n" if self.deindent_level else "\n\n" - # TODO: better handling of blank lines between blocks - _continuation_kws = {"elif", "else", "except", "finally"} - blocks = self.body_blocks - for i, block in enumerate(blocks): - block_formatted, state_ = block.formatted(mode, state_) - if block_formatted: # avoid adding extra blank lines for empty blocks - formatted.append(block_formatted) - # continuation keywords (else/elif/except/finally) must not be - # separated from the preceding block by a full blank line - next_block = blocks[i + 1] if i + 1 < len(blocks) else None - if ( - isinstance(next_block, IfForTryWithBlock) - and next_block.keyword in _continuation_kws - ): - formatted.append("\n") # continuation: elif/else/except/finally - elif isinstance(block, PythonBlock) and isinstance( - next_block, IfForTryWithBlock - ): - formatted.append("") # Python lead-in: no extra blank line - elif ( - isinstance(block, SnakemakeBlock) - and isinstance(next_block, SnakemakeBlock) - and not isinstance( - next_block, (NamedBlock, SnakemakeExecutableBlock) - ) - ): - formatted.append("") # Python lead-in: no extra blank line - else: - formatted.append(linesep) - if formatted: - formatted.pop() # remove the last separator - return "".join(formatted), state_ - def get_formatted(self, mode: Mode | None = None): if mode is None: mode = getattr(self, "mode", None) if mode is None: raise ValueError("Mode should be provided for formatting") - return self.formatted(mode, FormatState())[0] + python_codes: list[str] = [] + snakemake_codes: list[str] = [] + last_str = "" + for str, is_snake in self.segment2format(mode or self.mode, FormatState()): + if is_snake: + python_codes.append(last_str) + last_str = "" + snakemake_codes.append(str) + else: + last_str += str + place_hode_str = "o" * 50 + raw_str = "".join(python_codes) + while place_hode_str in raw_str: + place_hode_str *= 2 + raw_str = "#\n" + for python_code, snakemake_code in zip(python_codes, snakemake_codes): + if snakemake_code.count("\n") == 1: # must at the end of line + indent_str = extract_line_indent(snakemake_code) + place_hode = f"{indent_str}def l{place_hode_str}1ng(): ...\n" + else: + indent_str = extract_line_indent(snakemake_code) + place_hode = ( + f"{indent_str}def l{place_hode_str}ng():\n{indent_str} return\n" + ) + raw_str += python_code + place_hode + raw_str += last_str + formatted, *formatted_split = format_black(raw_str, mode).split(place_hode_str) + final_str = formatted + for formatted, snakemake_code in zip(formatted_split, snakemake_codes): + final_str = final_str.rsplit("\n", 1)[0] + "\n" + snakemake_code + if formatted.startswith("1"): + final_str += formatted.split("\n", 1)[-1] + else: + final_str += formatted.split("\n", 2)[-1] + return final_str[1:].lstrip("\n") def compilation(self): raise NotImplementedError From f0caba80236a6bd610d1a6dcc99ae4553c6a0713 Mon Sep 17 00:00:00 2001 From: hwrn Date: Wed, 8 Apr 2026 21:09:10 +0800 Subject: [PATCH 36/53] fix: extend to correct block --- snakefmt/blocken.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 266eac2..cd73645 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -730,8 +730,9 @@ class NoSnakemakeBlock(ColonBlock): def consume_body(self, tokens): lines, tail_noncoding = tokens.next_block() - self.body_blocks.append(PythonBlock(self.deindent_level + 1, tokens, lines)) - self.extend_tail_noncoding(tail_noncoding) + codes = PythonBlock(self.deindent_level + 1, tokens, lines) + codes.extend_tail_noncoding(tail_noncoding) + self.body_blocks.append(codes) def compilation(self): raise NotImplementedError @@ -1043,10 +1044,9 @@ def consume_body(self, tokens): return lines, tail_noncoding = tokens.next_block() if lines: - self.body_blocks.append( - self.Argument(self.deindent_level + 1, tokens, lines) - ) - self.extend_tail_noncoding(tail_noncoding) + args = self.Argument(self.deindent_level + 1, tokens, lines) + args.extend_tail_noncoding(tail_noncoding) + self.body_blocks.append(args) else: assert ( self.colon_line.end_op != ":" @@ -1196,8 +1196,9 @@ class SnakemakeExecutableBlock(SnakemakeBlock): def consume_body(self, tokens): lines, tail_noncoding = tokens.next_block() - self.body_blocks.append(PythonBlock(self.deindent_level + 1, tokens, lines)) - self.extend_tail_noncoding(tail_noncoding) + executable = PythonBlock(self.deindent_level + 1, tokens, lines) + executable.extend_tail_noncoding(tail_noncoding) + self.body_blocks.append(executable) def format_body(self, mode, state, post_colon): if post_colon: @@ -1230,7 +1231,6 @@ class SnakemakeKeywordBlock(SnakemakeBlock): def consume_body(self, tokens): blocks = self.consume_subblocks(tokens, ender_subblock=True) if any(not isinstance(i, SnakemakeBlock) for i in blocks[1:]): - breakpoint() raise UnsupportedSyntax( f"Unexpected content in {self.keyword} block: " f"only snakemake blocks are allowed, but got {blocks}" From c46eadc77767f0fdc5ff0f6eca4f9bf5606569cb Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 9 Apr 2026 00:49:46 +0800 Subject: [PATCH 37/53] fix: handle keywords --- snakefmt/blocken.py | 131 +++++++++++++++++++++++++++++++++++++----- tests/test_blocken.py | 30 +++++----- 2 files changed, 135 insertions(+), 26 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index cd73645..8b013c3 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -208,6 +208,20 @@ def end_op(self): return None return last_token.string + @property + def is_keyword_line(self): + if len(self.body) < 2: + return False + if ( + self.body[0].type == tokenize.NAME + and self.body[1].type == tokenize.OP + and self.body[1].string == "=" + ): + return True + if self.body[0].type == "**": + return True + return False + @property def deindelta(self): if not self.deindents: @@ -873,6 +887,29 @@ def compilation(self): raise NotImplementedError +def try_combine_format( + arg_lines: list[str], mode: Mode | None = None +) -> list[list[str]] | None: + """Try to combine multiple param lines without comma inside + Search reversly, so it only give one of the possible results. + + Since the non-comma param is the mistake of the user, + please do not blame if the olgorithm is slow :) + """ + if len(arg_lines) <= 1: + return [arg_lines] + mode = mode or Mode() + for i in range(len(arg_lines) - 1, 0, -1): + try: + combine = format_black("\n".join(arg_lines[:i]) + "\n,", mode) + except black.parsing.InvalidInput: + continue + rest = try_combine_format(arg_lines[i:], mode) + if rest is not None: + return [[combine]] + rest + return None + + class PythonArgumentsBlock(PythonBlock): """Block inside snakemake directives, such as `data.txt` in `input: \n "data.txt"` @@ -903,29 +940,97 @@ def format_post_colon( Morover, the original snakefmt allow sort positional arguments before keyword arguments. Here need check, too + + Input: + post_colon: tokens after the colon in the head line, e.g. `balabal,` in the above example + post_colon[0] := TokenInfo(type=NAME, string='balabal', ...) + body_blocks: indent body blocks, e.g. the block of `balabal2` in the above example """ + assert post_colon or body_blocks, "should have something in the comment" + args: dict[bool, list[list[str]]] = {True: [], False: []} if post_colon: assert ( post_colon[-1].type == tokenize.NEWLINE ), "Unexpected post_colon without a new line at the end" - colon_token = post_colon[0] - partial_line = LogicalLine([], [], post_colon[1:-1], post_colon[-1]) - post = tokens2linestrs(iter(partial_line.body)) - post[0] = post[0][colon_token.end[1] :] + partial_line = LogicalLine([], [], post_colon[:-1], post_colon[-1]) + may_incomplete_param = tokens2linestrs(iter(partial_line.body)) + may_incomplete_param[0] = may_incomplete_param[0][post_colon[0].end[1] :] + this_is_keyword = partial_line.is_keyword_line + if partial_line.end_op == ",": + args[this_is_keyword].append(may_incomplete_param) + may_incomplete_param = [] else: - post = [] + may_incomplete_param = [] + + def _find_split_and_push(): + nonlocal partial_line, may_incomplete_param + try_combined = try_combine_format(may_incomplete_param, mode) + if try_combined: + args[this_is_keyword].append(try_combined[0]) + args[False].extend(try_combined[1:]) + tokens = tokenize.generate_tokens(iter(try_combined[0]).__next__) + _line = TokenIterator("", tokens).next_new_line() + else: + # TODO: raise error here + args[this_is_keyword].append(may_incomplete_param) + _line = line + may_incomplete_param = [] + if this_is_keyword: + partial_line = _line + if body_blocks: (param_space,) = body_blocks - assert ( - not param_space.body_blocks - ), "Argument block should not have body blocks" + assert not param_space.body_blocks, "Argument block have no body blocks" for line in param_space.head_lines: - post.extend(line.linestrs) - post.extend(tokens2linestrs(iter(param_space.tail_noncoding))) + if not line.is_keyword_line: + # without keyword, the line is appandable + if not may_incomplete_param: + this_is_keyword = False + elif line.body[0].type in (tokenize.NAME, tokenize.NUMBER): + # Since the previous line is 'logical complete', + # if the line start with a simple name or number, + # it is impossible to be the continuation of the previous line + may_incomplete_param[-1] += "\n," + _find_split_and_push() + this_is_keyword = False + may_incomplete_param.append("".join(line.linestrs)) + if line.end_op == ",": + _find_split_and_push() + else: + if may_incomplete_param: + # last line not end by comma, + # but actually is a new line between params, + # manually add a comma + may_incomplete_param[-1] += "\n," + _find_split_and_push() + this_is_keyword = True + may_incomplete_param = ["".join(line.linestrs)] + if line.end_op == ",": + args[this_is_keyword].append(may_incomplete_param) + may_incomplete_param = [] + partial_line = line + if may_incomplete_param: + if this_is_keyword or not args[True]: + # if the last line is keyword line, + # or there is no keyword line at all, + # then the last line is used to check the end comma + partial_line = param_space.head_lines[-1] + else: + if not line.end_op == ",": + may_incomplete_param.append("\n,") + args[this_is_keyword].append(may_incomplete_param) + elif not args[True]: + partial_line = line + tail_noncoding = "".join(tokens2linestrs(iter(param_space.tail_noncoding))) + else: + args[this_is_keyword].append(may_incomplete_param) + tail_noncoding = "" # here is used to check the end_op - partial_line = param_space.head_lines[-1] - raw = cls.handle_end_comma("".join(post), partial_line) - formatted = format_black(raw, mode, deindent_level, partial="(") + raw = "".join( + (*(i for l in args[False] for i in l), *(i for l in args[True] for i in l)) + ) + formatable = cls.handle_end_comma(raw, partial_line) + tail_noncoding + formatted = format_black(formatable, mode, deindent_level, partial="(") return formatted @staticmethod diff --git a/tests/test_blocken.py b/tests/test_blocken.py index c7aace4..80f4ef5 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -279,19 +279,19 @@ def test_parse_python_block(self): fun1 = block.body_blocks[0] assert isinstance(fun1, NoSnakemakeBlock) assert [i.string for i in fun1.colon_line.body] == ["def", "f", "(", ")", ":"] - assert [tuple(i) for i in fun1.tail_noncoding] == [ - (tokenize.NL, "\n", (3, 0), (3, 1), "\n"), - (tokenize.NL, "\n", (4, 0), (4, 1), "\n"), - (tokenize.DEDENT, "", (5, 0), (5, 0), "b = f'''\n"), - ] + assert not fun1.tail_noncoding assert ["".join(i.full_linestrs) for i in fun1.body_blocks] == [ - " return 1\n" + " return 1\n\n\n" ] fun11 = fun1.body_blocks[0] assert isinstance(fun11, PythonBlock) assert [line.linestrs for line in fun11.head_lines] == [[" return 1\n"]] assert not fun11.body_blocks - assert not fun11.tail_noncoding + assert [tuple(i) for i in fun11.tail_noncoding] == [ + (tokenize.NL, "\n", (3, 0), (3, 1), "\n"), + (tokenize.NL, "\n", (4, 0), (4, 1), "\n"), + (tokenize.DEDENT, "", (5, 0), (5, 0), "b = f'''\n"), + ] if3 = block.body_blocks[2] assert isinstance(if3, IfForTryWithBlock) assert [i.string for i in if3.colon_line.body] == [ @@ -424,6 +424,11 @@ def test_format_partial_colon_indent(self): ) assert fmted == i + def test_format_reposity_def(self): + key = "o" * 100 + raw = f"def {key}(): ...\n" + assert format_black(raw, mode=mode) == raw + class TestBlockFormat: @@ -469,13 +474,12 @@ def test_format_python_block(self): ] py2 = block.body_blocks[1] assert len(py2.head_lines) == 3 + assert isinstance(py2, PythonBlock) assert ( py2.formatted(mode, state)[0] == 'b = f"""\n{b =} f"""\n# comment\nc = [i for j in k] if m else (lambda: None)\n' ) - assert block.formatted(mode, state)[0] == black.format_str( - self.example1, mode=mode - ) + assert block.get_formatted(mode) == black.format_str(self.example1, mode=mode) example2 = ( "rule A:\n" # L1 @@ -548,6 +552,6 @@ def test_format_python_block(self): def test_format_snakefile(self): code, formatted = self.example2 block = parse(code) - assert block.formatted(mode, state)[0].replace("\n", "<\n") == ( - formatted - ).replace("\n", "<\n") + assert block.get_formatted(mode).replace("\n", "<\n") == (formatted).replace( + "\n", "<\n" + ) From 4e692a97b64863ee3808b712e118ab0405f68bdc Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 9 Apr 2026 02:45:47 +0800 Subject: [PATCH 38/53] fix: `SnakemakeInlineArgumentBlock` should be taken very careful of --- snakefmt/blocken.py | 72 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 8b013c3..03f4f68 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -638,11 +638,77 @@ def segment2format(self, mode: Mode, state: FormatState): """yield: - [unformated_python_code, Literal[False]] - [formated_snakemake_code, Literal[True]] + + `SnakemakeInlineArgumentBlock` should be taken very careful of, + since they are formatedd as `def` blocks, and may not sperate from + blocks with different keywords. So here are the special principles + specially for one-line snakemake blocks: + + - the previous block should be in the same indent of current block; + - if previous line (with no newline nor comments) is: + 1, `def` block; or + 2. another one-line block with differnt keyword: + then add a newline + - if previous line is the same keyword with: + only comment lines but NO blank line between: + merge the two lines into one block, with comments in between + - (doesn't matter if this block is actually one-line or not) """ - yield "".join(self.head_linestrs), False + + if self.head_linestrs: + yield "".join(self.head_linestrs), False + last_keyword = "" + line = "" for block in self.body_blocks: - yield from block.segment2format(mode, state) - yield "".join(tokens2linestrs(iter(self.tail_noncoding))), False + if isinstance(block, ColonBlock): + if block.keyword == "def": + if last_keyword and last_keyword != "def": + # line must exists, check if the last line is start + if ( + line.rstrip() + .rsplit("\n", 1)[-1] + .startswith(block.indent_str + last_keyword) + ): + # Oh, differnt keyword detected, + # is there NO any line before the first line of this block? + if not block.head_lines[0].head_noncoding: + yield "\n", False + last_keyword = "def" + for line, is_snake in block.segment2format(mode, state): + # record `line` for next useage + yield line, is_snake + elif isinstance(block, SnakemakeBlock): + segs = [i for i in block.segment2format(mode, state) if i[0]] + if last_keyword: + if last_keyword == block.keyword: + head_noncoding = block.head_lines[0].head_noncoding + if head_noncoding and "\n" not in tokens2linestrs( + iter(head_noncoding) + ): + # Ah, no line detected, + # just format comment lines (and all are only comments) + # before the next is_snake = False + indent_str = block.indent_str + for i, seg in enumerate(segs): + if seg[1]: # is_snake + break + for line in seg[0].splitlines(keepends=True): + formatted = format_black(line, mode, 0) + yield indent_str + formatted, True + segs = segs[i:] + elif not block.head_lines[0].head_noncoding: + yield "\n", False + last_keyword = block.keyword + for line, is_snake in segs: + yield line, is_snake + else: + last_keyword = "" + yield from block.segment2format(mode, state) + else: + last_keyword = "" + yield from block.segment2format(mode, state) + if self.tail_noncoding: + yield "".join(tokens2linestrs(iter(self.tail_noncoding))), False @abstractmethod def compilation(self): From be7bc3ae7d582af14289f5df6ec5d9413f64d050 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 9 Apr 2026 03:10:41 +0800 Subject: [PATCH 39/53] fix: report black bugs --- snakefmt/blocken.py | 57 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 03f4f68..d51ebdc 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -17,7 +17,7 @@ from snakefmt.config import read_black_config, Mode -from snakefmt.exceptions import UnsupportedSyntax +from snakefmt.exceptions import InvalidPython, UnsupportedSyntax from snakefmt.types import TAB if sys.version_info < (3, 12): @@ -365,7 +365,13 @@ def format_python_colon_head( return formatted.split("\n", fake_head_lines)[-1] -def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] = ""): +def format_black( + raw: str, + mode: Mode, + indent=0, + partial: Literal["", ":", "("] = "", + start_token: TokenInfo | None = None, +): """Format a string using Black formatter. if indent: @@ -398,7 +404,22 @@ def format_black(raw: str, mode: Mode, indent=0, partial: Literal["", ":", "("] try: fmted = black.format_str(prefix + string, mode=mode) except black.parsing.InvalidInput as e: - raise e + if start_token is not None: + import re + + match = re.search(r"(Cannot parse.*?:\s*)(?P\d+)(.*)", str(e)) + if match: + err_msg = match.group(1) + str(start_token.start[0]) + match.group(3) + else: + err_msg = str(e) + else: + err_msg = str(e) + err_msg += ( + "\n\n(Note reported line number may be incorrect, as" + " snakefmt could not determine the true line number)" + ) + err_msg = f"Black error:\n```\n{str(err_msg)}\n```\n" + raise InvalidPython(err_msg) from None if indent: fix = fmted.split("\n", indent)[-1] else: @@ -734,7 +755,9 @@ def formatted(self, mode, state) -> tuple[str, FormatState]: raw = "".join(self.full_linestrs) if not raw.strip(): return "", state - formatted = format_black(raw, mode, self.deindent_level) + formatted = format_black( + raw, mode, self.deindent_level, start_token=self.head_lines[0].body[0] + ) return formatted, state def compilation(self): @@ -968,7 +991,7 @@ def try_combine_format( for i in range(len(arg_lines) - 1, 0, -1): try: combine = format_black("\n".join(arg_lines[:i]) + "\n,", mode) - except black.parsing.InvalidInput: + except InvalidPython: continue rest = try_combine_format(arg_lines[i:], mode) if rest is not None: @@ -1096,7 +1119,13 @@ def _find_split_and_push(): (*(i for l in args[False] for i in l), *(i for l in args[True] for i in l)) ) formatable = cls.handle_end_comma(raw, partial_line) + tail_noncoding - formatted = format_black(formatable, mode, deindent_level, partial="(") + formatted = format_black( + formatable, + mode, + deindent_level, + partial="(", + start_token=partial_line.body[0], + ) return formatted @staticmethod @@ -1138,7 +1167,13 @@ def formatted(self, mode, state): raw += "\n," tail_noncoding = tokens2linestrs(iter(self.tail_noncoding)) raw += "".join(i for i in tail_noncoding if i.strip()) - formatted = format_black(raw, mode, self.deindent_level - 1, partial="(") + formatted = format_black( + raw, + mode, + self.deindent_level - 1, + partial="(", + start_token=self.head_lines[0].body[0], + ) return formatted, state @staticmethod @@ -1170,7 +1205,13 @@ def formatted(self, mode, state): raw = raw[:comma_start] + raw[comma_start + 1 :] tail_noncoding = tokens2linestrs(iter(self.tail_noncoding)) raw += "".join(i for i in tail_noncoding if i.strip()) - formatted = format_black(raw, mode, self.deindent_level - 1, partial="(") + formatted = format_black( + raw, + mode, + self.deindent_level - 1, + partial="(", + start_token=self.head_lines[0].body[0], + ) return formatted, state @staticmethod From 345600fe2cf53e5350ecf7bbfab05f9845aa5888 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 9 Apr 2026 10:05:29 +0800 Subject: [PATCH 40/53] feat: enable sort --- snakefmt/blocken.py | 85 ++++++++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 39 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index d51ebdc..0f16887 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -333,38 +333,6 @@ def update(self, *str): return self._replace() -def format_python_colon_head( - raw: str, mode: Mode, keyword: str, indent_str: str = "", indent=0, partial=False -): - """Continuation keywords (else/elif/except/finally) need a preceding fake block - because black cannot parse them in isolation. - """ - if keyword == "elif" or keyword == "else": - fake_head = indent_str + "if 1: pass\n" - fake_head_lines = 2 # black always expands "if 1: pass" to 2 lines - elif keyword == "except" or keyword == "finally": - fake_head = indent_str + "try: pass\n" - fake_head_lines = 2 # black always expands "try: pass" to 2 lines - elif keyword == "match": - # match needs at least one case - dummy_case = indent_str + " case _: pass\n" - formatted = format_black(raw + dummy_case, mode, indent, "") - # Keep only the match line - return formatted.rsplit("\n", 3)[0] + "\n" - elif keyword == "case": - # case needs to be inside a match, construct the full block - assert indent_str and indent, "`case` block must be indented" - dummy_match = indent_str[:-1] + "match 1:\n" + raw - formatted = format_black(dummy_match, mode, indent - 1, ":") - return formatted.split("\n", 1)[-1] - else: - return format_black(raw, mode, indent, ":" if partial else "") - formatted = format_black(fake_head + raw, mode, indent, ":") - if not fake_head: - return formatted - return formatted.split("\n", fake_head_lines)[-1] - - def format_black( raw: str, mode: Mode, @@ -1452,25 +1420,47 @@ def consume_body(self, tokens): def format_body(self, mode, state, post_colon): assert not post_colon, "Invalid inline contents" formatted: list[str] = [] + sort_directives: dict[str, str] = {} tail_noncoding: list[str] = [] indent = TAB * (self.deindent_level + 1) for i, block in enumerate(self.body_blocks): + directive = "" if tail_noncoding: - tail_noncoding = [i.lstrip().rstrip("\n") for i in tail_noncoding] - formatted.extend(f"{indent}{i}\n" for i in tail_noncoding if i) + for line in tail_noncoding: + if line.strip(): + # only non-empty lines are formattable + line = format_black(line, mode, 0) + # possible update state? + directive += indent + line tail_noncoding = [] if i == 0 and isinstance(block, PythonBlock): body, state = block.formatted(mode, state) + formatted.append(body) else: assert isinstance( block, SnakemakeBlock ), "Unexpected block type in snakemake keyword block" noncoding, body = block.formatted(mode, state) - formatted.extend(i for i in noncoding if i.strip()) - formatted.append(body) + for line in noncoding: # here noncoding is already formated + if line.strip(): + # only non-empty lines are formattable + # possible update state? + directive += line + directive += body + if state.sort_direcives: + sort_directives[block.keyword] = directive + else: + formatted.append(directive) if block.tail_noncoding: tail_noncoding = tokens2linestrs(iter(block.tail_noncoding)) # no `\n` between + if sort_directives: + for keyword in self.subautomata: + if keyword in sort_directives: + formatted.append(sort_directives[keyword]) + if tail_noncoding: + tail_noncoding = [i.lstrip().rstrip("\n") for i in tail_noncoding] + formatted.extend(f"{indent}{i}\n" for i in tail_noncoding if i) return "".join(formatted) @@ -1481,12 +1471,14 @@ class Module(NamedBlock, SnakemakeKeywordBlock): @_register() class Name(SnakemakeInlineArgumentBlock): ... + # Reference @_register() class Snakefile(SnakemakeUnnamedArgumentBlock): ... @_register() class Meta_Wrapper(SnakemakeUnnamedArgumentBlock): ... + # Override @_register() class Skip_Validation(SnakemakeUnnamedArgumentBlock): ... @@ -1512,6 +1504,7 @@ class Name(SnakemakeUnnamedArgumentBlock): ... @_register("default_target") class Default_Target_Rule(SnakemakeInlineArgumentBlock): ... + # I/O @_register() class Input(SnakemakeArgumentsBlock): ... @@ -1524,12 +1517,14 @@ class Log(SnakemakeArgumentsBlock): ... @_register() class Benchmark(SnakemakeUnnamedArgumentBlock): ... + # Rule logic @_register() class Pathvars(SnakemakeArgumentsBlock): ... @_register("wildcard_constraints") class Register_Wildcard_Constraints(SnakemakeArgumentsBlock): ... + # Scheduling & control @_register("cache") class Cache_Rule(SnakemakeInlineArgumentBlock): ... @@ -1548,6 +1543,7 @@ class LocalRule(SnakemakeInlineArgumentBlock): ... @_register() class Handover(SnakemakeInlineArgumentBlock): ... + # Execution environment @_register() class Shadow(SnakemakeUnnamedArgumentBlock): ... @@ -1564,6 +1560,8 @@ class Containerized(SnakemakeUnnamedArgumentBlock): ... @_register() class EnvModules(SnakemakeUnnamedArgumentsBlock): ... + # Execution resources and parameters + @_register() class Threads(SnakemakeInlineArgumentBlock): ... @@ -1573,6 +1571,7 @@ class Resources(SnakemakeArgumentsBlock): ... @_register() class Params(SnakemakeArgumentsBlock): ... + # Runtime messages @_register() class Message(SnakemakeUnnamedArgumentBlock): ... @@ -1611,6 +1610,7 @@ def formatted(self, mode, state): @_register() class Rule(_Rule): + # Action exec_subautomata, _register = init_block_register() @_register() @@ -1651,8 +1651,9 @@ class GlobalBlock(Block): so tail_noncoding always updated to the last body_block """ - __slots__ = ("mode",) + __slots__ = ("mode", "sort_direcives") mode: Mode + sort_direcives: bool subautomata = {**python_subautomata, **global_snakemake_subautomata} @@ -1662,15 +1663,20 @@ def __init__(self, deindent_level, tokens, lines=None): def consume(self, tokens): self.body_blocks = self.consume_subblocks(tokens) - def get_formatted(self, mode: Mode | None = None): + def get_formatted( + self, mode: Mode | None = None, sort_directives: bool | None = None + ): if mode is None: mode = getattr(self, "mode", None) if mode is None: raise ValueError("Mode should be provided for formatting") + if sort_directives is None: + sort_directives = bool(getattr(self, "sort_direcives", False)) + state = FormatState(sort_direcives=sort_directives) python_codes: list[str] = [] snakemake_codes: list[str] = [] last_str = "" - for str, is_snake in self.segment2format(mode or self.mode, FormatState()): + for str, is_snake in self.segment2format(mode or self.mode, state): if is_snake: python_codes.append(last_str) last_str = "" @@ -1729,4 +1735,5 @@ def setup_formatter( mode.line_length = line_length formatter.mode = mode + formatter.sort_direcives = sort_params return formatter From 5519bf63dcd286dbe391309886232b320bd573a8 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 9 Apr 2026 10:59:26 +0800 Subject: [PATCH 41/53] chore: remove unused --- snakefmt/blocken.py | 99 +++++++++---------------------------------- tests/test_blocken.py | 44 ------------------- 2 files changed, 21 insertions(+), 122 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 0f16887..1c79361 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -275,22 +275,8 @@ def from_token(cls, tokens: Iterator[TokenInfo]): return cls(head_empty_lines, deindents, contents, token) -def split_token_lines(token: TokenInfo): - """Token can be multiline. - e.g., `f'''\\nplaintext\\n'''` has these tokens: - - TokenInfo(type=61 (FSTRING_START), string="f'''", - start=(21, 0), end=(21, 4), line="f'''\\n") - TokenInfo(type=62 (FSTRING_MIDDLE), string='\\ncccccccc\\n', - start=(21, 4), end=(23, 0), line="f'''\\ncccccccc\\n'''\\n") - TokenInfo(type=63 (FSTRING_END), string="'''", - start=(23, 0), end=(23, 3), line="'''\\n") - - lines should be split to drop overlapping lines and keep unique ones. - """ - return zip( - range(token.start[0], token.end[0] + 1), token.line.splitlines(keepends=True) - ) +def not_deindent(token: TokenInfo) -> bool: + return token.type != tokenize.INDENT and token.type != tokenize.DEDENT def tokens2linestrs(tokens: Iterator[TokenInfo]): @@ -304,7 +290,13 @@ def tokens2linestrs(tokens: Iterator[TokenInfo]): string_interior_lines: set[int] = set() for token in tokens: if not_deindent(token) and token.end[0] not in lines: - lines.update(split_token_lines(token)) + # split multiline tokens with lineno for dereplication + lines.update( + zip( + range(token.start[0], token.end[0] + 1), + token.line.splitlines(keepends=True), + ) + ) if token.start[0] != token.end[0]: string_interior_lines.update( range(token.start[0] + 1, token.end[0] + 1) @@ -320,10 +312,6 @@ def tokens2linestrs(tokens: Iterator[TokenInfo]): return newlines -def not_deindent(token: TokenInfo) -> bool: - return token.type != tokenize.INDENT and token.type != tokenize.DEDENT - - class FormatState(NamedTuple): fmt_off: bool = False sort_direcives: bool = False @@ -414,11 +402,12 @@ class Block(ABC): A block can be: a continuous python code of lines with the same indentation level. Also include functions, classes and decoraters (`@` lines) - a single block identifed by keywords in `{PYTHON_INDENT_KEYWORDS}` + a single block identifed by keywords in + if/elif/else / for/while / try/except/finally / with and all the code under it, until the next block of the same or lower indent level. a snakemake keyword block (rule, module, config, etc.) and all the code under it, until the next block of the same or lower indent level. - snakemake keywords should NEVER in functions or classes + (snakemake keywords should NEVER in functions or classes) comments between blocks (exclude the comment right before the indenting keyword, which is considered part of the block) @@ -830,15 +819,9 @@ class UnknownIndentBlock(IfForTryWithBlock): """ -PYTHON_INDENT_KEYWORDS = { - i - for j in ("if elif else", "for while", "try except finally", "with") - for i in j.split() -} - if_for_try_with_blocks: dict[str, type[IfForTryWithBlock]] = { i.lower(): type(i.capitalize(), (IfForTryWithBlock,), {}) - for i in PYTHON_INDENT_KEYWORDS + for i in ("if elif else " "for while " "try except finally " "with").split() } @@ -857,9 +840,6 @@ def consume_body(self, tokens): ) self.body_blocks = blocks - def formatted(self, mode, state): - raise NotImplementedError("Not supported to format match-case blocks yet") - def compilation(self): raise NotImplementedError @@ -1098,7 +1078,11 @@ def _find_split_and_push(): @staticmethod @abstractmethod - def handle_end_comma(raw: str, last_line: LogicalLine) -> str: ... + def handle_end_comma(raw: str, last_line: LogicalLine) -> str: + """ + For PythonArguments: the last line should always endswith `,`; + For PythonOneLineArgument: the last line should never endswith `,`; + """ class PythonArguments(PythonArgumentsBlock): @@ -1124,26 +1108,6 @@ class PythonArguments(PythonArgumentsBlock): but that's not eazy, especially for unnamed arguments """ - def formatted(self, mode, state): - """PythonArguments and its subclasses always at the terminal - of the snakemake keyword tree, - so returned state never used anymore - """ - assert not self.body_blocks, "PythonArguments should not have body blocks" - raw = "".join(self.head_linestrs) - if not self.head_lines[-1].end_op == ",": - raw += "\n," - tail_noncoding = tokens2linestrs(iter(self.tail_noncoding)) - raw += "".join(i for i in tail_noncoding if i.strip()) - formatted = format_black( - raw, - mode, - self.deindent_level - 1, - partial="(", - start_token=self.head_lines[0].body[0], - ) - return formatted, state - @staticmethod def handle_end_comma(raw, last_line): if not last_line.end_op == ",": @@ -1158,30 +1122,6 @@ class PythonUnnamedArguments(PythonArguments): class PythonOneLineArgument(PythonArgumentsBlock): """Only allow simple expressions on the right""" - def formatted(self, mode, state): - """Only a single expression, trim the trailing comma""" - assert not self.body_blocks, "PythonArguments should not have body blocks" - raw = "".join(self.head_linestrs) - if self.head_lines[-1].end_op == ",": - last_line = self.head_lines[-1] - comma_token = ( - last_line.body[-2] - if last_line.body[-1].type == tokenize.COMMENT - else last_line.body[-1] - ) - comma_start = comma_token.start[1] - len(comma_token.line) - raw = raw[:comma_start] + raw[comma_start + 1 :] - tail_noncoding = tokens2linestrs(iter(self.tail_noncoding)) - raw += "".join(i for i in tail_noncoding if i.strip()) - formatted = format_black( - raw, - mode, - self.deindent_level - 1, - partial="(", - start_token=self.head_lines[0].body[0], - ) - return formatted, state - @staticmethod def handle_end_comma(raw, last_line): if last_line.end_op == ",": @@ -1418,6 +1358,9 @@ def consume_body(self, tokens): self.body_blocks = blocks def format_body(self, mode, state, post_colon): + """Sort directives in the order of subautomata, + and format them together with the head line. + """ assert not post_colon, "Invalid inline contents" formatted: list[str] = [] sort_directives: dict[str, str] = {} diff --git a/tests/test_blocken.py b/tests/test_blocken.py index 80f4ef5..6bb1c50 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -9,7 +9,6 @@ consume_fstring, TokenIterator, format_black, - format_python_colon_head, tokenize, is_fstring_start, UnsupportedSyntax, @@ -381,49 +380,6 @@ def test_format_paren(self): fmted = format_black(raw, mode=mode, indent=1, partial="(") assert fmted == (f'{TAB * 2}"b = 2",\n') - def test_format_partial_colon(self): - for i in ( - "if cond:\n", - "else:\n", - "elif x > 0:\n", - "except ValueError:\n", - "finally:\n", - "match val:\n", - ): - fmted = format_python_colon_head( - i, mode, i.strip().split()[0].replace(":", ""), partial=True - ) - assert fmted == i - - def test_format_partial_colon_indent(self): - for i in ( - f"{TAB}else:\n", - f"{TAB}elif x > 0:\n", - f"{TAB}except (ValueError, KeyError):\n", - f"{TAB}finally:\n", - f"{TAB}match val:\n", - f"{TAB}case Point(x, 0):\n", - ): - fmted = format_python_colon_head( - i, - mode, - i.strip().split()[0].replace(":", ""), - indent_str=TAB, - indent=1, - partial=True, - ) - assert fmted == i - i = f"{TAB}elif (\n x > 0\n ):\n" - fmted = format_python_colon_head( - i, mode, "elif", indent_str=TAB, indent=1, partial=True - ) - assert fmted == " elif x > 0:\n" - i = f"{TAB*2}case Point(x, 0):\n" - fmted = format_python_colon_head( - i, mode, "case", indent_str=TAB * 2, indent=2, partial=True - ) - assert fmted == i - def test_format_reposity_def(self): key = "o" * 100 raw = f"def {key}(): ...\n" From e75b6484825946e9c193069cec4fc083a9926179 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 9 Apr 2026 19:51:45 +0800 Subject: [PATCH 42/53] fix: partial TestFmtOffOn --- snakefmt/blocken.py | 307 +++++++++++++++++++++++++++++------------- tests/test_blocken.py | 2 +- 2 files changed, 214 insertions(+), 95 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 1c79361..b1f83fa 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -1,25 +1,30 @@ +import re import sys import tokenize from abc import ABC, abstractmethod +from collections import OrderedDict +from tokenize import TokenInfo from typing import ( Callable, + Generator, Iterator, Literal, + Mapping, NamedTuple, Optional, - Mapping, TypeVar, ) -from tokenize import TokenInfo -from collections import OrderedDict -import black.parsing - -from snakefmt.config import read_black_config, Mode +import black.parsing +from snakefmt.config import Mode, read_black_config from snakefmt.exceptions import InvalidPython, UnsupportedSyntax from snakefmt.types import TAB +_FMT_DIRECTIVE_RE = re.compile( + r"^# fmt: (off|on)(?:\[(\w+(?:,\s*\w+)*)\])?(?=$|\s{2}|\s#)" +) + if sys.version_info < (3, 12): is_fstring_start = lambda token: False else: @@ -313,12 +318,67 @@ def tokens2linestrs(tokens: Iterator[TokenInfo]): class FormatState(NamedTuple): - fmt_off: bool = False - sort_direcives: bool = False + fmt_on: bool = True + sort_direcives: bool | None = None + skip_next: bool = False # one-time directive for the next snakemake block - def update(self, *str): - # TODO: implement state update logic - return self._replace() + @property + def not_format(self): + return not self.fmt_on or self.skip_next + + def update(self, comment: str): + """check single line comment line for pattern: + # fmt: off + # fmt: off[option1, option2, ...] + # fmt: on + # fmt: on[option1, option2, ...] + + Currently, options can be: + - sort: whether to sort snakemake directives (e.g. input, output, params, etc.) + - next: whether to apply the directive to the next snakemake block only + Do not effect blocks after empty lines. + Cannot be disabled by `# fmt: on[next]` + - only the first directive will be applied + + If found `# fmt: on` and no `# fmt: off` before: + if `fmt: off[sort]` is False: + sort_direcives == True -> enabled + sort_direcives == False -> disabled in this indent before + sort_direcives == None -> haven't enabled originally + turn it on + """ + match = _FMT_DIRECTIVE_RE.match(comment) + if match := _FMT_DIRECTIVE_RE.match(comment): + directive, options = match.groups() + # Parse options: "sort,next" -> ["sort", "next"] -> "sort" + option = [opt.strip() for opt in (options or "").split(",")][0] + if not self.fmt_on: # only check `# fmt: on` + if directive == "on" and not option: + return self._replace(fmt_on=True) + elif directive == "on": + if option == "sort": + return self._replace(sort_direcives=True) + if self.sort_direcives is False: + # re-enable sorting if it was disabled by `# fmt: off[sort]` before, + # but should effect if no `# fmt: off[sort]` in this indent before. + return self._replace(sort_direcives=None) + elif directive == "off": + if option == "sort": + return self._replace(sort_direcives=False) + if option == "next": + return self._replace(skip_next=True) + return self._replace(fmt_on=False) + return self + + def consume_skip_next(self) -> "FormatState": + """Returns new state with skip_next consumed (set to False)""" + if self.skip_next: + return self._replace(skip_next=False) + return self + + @staticmethod + def found_skip(comment: str): + return "# fmt: skip" in comment def format_black( @@ -327,7 +387,7 @@ def format_black( indent=0, partial: Literal["", ":", "("] = "", start_token: TokenInfo | None = None, -): +) -> str: """Format a string using Black formatter. if indent: @@ -612,7 +672,9 @@ def components(self) -> "Iterator[DocumentSymbol]": for block in self.body_blocks: yield from block.components() - def segment2format(self, mode: Mode, state: FormatState): + def segment2format( + self, mode: Mode, state: FormatState + ) -> Generator[tuple[str, bool], None, None]: """yield: - [unformated_python_code, Literal[False]] - [formated_snakemake_code, Literal[True]] @@ -633,22 +695,34 @@ def segment2format(self, mode: Mode, state: FormatState): - (doesn't matter if this block is actually one-line or not) """ + # comment fmt directives in head_linestrs + # will effect on post blocks of the same indent, + # so should be updated during the parent body_blocks iteration. if self.head_linestrs: yield "".join(self.head_linestrs), False last_keyword = "" line = "" for block in self.body_blocks: + restart_state = state = state.consume_skip_next() + # update state from head_noncoding + for head_line in block.head_lines: + for noncoding_token in head_line.head_noncoding: + if noncoding_token.type == tokenize.COMMENT: + state = state.update(noncoding_token.string) + elif state.skip_next and not noncoding_token.line.strip(): + state = state.consume_skip_next() if isinstance(block, ColonBlock): if block.keyword == "def": if last_keyword and last_keyword != "def": - # line must exists, check if the last line is start + # Oh, differnt keyword detected, so (last)line must exists + # Then check if that line is start if ( line.rstrip() .rsplit("\n", 1)[-1] .startswith(block.indent_str + last_keyword) ): - # Oh, differnt keyword detected, - # is there NO any line before the first line of this block? + # If NO any line before the first line of this block, + # black cannot split them: Add one to force splitting if not block.head_lines[0].head_noncoding: yield "\n", False last_keyword = "def" @@ -656,29 +730,11 @@ def segment2format(self, mode: Mode, state: FormatState): # record `line` for next useage yield line, is_snake elif isinstance(block, SnakemakeBlock): - segs = [i for i in block.segment2format(mode, state) if i[0]] - if last_keyword: - if last_keyword == block.keyword: - head_noncoding = block.head_lines[0].head_noncoding - if head_noncoding and "\n" not in tokens2linestrs( - iter(head_noncoding) - ): - # Ah, no line detected, - # just format comment lines (and all are only comments) - # before the next is_snake = False - indent_str = block.indent_str - for i, seg in enumerate(segs): - if seg[1]: # is_snake - break - for line in seg[0].splitlines(keepends=True): - formatted = format_black(line, mode, 0) - yield indent_str + formatted, True - segs = segs[i:] - elif not block.head_lines[0].head_noncoding: - yield "\n", False - last_keyword = block.keyword - for line, is_snake in segs: + for line, is_snake in block.segment2format( + mode, restart_state, last_keyword + ): yield line, is_snake + last_keyword = block.keyword else: last_keyword = "" yield from block.segment2format(mode, state) @@ -708,14 +764,14 @@ class PythonBlock(Block): def consume(self, tokens): "Do nothing, win" - def formatted(self, mode, state) -> tuple[str, FormatState]: + def formatted(self, mode: Mode): raw = "".join(self.full_linestrs) if not raw.strip(): - return "", state + return "" formatted = format_black( raw, mode, self.deindent_level, start_token=self.head_lines[0].body[0] ) - return formatted, state + return formatted def compilation(self): raise NotImplementedError @@ -758,7 +814,7 @@ def split_colon_line(self): prior = tokens2linestrs(iter(last_line_tokens)) prior[-1] = prior[-1][: colon_token.start[1]] token_iter.denext(colon_token) - return self.colon_line.head_noncoding, prior, token_iter + return prior, token_iter @property def colon_line(self): @@ -878,30 +934,76 @@ class SnakemakeBlock(ColonBlock): def components(self) -> Iterator[DocumentSymbol]: yield from [] - def segment2format(self, mode, state): + def segment2format(self, mode, state, last_keyword=""): """yield: - [unformated_python_code, Literal[False]] - [formated_snakemake_code, Literal[True]] + + If state.skip_next is True, or state.fmt_on is False, + return unformatted content with proper True/False markers. """ - head_noncding, body = self.formatted(mode, state) - yield "".join(head_noncding), False - yield body, True - yield "".join(tokens2linestrs(iter(self.tail_noncoding))), False + + # Get noncoding_lines early to check fmt directives + indent_str = self.indent_str + assert len(self.head_lines) == 1, "Snakemake keywords should only in one line" + noncoding_lines: list[str] = [] + last_fmt_on = state.fmt_on + # Check if there's fmt: on/off in noncoding_lines to update state + for noncoding_line in tokens2linestrs(iter(self.colon_line.head_noncoding)): + if not noncoding_line.strip(): + last_keyword = "" + else: + state = state.update(noncoding_line.lstrip()) + if not state.fmt_on: + noncoding_lines.append(noncoding_line) + else: + noncoding_lines.append( + indent_str + format_black(noncoding_line, mode, 0) + ) + if last_fmt_on and state.fmt_on: + if last_keyword == self.keyword: + # pre-format these lines and yield together + pre_formatted = format_black( + "".join(noncoding_lines), mode, 0 + ).splitlines(keepends=True) + for line in pre_formatted: + if state.found_skip(line): + yield line, False + else: + yield indent_str + line.lstrip(), True + else: + if not noncoding_lines: + yield "\n", False + yield "".join(noncoding_lines), False + else: + yield "".join(noncoding_lines), False + + # Check if this block should be skipped from formatting + if state.not_format: + raw = "".join( + [self.colon_line.body[-1].line] + + [line for block in self.body_blocks for line in block.full_linestrs] + ) + yield raw, True + else: + yield self.formatted(mode, state), True + if self.tail_noncoding: + yield "".join(tokens2linestrs(iter(self.tail_noncoding))), False def formatted(self, mode, state): - noncoding_lines, formatted_prior, post_colon = self.format_head(mode) + formatted_prior, post_colon = self.format_head(mode) formatted_body = self.format_body(mode, state, post_colon) formatted = [formatted_prior, formatted_body] - return noncoding_lines, "".join(formatted) + return "".join(formatted) - def format_head(self, mode: Mode) -> tuple[list[str], str, list[TokenInfo]]: - assert ( - len(self.head_lines) == 1 - ), "Snakemake keywords should only have one head line" + def format_head(self, mode: Mode) -> tuple[str, list[TokenInfo]]: indent = TAB * self.deindent_level - noncoding, prior_colon, post_colon = self.split_colon_line() - noncoding_lines = tokens2linestrs(iter(noncoding)) - assert len(prior_colon) == 1, "Snakemake keywords should be single line" + if self.colon_line.body[-1].type == tokenize.COMMENT: + line = self.colon_line.body[-1].line + if FormatState.found_skip(line): + return indent + line.lstrip(), [] + prior_colon, post_colon = self.split_colon_line() + assert len(prior_colon) == 1, "Snakemake keywords should be in one line" (head,) = prior_colon components = head.strip().split() formatted_head = indent + " ".join(components) + ":" @@ -913,12 +1015,14 @@ def format_head(self, mode: Mode) -> tuple[list[str], str, list[TokenInfo]]: fake_str = f"if 1:" + "".join(post) + " ..." fake_fmt = format_black(fake_str, mode).strip() formatted_head += fake_fmt.split(":", 1)[1].rsplit("\n", 1)[0] + "\n" - return noncoding_lines, formatted_head, [] + return formatted_head, [] else: - return noncoding_lines, formatted_head + "\n", list(post_colon.rest) + return formatted_head + "\n", list(post_colon.rest) @abstractmethod - def format_body(self, mode, state, post_colon: list[TokenInfo]) -> str: ... + def format_body( + self, mode: Mode, state: FormatState, post_colon: list[TokenInfo] + ) -> str: ... def compilation(self): raise NotImplementedError @@ -983,7 +1087,8 @@ def format_post_colon( post_colon[0] := TokenInfo(type=NAME, string='balabal', ...) body_blocks: indent body blocks, e.g. the block of `balabal2` in the above example """ - assert post_colon or body_blocks, "should have something in the comment" + if not (post_colon or body_blocks): + return "" args: dict[bool, list[list[str]]] = {True: [], False: []} if post_colon: assert ( @@ -1198,7 +1303,7 @@ def formatted(self, mode, state): """Try to merge the inline argument into the head line. If the line is too long after merging, then keep them separate. """ - noncoding_lines, formatted_prior, post_colon = self.format_head(mode) + formatted_prior, post_colon = self.format_head(mode) formatted_body = self.format_body(mode, state, post_colon) formatted = [formatted_prior, formatted_body] if formatted_body.count("\n") == 1 and formatted_body.endswith("\n"): @@ -1211,7 +1316,7 @@ def formatted(self, mode, state): formatted_merge = last_head_line + " " + formatted_body.lstrip() if len(formatted_merge) <= mode.line_length: formatted = [prev + formatted_merge] - return noncoding_lines, "".join(formatted) + return "".join(formatted) def init_block_register(): @@ -1328,7 +1433,7 @@ def format_body(self, mode, state, post_colon): else: (param_space,) = self.body_blocks assert isinstance(param_space, PythonBlock), "Unexpected body block type" - return param_space.formatted(mode, state)[0] + return param_space.formatted(mode) @_register() @@ -1363,49 +1468,63 @@ def format_body(self, mode, state, post_colon): """ assert not post_colon, "Invalid inline contents" formatted: list[str] = [] - sort_directives: dict[str, str] = {} + directives: dict[str, str] = {} tail_noncoding: list[str] = [] indent = TAB * (self.deindent_level + 1) for i, block in enumerate(self.body_blocks): - directive = "" - if tail_noncoding: - for line in tail_noncoding: - if line.strip(): - # only non-empty lines are formattable - line = format_black(line, mode, 0) - # possible update state? - directive += indent + line - tail_noncoding = [] + assert not tail_noncoding, "no tail_noncoding before body_blocks" if i == 0 and isinstance(block, PythonBlock): - body, state = block.formatted(mode, state) + body = block.formatted(mode) formatted.append(body) + for line in block.head_linestrs: + state = state.update(line.lstrip()) else: assert isinstance( block, SnakemakeBlock ), "Unexpected block type in snakemake keyword block" - noncoding, body = block.formatted(mode, state) + noncoding = tokens2linestrs(iter(block.colon_line.head_noncoding)) + directive = "" for line in noncoding: # here noncoding is already formated if line.strip(): # only non-empty lines are formattable - # possible update state? directive += line - directive += body - if state.sort_direcives: - sort_directives[block.keyword] = directive + state = state.update(line.lstrip()) + if state.not_format: + if directives: + formatted.extend(self.sort_directives(directives)) + if directive: + formatted.append(directive) + directive = "" + if not line.strip(): + formatted.append(line) + if state.not_format: + formatted.append("".join(block.colon_line.body[-1].line)) + for block_ in block.body_blocks: + formatted.append("".join(block_.full_linestrs)) else: - formatted.append(directive) + directive += block.formatted(mode, state) + if state.sort_direcives: + directives[block.keyword] = directive + else: + formatted.append(directive) if block.tail_noncoding: tail_noncoding = tokens2linestrs(iter(block.tail_noncoding)) # no `\n` between - if sort_directives: - for keyword in self.subautomata: - if keyword in sort_directives: - formatted.append(sort_directives[keyword]) + if directives: + formatted.extend(self.sort_directives(directives)) if tail_noncoding: tail_noncoding = [i.lstrip().rstrip("\n") for i in tail_noncoding] formatted.extend(f"{indent}{i}\n" for i in tail_noncoding if i) return "".join(formatted) + @classmethod + def sort_directives(cls, directives: dict[str, str]): + """Sort directives in the order of subautomata. Clear input""" + for keyword in cls.subautomata: + if keyword in directives: + yield directives.pop(keyword) + assert not directives, f"Unknown directives: {', '.join(directives)}" + @_register() class Module(NamedBlock, SnakemakeKeywordBlock): @@ -1537,18 +1656,17 @@ def formatted(self, mode, state): if ":" not in head_bulk_line: # return quickly (also no body block here) indent = TAB * self.deindent_level - noncoding_lines = tokens2linestrs(iter(self.head_lines[0].head_noncoding)) components = head_bulk_line.strip().split() formatted_head = indent + " ".join(components) if "#" in head_line[0]: formatted_head += " " + format_black( "#" + head_line[0].split("#", 1)[1], mode=mode ).rstrip("\n") - return noncoding_lines, formatted_head + "\n" - noncoding_lines, formatted_prior, post_colon = self.format_head(mode) + return formatted_head + "\n" + formatted_prior, post_colon = self.format_head(mode) formatted_body = self.format_body(mode, state, post_colon) formatted = [formatted_prior, formatted_body] - return noncoding_lines, "".join(formatted) + return "".join(formatted) @_register() @@ -1614,18 +1732,19 @@ def get_formatted( if mode is None: raise ValueError("Mode should be provided for formatting") if sort_directives is None: - sort_directives = bool(getattr(self, "sort_direcives", False)) - state = FormatState(sort_direcives=sort_directives) + sort_directives = getattr(self, "sort_direcives", None) + state = FormatState(sort_direcives=sort_directives or None) + # if set to None, it will not be enabled by `# fmt: on` python_codes: list[str] = [] snakemake_codes: list[str] = [] last_str = "" - for str, is_snake in self.segment2format(mode or self.mode, state): + for segment, is_snake in self.segment2format(mode or self.mode, state): if is_snake: python_codes.append(last_str) last_str = "" - snakemake_codes.append(str) + snakemake_codes.append(segment) else: - last_str += str + last_str += segment place_hode_str = "o" * 50 raw_str = "".join(python_codes) while place_hode_str in raw_str: @@ -1656,7 +1775,7 @@ def compilation(self): raise NotImplementedError -def parse(input: str | Callable[[], str], name: str = "") -> GlobalBlock: +def parse(input: str | Callable[[], str], name: str = ""): if isinstance(input, str): tokens = tokenize.generate_tokens( iter(input.splitlines(keepends=True)).__next__ diff --git a/tests/test_blocken.py b/tests/test_blocken.py index 6bb1c50..f75210a 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -432,7 +432,7 @@ def test_format_python_block(self): assert len(py2.head_lines) == 3 assert isinstance(py2, PythonBlock) assert ( - py2.formatted(mode, state)[0] + py2.formatted(mode) == 'b = f"""\n{b =} f"""\n# comment\nc = [i for j in k] if m else (lambda: None)\n' ) assert block.get_formatted(mode) == black.format_str(self.example1, mode=mode) From 75ac4c392204dd4859eb688e291e1791a79b8852 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 9 Apr 2026 22:17:53 +0800 Subject: [PATCH 43/53] fix: off[sort] --- snakefmt/blocken.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index b1f83fa..f2d422d 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -361,7 +361,7 @@ def update(self, comment: str): if self.sort_direcives is False: # re-enable sorting if it was disabled by `# fmt: off[sort]` before, # but should effect if no `# fmt: off[sort]` in this indent before. - return self._replace(sort_direcives=None) + return self._replace(sort_direcives=True) elif directive == "off": if option == "sort": return self._replace(sort_direcives=False) @@ -380,6 +380,11 @@ def consume_skip_next(self) -> "FormatState": def found_skip(comment: str): return "# fmt: skip" in comment + def reset_sort(self): + if self.sort_direcives is False: + return self._replace(sort_direcives=None) + return self + def format_black( raw: str, @@ -694,7 +699,6 @@ def segment2format( merge the two lines into one block, with comments in between - (doesn't matter if this block is actually one-line or not) """ - # comment fmt directives in head_linestrs # will effect on post blocks of the same indent, # so should be updated during the parent body_blocks iteration. @@ -702,6 +706,7 @@ def segment2format( yield "".join(self.head_linestrs), False last_keyword = "" line = "" + state = state.reset_sort() for block in self.body_blocks: restart_state = state = state.consume_skip_next() # update state from head_noncoding @@ -1485,18 +1490,34 @@ def format_body(self, mode, state, post_colon): noncoding = tokens2linestrs(iter(block.colon_line.head_noncoding)) directive = "" for line in noncoding: # here noncoding is already formated - if line.strip(): + linelstrip = line.lstrip() + last_sort_off = state.sort_direcives + if linelstrip: # only non-empty lines are formattable - directive += line - state = state.update(line.lstrip()) + if state.found_skip(linelstrip): + directive += line + else: + directive += indent + format_black(linelstrip, mode, 0) + state = state.update(linelstrip) if state.not_format: if directives: formatted.extend(self.sort_directives(directives)) if directive: formatted.append(directive) directive = "" - if not line.strip(): + if not linelstrip: formatted.append(line) + elif not state.sort_direcives: + if directives: + formatted.extend(self.sort_directives(directives)) + if directive: + formatted.append(directive) + directive = "" + elif not last_sort_off: + # state.sort_direcives switched on, this comment is + # actually `# fmt: on[sort]` directive, so split from next directive + formatted.append(directive) + directive = "" if state.not_format: formatted.append("".join(block.colon_line.body[-1].line)) for block_ in block.body_blocks: @@ -1506,6 +1527,7 @@ def format_body(self, mode, state, post_colon): if state.sort_direcives: directives[block.keyword] = directive else: + assert not directives, "Already flushed once fmt: off[sort]" formatted.append(directive) if block.tail_noncoding: tail_noncoding = tokens2linestrs(iter(block.tail_noncoding)) From 692ec6c6b4fb45321b450527a2ed1b861ab12216 Mon Sep 17 00:00:00 2001 From: hwrn Date: Thu, 9 Apr 2026 23:10:31 +0800 Subject: [PATCH 44/53] fix: fmt: off[next] --- snakefmt/blocken.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index f2d422d..70e436f 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -939,7 +939,7 @@ class SnakemakeBlock(ColonBlock): def components(self) -> Iterator[DocumentSymbol]: yield from [] - def segment2format(self, mode, state, last_keyword=""): + def segment2format(self, mode: Mode, state: FormatState, last_keyword=""): """yield: - [unformated_python_code, Literal[False]] - [formated_snakemake_code, Literal[True]] @@ -989,7 +989,18 @@ def segment2format(self, mode, state, last_keyword=""): [self.colon_line.body[-1].line] + [line for block in self.body_blocks for line in block.full_linestrs] ) + # Trailing blank lines from body_blocks belong to the next block's + # separator, not this block's content. Strip extra trailing blank + # lines so the compilation loop doesn't double-count them with + # black's blank-line insertion. + if raw.endswith("\n\n") and state.skip_next: + n_trailing_space = len(raw) - len(raw.rstrip("\n")) - 1 + raw = raw.rstrip("\n") + "\n" + else: + n_trailing_space = 0 yield raw, True + if n_trailing_space > 0: + yield "\n" * n_trailing_space, False else: yield self.formatted(mode, state), True if self.tail_noncoding: From ccfab56350b1edbe1abb0bdaaac7a36186db140c Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 00:52:07 +0800 Subject: [PATCH 45/53] feat: swith to differnt logic --- snakefmt/blocken.py | 109 ++++++++++++++---------- tests/test_formatter.py | 178 ++++++++++++++++++++++++---------------- 2 files changed, 173 insertions(+), 114 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 70e436f..7935fd9 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -426,8 +426,6 @@ def format_black( fmted = black.format_str(prefix + string, mode=mode) except black.parsing.InvalidInput as e: if start_token is not None: - import re - match = re.search(r"(Cannot parse.*?:\s*)(?P\d+)(.*)", str(e)) if match: err_msg = match.group(1) + str(start_token.start[0]) + match.group(3) @@ -679,10 +677,10 @@ def components(self) -> "Iterator[DocumentSymbol]": def segment2format( self, mode: Mode, state: FormatState - ) -> Generator[tuple[str, bool], None, None]: + ) -> Generator[tuple[str, str | None], None, None]: """yield: - - [unformated_python_code, Literal[False]] - - [formated_snakemake_code, Literal[True]] + - [unformated_python_code, None] + - [formated_snakemake_code, indent_str] `SnakemakeInlineArgumentBlock` should be taken very careful of, since they are formatedd as `def` blocks, and may not sperate from @@ -703,7 +701,7 @@ def segment2format( # will effect on post blocks of the same indent, # so should be updated during the parent body_blocks iteration. if self.head_linestrs: - yield "".join(self.head_linestrs), False + yield "".join(self.head_linestrs), None last_keyword = "" line = "" state = state.reset_sort() @@ -729,16 +727,16 @@ def segment2format( # If NO any line before the first line of this block, # black cannot split them: Add one to force splitting if not block.head_lines[0].head_noncoding: - yield "\n", False + yield "\n", None last_keyword = "def" - for line, is_snake in block.segment2format(mode, state): + for line, indent in block.segment2format(mode, state): # record `line` for next useage - yield line, is_snake + yield line, indent elif isinstance(block, SnakemakeBlock): - for line, is_snake in block.segment2format( + for line, indent in block.segment2format( mode, restart_state, last_keyword ): - yield line, is_snake + yield line, indent last_keyword = block.keyword else: last_keyword = "" @@ -747,7 +745,7 @@ def segment2format( last_keyword = "" yield from block.segment2format(mode, state) if self.tail_noncoding: - yield "".join(tokens2linestrs(iter(self.tail_noncoding))), False + yield "".join(tokens2linestrs(iter(self.tail_noncoding))), None @abstractmethod def compilation(self): @@ -932,6 +930,14 @@ def components(self): yield this_symbol +def deindent_lines(old_indent: str, target_indent_level: int, lines: list[str]): + target_indent = TAB * target_indent_level + return [ + target_indent + line[len(old_indent) :] if line.startswith(old_indent) else line + for line in lines + ] + + class SnakemakeBlock(ColonBlock): subautomata = {} deprecated = {} @@ -941,15 +947,15 @@ def components(self) -> Iterator[DocumentSymbol]: def segment2format(self, mode: Mode, state: FormatState, last_keyword=""): """yield: - - [unformated_python_code, Literal[False]] - - [formated_snakemake_code, Literal[True]] + - [unformated_python_code, None] + - [formated_snakemake_code, indent] If state.skip_next is True, or state.fmt_on is False, return unformatted content with proper True/False markers. """ # Get noncoding_lines early to check fmt directives - indent_str = self.indent_str + indent_str = TAB * self.deindent_level assert len(self.head_lines) == 1, "Snakemake keywords should only in one line" noncoding_lines: list[str] = [] last_fmt_on = state.fmt_on @@ -959,7 +965,7 @@ def segment2format(self, mode: Mode, state: FormatState, last_keyword=""): last_keyword = "" else: state = state.update(noncoding_line.lstrip()) - if not state.fmt_on: + if state.not_format: noncoding_lines.append(noncoding_line) else: noncoding_lines.append( @@ -973,21 +979,29 @@ def segment2format(self, mode: Mode, state: FormatState, last_keyword=""): ).splitlines(keepends=True) for line in pre_formatted: if state.found_skip(line): - yield line, False + yield line, None else: - yield indent_str + line.lstrip(), True + yield indent_str + line.lstrip(), self.indent_str else: if not noncoding_lines: - yield "\n", False - yield "".join(noncoding_lines), False + yield "\n", None + yield "".join(noncoding_lines), None else: - yield "".join(noncoding_lines), False + yield "".join(noncoding_lines), None # Check if this block should be skipped from formatting if state.not_format: raw = "".join( - [self.colon_line.body[-1].line] - + [line for block in self.body_blocks for line in block.full_linestrs] + deindent_lines( + self.indent_str, + self.deindent_level, + [self.colon_line.body[-1].line] + + [ + line + for block in self.body_blocks + for line in block.full_linestrs + ], + ) ) # Trailing blank lines from body_blocks belong to the next block's # separator, not this block's content. Strip extra trailing blank @@ -998,13 +1012,13 @@ def segment2format(self, mode: Mode, state: FormatState, last_keyword=""): raw = raw.rstrip("\n") + "\n" else: n_trailing_space = 0 - yield raw, True + yield raw, self.indent_str if n_trailing_space > 0: - yield "\n" * n_trailing_space, False + yield "\n" * n_trailing_space, None else: - yield self.formatted(mode, state), True + yield self.formatted(mode, state), self.indent_str if self.tail_noncoding: - yield "".join(tokens2linestrs(iter(self.tail_noncoding))), False + yield "".join(tokens2linestrs(iter(self.tail_noncoding))), None def formatted(self, mode, state): formatted_prior, post_colon = self.format_head(mode) @@ -1517,7 +1531,11 @@ def format_body(self, mode, state, post_colon): formatted.append(directive) directive = "" if not linelstrip: - formatted.append(line) + formatted.extend( + deindent_lines( + block.indent_str, self.deindent_level + 1, [line] + ) + ) elif not state.sort_direcives: if directives: formatted.extend(self.sort_directives(directives)) @@ -1530,9 +1548,18 @@ def format_body(self, mode, state, post_colon): formatted.append(directive) directive = "" if state.not_format: - formatted.append("".join(block.colon_line.body[-1].line)) - for block_ in block.body_blocks: - formatted.append("".join(block_.full_linestrs)) + formatted.extend( + deindent_lines( + block.indent_str, + self.deindent_level + 1, + [block.colon_line.body[-1].line] + + [ + line + for block in block.body_blocks + for line in block.full_linestrs + ], + ) + ) else: directive += block.formatted(mode, state) if state.sort_direcives: @@ -1769,13 +1796,13 @@ def get_formatted( state = FormatState(sort_direcives=sort_directives or None) # if set to None, it will not be enabled by `# fmt: on` python_codes: list[str] = [] - snakemake_codes: list[str] = [] + snakemake_codes: list[tuple[str, str]] = [] last_str = "" - for segment, is_snake in self.segment2format(mode or self.mode, state): - if is_snake: + for segment, indent_proxy in self.segment2format(mode or self.mode, state): + if indent_proxy is not None: python_codes.append(last_str) last_str = "" - snakemake_codes.append(segment) + snakemake_codes.append((segment, indent_proxy)) else: last_str += segment place_hode_str = "o" * 50 @@ -1783,20 +1810,16 @@ def get_formatted( while place_hode_str in raw_str: place_hode_str *= 2 raw_str = "#\n" - for python_code, snakemake_code in zip(python_codes, snakemake_codes): + for python_code, (snakemake_code, indent) in zip(python_codes, snakemake_codes): if snakemake_code.count("\n") == 1: # must at the end of line - indent_str = extract_line_indent(snakemake_code) - place_hode = f"{indent_str}def l{place_hode_str}1ng(): ...\n" + place_hode = f"{indent}def l{place_hode_str}1ng(): ...\n" else: - indent_str = extract_line_indent(snakemake_code) - place_hode = ( - f"{indent_str}def l{place_hode_str}ng():\n{indent_str} return\n" - ) + place_hode = f"{indent}def l{place_hode_str}ng():\n{indent} return\n" raw_str += python_code + place_hode raw_str += last_str formatted, *formatted_split = format_black(raw_str, mode).split(place_hode_str) final_str = formatted - for formatted, snakemake_code in zip(formatted_split, snakemake_codes): + for formatted, (snakemake_code, _) in zip(formatted_split, snakemake_codes): final_str = final_str.rsplit("\n", 1)[0] + "\n" + snakemake_code if formatted.startswith("1"): final_str += formatted.split("\n", 1)[-1] diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 99f9685..f5b680d 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -15,7 +15,7 @@ from snakefmt.parser.grammar import SingleParam, SnakeGlobal from snakefmt.parser.syntax import COMMENT_SPACING from snakefmt.types import TAB -from tests import setup_formatter +from snakefmt.blocken import setup_formatter def test_emptyInput_emptyOutput(): @@ -388,28 +388,8 @@ def test_param_comment_multiline(self): class TestSimplePythonFormatting: - @mock.patch( - "snakefmt.formatter.Formatter.run_black_format_str", spec=True, return_value="" - ) - def test_commented_snakemake_syntax_formatted_as_python_code(self, mock_method): - """ - Tests this line triggers call to black formatting - """ - formatter = setup_formatter("#configfile: 'foo.yaml'") - - formatter.get_formatted() - mock_method.assert_called_once() - def test_python_code_with_multi_indent_passes(self): python_code = "if p:\n" f"{TAB * 1}for elem in p:\n" f"{TAB * 2}dothing(elem)\n" - # test black gets called - with mock.patch( - "snakefmt.formatter.Formatter.run_black_format_str", - spec=True, - return_value="", - ) as mock_m: - setup_formatter(python_code) - mock_m.assert_called_once() # test black formatting output (here, is identical) formatter = setup_formatter(python_code) @@ -555,17 +535,11 @@ def test_snakemake_code_inside_python_code(self): def test_python_code_after_nested_snakecode_gets_formatted(self): snakecode = "if condition:\n" f'{TAB * 1}include: "a"\n' "b=2\n" - with mock.patch( - "snakefmt.formatter.Formatter.run_black_format_str", spec=True - ) as mock_m: + with mock.patch("snakefmt.blocken.format_black", spec=True) as mock_m: mock_m.return_value = "if condition:\n" - setup_formatter(snakecode) - assert mock_m.call_count == 3 - assert mock_m.call_args_list[1] == mock.call( - 'f("a")', 0, 3, no_nesting=True - ) - - assert mock_m.call_args_list[2] == mock.call("b=2\n", 0) + formatter = setup_formatter(snakecode) + formatter.get_formatted() + assert mock_m.call_count == 2 formatter = setup_formatter(snakecode) expected = ( @@ -577,12 +551,10 @@ def test_python_code_after_nested_snakecode_gets_formatted(self): def test_python_code_before_nested_snakecode_gets_formatted(self): snakecode = "b=2\n" "if condition:\n" f'{TAB * 1}include: "a"\n' - with mock.patch( - "snakefmt.formatter.Formatter.run_black_format_str", spec=True - ) as mock_m: + with mock.patch("snakefmt.blocken.format_black", spec=True) as mock_m: mock_m.return_value = "b=2\nif condition:\n" - setup_formatter(snakecode) - assert mock_m.call_count == 3 + setup_formatter(snakecode).get_formatted() + assert mock_m.call_count == 2 formatter = setup_formatter(snakecode) expected = "b = 2\n" "if condition:\n\n" f'{TAB * 1}include: "a"\n' @@ -863,13 +835,13 @@ def test_tpq_alignment_and_keep_relative_indenting(self): ''' formatter = setup_formatter(snakecode) - expected = f''' -rule a: + # Now the activity is corrected. + expected = f'''rule a: {TAB * 1}shell: {TAB * 2}"""Starts here {TAB * 0} Hello {TAB * 1}World -{TAB * 2} Tabbed + \t\tTabbed {TAB * 1}""" ''' assert formatter.get_formatted() == expected @@ -924,8 +896,7 @@ def test_single_quoted_multiline_string_proper_tabbing(self): 2> log.stderr" """ formatter = setup_formatter(snakecode) - expected = f""" -rule a: + expected = f"""rule a: {TAB * 1}shell: {TAB * 2}"(kallisto quant \\ {TAB * 2}--pseudobam \\ @@ -1062,7 +1033,7 @@ def test_fstring_spacing_of_consecutive_braces(self): formatter = setup_formatter(snakecode) assert formatter.get_formatted() == snakecode - @mock.patch("snakefmt.formatter.Formatter.run_black_format_str", spec=True) + @mock.patch("snakefmt.blocken.format_black", spec=True) def test_invalid_python_recovery(self, mock_format): from snakefmt.exceptions import InvalidPython @@ -1082,7 +1053,7 @@ def side_effect(val, *args, **kwargs): ) formatter = setup_formatter(snakecode) assert formatter.get_formatted() == snakecode - assert mock_format.call_count == 2 + assert mock_format.call_count == 4 def test_fstring_with_equal_sign_inside_function_call(self): """https://github.com/snakemake/snakefmt/issues/220""" @@ -1127,8 +1098,9 @@ def test_comment_after_parameter_keyword_twonewlines(self): def test_comment_after_keyword_kept(self): snakecode = "rule a: # A comment \n" f"{TAB * 1}threads: 4\n" + formatted = "rule a: # A comment\n" f"{TAB * 1}threads: 4\n" formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == snakecode + assert formatter.get_formatted() == formatted def test_comments_after_parameters_kept(self): snakecode = ( @@ -1172,8 +1144,15 @@ def test_comment_below_paramkeyword_stays_untouched(self): f"{TAB * 2}elem1, #The first elem\n" f"{TAB * 2}elem1, #The second elem\n" ) + formatted = ( + "rule all:\n" + f"{TAB * 1}input:\n" + f"{TAB * 2}# A list of inputs\n" + f"{TAB * 2}elem1, # The first elem\n" + f"{TAB * 2}elem1, # The second elem\n" + ) formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == snakecode + assert formatter.get_formatted() == formatted @pytest.mark.xfail( reason="""This is non-trivial to implement, and black does no align the comments @@ -1218,8 +1197,16 @@ def test_inline_formatted_params_relocate_inline_comments(self): f"{TAB * 1}# Threads 1\n" f"{TAB * 1}threads: 8 # Threads 2\n" ) + new_expected = ( + "include: # Include\n" + f"{TAB * 1}file.txt\n\n\n" + "rule all:\n" + f"{TAB * 1}threads: # Threads 1\n" + f"{TAB * 2}8 # Threads 2\n" + ) formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == expected + assert formatter.get_formatted() != expected + assert formatter.get_formatted() == new_expected def test_preceding_comments_in_inline_formatted_params_get_relocated(self): snakecode = ( @@ -1236,8 +1223,16 @@ def test_preceding_comments_in_inline_formatted_params_get_relocated(self): f"{TAB * 1}# Threads3\n" f"{TAB * 1}threads: 8 # Threads 4\n" ) + new_expected = ( + "rule all:\n" + f"{TAB * 1}# Threads1\n" + f"{TAB * 1}threads: # Threads2\n" + f"{TAB * 2}# Threads3\n" + f"{TAB * 2}8 # Threads 4\n" + ) formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == expected + assert formatter.get_formatted() != expected + assert formatter.get_formatted() == new_expected def test_no_inline_comments_stay_untouched(self): snakecode = ( @@ -1247,8 +1242,15 @@ def test_no_inline_comments_stay_untouched(self): f"{TAB * 2}#comment1\n" f"{TAB * 2}#comment2\n" ) + formatted = ( + "rule all:\n" + f"{TAB * 1}input:\n" + f"{TAB * 2}p=2,\n" + f"{TAB * 2}# comment1\n" + f"{TAB * 2}# comment2\n" + ) formatter = setup_formatter(snakecode) - assert formatter.get_formatted() == snakecode + assert formatter.get_formatted() == formatted def test_snakecode_after_indented_comment_does_not_get_unindented(self): """https://github.com/snakemake/snakefmt/issues/159#issue-1441174995""" @@ -1478,7 +1480,7 @@ def test_buffer_with_lone_comment(self): def test_comment_inside_python_code_sticks_to_rule(self): snakecode = f"if p:\n" f"{TAB * 1}# A comment\n" f'{TAB * 1}include: "a"\n' - expected = f"if p:\n\n" f"{TAB * 1}# A comment\n" f'{TAB * 1}include: "a"\n' + expected = f"if p:\n" f"{TAB * 1}# A comment\n" f'{TAB * 1}include: "a"\n' assert setup_formatter(snakecode).get_formatted() == expected def test_comment_below_keyword_gets_spaced(self): @@ -1666,8 +1668,7 @@ def test_shell_indention_long_line(self): class TestStorage: def test_storage(self): - code = textwrap.dedent(""" - storage http_local: + code = textwrap.dedent(""" storage http_local: provider="http", keep_local=True, """) @@ -1874,20 +1875,20 @@ def test_sorting_with_inline_parameter_comments(self): f"{TAB}name: 'n'\n", "module other:\n" f'{TAB}name: "n"\n' - f"{TAB}pathvars:\n" - f'{TAB * 2}["pv"],\n' f"{TAB}snakefile:\n" f'{TAB * 2}"s"\n' - f"{TAB}config:\n" - f'{TAB * 2}"c"\n' + f"{TAB}meta_wrapper:\n" + f'{TAB * 2}"wrapper"\n' f"{TAB}skip_validation:\n" f"{TAB * 2}True\n" + f"{TAB}config:\n" + f'{TAB * 2}"c"\n' + f"{TAB}pathvars:\n" + f'{TAB * 2}["pv"],\n' f"{TAB}prefix:\n" f'{TAB * 2}"p"\n' f"{TAB}replace_prefix:\n" - f'{TAB * 2}"rp"\n' - f"{TAB}meta_wrapper:\n" - f'{TAB * 2}"wrapper"\n', + f'{TAB * 2}"rp"\n', ) def test_sorting_module(self): @@ -1975,7 +1976,7 @@ def test_invalid_python_error_eof(): msg = str(excinfo.value) assert "Black error:" in msg assert ": 3:" in msg - assert "Note reported line number may be an approximation" in msg + assert "Note reported line number may be incorrect" in msg @mock.patch("black.format_str", spec=True) @@ -2032,7 +2033,7 @@ def side_effect(*args, **kwargs): assert "Custom black error without line number" in msg -@mock.patch("snakefmt.formatter.Formatter.run_black_format_str", spec=True) +@mock.patch("snakefmt.blocken.format_black", spec=True) def test_multiline_fallback(mock_format): from snakefmt.exceptions import InvalidPython @@ -2188,6 +2189,22 @@ def test_fmt_off_on_in_run(self): "z = [4, 5, 6]\n" ) assert setup_formatter(code).get_formatted() == expected + + @pytest.mark.xfail( + reason="Current black version doesn't handle this case correctly" + ) + def test_fmt_off_on_in_run_fail(self): + code = ( + "# ?\n" + "x = [1,2,3]\n" + "# fmt: off\n" + "y = [ 1, 2]\n" + "s = f'''\n" + " {y} \n" + " '''\n" + "# fmt: on\n" + "z = [4,5,6]\n" + ) bad_indent = " " snakecode = "rule:\n" " run:\n" + ( "".join(f"{bad_indent}{i}\n" for i in code.splitlines()) @@ -2376,13 +2393,12 @@ def test_fmt_skip_in_directive(self): expected = ( "rule a:\n" f"{TAB}params:\n" - f"{TAB * 2}x=[1, 2, 3], # fmt: skip\n" - f"{TAB}input:\n" - f'{TAB * 2}a="sth", # fmt: skip\n' + f"{TAB * 2}x = [ 1,2,3] # fmt: skip\n" + f"{TAB * 2},\n" + f"{TAB}input: a= 'sth' # fmt: skip\n" ) # TODO: currently `# fmt: skip` in directives is not supported - assert formatter.get_formatted() # == expected - assert expected + assert formatter.get_formatted() == expected class TestFmtOffSort: @@ -2404,6 +2420,13 @@ def test_fmt_off_sort(self): expected = "# fmt: off[sort]\n" + setup_formatter(code).get_formatted() assert setup_formatter(code1, sort_params=True).get_formatted() == expected + # `# fmt: off[sort]` disables sorting for the second rule + code2 = code1 + "\n\n# nothing\n" + code + expected2 = ( + expected + "\n\n# nothing\n" + setup_formatter(code).get_formatted() + ) + assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 + # `# fmt: on[sort]` re-enables sorting after `# fmt: off[sort]` code2 = code1 + "\n\n# fmt: on[sort]\n" + code expected2 = expected + "\n\n# fmt: on[sort]\n" + formatted @@ -2415,9 +2438,10 @@ def test_fmt_off_sort(self): assert setup_formatter(code2, sort_params=True).get_formatted() == expected2 def test_fmt_off_sort_dedent(self): - """`# fmt: on` or `on[sort]` at a deeper indentation level than `off[sort]` - has no effect""" - code1, formatted1 = TestSortFormatting.sorting_comprehensive + """`# fmt: on` at a deeper indentation level than `off[sort]` has no effect + but `# fmt: on[sort]` does + """ + code1, formatted0 = TestSortFormatting.sorting_comprehensive formatted1 = setup_formatter(code1).get_formatted() code2, formatted2 = TestSortFormatting.sort_with_comments formatted2 = setup_formatter(code2).get_formatted() @@ -2432,7 +2456,6 @@ def test_fmt_off_sort_dedent(self): expected = ( "# fmt: off[sort]\n" "if 1:\n" - "\n" f"{TAB}# fmt: on\n" + "".join(TAB + i for i in formatted1.splitlines(keepends=True)).rstrip() + "\n" @@ -2457,7 +2480,6 @@ def test_fmt_off_sort_on_noeffect(self): expected = ( formatted1 + "\n\n" "if 1:\n" - "\n" f"{TAB}# fmt: off[sort]\n" + "".join(TAB + i for i in formatted2.splitlines(keepends=True)) + "\n\n" @@ -2707,6 +2729,7 @@ def test_rule_if2_rule(self): f"{TAB * 2}" + i for i in format2.splitlines(keepends=True) ).rstrip("\n") + "\n" + + "\n" f"{TAB * 1}# fmt: off[next]\n" + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)) + "\n" @@ -2782,6 +2805,7 @@ def test_fmt_off_next_in_if(self): + format3 ) assert formatter.get_formatted() == expected + # will no longer skip formatting the entire block formatter = setup_formatter( code1.rstrip("\n") + "\n# fmt: off[next]\n" "if 1:\n" @@ -2796,7 +2820,16 @@ def test_fmt_off_next_in_if(self): + "\n\n\n" + format3 ) - assert formatter.get_formatted() == expected + assert formatter.get_formatted() != expected + # instead, only effect if right before the snakemake keyword. + expected = ( + format1 + "\n\n# fmt: off[next]\n" + "if 1:\n" + + "".join(f"{TAB * 1}" + i for i in format2.splitlines(keepends=True)) + + "\n\n\n" + + format3 + ) + assert formatter.get_formatted() != expected def test_fmt_off_next_in_2if(self): code1, format1 = TestSimpleParamFormatting.example_shell_newline @@ -2832,6 +2865,7 @@ def test_fmt_off_next_in_2if(self): format1.rstrip("\n") + "\n" "\n\n" "if 1:\n" + "\n" f"{TAB * 1}# fmt: off[next]\n" + "".join(f"{TAB * 1}" + i for i in code2.splitlines(keepends=True)).strip( "\n" @@ -2862,9 +2896,11 @@ def test_fmt_off_2(self): f"{TAB}rule a:\n" f"{TAB * 2}input:\n" f'{TAB * 3}"foo",\n' + "\n" f"{TAB}# fmt: off[next]\n" f"{TAB}rule b:\n" f'{TAB} input: "bar"\n' + "\n" f"{TAB}# fmt: off[next]\n" f"{TAB}rule c:\n" f'{TAB} input: "baz"\n' From 6f1050dd807126210fa1b9ae136d695bcf2ed2f0 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 01:00:40 +0800 Subject: [PATCH 46/53] fix: resolve conflict --- snakefmt/parser/parser.py | 148 +++++------------------------------ snakefmt/parser/syntax.py | 13 ++-- tests/test_formatter.py | 160 -------------------------------------- 3 files changed, 26 insertions(+), 295 deletions(-) diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index 95c88d9..e43f7eb 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -1,8 +1,6 @@ import re -import re import tokenize from abc import ABC, abstractmethod -from tokenize import TokenInfo from typing import Literal, NamedTuple, Optional from snakefmt.exceptions import UnsupportedSyntax @@ -28,7 +26,7 @@ class FMT_DIRECTIVE(NamedTuple): modifiers: list[str] @classmethod - def from_token(cls, token: TokenInfo): + def from_token(cls, token: Token): if token.type != tokenize.COMMENT: return None return cls.from_str(token.string) @@ -48,7 +46,7 @@ def from_str(cls, token_string: str): return cls(disable, mods) # type: ignore[arg-type] -def split_token_lines(token: TokenInfo): +def split_token_lines(token: tokenize.TokenInfo): """Token can be multiline. e.g., `f'''\\nplaintext\\n'''` has these tokens: @@ -66,7 +64,7 @@ def split_token_lines(token: TokenInfo): ) -def not_a_comment_related_token(token: TokenInfo): +def not_a_comment_related_token(token: Token): return token.type not in { tokenize.COMMENT, tokenize.NEWLINE, @@ -84,7 +82,7 @@ def check_indent(line: str, indents: list[str]) -> int: raise SyntaxError("Unexpected indent") -def token_indents_updated(token: TokenInfo, indents: list[str]) -> bool: +def token_indents_updated(token: Token, indents: list[str]) -> bool: if token.type == tokenize.INDENT: line = token.line indent = line[: len(line) - len(line.lstrip())] @@ -118,12 +116,12 @@ def __init__(self, fpath_or_stream, rulecount=0): self.rulecount = rulecount self.lines = 0 - def __next__(self) -> TokenInfo: + def __next__(self) -> Token: if self._buffered_tokens: return self._buffered_tokens.pop() return next(self._live_tokens) - def denext(self, token: TokenInfo) -> None: + def denext(self, token: Token) -> None: self._buffered_tokens.append(token) @@ -134,7 +132,7 @@ def comment_start(string: str) -> bool: class Status(NamedTuple): """Communicates the result of parsing a chunk of code""" - token: TokenInfo + token: Token block_indent: int # indent of the start of the parsed block cur_indent: int # indent of the end of the parsed block buffer: str @@ -160,7 +158,6 @@ class Parser(ABC): and the alternation in `:self.last_block_was_snakecode`. """ - def __init__(self, snakefile: Snakefile, sort_directives=False): def __init__(self, snakefile: Snakefile, sort_directives=False): self.context = Context( SnakeGlobal(), KeywordSyntax("Global", keyword_indent=0, accepts_py=True) @@ -173,7 +170,7 @@ def __init__(self, snakefile: Snakefile, sort_directives=False): self.block_indent = 0 self.queriable = True self.in_fstring = False - self.last_token: Optional[TokenInfo] = None + self.last_token: Optional[Token] = None # for `# fmt: off`, (indent, kind) # kind: "region" = off/on, "sort" = off[sort]/on[sort], "next" self.fmt_off: Optional[tuple[int, Literal["next", "region"]]] = None @@ -233,53 +230,15 @@ def __init__(self, snakefile: Snakefile, sort_directives=False): and status.cur_indent < self.sort_off_indent ): self.sort_off_indent = None - if fmt_label := FMT_DIRECTIVE.from_token(status.token): - if fmt_label.disable: - if not fmt_label.modifiers: - self.fmt_off = (status.cur_indent, "region") - self.fmt_off_expected_indent = status.token.line[ - : col_nb(status.token) - ] - elif "next" in fmt_label.modifiers: - self.fmt_off = (status.cur_indent, "next") - self.fmt_off_expected_indent = status.token.line[ - : col_nb(status.token) - ] - elif "sort" in fmt_label.modifiers: - self.sort_off_indent = status.cur_indent - elif self._check_fmt_on(fmt_label, status.token) == "sort": - if not self.from_python and self.keyword_indent: - # multiline string is impossible here - # and we assume that origin_indent is the same indent - # of this comment - token_indent = status.cur_indent - sort_on = token_indent * TAB + status.token.line.strip() + "\n" - self.flush_sort_signal(sort_on) - status = self.get_next_queriable() - self.buffer = status.buffer - continue - elif self.fmt_off and status.cur_indent <= self.fmt_off[0]: - self.fmt_off = None - elif ( - self.sort_off_indent is not None - and status.cur_indent < self.sort_off_indent - ): - self.sort_off_indent = None if self.vocab.recognises(keyword): - new_vocab, new_syntax_cls = self.vocab.get(keyword) - is_context_kw = new_vocab is not None and issubclass( - new_syntax_cls, KeywordSyntax - ) new_vocab, new_syntax_cls = self.vocab.get(keyword) is_context_kw = new_vocab is not None and issubclass( new_syntax_cls, KeywordSyntax ) if status.cur_indent > self.keyword_indent: - if self.syntax.from_python or status.pythonable: if self.syntax.from_python or status.pythonable: self.from_python = True - elif self.from_python and not is_context_kw: elif self.from_python and not is_context_kw: # We are exiting python context, so force spacing out keywords self.last_recognised_keyword = "" @@ -305,20 +264,6 @@ def __init__(self, snakefile: Snakefile, sort_directives=False): self.block_indent = status.block_indent self.last_block_was_snakecode = self.keyword_indent > 0 self.buffer = status.buffer.lstrip() - elif self.fmt_off: - self.flush_buffer( - from_python=True, - in_global_context=self.in_global_context, - ) - if self.keyword_indent > 0: - self.syntax.add_processed_keyword(status.token, keyword) - status = self._consume_fmt_off( - status.token, min_indent=status.cur_indent - ) - if self.last_block_was_snakecode and not status.eof: - self.block_indent = status.block_indent - self.last_block_was_snakecode = self.keyword_indent > 0 - self.buffer = status.buffer.lstrip() else: if not self.syntax.accepts_python_code and not comment_start(keyword): raise SyntaxError( @@ -326,8 +271,6 @@ def __init__(self, snakefile: Snakefile, sort_directives=False): f"in {self.syntax.keyword_name} definition" ) else: - source, status = self._consume_python(status.token) - self.buffer += source source, status = self._consume_python(status.token) self.buffer += source if self.last_block_was_snakecode and not status.eof: @@ -385,7 +328,6 @@ def flush_buffer( final_flush: bool = False, in_global_context: bool = False, exiting_keywords: bool = False, - exiting_keywords: bool = False, ) -> None: """Processes the text in :self.buffer:""" @@ -405,7 +347,7 @@ def post_process_keyword(self) -> None: eg after finishing parsing a 'rule:'""" def _consume_python( - self, start_token: TokenInfo, vocab_recognises=True, added_indent: str = "" + self, start_token: Token, vocab_recognises=True, added_indent: str = "" ) -> tuple[str, Status]: """Collect Python source lines until a snakemake keyword at correct indent, or dedent below min_indent, or EOF. @@ -427,7 +369,7 @@ def _consume_python( consuming_next = False # used with stop_at_min seen_next_block_keyword = False - def _init_min_indent(token: TokenInfo): + def _init_min_indent(token: Token): nonlocal min_indent if not comment_start(token.string): while not token.line.startswith(self.indents[-1]): @@ -439,7 +381,7 @@ def _init_min_indent(token: TokenInfo): try: token = next(self.snakefile) except StopIteration: - eof_token = TokenInfo(tokenize.ENDMARKER, "", (0, 0), (0, 0), "") + eof_token = Token(tokenize.ENDMARKER, "", (0, 0), (0, 0), "") self.snakefile.denext(eof_token) break if min_indent == -1: @@ -521,9 +463,7 @@ def _init_min_indent(token: TokenInfo): pythonable=next_status.pythonable or bool(verbatim.strip()) ) - def _detent_last_indent( - self, token: TokenInfo, last_indent_token: Optional[TokenInfo] - ): + def _detent_last_indent(self, token: Token, last_indent_token: Optional[Token]): """ A whole keyword block consumed, hand the next same-level block back to main loop. @@ -534,7 +474,7 @@ def _detent_last_indent( self.indents.pop() self.syntax.cur_indent = len(self.indents) - 1 - def _consume_fmt_off_in_python(self, token: TokenInfo, lines: dict[int, str]): + def _consume_fmt_off_in_python(self, token: Token, lines: dict[int, str]): """ Consume `# fmt: off/on` directives within Python code. lines is needed to: @@ -595,7 +535,7 @@ def flush_fmt_off_region(self, verbatim: str) -> None: def flush_sort_signal(self, verbatim: str) -> None: """Commit fmt:on sort signal directly.""" - def _consume_fmt_off(self, start_token: TokenInfo, min_indent: int): + def _consume_fmt_off(self, start_token: Token, min_indent: int): verbatim, next_status = self._consume_python( start_token, vocab_recognises=False, added_indent=TAB * min_indent ) @@ -641,7 +581,6 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: if new_vocab is not None and issubclass(new_syntax, KeywordSyntax): in_global_context = self.in_global_context saved_context: Context = self.context - saved_context: Context = self.context # 'use' keyword can not enter a new context self.context = Context( new_vocab(), @@ -663,7 +602,6 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: self.queriable = True self.block_indent = self.syntax.keyword_indent + 1 status = self.get_next_queriable() - if self.context.syntax.accepts_python_code: if self.context.syntax.accepts_python_code: self.buffer += status.buffer.lstrip("\n\r") else: @@ -681,17 +619,12 @@ def process_keyword(self, status: Status, from_python: bool = False) -> Status: self.process_keyword_param(param_context, self.in_global_context) self.syntax.add_processed_keyword(status.token, status.token.string) cur_indent = param_context.cur_indent - if param_context.token.type == tokenize.COMMENT and not param_context.eof: - cur_indent = self._determine_comment_indent(param_context.token) - cur_indent = param_context.cur_indent if param_context.token.type == tokenize.COMMENT and not param_context.eof: cur_indent = self._determine_comment_indent(param_context.token) return Status( param_context.token, cur_indent, cur_indent, - cur_indent, - cur_indent, status.buffer, param_context.eof, self.from_python, @@ -707,8 +640,6 @@ def context_exit(self, status: Status) -> None: if callback_context.syntax.accepts_python_code: # Flushes any code inside 'run' directive self.flush_buffer(exiting_keywords=True) - # Flushes any code inside 'run' directive - self.flush_buffer(exiting_keywords=True) else: callback_context.syntax.check_empty() self.context = self.context_stack[-1] @@ -724,7 +655,7 @@ def context_exit(self, status: Status) -> None: while len(self.indents) - 1 > status.cur_indent: self.indents.pop() - def _determine_comment_indent(self, comment_token: TokenInfo) -> int: + def _determine_comment_indent(self, comment_token: Token) -> int: """ This function returns the real indent level of a comment token and update self.indents if needed, @@ -748,7 +679,7 @@ def _determine_comment_indent(self, comment_token: TokenInfo) -> int: then put all peeked tokens back. """ # ── Step 1: peek ahead to find follow_indent ──────────────────────── - peeked: list[TokenInfo] = [] + peeked: list[Token] = [] saved_indents = list(self.indents) follow_indent = len(self.indents) - 1 try: @@ -777,7 +708,7 @@ def _determine_comment_indent(self, comment_token: TokenInfo) -> int: # highest indent level fitting within the comment's column. return max(check_indent(comment_token.line, self.indents), follow_indent) - def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: TokenInfo): + def _check_fmt_on(self, fmt_label: FMT_DIRECTIVE, token: Token): """Determine which fmt: on can turn on formatting""" if self.fmt_off: # `# fmt: on[sort]` no effect @@ -808,21 +739,17 @@ def get_next_queriable(self) -> Status: newline = False pythonable = False block_indent = -1 - prev_token: Optional[TokenInfo] = TokenInfo( - tokenize.NAME, "", (-1, -1), (-1, -1), "" - ) + prev_token: Optional[Token] = Token(tokenize.NAME, "", (-1, -1), (-1, -1), "") while True: token = next(self.snakefile) self.last_token = token self.in_fstring = fstring_processing(token, prev_token, self.in_fstring) if block_indent == -1 and not_a_comment_related_token(token): block_indent = self.cur_indent - if token_indents_updated(token, self.indents): if token_indents_updated(token, self.indents): prev_token = None newline = True self.syntax.cur_indent = len(self.indents) - 1 - self.syntax.cur_indent = len(self.indents) - 1 continue elif token.type == tokenize.ENDMARKER: return Status( @@ -830,14 +757,6 @@ def get_next_queriable(self) -> Status: ) elif token.type == tokenize.COMMENT: fmt_dir = FMT_DIRECTIVE.from_token(token) - if ( - fmt_dir - and col_nb(token) == 0 - and not (fmt_dir.disable and "next" in fmt_dir.modifiers) - ): - # col-0 comments report cur_indent=0 to trigger context_exit; - # fmt directives at other columns report actual cur_indent. - fmt_dir = FMT_DIRECTIVE.from_token(token) if ( fmt_dir and col_nb(token) == 0 @@ -872,32 +791,6 @@ def get_next_queriable(self) -> Status: return Status( token, block_indent, effective_indent, buffer, False, pythonable ) - # Comments arrive in the token stream *before* any following - # INDENT/DEDENT tokens, so self.cur_indent still reflects the - # previous (potentially higher) level. Delegate to - # _determine_comment_indent which peeks ahead and applies the - # two snapping rules. - effective_indent = self._determine_comment_indent(token) - self.syntax.cur_indent = effective_indent - if effective_indent < max(self.keyword_indent, self.block_indent): - return Status( - token, block_indent, effective_indent, buffer, False, pythonable - ) - # `# fmt: off[next]` always needs parser-level handling. - # Plain `# fmt: off` is parser-level only in global context; in other - # Python contexts it is handled by Black. - if ( - fmt_dir - and fmt_dir.disable - and ( - "next" in fmt_dir.modifiers - or "sort" in fmt_dir.modifiers - or (not fmt_dir.modifiers and self.in_global_context) - ) - ): - return Status( - token, block_indent, effective_indent, buffer, False, pythonable - ) elif is_newline(token): self.queriable, newline = True, True @@ -916,11 +809,6 @@ def get_next_queriable(self) -> Status: else: buffer += TAB * self.effective_indent - if ( - (token.type == tokenize.NAME or token.string == "@") - and self.queriable - and not self.in_fstring - ): if ( (token.type == tokenize.NAME or token.string == "@") and self.queriable diff --git a/snakefmt/parser/syntax.py b/snakefmt/parser/syntax.py index 36462a6..8cb5642 100644 --- a/snakefmt/parser/syntax.py +++ b/snakefmt/parser/syntax.py @@ -6,7 +6,6 @@ from abc import ABC, abstractmethod from collections import OrderedDict from re import match as re_match -from tokenize import TokenInfo from typing import ClassVar, NamedTuple, Optional, Type from snakefmt import fstring_tokeniser_in_use @@ -21,7 +20,14 @@ SyntaxFormError, TooManyParameters, ) -from snakefmt.types import COMMENT_SPACING, TokenIterator, col_nb, line_nb, not_empty +from snakefmt.types import ( + COMMENT_SPACING, + TokenInfo, + TokenIterator, + col_nb, + line_nb, + not_empty, +) # ___Token parsing___# BRACKETS_OPEN = {"(", "[", "{"} @@ -513,17 +519,14 @@ def parse_params(self, snakefile: TokenIterator): self.flush_param(cur_param, skip_empty=True) self.eof = True break - if self.check_exit(cur_param, snakefile): if self.check_exit(cur_param, snakefile): break if self.num_params() == 0: raise NoParametersError(f"{self.line_nb}In {self.keyword_name} definition.") - def check_exit(self, cur_param: Parameter, snakefile: TokenIterator): def check_exit(self, cur_param: Parameter, snakefile: TokenIterator): exit = False - if not self.found_newline or not self.token: if not self.found_newline or not self.token: return exit if not_empty(self.token): diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 9ea4a2b..f5b680d 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -9,7 +9,6 @@ import black import black.parsing -import black.parsing import pytest from snakefmt.exceptions import InvalidPython @@ -58,33 +57,7 @@ def test_single_param_keyword_stays_on_same_line(self): f"{TAB * 1}shell:\n" f'{TAB * 2}"for i in $(seq 1 5);" "do echo $i;" "done"\n', ) - example_shell_newline = ( - "rule a:\n" - f'{TAB * 1}shell: "for i in $(seq 1 5);"\n' - f'{TAB * 2}"do echo $i;"\n' - f'{TAB * 2}"done"', - "rule a:\n" - f"{TAB * 1}shell:\n" - f'{TAB * 2}"for i in $(seq 1 5);" "do echo $i;" "done"\n', - ) - def test_shell_param_newline_indented(self): - formatter = setup_formatter(self.example_shell_newline[0]) - assert formatter.get_formatted() == self.example_shell_newline[1] - - example_params_newline = ( - f"rule b: \n" - f'{TAB * 1}input: "a", "b",\n' - f'{TAB * 4}"c"\n' - f'{TAB * 1}wrapper: "mywrapper"', - f"rule b:\n" - f"{TAB * 1}input:\n" - f'{TAB * 2}"a",\n' - f'{TAB * 2}"b",\n' - f'{TAB * 2}"c",\n' - f"{TAB * 1}wrapper:\n" - f'{TAB * 2}"mywrapper"\n', - ) def test_shell_param_newline_indented(self): formatter = setup_formatter(self.example_shell_newline[0]) assert formatter.get_formatted() == self.example_shell_newline[1] @@ -106,27 +79,7 @@ def test_shell_param_newline_indented(self): def test_single_param_keyword_in_rule_gets_newline_indented(self): formatter = setup_formatter(self.example_params_newline[0]) assert formatter.get_formatted() == self.example_params_newline[1] - def test_single_param_keyword_in_rule_gets_newline_indented(self): - formatter = setup_formatter(self.example_params_newline[0]) - assert formatter.get_formatted() == self.example_params_newline[1] - - example_input_threads_newline = ( - f"rule c: \n" - f'{TAB * 1}input: "c"\n' - f"{TAB * 1}threads:\n" - f"{TAB * 2}20\n" - f"{TAB * 1}default_target:\n" - f"{TAB * 2}True\n", - f"rule c:\n" - f"{TAB * 1}input:\n" - f'{TAB * 2}"c",\n' - f"{TAB * 1}threads: 20\n" - f"{TAB * 1}default_target: True\n", - ) - def test_single_numeric_param_keyword_in_rule_stays_on_same_line(self): - formatter = setup_formatter(self.example_input_threads_newline[0]) - assert formatter.get_formatted() == self.example_input_threads_newline[1] example_input_threads_newline = ( f"rule c: \n" f'{TAB * 1}input: "c"\n' @@ -1785,29 +1738,6 @@ class TestSortFormatting: f'{TAB * 2}print("hello world")\n', ) - def test_sorting_of_params(self): - snakecode = self.sort_simple[0] + ( - sort_simple = ( - "rule a:\n" - f"{TAB * 1}# annots\n" - f"{TAB * 1}threads: 1\n" - f'{TAB * 1}log: "b",\n' - f'{TAB * 1}output: "a", "fsdfdsdfd", "ccc"\n' - f"{TAB * 1}run:\n" - f'{TAB * 2}print("hello world")\n', - "rule a:\n" - f"{TAB * 1}output:\n" - f'{TAB * 2}"a",\n' - f'{TAB * 2}"fsdfdsdfd",\n' - f'{TAB * 2}"ccc",\n' - f"{TAB * 1}log:\n" - f'{TAB * 2}"b",\n' - f"{TAB * 1}# annots\n" - f"{TAB * 1}threads: 1\n" - f"{TAB * 1}run:\n" - f'{TAB * 2}print("hello world")\n', - ) - def test_sorting_of_params(self): snakecode = self.sort_simple[0] + ( "if 2:\n" @@ -1825,8 +1755,6 @@ def test_sorting_of_params(self): f'{TAB * 1}print("error")\n' ) formatter = setup_formatter(snakecode, sort_params=True) - expected = self.sort_simple[1] + ( - f"\n\n" expected = self.sort_simple[1] + ( f"\n\n" "if 2:\n" @@ -1931,89 +1859,6 @@ def test_sorting_with_comments_preservation(self): f'{TAB * 2}"echo"\n', ) - def test_sorting_with_inline_parameter_comments(self): - formatter = setup_formatter(self.sort_inline_comments[0], sort_params=True) - assert formatter.get_formatted() == self.sort_inline_comments[1] - sorting_comprehensive = ( - "rule all:\n" - f"{TAB}params: p=1\n" - f"{TAB}resources: mem_mb=100\n" - f"{TAB}threads: 4\n" - f"{TAB}conda: 'env.yaml'\n" - f"{TAB}message: 'finishing'\n" - f"{TAB}log: 'log.txt'\n" - f"{TAB}output: 'out.txt'\n" - f"{TAB}# Important input\n" - f"{TAB}input: 'in.txt'\n" - f"{TAB}name: 'myrule'\n" - f"{TAB}shell: 'echo done'\n", - "rule all:\n" - f"{TAB}name:\n" - f'{TAB * 2}"myrule"\n' - f"{TAB}# Important input\n" - f"{TAB}input:\n" - f'{TAB * 2}"in.txt",\n' - f"{TAB}output:\n" - f'{TAB * 2}"out.txt",\n' - f"{TAB}log:\n" - f'{TAB * 2}"log.txt",\n' - f"{TAB}conda:\n" - f'{TAB * 2}"env.yaml"\n' - f"{TAB}threads: 4\n" - f"{TAB}resources:\n" - f"{TAB * 2}mem_mb=100,\n" - f"{TAB}params:\n" - f"{TAB * 2}p=1,\n" - f"{TAB}message:\n" - f'{TAB * 2}"finishing"\n' - f"{TAB}shell:\n" - f'{TAB * 2}"echo done"\n', - ) - - def test_sorting_comprehensive(self): - formatter = setup_formatter(self.sorting_comprehensive[0], sort_params=True) - assert formatter.get_formatted() == self.sorting_comprehensive[1] - - sort_with_comments = ( - "rule complex:\n" - f"{TAB}# Action comment\n" - f"{TAB}shell: 'do something'\n" - f"{TAB}# Resource comment\n" - f"{TAB}resources: res=1\n" - f"{TAB}# Input comment\n" - f"{TAB}input: 'i'\n", - "rule complex:\n" - f"{TAB}# Input comment\n" - f"{TAB}input:\n" - f'{TAB * 2}"i",\n' - f"{TAB}# Resource comment\n" - f"{TAB}resources:\n" - f"{TAB * 2}res=1,\n" - f"{TAB}# Action comment\n" - f"{TAB}shell:\n" - f'{TAB * 2}"do something"\n', - ) - - def test_sorting_with_comments_preservation(self): - """Comments stay with their keywords""" - formatter = setup_formatter(self.sort_with_comments[0], sort_params=True) - assert formatter.get_formatted() == self.sort_with_comments[1] - - sort_inline_comments = ( - "rule inline_comments:\n" - f"{TAB}shell: 'echo'\n" - f"{TAB}params:\n" - f"{TAB * 2}p=1, # parameter comment\n" - f"{TAB}input: 'i'\n", - "rule inline_comments:\n" - f"{TAB}input:\n" - f'{TAB * 2}"i",\n' - f"{TAB}params:\n" - f"{TAB * 2}p=1, # parameter comment\n" - f"{TAB}shell:\n" - f'{TAB * 2}"echo"\n', - ) - def test_sorting_with_inline_parameter_comments(self): formatter = setup_formatter(self.sort_inline_comments[0], sort_params=True) assert formatter.get_formatted() == self.sort_inline_comments[1] @@ -2062,13 +1907,10 @@ def test_sorting_checkpoint(self): "checkpoint map_reads:\n" f"{TAB}input:\n" f'{TAB * 2}"in.txt",\n' - f'{TAB * 2}"in.txt",\n' f"{TAB}output:\n" f'{TAB * 2}"out.txt",\n' - f'{TAB * 2}"out.txt",\n' f"{TAB}shell:\n" f'{TAB * 2}"echo"\n' - f'{TAB * 2}"echo"\n' ) assert formatter.get_formatted() == expected @@ -2160,13 +2002,11 @@ def side_effect(*args, **kwargs): formatter.black_mode = black.Mode() formatter.from_python = False formatter.fmt_off = None - formatter.fmt_off = None from snakefmt.parser.parser import Context from snakefmt.parser.syntax import KeywordSyntax formatter.context = Context( None, KeywordSyntax("Global", keyword_indent=0, accepts_py=True) # type: ignore - None, KeywordSyntax("Global", keyword_indent=0, accepts_py=True) # type: ignore ) # Manually set last_token to something that isn't DEDENT/ENDMARKER formatter.last_token = tokenize.TokenInfo( From 6d083b103e632d43284c9c1f68eca8efd7eb674e Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 01:04:08 +0800 Subject: [PATCH 47/53] test: revert changes --- snakefmt/parser/parser.py | 2 +- snakefmt/parser/syntax.py | 39 ++++++++++++++++++--------------------- snakefmt/types.py | 11 +++++++---- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/snakefmt/parser/parser.py b/snakefmt/parser/parser.py index e43f7eb..3566865 100644 --- a/snakefmt/parser/parser.py +++ b/snakefmt/parser/parser.py @@ -14,7 +14,7 @@ is_newline, re_add_curly_bracket_if_needed, ) -from snakefmt.types import TAB, TokenIterator, col_nb +from snakefmt.types import TAB, Token, TokenIterator, col_nb _FMT_DIRECTIVE_RE = re.compile( r"^# fmt: (off|on)(?:\[(\w+(?:,\s*\w+)*)\])?(?=$|\s{2}|\s#)" diff --git a/snakefmt/parser/syntax.py b/snakefmt/parser/syntax.py index 8cb5642..ae90fc5 100644 --- a/snakefmt/parser/syntax.py +++ b/snakefmt/parser/syntax.py @@ -22,7 +22,7 @@ ) from snakefmt.types import ( COMMENT_SPACING, - TokenInfo, + Token, TokenIterator, col_nb, line_nb, @@ -110,7 +110,7 @@ def _extract_line_mid( return t -def re_add_curly_bracket_if_needed(token: TokenInfo) -> str: +def re_add_curly_bracket_if_needed(token: Token) -> str: result = "" if ( fstring_tokeniser_in_use @@ -125,7 +125,7 @@ def re_add_curly_bracket_if_needed(token: TokenInfo) -> str: def fstring_processing( - token: TokenInfo, prev_token: Optional[TokenInfo], in_fstring: bool + token: Token, prev_token: Optional[Token], in_fstring: bool ) -> bool: """ Returns True if we are entering, or have already entered and not exited, @@ -140,7 +140,7 @@ def fstring_processing( def operator_skip_spacing( - prev_token: TokenInfo, token: TokenInfo, in_fstring: bool = False + prev_token: Token, token: Token, in_fstring: bool = False ) -> bool: # Check for f-string conversion specifiers: ! followed by r, s, or a if ( @@ -170,7 +170,7 @@ def operator_skip_spacing( def add_token_space( - prev_token: Optional[TokenInfo], token: TokenInfo, in_fstring: bool = False + prev_token: Optional[Token], token: Token, in_fstring: bool = False ) -> bool: result = False if prev_token is not None: @@ -183,27 +183,27 @@ def add_token_space( return result -def is_colon(token: TokenInfo): +def is_colon(token: Token): return token.type == tokenize.OP and token.string == ":" -def is_newline(token: TokenInfo): +def is_newline(token: Token): return token.type == tokenize.NEWLINE or token.type == tokenize.NL -def brack_open(token: TokenInfo): +def brack_open(token: Token): return token.type == tokenize.OP and token.string in BRACKETS_OPEN -def brack_close(token: TokenInfo): +def brack_close(token: Token): return token.type == tokenize.OP and token.string in BRACKETS_CLOSE -def is_equal_sign(token: TokenInfo): +def is_equal_sign(token: Token): return token.type == tokenize.OP and token.string == "=" -def is_comma_sign(token: TokenInfo): +def is_comma_sign(token: Token): return token.type == tokenize.OP and token.string == "," @@ -212,7 +212,7 @@ class Parameter: Holds the value of a parameter-accepting keyword """ - def __init__(self, token: TokenInfo): + def __init__(self, token: Token): self.line_nb = line_nb(token) self.col_nb = col_nb(token) self.key = "" @@ -247,10 +247,7 @@ def has_value(self) -> bool: return len(self.value) > 0 def add_elem( - self, - prev_token: Optional[TokenInfo], - token: TokenInfo, - in_fstring: bool = False, + self, prev_token: Optional[Token], token: Token, in_fstring: bool = False ): if add_token_space(prev_token, token, in_fstring) and len(self.value) > 0: self.value += " " @@ -260,7 +257,7 @@ def add_elem( self.value += token.string - def to_key_val_mode(self, token: TokenInfo): + def to_key_val_mode(self, token: Token): if not self.has_value(): raise InvalidParameterSyntax( f"L{token.start[0]}:Operator = used with no preceding key" @@ -312,7 +309,7 @@ def __init__( self.keyword_indent = keyword_indent self.cur_indent = max(self.keyword_indent - 1, 0) self.comment = "" - self.token: TokenInfo + self.token: Token if snakefile is not None: self.validate_keyword_line(snakefile) @@ -415,7 +412,7 @@ def validate_rulelike_syntax(self, snakefile: TokenIterator): ColonError(self.line_nb, self.token.string, self.keyword_line) self.token = next(snakefile) - def add_processed_keyword(self, token: TokenInfo, keyword: str): + def add_processed_keyword(self, token: Token, keyword: str): self.processed_keywords.add(keyword) def check_empty(self): @@ -537,7 +534,7 @@ def check_exit(self, cur_param: Parameter, snakefile: TokenIterator): # untouched — the real processing will update it once tokens # are put back. temp_indent = self.cur_indent - cached_tokens: list[TokenInfo] = [] + cached_tokens: list[Token] = [] try: while True: t = next(snakefile) @@ -566,7 +563,7 @@ def check_exit(self, cur_param: Parameter, snakefile: TokenIterator): return exit def process_token( - self, cur_param: Parameter, prev_token: Optional[TokenInfo] + self, cur_param: Parameter, prev_token: Optional[Token] ) -> Parameter: token_type = self.token.type # f-string treatment (since python 3.12) diff --git a/snakefmt/types.py b/snakefmt/types.py index f0cb34a..f39b92a 100644 --- a/snakefmt/types.py +++ b/snakefmt/types.py @@ -5,16 +5,19 @@ COMMENT_SPACING = " " # PEP8, minimum of two spaces for inline comments -def line_nb(token: TokenInfo) -> int: +Token = TokenInfo + + +def line_nb(token: Token) -> int: return token.start[0] -def col_nb(token: TokenInfo) -> int: +def col_nb(token: Token) -> int: return token.start[1] -def not_empty(token: TokenInfo): +def not_empty(token: Token): return len(token.string) > 0 and not token.string.isspace() -TokenIterator = Iterator[TokenInfo] +TokenIterator = Iterator[Token] From 618c24783138fa0794b8fc2d4af3a1ac243c0c93 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 01:21:03 +0800 Subject: [PATCH 48/53] style: flake8 --- .flake8 | 3 +- snakefmt/blocken.py | 83 ++++++++++++++++++++++++++++--------------- tests/test_blocken.py | 36 ++++++++++++++----- 3 files changed, 82 insertions(+), 40 deletions(-) diff --git a/.flake8 b/.flake8 index 2052590..3b5443a 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,4 @@ [flake8] max-line-length = 88 # the default ignores minus E704 -ignore = E121,E123,E126,E226,E203,E24,W503,W504 - +ignore = E121,E123,E126,E226,E203,E24,E701,E704,W503,W504 diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 7935fd9..f5d8549 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -26,9 +26,14 @@ ) if sys.version_info < (3, 12): - is_fstring_start = lambda token: False + + def is_fstring_start(token: TokenInfo): + return False + else: - is_fstring_start = lambda token: token.type == tokenize.FSTRING_START + + def is_fstring_start(token: TokenInfo): + return token.type == tokenize.FSTRING_START def consume_fstring(tokens: Iterator[TokenInfo]): finished: list[TokenInfo] = [] @@ -79,7 +84,8 @@ def next_component(self): elif token.string in ")]}": if not expect_brackets or expect_brackets[-1] != token.string: raise UnsupportedSyntax( - f"Unexpected closing bracket {token.string!r} at line {token.start[0]}" + f"Unexpected closing bracket " + f"{token.string!r} at line {token.start[0]}" ) expect_brackets.pop() elif is_fstring_start(token): @@ -111,8 +117,8 @@ def next_block(self): assert deindelta == 1 break lines.append(line) - # there must be somewhere a DEDENT token to end the block, otherwise raise from __next__ - # now check comments + # there must be somewhere a DEDENT token to end the block, + # otherwise raise from __next__ now check comments indent = extract_line_indent(lines[0].body[0].line) tail_noncoding = self.denext_by_indent(line, indent, deindelta) return lines, tail_noncoding @@ -173,7 +179,8 @@ def __next__(self) -> TokenInfo: ) from e else: raise UnsupportedSyntax( - f"Unexpected end of file after symbol[{self._last_token}] while parsing '{self.name}'" + f"Unexpected end of file after symbol" + f"[{self._last_token}] while parsing '{self.name}'" ) from e self._last_token = token return token @@ -450,7 +457,7 @@ def format_black( if fix.startswith("Tb(\n"): fix = fix.split("\n", 1)[1].rsplit("\n", 1)[0] + "\n" else: - if not "#" in fix: # safe to unpack function + if "#" not in fix: # safe to unpack function fix = TAB * (indent + 1) + fix[3:-1] + "\n" else: fix = ( @@ -467,20 +474,24 @@ class Block(ABC): Also include functions, classes and decoraters (`@` lines) a single block identifed by keywords in if/elif/else / for/while / try/except/finally / with - and all the code under it, until the next block of the same or lower indent level. + and all the code under it, until the next block + of the same or lower indent level. a snakemake keyword block (rule, module, config, etc.) - and all the code under it, until the next block of the same or lower indent level. + and all the code under it, until the next block + of the same or lower indent level. (snakemake keywords should NEVER in functions or classes) comments between blocks - (exclude the comment right before the indenting keyword, which is considered part of the block) + (exclude the comment right before the indenting keyword, + which is considered part of the block) Starting of blocks (file or new indent): - the space and comments until the first indenting keyword are considered a block of their own. + the space and comments until the first indenting keyword + are considered a block of their own. All other spaces are considered part of the previous block's trailing empty lines. Comment belongness: - Only comments with neither empty lines between/after the next block nor different indent levels - are considered part of the same block. + Only comments with neither empty lines between/after the next block + nor different indent levels are considered part of the same block. e.g.: sth # block 1 # comment 1 -> block 1 @@ -643,7 +654,8 @@ def head_linestrs(self): @property def full_linestrs(self) -> list[str]: - """return the code splited by lines, but should keep multiline-string or multiline-f-string complete, + """return the code splited by lines, but should keep multiline-string + or multiline-f-string complete, to make trimming and reformatting easier. Should and Only should be rewrite for pure python blocks. @@ -659,10 +671,14 @@ def components(self) -> "Iterator[DocumentSymbol]": """ - position := (file, line number, column number) - type := name / rule, input, output / function, class / etc. - if not a name, then that's the definition of the name (should link blank names to here) - - identifier := the identifier of the block, e.g. rule `a`, `input`, input `b`, etc. - when iterating sub-blocks in rule, identifier should modified to reflect the parent block, e.g. `rules.a.input.b` - (`b` may be difficult to identify, but at least we know the content of `input` block) + if not a name, then that's the definition of the name + (should link blank names to here) + - identifier := the identifier of the block, + e.g. rule `a`, `input`, input `b`, etc. + when iterating sub-blocks in rule, identifier should modified to + reflect the parent block, e.g. `rules.a.input.b` + (`b` may be difficult to identify, + but at least we know the content of `input` block) - content := "self.raw()", e.g. `"data.txt"` for input `b` in rule `a`, and the whole content of the block for rule `a` @@ -749,7 +765,8 @@ def segment2format( @abstractmethod def compilation(self): - """return pure python code compiled from the block, without snakemake keywords and comments""" + """return pure python code compiled from the block, + without snakemake keywords and comments""" class DocumentSymbol(NamedTuple): @@ -843,7 +860,8 @@ class NoSnakemakeBlock(ColonBlock): Also, snakemake keywords should not be used in `async` blocks - TODO: although not recommended, snakemake keywords can be used in function/class body + TODO: although not recommended, snakemake keywords can be used in + function/class body Should handle that cases in the future """ @@ -1042,7 +1060,7 @@ def format_head(self, mode: Mode) -> tuple[str, list[TokenInfo]]: colon_token = next(post_colon) post = tokens2linestrs(post_colon.rest) post[0] = post[0][colon_token.end[1] :] - fake_str = f"if 1:" + "".join(post) + " ..." + fake_str = "if 1:" + "".join(post) + " ..." fake_fmt = format_black(fake_str, mode).strip() formatted_head += fake_fmt.split(":", 1)[1].rsplit("\n", 1)[0] + "\n" return formatted_head, [] @@ -1103,19 +1121,21 @@ def format_post_colon( even if expressions exist in that line, indent body should be formatted as part of the cotent: input: balabal, # <- expression after the colon - balabal2 # <- indent body, should be formatted as part of the content + balabal2 # <- indent body, should format as part of the content to: input: balabal, balabal2, - Morover, the original snakefmt allow sort positional arguments before keyword arguments. - Here need check, too + Morover, the original snakefmt allow sort positional arguments + before keyword arguments. Here need check, too Input: - post_colon: tokens after the colon in the head line, e.g. `balabal,` in the above example + post_colon: tokens after the colon in the head line, + e.g. `balabal,` in the above example post_colon[0] := TokenInfo(type=NAME, string='balabal', ...) - body_blocks: indent body blocks, e.g. the block of `balabal2` in the above example + body_blocks: indent body blocks, + e.g. the block of `balabal2` in the above example """ if not (post_colon or body_blocks): return "" @@ -1199,7 +1219,10 @@ def _find_split_and_push(): tail_noncoding = "" # here is used to check the end_op raw = "".join( - (*(i for l in args[False] for i in l), *(i for l in args[True] for i in l)) + ( + *(i for line in args[False] for i in line), + *(i for line in args[True] for i in line), + ) ) formatable = cls.handle_end_comma(raw, partial_line) + tail_noncoding formatted = format_black( @@ -1251,7 +1274,8 @@ def handle_end_comma(raw, last_line): class PythonUnnamedArguments(PythonArguments): - """Only allow simple expressions on the right, and the whole block should be a list""" + """Only allow simple expressions on the right, + and the whole block should be a list""" class PythonOneLineArgument(PythonArgumentsBlock): @@ -1544,7 +1568,8 @@ def format_body(self, mode, state, post_colon): directive = "" elif not last_sort_off: # state.sort_direcives switched on, this comment is - # actually `# fmt: on[sort]` directive, so split from next directive + # actually `# fmt: on[sort]` directive, + # so split from next directive formatted.append(directive) directive = "" if state.not_format: diff --git a/tests/test_blocken.py b/tests/test_blocken.py index f75210a..0f0e3ce 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -134,7 +134,8 @@ def test_next_new_line(self): "def components(self):\n" # " this_symbol: DocumentSymbol = DocumentSymbol(\n" " name=self.name,\n" - " detail='\\n'.join(i.rstrip() for i in self.block_lines()).strip('\\n'),\n" + " detail='\\n'.join(i.rstrip() for i in " + "self.block_lines()).strip('\\n'),\n" " symbol_kind=self._keyword(),\n" " position_start=self.start_token.start,\n" " position_end=self.head_tokens[-1].end,\n" @@ -345,11 +346,26 @@ def test_parse_snakefile(self): assert "".join(block.full_linestrs) == self.example2 assert isinstance(block, GlobalBlock) assert ["".join(i.full_linestrs) for i in block.body_blocks] == [ - "rule A:\n input:\n a = '1'\n output:\n 'b = 2'\n run:\n print(1)\n\n\n", - "checkpoint:\n name: 'check'\n params:\n c = '''\n c = '''\n conda: 'conda.yaml'\n shell: 'touch d'\n\n\n", - "onsuccess:\n for i in range(10):\n print(i)\n\n\n", - "wildcard_constraints:\n sth = r'a|b|c',\n sth2 = r'a|b|c',\n sth3 = r'a|b|c'\n\n\n", - "Report:\n 'report'\n", + "rule A:\n" + " input:\n" + " a = '1'\n" + " output:\n" + " 'b = 2'\n" + " run:\n" + " print(1)\n\n\n", + "checkpoint:\n" + " name: 'check'\n" + " params:\n" + " c = '''\n" + " c = '''\n" + " conda: 'conda.yaml'\n" + " shell: 'touch d'\n\n\n", + "onsuccess:\n" " for i in range(10):\n" " print(i)\n\n\n", + "wildcard_constraints:\n" + " sth = r'a|b|c',\n" + " sth2 = r'a|b|c',\n" + " sth3 = r'a|b|c'\n\n\n", + "Report:\n" " 'report'\n", "", ] @@ -365,7 +381,7 @@ def test_format_colon(self): assert fmted == "if 1: # comment\n" def test_format_def(self): - raw = f"{TAB}def s(a):\n" f"{TAB*2}if a:\n" f'{TAB* 3}return "Hello World"\n' + raw = f"{TAB}def s(a):\n" f"{TAB*2}if a:\n" f'{TAB * 3}return "Hello World"\n' fmted = format_black(raw, mode=mode, indent=1) assert fmted == raw @@ -432,8 +448,10 @@ def test_format_python_block(self): assert len(py2.head_lines) == 3 assert isinstance(py2, PythonBlock) assert ( - py2.formatted(mode) - == 'b = f"""\n{b =} f"""\n# comment\nc = [i for j in k] if m else (lambda: None)\n' + py2.formatted(mode) == 'b = f"""\n' + '{b =} f"""\n' + "# comment\n" + "c = [i for j in k] if m else (lambda: None)\n" ) assert block.get_formatted(mode) == black.format_str(self.example1, mode=mode) From 6498fce92401fa1a9caba0a9336b7ec412a880f4 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 02:11:38 +0800 Subject: [PATCH 49/53] fix: isort --- tests/test_blocken.py | 10 +++++----- tests/test_formatter.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_blocken.py b/tests/test_blocken.py index 0f0e3ce..fde238f 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -2,18 +2,18 @@ from snakefmt.blocken import ( FormatState, - NoSnakemakeBlock, GlobalBlock, IfForTryWithBlock, + NoSnakemakeBlock, PythonBlock, - consume_fstring, TokenIterator, + UnsupportedSyntax, + black, + consume_fstring, format_black, - tokenize, is_fstring_start, - UnsupportedSyntax, parse, - black, + tokenize, ) from snakefmt.config import read_black_config from snakefmt.types import TAB diff --git a/tests/test_formatter.py b/tests/test_formatter.py index f5b680d..820bbe1 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -11,11 +11,11 @@ import black.parsing import pytest +from snakefmt.blocken import setup_formatter from snakefmt.exceptions import InvalidPython from snakefmt.parser.grammar import SingleParam, SnakeGlobal from snakefmt.parser.syntax import COMMENT_SPACING from snakefmt.types import TAB -from snakefmt.blocken import setup_formatter def test_emptyInput_emptyOutput(): From cd6f2649842cb818eebe7a15375631b42fb61d66 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 02:13:52 +0800 Subject: [PATCH 50/53] ./snakefmt/blocken.py:126:38: F821 undefined name 'LogicalLine' --- snakefmt/blocken.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index f5d8549..9ca611e 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -123,7 +123,7 @@ def next_block(self): tail_noncoding = self.denext_by_indent(line, indent, deindelta) return lines, tail_noncoding - def denext_by_indent(self, line: LogicalLine, indent: str, deindelta=1): + def denext_by_indent(self, line: "LogicalLine", indent: str, deindelta=1): """Call when a block is ended by a DEDENT token, to split comments belong to this block from those belong to parent blocks, and reorder tokens so that the next block can be parsed correctly. From f2ee0517795b5a0703ad18f3c98fba81929df492 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 02:19:27 +0800 Subject: [PATCH 51/53] fix: import consume_fstring --- tests/test_blocken.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_blocken.py b/tests/test_blocken.py index fde238f..f052ed9 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -1,3 +1,5 @@ +import sys + import pytest from snakefmt.blocken import ( @@ -9,7 +11,6 @@ TokenIterator, UnsupportedSyntax, black, - consume_fstring, format_black, is_fstring_start, parse, @@ -18,6 +19,9 @@ from snakefmt.config import read_black_config from snakefmt.types import TAB +if sys.version_info >= (3, 12): + from snakefmt.blocken import consume_fstring + def generate_tokens(input: str): return list( @@ -27,6 +31,8 @@ def generate_tokens(input: str): class TestTokenIterator: def test_fstring1(self): + if sys.version_info < (3, 12): + return input = 'f"hello world"' tokens = generate_tokens(input) token_iter = TokenIterator("", iter(tokens)) @@ -45,6 +51,8 @@ def test_fstring1(self): ] def test_fstring_with_bracket(self): + if sys.version_info < (3, 12): + return input = 'a = f"hello {world}"' tokens = generate_tokens(input) token_iter = TokenIterator("", iter(tokens)) From 5da1e6ce3f6848693402b0672d5729181e7f2628 Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 02:24:35 +0800 Subject: [PATCH 52/53] test: skip for version --- tests/test_blocken.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_blocken.py b/tests/test_blocken.py index f052ed9..53b3d91 100644 --- a/tests/test_blocken.py +++ b/tests/test_blocken.py @@ -21,6 +21,9 @@ if sys.version_info >= (3, 12): from snakefmt.blocken import consume_fstring +py12_guard = pytest.mark.skipif( + sys.version_info < (3, 12), reason="Requires Python 3.12 or higher" +) def generate_tokens(input: str): @@ -30,9 +33,9 @@ def generate_tokens(input: str): class TestTokenIterator: + + @py12_guard def test_fstring1(self): - if sys.version_info < (3, 12): - return input = 'f"hello world"' tokens = generate_tokens(input) token_iter = TokenIterator("", iter(tokens)) @@ -50,9 +53,8 @@ def test_fstring1(self): tokenize.FSTRING_END, ] + @py12_guard def test_fstring_with_bracket(self): - if sys.version_info < (3, 12): - return input = 'a = f"hello {world}"' tokens = generate_tokens(input) token_iter = TokenIterator("", iter(tokens)) @@ -91,6 +93,7 @@ def test_consum_all(self): " pass" ) + @py12_guard def test_next_new_line(self): tokens = generate_tokens(self.example1) token_iter = TokenIterator("", iter(tokens)) From 4ea65f650210e626272d2b8b64a87ea4e1eaefdf Mon Sep 17 00:00:00 2001 From: hwrn Date: Fri, 10 Apr 2026 02:38:30 +0800 Subject: [PATCH 53/53] fix: ai --- snakefmt/blocken.py | 72 ++++++++++++++++++++--------------------- tests/test_formatter.py | 17 ++++++++++ 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/snakefmt/blocken.py b/snakefmt/blocken.py index 9ca611e..86297d2 100644 --- a/snakefmt/blocken.py +++ b/snakefmt/blocken.py @@ -58,7 +58,6 @@ def __init__(self, name, tokens: Iterator[TokenInfo]): self.name = name self._live_tokens = tokens self._buffered_tokens: list[TokenInfo] = list() - self.tokens = tokens self.lines = 0 self.rulecount = 0 self._overwrite_cmd: Optional[str] = None @@ -94,10 +93,11 @@ def next_component(self): def next_block(self): """Returns a entire block, just consume until the end of the block. - Donot care if there are nested blocks inside or snakemake keywords inside. + Do not care if there are nested blocks inside or snakemake keywords inside. - it could be INDEDT -> [any content] -> DEDENT, or [any content] -> DEDENT + it could be INDENT -> [any content] -> DEDENT, or [any content] -> DEDENT """ + line = self.next_new_line() if line.end.type == tokenize.ENDMARKER: self.denext(*reversed(list(line.iter))) @@ -230,7 +230,7 @@ def is_keyword_line(self): and self.body[1].string == "=" ): return True - if self.body[0].type == "**": + if self.body[0].string == "**": return True return False @@ -326,7 +326,7 @@ def tokens2linestrs(tokens: Iterator[TokenInfo]): class FormatState(NamedTuple): fmt_on: bool = True - sort_direcives: bool | None = None + sort_directives: bool | None = None skip_next: bool = False # one-time directive for the next snakemake block @property @@ -349,12 +349,11 @@ def update(self, comment: str): If found `# fmt: on` and no `# fmt: off` before: if `fmt: off[sort]` is False: - sort_direcives == True -> enabled - sort_direcives == False -> disabled in this indent before - sort_direcives == None -> haven't enabled originally + sort_directives == True -> enabled + sort_directives == False -> disabled in this indent before + sort_directives == None -> haven't enabled originally turn it on """ - match = _FMT_DIRECTIVE_RE.match(comment) if match := _FMT_DIRECTIVE_RE.match(comment): directive, options = match.groups() # Parse options: "sort,next" -> ["sort", "next"] -> "sort" @@ -364,14 +363,14 @@ def update(self, comment: str): return self._replace(fmt_on=True) elif directive == "on": if option == "sort": - return self._replace(sort_direcives=True) - if self.sort_direcives is False: + return self._replace(sort_directives=True) + if self.sort_directives is False: # re-enable sorting if it was disabled by `# fmt: off[sort]` before, # but should effect if no `# fmt: off[sort]` in this indent before. - return self._replace(sort_direcives=True) + return self._replace(sort_directives=True) elif directive == "off": if option == "sort": - return self._replace(sort_direcives=False) + return self._replace(sort_directives=False) if option == "next": return self._replace(skip_next=True) return self._replace(fmt_on=False) @@ -388,8 +387,8 @@ def found_skip(comment: str): return "# fmt: skip" in comment def reset_sort(self): - if self.sort_direcives is False: - return self._replace(sort_direcives=None) + if self.sort_directives is False: + return self._replace(sort_directives=None) return self @@ -1083,8 +1082,9 @@ def try_combine_format( Search reversly, so it only give one of the possible results. Since the non-comma param is the mistake of the user, - please do not blame if the olgorithm is slow :) + please do not blame if the algorithm is slow :) """ + if len(arg_lines) <= 1: return [arg_lines] mode = mode or Mode() @@ -1540,7 +1540,7 @@ def format_body(self, mode, state, post_colon): directive = "" for line in noncoding: # here noncoding is already formated linelstrip = line.lstrip() - last_sort_off = state.sort_direcives + last_sort_off = state.sort_directives if linelstrip: # only non-empty lines are formattable if state.found_skip(linelstrip): @@ -1560,14 +1560,14 @@ def format_body(self, mode, state, post_colon): block.indent_str, self.deindent_level + 1, [line] ) ) - elif not state.sort_direcives: + elif not state.sort_directives: if directives: formatted.extend(self.sort_directives(directives)) if directive: formatted.append(directive) directive = "" elif not last_sort_off: - # state.sort_direcives switched on, this comment is + # state.sort_directives switched on, this comment is # actually `# fmt: on[sort]` directive, # so split from next directive formatted.append(directive) @@ -1587,7 +1587,7 @@ def format_body(self, mode, state, post_colon): ) else: directive += block.formatted(mode, state) - if state.sort_direcives: + if state.sort_directives: directives[block.keyword] = directive else: assert not directives, "Already flushed once fmt: off[sort]" @@ -1797,9 +1797,9 @@ class GlobalBlock(Block): so tail_noncoding always updated to the last body_block """ - __slots__ = ("mode", "sort_direcives") + __slots__ = ("mode", "sort_directives") mode: Mode - sort_direcives: bool + sort_directives: bool subautomata = {**python_subautomata, **global_snakemake_subautomata} @@ -1817,8 +1817,8 @@ def get_formatted( if mode is None: raise ValueError("Mode should be provided for formatting") if sort_directives is None: - sort_directives = getattr(self, "sort_direcives", None) - state = FormatState(sort_direcives=sort_directives or None) + sort_directives = getattr(self, "sort_directives", None) + state = FormatState(sort_directives=sort_directives or None) # if set to None, it will not be enabled by `# fmt: on` python_codes: list[str] = [] snakemake_codes: list[tuple[str, str]] = [] @@ -1830,19 +1830,19 @@ def get_formatted( snakemake_codes.append((segment, indent_proxy)) else: last_str += segment - place_hode_str = "o" * 50 + placeholder = "o" * 50 raw_str = "".join(python_codes) - while place_hode_str in raw_str: - place_hode_str *= 2 + while placeholder in raw_str: + placeholder *= 2 raw_str = "#\n" for python_code, (snakemake_code, indent) in zip(python_codes, snakemake_codes): if snakemake_code.count("\n") == 1: # must at the end of line - place_hode = f"{indent}def l{place_hode_str}1ng(): ...\n" + snakemake_proxy = f"{indent}def l{placeholder}1ng(): ...\n" else: - place_hode = f"{indent}def l{place_hode_str}ng():\n{indent} return\n" - raw_str += python_code + place_hode + snakemake_proxy = f"{indent}def l{placeholder}ng():\n{indent} return\n" + raw_str += python_code + snakemake_proxy raw_str += last_str - formatted, *formatted_split = format_black(raw_str, mode).split(place_hode_str) + formatted, *formatted_split = format_black(raw_str, mode).split(placeholder) final_str = formatted for formatted, (snakemake_code, _) in zip(formatted_split, snakemake_codes): final_str = final_str.rsplit("\n", 1)[0] + "\n" + snakemake_code @@ -1856,13 +1856,13 @@ def compilation(self): raise NotImplementedError -def parse(input: str | Callable[[], str], name: str = ""): - if isinstance(input, str): +def parse(source: str | Callable[[], str], name: str = ""): + if isinstance(source, str): tokens = tokenize.generate_tokens( - iter(input.splitlines(keepends=True)).__next__ + iter(source.splitlines(keepends=True)).__next__ ) else: - tokens = tokenize.generate_tokens(input) + tokens = tokenize.generate_tokens(source) return GlobalBlock(0, TokenIterator(name, tokens), []) @@ -1878,5 +1878,5 @@ def setup_formatter( mode.line_length = line_length formatter.mode = mode - formatter.sort_direcives = sort_params + formatter.sort_directives = sort_params return formatter diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 820bbe1..64ceef7 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2462,6 +2462,23 @@ def test_fmt_off_sort_dedent(self): "\n\n" + formatted2 ) assert setup_formatter(code, sort_params=True).get_formatted() == expected + code = ( + "# fmt: off[sort]\n" + "if 1:\n" + " # fmt: on[sort]\n" + + "".join(" " + i for i in code1.splitlines(keepends=True)).rstrip() + + "\n" + + code2.rstrip() + ) + expected = ( + "# fmt: off[sort]\n" + "if 1:\n" + f"{TAB}# fmt: on[sort]\n" + + "".join(TAB + i for i in formatted0.splitlines(keepends=True)).rstrip() + + "\n" + "\n\n" + formatted2 + ) + assert setup_formatter(code, sort_params=True).get_formatted() == expected def test_fmt_off_sort_on_noeffect(self): code1, formatted1 = TestSortFormatting.sorting_comprehensive