From e8643680f955e4c2ec819f98de8f3de2c4cc2f0f Mon Sep 17 00:00:00 2001 From: Antonio Cordero Balcazar Date: Sun, 17 Dec 2023 17:58:23 +0100 Subject: [PATCH] * Support for prompt composition (AND) and cleaning around it. * Option to cleanup around extra network tags. * Improved cleanup. * Some refactoring. --- README.md | 21 ++-- ppp.py | 240 +++++++++++++++++++++++++++++++----------- scripts/ppp_script.py | 32 ++++-- tests/tests.py | 186 ++++++++++++-------------------- 4 files changed, 280 insertions(+), 199 deletions(-) diff --git a/README.md b/README.md index 7d62e89..5087f83 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Currently this extension has these functions: * Detect invalid wildcards and act on them. * Clean up the prompt and negative prompt. -Note: The extension must be loaded after the installed wildcards extension (or any other that modifies the prompt). Extensions load by their folder name in alphanumeric order. +Note: The extension must be loaded after the installed wildcards extension (or any other that modifies the prompt or has it's own syntax expressions). Extensions load by their folder name in alphanumeric order. With the ["Dynamic Prompts" extension](https://github.com/adieyal/sd-dynamic-prompts) this happens by default due to default folder names for both extensions. But if this is not the case, you can just rename this extension's folder so the ordering works out. @@ -23,11 +23,14 @@ Notes: * **Attention**: \[prompt\] (prompt) (prompt:weight) * **Alternation**: \[prompt1|prompt2|...\] * **Scheduling**: \[prompt1:prompt2:step\] - * **Models**: \ + * **Extra networks**: \ + * **BREAK**: prompt1 BREAK prompt2 + * **Composable Diffusion**: prompt1 AND prompt2 In SD.Next that means only the *A1111* or *Full* parsers. It will warn you if you use the *Compel* parser. 2. It only recognizes wildcards in the *\_\_wildcard\_\_* and *{choice|choice}* formats. -3. It does not translate equivalent *AND/BREAK* separations into the negative prompt. +3. Since it should run after other extensions that apply to the prompt, the content should have already been processed by them and there should't be any non recognized syntax anymore. +4. It does not create *AND/BREAK* constructs when moving content to the negative prompt. ## Installation @@ -100,11 +103,15 @@ Then, if that option is chosen this extension will process it later and move tha * **Apply in img2img**: check if you want to do this processing in img2img processes. * **Remove empty constructs**: removes attention/scheduling/alternation constructs when they are invalid. -* **Remove extra separators**: removes unnecesary separators. This applies to the configured separator and regular commas. -* **Clean up around BREAKs**: removes consecutive BREAKs and unnecesary commas and space around them. -* **Remove extra spaces**: removes unnecesary spaces. +* **Remove extra separators**: removes unnecessary separators. This applies to the configured separator and regular commas. +* **Clean up around BREAKs**: removes consecutive BREAKs and unnecessary commas and space around them. +* **Clean up around ANDs**: removes consecutive ANDs and unnecessary commas and space around them. +* **Clean up around extra network tags**: removes spaces around them. +* **Remove extra spaces**: removes other unnecessary spaces. -## Notes +## Notes on negative tags + +Positional insertion tags have less priority that start/end tags, so even if they are at the start or end of the negative prompt, they will end up inside any start/end (and default position) tags. The content of the negative tags is not processed and is copied as-is to the negative prompt. Other modifiers around the tags are processed in the following way. diff --git a/ppp.py b/ppp.py index 3973a47..907bbb5 100644 --- a/ppp.py +++ b/ppp.py @@ -40,7 +40,7 @@ class PromptPostProcessor: # pylint: disable=too-few-public-methods,too-many-in """ NAME = "Prompt Post-Processor" - VERSION = "2.2.0" + VERSION = "2.3.0" DEFAULT_STN_SEPARATOR = ", " IFWILDCARDS_CHOICES = { @@ -92,30 +92,44 @@ def __init__( self.cup_emptyconstructs = getattr(opts, "ppp_cup_emptyconstructs", True) if opts is not None else True self.cup_extraseparators = getattr(opts, "ppp_cup_extraseparators", True) if opts is not None else True self.cup_breaks = getattr(opts, "ppp_cup_breaks", True) if opts is not None else True + self.cup_ands = getattr(opts, "ppp_cup_ands", True) if opts is not None else True + self.cup_extranetworktags = getattr(opts, "ppp_cup_extranetworktags", False) if opts is not None else False self.__insertion_point_tags = [f"" for x in range(10)] # Process with lark (debug with https://www.lark-parser.org/ide/) self.__parser_complete = lark.Lark( r""" - start: (prompt | /[\][():|<>!{}]/+)* - prompt: (emphasized | deemphasized | scheduled | alternate | modeltag | negtag | wildcard | choices | plain)* - nonegprompt: (emphasized | deemphasized | scheduled | alternate | modeltag | wildcard | choices | plain)* + start: (promptcomp | specialchars)* // BUG: sometimes it chooses specialchars instead of promptcomp and fails!!! + // prompt composition with AND + promptcomp: promptcomppart ([":" numpar] (/\bAND\b/ promptcomppart [":" numpar])+)? + promptcomppart: prompt + // prompt scheduling and alternation + alternate: "[" alternateoption ("|" alternateoption)+ "]" + alternateoption: prompt + scheduled: "[" [prompt ":"] prompt ":" numpar "]" + // wildcard extension support wildcard: "__" /(?:(?!__)\S)+/ "__" choices: "{" choice ("|" choice)* "}" - choice: prompt # we ignore weight and any other parameters + // we ignore weight and any other parameters in each choice + choice: prompt + // simple prompts + prompt: (emphasized | deemphasized | scheduled | alternate | extranetworktag | negtag | wildcard | choices | plain)* + nonegprompt: (emphasized | deemphasized | scheduled | alternate | extranetworktag | wildcard | choices | plain)* + // attention modifiers emphasized: "(" prompt [":" numpar] ")" deemphasized: "[" prompt "]" - scheduled: "[" [prompt ":"] prompt ":" numpar "]" - alternate: "[" alternateoption ("|" alternateoption)+ "]" - alternateoption: prompt + // extra network tags + extranetworktag: "<" /(?!!)[^>]+/ ">" + // negative tags negtag: "" negtagparameters: "!" /s|e|[ip]\d/ "!" - modeltag: "<" /(?!!)[^>]+/ ">" + // plain text and weights numpar: WHITESPACE* NUMBER WHITESPACE* WHITESPACE: /\s+/ - plain: /((?!__)[^\\[\]():|<>!{}]|\\.)+/s + plain: /((?!__|\bAND\b)[^\\[\]():|<>!{}]|\\.)+/s + specialchars: /[\]():|<>!{}]|\bAND\b/+ %import common.SIGNED_NUMBER -> NUMBER - """, # prompt, nonegprompt, plain with ? + """, propagate_positions=True, ) @@ -158,7 +172,7 @@ def __init__(self, ppp, prompt, add_at): super().__init__() self.__ppp = ppp self.__prompt = prompt - self.AccumulatedShell = namedtuple("AccumulatedShell", ["type", "info1", "info2"]) + self.AccumulatedShell = namedtuple("AccumulatedShell", ["type", "data", "position"]) AccumulatedShell = self.AccumulatedShell self.__shell: list[AccumulatedShell] = [] self.NegTag = namedtuple("NegTag", ["start", "end", "content", "parameters", "shell"]) @@ -190,6 +204,9 @@ def scheduled(self, tree): Returns: None """ + treemetaposition = ( + [tree.meta.start_pos, tree.meta.end_pos] if hasattr(tree, "meta") and not tree.meta.empty else None + ) if len(tree.children) > 2: # before & after before = tree.children[0] else: @@ -199,7 +216,7 @@ def scheduled(self, tree): pos = self.__get_numpar_value(numpar) if pos >= 1: pos = int(pos) - # self.__shell.append(self.AccumulatedShell("sc", tree.meta.start_pos, pos)) + # self.__shell.append(self.AccumulatedShell("sc", pos, treemetaposition)) if before is not None and hasattr(before, "data"): if self.__ppp.debug: before_metaposition = ( @@ -208,7 +225,7 @@ def scheduled(self, tree): else "?" ) self.__ppp.logger.info(f"Shell scheduled before at {before_metaposition} with position {pos}") - self.__shell.append(self.AccumulatedShell("scb", pos, None)) + self.__shell.append(self.AccumulatedShell("scb", pos, treemetaposition)) self.visit(before) self.__shell.pop() if hasattr(after, "data"): @@ -219,7 +236,7 @@ def scheduled(self, tree): else "?" ) self.__ppp.logger.info(f"Shell scheduled after at {after_metaposition} with position {pos}") - self.__shell.append(self.AccumulatedShell("sca", pos, None)) + self.__shell.append(self.AccumulatedShell("sca", pos, treemetaposition)) self.visit(after) self.__shell.pop() # self.__shell.pop() @@ -234,7 +251,10 @@ def alternate(self, tree): Returns: None """ - # self.__shell.append(self.AccumulatedShell("al", tree.meta.start_pos, len(tree.children))) + treemetaposition = ( + [tree.meta.start_pos, tree.meta.end_pos] if hasattr(tree, "meta") and not tree.meta.empty else None + ) + # self.__shell.append(self.AccumulatedShell("al", len(tree.children), treemetaposition)) for i, opt in enumerate(tree.children): if self.__ppp.debug: metaposition = ( @@ -242,7 +262,9 @@ def alternate(self, tree): ) self.__ppp.logger.info(f"Shell alternate at {metaposition} option {i+1}") if hasattr(opt, "data"): - self.__shell.append(self.AccumulatedShell("alo", i + 1, len(tree.children))) + self.__shell.append( + self.AccumulatedShell("alo", {"pos": i + 1, "len": len(tree.children)}, treemetaposition) + ) self.visit(opt) self.__shell.pop() # self.__shell.pop() @@ -257,14 +279,14 @@ def emphasized(self, tree): Returns: None """ + treemetaposition = ( + [tree.meta.start_pos, tree.meta.end_pos] if hasattr(tree, "meta") and not tree.meta.empty else None + ) numpar = tree.children[-1] weight = self.__get_numpar_value(numpar) if numpar is not None else 1.1 if self.__ppp.debug: - metaposition = ( - [tree.meta.start_pos, tree.meta.end_pos] if hasattr(tree, "meta") and not tree.meta.empty else "?" - ) - self.__ppp.logger.info(f"Shell attention at {metaposition} with weight {weight}") - self.__shell.append(self.AccumulatedShell("at", weight, None)) + self.__ppp.logger.info(f"Shell attention at {treemetaposition or '?'} with weight {weight}") + self.__shell.append(self.AccumulatedShell("at", weight, treemetaposition)) self.visit_children(tree) self.__shell.pop() @@ -279,12 +301,12 @@ def deemphasized(self, tree): None """ weight = 0.9 + treemetaposition = ( + [tree.meta.start_pos, tree.meta.end_pos] if hasattr(tree, "meta") and not tree.meta.empty else None + ) if self.__ppp.debug: - metaposition = ( - [tree.meta.start_pos, tree.meta.end_pos] if hasattr(tree, "meta") and not tree.meta.empty else "?" - ) - self.__ppp.logger.info(f"Shell attention at {metaposition} with weight {weight}") - self.__shell.append(self.AccumulatedShell("at", weight, None)) + self.__ppp.logger.info(f"Shell attention at {treemetaposition or '?'} with weight {weight}") + self.__shell.append(self.AccumulatedShell("at", weight, treemetaposition)) self.visit_children(tree) self.__shell.pop() @@ -298,6 +320,9 @@ def negtag(self, tree): Returns: None """ + treemetaposition = ( + [tree.meta.start_pos, tree.meta.end_pos] if hasattr(tree, "meta") and not tree.meta.empty else None + ) negtagparameters = tree.children[0] parameters = negtagparameters.children[0].value if negtagparameters is not None else "" rest = [] @@ -312,11 +337,8 @@ def negtag(self, tree): self.NegTag(tree.meta.start_pos, tree.meta.end_pos, content, parameters, self.__shell.copy()) ) if self.__ppp.debug: - metaposition = ( - [tree.meta.start_pos, tree.meta.end_pos] if hasattr(tree, "meta") and not tree.meta.empty else "?" - ) self.__ppp.logger.info( - f"Negative tag at {metaposition}: {parameters or 'with no parameters :'} {self.__ppp.formatOutput(content)}" + f"Negative tag at {treemetaposition or '?'}: {parameters or 'with no parameters :'} {self.__ppp.formatOutput(content)}" ) def start(self, tree): @@ -338,9 +360,9 @@ def start(self, tree): if negtag.shell[i].type == "at" and negtag.shell[i - 1].type == "at": negtag.shell[i - 1] = self.AccumulatedShell( "at", - math.floor(100 * negtag.shell[i - 1].info1 * negtag.shell[i].info1) + math.floor(100 * negtag.shell[i - 1].data * negtag.shell[i].data) / 100, # we limit the new weight to two decimals - None, + negtag.shell[i - 1].position, ) negtag.shell.pop(i) start = "" @@ -348,26 +370,26 @@ def start(self, tree): for s in negtag.shell: match s.type: case "at": - if s.info1 == 0.9: + if s.data == 0.9: start += "[" end = "]" + end - elif s.info1 == 1.1: + elif s.data == 1.1: start += "(" end = ")" + end else: start += "(" - end = f":{s.info1})" + end + end = f":{s.data})" + end # case "sc": case "scb": start += "[" - end = f"::{s.info1}]" + end + end = f"::{s.data}]" + end case "sca": start += "[" - end = f":{s.info1}]" + end + end = f":{s.data}]" + end # case "al": case "alo": - start += "[" + ("|" * int(s.info1 - 1)) - end = ("|" * int(s.info2 - s.info1)) + "]" + end + start += "[" + ("|" * int(s.data["pos"] - 1)) + end = ("|" * int(s.data["len"] - s.data["pos"])) + "]" + end content = start + negtag.content + end position = negtag.parameters or "s" if len(content) > 0: @@ -403,7 +425,6 @@ def __find_tags(self, prompt): """ add_at = {"start": [], "insertion_point": [[] for x in range(10)], "end": []} tree = self.__parser_complete.parse(prompt) - # self.logger.info(f"tree from prompt:\n{tree.pretty()}") readtree = self.STNTree(self, prompt, add_at) readtree.visit(tree) @@ -526,11 +547,12 @@ class TransformerTree(lark.visitors.Transformer_NonRecursive): __ppp (object): An instance of the parent class `ppp`. Methods: + promptcomp(tree): Replicates prompt composition constructs. scheduled(tree): Replicates or removes scheduling constructs based on conditions. alternate(tree): Replicates or removes alternation constructs based on conditions. emphasized(tree): Replicates or removes attention constructs based on conditions. deemphasized(tree): Replicates or removes attention constructs based on conditions. - modeltag(tree): Replicates model constructs. + extranetworktag(tree): Replicates extra network constructs. numpar(tree): Cleans up number parameter. negtag(tree): Replicates or removes negative tag constructs based on conditions. wildcard(tree): Replicates or removes wildcard constructs based on conditions. @@ -538,7 +560,6 @@ class TransformerTree(lark.visitors.Transformer_NonRecursive): choice(tree): Replicates choices. plain(tree): Cleans up plain text based on conditions. __default__(data, children, meta): Default method for joining children and cleaning up text based on conditions. - """ def __init__(self, ppp, phase="cleanup"): @@ -547,6 +568,27 @@ def __init__(self, ppp, phase="cleanup"): self.__phase = phase self.detectedWildcards = [] + def promptcomp(self, tree): + r = tree[0] + if len(tree) > 1: + if tree[1] is not None: + r += f":{tree[1]}" + for i in range(2, len(tree), 3): + if self.__phase == "cleanup" and self.__ppp.cup_ands: + r = re.sub(r"[, ]+$", " ", r) + if r[-1:].isalnum(): # add space if needed + r += " " + r += "AND" + t = tree[i + 1] + if self.__phase == "cleanup" and self.__ppp.cup_ands: + t = re.sub(r"^[, ]+", " ", t) + if t[0:1].isalnum(): # add space if needed + r += " " + r += t + if tree[i + 2] is not None: + r += f":{tree[i+2]}" + return r + def scheduled(self, tree): if len(tree) == 0 and self.__phase == "cleanup" and self.__ppp.cup_emptyconstructs: return "" # remove invalid scheduling construct (probably this is not reachable) @@ -572,8 +614,8 @@ def deemphasized(self, tree): return "" # remove empty attention construct (invalid scheduling or alternation constructs end up here too?) return f"[{tree[0]}]" # replicate attention construct - def modeltag(self, tree): - return f"<{tree[0]}>" # replicate model construct + def extranetworktag(self, tree): + return f"<{tree[0]}>" # replicate extra network construct def numpar(self, tree): return next(x for x in tree if x.type == "NUMBER").value.strip() # clean up number parameter @@ -610,7 +652,7 @@ def plain(self, tree): def __default__(self, data, children, meta): joined = "".join(children) # join all children if self.__phase == "cleanup": - # clean up joined text if there are no constructs to take care of cleaning the joints + # take care of cleaning the joints only if there are no constructs that can be affected if not re.match(r"[([<{]", joined): joined = self.__ppp.cleanup_text(joined) return joined @@ -648,7 +690,7 @@ def __cleanup(self, prompt, negative_prompt): def cleanup_text(self, text): """ - Cleans up the given text by removing extra separators, breaks, and spaces. + Cleans up the given text by removing extra separators, breaks, and spaces. This is called for plain text only or when there are no constructs. Args: text (str): The text to be cleaned up. @@ -656,24 +698,38 @@ def cleanup_text(self, text): Returns: str: The cleaned up text. """ + # NOTE: we can't use start/end of line regex since the text might only be a part of a larger line due to the parser if self.cup_extraseparators: + # # sendtonegative separator + # escapedSeparator = re.escape(self.stn_separator) + # collapse separators text = re.sub(r"(?:\s*" + escapedSeparator + r"\s*){2,}", self.stn_separator, text) + # # regular comma separator + # + # collapse separators text = re.sub(r"(?:\s*,\s*){2,}", ", ", text) if self.cup_breaks: + # collapse separators and commas before BREAK + text = re.sub(r"[, ]+BREAK\b", " BREAK", text) + # collapse separators and commas after BREAK + text = re.sub(r"\bBREAK[, ]+", "BREAK ", text) + # collapse separators and commas around BREAK text = re.sub(r"[, ]+BREAK[, ]+", " BREAK ", text) - text = re.sub(r"BREAK(?:\s+BREAK)+[ ]+", "BREAK ", text) - text = re.sub(r"[ ]+BREAK(?:\s+BREAK)+", " BREAK", text) + # collapse BREAKs + text = re.sub(r"\bBREAK(?:\s+BREAK)+\b", " BREAK ", text) if self.cup_extraspaces: - text = re.sub(r"[ ]+,", ",", text) # remove spaces before comma - text = re.sub(r"[ ]{2,}", " ", text) # collapse spaces + # remove spaces before comma + text = re.sub(r"[ ]+,", ",", text) + # collapse spaces + text = re.sub(r"[ ]{2,}", " ", text) return text def trim_text(self, text): """ - Trims the given text based on the specified cleanup options. + Trims the given text based on the specified cleanup options. This is only called for the reconstructed prompt. Args: text (str): The text to be trimmed. @@ -681,18 +737,69 @@ def trim_text(self, text): Returns: str: The trimmed text. """ + # NOTE: here we can only do cleanups that can be done on the whole text, including inside constructs and around them if self.cup_extraseparators: + # # sendtonegative separator + # escapedSeparator = re.escape(self.stn_separator) - text = re.sub(r"^(?:\s*" + escapedSeparator + r"\s*)", "", text) - text = re.sub(r"(?:\s*" + escapedSeparator + r"\s*)$", "", text) + # remove duplicate separator after starting parenthesis or bracket + text = re.sub(r"(\s*" + escapedSeparator + r"\s*[([])\s*" + escapedSeparator + r"\s*", r"\1", text) + # remove before colon or ending parenthesis or bracket + text = re.sub(r"\s*" + escapedSeparator + r"\s*([:)\]]\s*" + escapedSeparator + r"\s*)", r"\1", text) + # remove at start of prompt or line + text = re.sub(r"^(?:\s*" + escapedSeparator + r"\s*)", "", text, flags=re.MULTILINE) + # remove at end of prompt or line + text = re.sub(r"(?:\s*" + escapedSeparator + r"\s*)$", "", text, flags=re.MULTILINE) + # # regular comma separator - text = re.sub(r"^\s*,\s*", "", text) - text = re.sub(r"\s*,\s*$", "", text) + # + # remove duplicate separators after starting parenthesis or bracket + text = re.sub(r"(\s*,\s*[([])\s*,\s*", r"\1", text) + # remove duplicate separators before colon or ending parenthesis or bracket + text = re.sub(r"\s*,\s*([:)\]]\s*,\s*)", r"\1", text) + # remove at start of prompt or line + text = re.sub(r"^\s*,\s*", "", text, flags=re.MULTILINE) + # remove at end of prompt or line + text = re.sub(r"\s*,\s*$", "", text, flags=re.MULTILINE) if self.cup_breaks: - text = re.sub(r"^BREAK\s+", "", text) - text = re.sub(r"\s+BREAK$", "", text) + # remove spaces between start of line and BREAK + text = re.sub(r"^[ ]+BREAK\b", "BREAK", text, flags=re.MULTILINE) + # remove spaces between BREAK and end of line + text = re.sub(r"\bBREAK[ ]+$", "BREAK", text, flags=re.MULTILINE) + # remove at start of prompt + text = re.sub(r"\ABREAK\b", "", text) + # remove at end of prompt + text = re.sub(r"\bBREAK\Z", "", text) + if self.cup_ands: + # collapse ANDs with space after + text = re.sub(r"\bAND(?:\s+AND)+\s+", "AND ", text) + # collapse ANDs without space after + text = re.sub(r"\bAND(?:\s+AND)+\b", "AND", text) + # collapse separators and spaces before ANDs + text = re.sub(r"[, ]+AND\b", " AND", text) + # collapse separators and spaces after ANDs + text = re.sub(r"\bAND[, ]+", "AND ", text) + # remove at start of prompt + text = re.sub(r"\AAND\b", "", text) + # remove at end of prompt + text = re.sub(r"\bAND\Z", "", text) + if self.cup_extranetworktags: + # + # all cases since we can't find them inside plain text + # + # remove spaces before < + text = re.sub(r"\B\s+<(?!!)", "<", text) + # remove spaces after > + text = re.sub(r"(?\s+\B", ">", text) if self.cup_extraspaces: + # remove extra spaces after starting parenthesis or bracket + text = re.sub(r"([,\.;\s]+[([])\s+", r"\1", text) + # remove extra spaces before ending parenthesis or bracket + text = re.sub(r"\s+([)\]][,\.;\s]+)", r"\1", text) + # collapse spaces + # text = re.sub(r"[ ]{2,}", " ", text) + # remove spaces at start and end text = text.strip() return text @@ -714,7 +821,6 @@ def __findwildcards(self, prompt, negative_prompt): p_transformtree = self.TransformerTree(self, phase="wildcards") try: p_tree = self.__parser_complete.parse(prompt) - # self.logger.info(f"Wildcards tree from prompt:\n{p_tree.pretty()}") prompt = p_transformtree.transform(p_tree) except Exception as e: # pylint: disable=broad-except self.logger.warning("Wildcards parsing failed in prompt!: %s", e) @@ -722,7 +828,6 @@ def __findwildcards(self, prompt, negative_prompt): np_transformtree = self.TransformerTree(self, phase="wildcards") try: np_tree = self.__parser_complete.parse(negative_prompt) - # self.logger.info(f"Wildcards tree from negative prompt:\n{np_tree.pretty()}") negative_prompt = np_transformtree.transform(np_tree) except Exception as e: # pylint: disable=broad-except self.logger.warning("Wildcards parsing failed in negative prompt!: %s", e) @@ -745,7 +850,8 @@ def __findwildcards(self, prompt, negative_prompt): prompt = self.WILDCARD_STOP + prompt if foundNP: negative_prompt = self.WILDCARD_STOP + negative_prompt - self.script.ppp_interrupt() + if hasattr(self.script, "ppp_interrupt"): + self.script.ppp_interrupt() if self.debug: self.logger.info(f"prompt after wildcards: {self.formatOutput(prompt)}") self.logger.info(f"negative_prompt after wildcards: {self.formatOutput(negative_prompt)}") @@ -769,7 +875,11 @@ def process_prompt(self, original_prompt, original_negative_prompt): if not self.is_i2i or self.stn_doi2i or self.cup_doi2i: if self.debug: self.logger.info(f"Input prompt: {self.formatOutput(prompt)}") + p_tree = self.__parser_complete.parse(prompt) + self.logger.info(f"Tree from prompt:\n{p_tree.pretty()}") self.logger.info(f"Input negative_prompt: {self.formatOutput(negative_prompt)}") + np_tree = self.__parser_complete.parse(negative_prompt) + self.logger.info(f"Tree from negative prompt:\n{np_tree.pretty()}") if self.ifwildcards != self.IFWILDCARDS_CHOICES["ignore"]: prompt, negative_prompt = self.__findwildcards(prompt, negative_prompt) @@ -779,10 +889,14 @@ def process_prompt(self, original_prompt, original_negative_prompt): # pylint: disable-next=too-many-boolean-expressions if (not self.is_i2i or self.cup_doi2i) and ( - self.cup_extraspaces or self.cup_emptyconstructs or self.cup_extraseparators or self.cup_breaks + self.cup_extraspaces + or self.cup_emptyconstructs + or self.cup_extraseparators + or self.cup_breaks + or self.cup_ands + or self.cup_extranetworktags ): prompt, negative_prompt = self.__cleanup(prompt, negative_prompt) - return prompt, negative_prompt except Exception as e: # pylint: disable=broad-exception-caught self.logger.exception(e) diff --git a/scripts/ppp_script.py b/scripts/ppp_script.py index d14c569..b766237 100644 --- a/scripts/ppp_script.py +++ b/scripts/ppp_script.py @@ -194,14 +194,6 @@ def __on_ui_settings(self): section=section, ), ) - shared.opts.add_option( - key="ppp_cup_extraspaces", - info=shared.OptionInfo( - True, - label="Remove extra spaces", - section=section, - ), - ) shared.opts.add_option( key="ppp_cup_emptyconstructs", info=shared.OptionInfo( @@ -226,3 +218,27 @@ def __on_ui_settings(self): section=section, ), ) + shared.opts.add_option( + key="ppp_cup_ands", + info=shared.OptionInfo( + True, + label="Clean up around ANDs", + section=section, + ), + ) + shared.opts.add_option( + key="ppp_cup_extranetworktags", + info=shared.OptionInfo( + False, + label="Clean up around extra network tags", + section=section, + ), + ) + shared.opts.add_option( + key="ppp_cup_extraspaces", + info=shared.OptionInfo( + True, + label="Remove extra spaces", + section=section, + ), + ) diff --git a/tests/tests.py b/tests/tests.py index 77bd99b..4cd4fca 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -49,12 +49,29 @@ def setUp(self): "ppp_cup_extraseparators": True, "ppp_cup_extraspaces": True, "ppp_cup_breaks": True, + "ppp_cup_ands": True, + "ppp_cup_extranetworktags": True, + } + ) + self.__nocupopts = DictToObj( + { + "ppp_gen_debug": True, + "ppp_gen_ifwildcards": PromptPostProcessor.IFWILDCARDS_CHOICES["ignore"], + "ppp_stn_doi2i": False, + "ppp_stn_separator": ", ", + "ppp_stn_ignore_repeats": True, + "ppp_stn_join_attention": True, + "ppp_cup_doi2i": False, + "ppp_cup_emptyconstructs": False, + "ppp_cup_extraseparators": False, + "ppp_cup_extraspaces": False, + "ppp_cup_breaks": False, + "ppp_cup_ands": False, + "ppp_cup_extranetworktags": False, } ) self.defppp = PromptPostProcessor(self, self.__defopts) - - def ppp_interrupt(self): - pass # fake interrupt + self.nocupppp = PromptPostProcessor(self, self.__nocupopts) def process( self, @@ -88,131 +105,40 @@ def process( # Send To Negative tests - def test_tag_default(self): # negtag with no parameters - self.process( - "flowers", - "normal quality, worse quality", - "flowers", - "red, normal quality, worse quality", - ) - - def test_tag_start(self): # negtag with s parameter - self.process( - "flowers", - "normal quality, worse quality", - "flowers", - "red, normal quality, worse quality", - ) - - def test_tag_end(self): # negtag with e parameter + def test_nt_simple(self): # negtags with different parameters and separations self.process( - "flowers", - "normal quality, worse quality", + "flowers, , , ", + "normal quality, worse quality", "flowers", - "normal quality, worse quality, red", + "red, green, yellow, normal quality, purple, worse quality, black, blue", ) - def test_tag_insertion_mid_sep(self): # negtag with p parameter and insertion in the middle + def test_nt_complex(self): # complex negtags self.process( - "flowers", - "normal quality, , worse quality", - "flowers", - "normal quality, red, worse quality", - ) - - def test_tag_insertion_mid_no_sep(self): # negtag with p parameter and insertion in the middle without separator - self.process( - "flowers", - "normal qualityworse quality", - "flowers", - "normal quality, red, worse quality", - ) - - def test_tag_insertion_start_sep(self): # negtag with p parameter and insertion at the start - self.process( - "flowers", - ", normal quality, worse quality", - "flowers", - "red, normal quality, worse quality", - ) - - def test_tag_insertion_start_no_sep(self): # negtag with p parameter and insertion at the start without separator - self.process( - "flowers", - "normal quality, worse quality", - "flowers", - "red, normal quality, worse quality", - ) - - def test_tag_insertion_end_sep(self): # negtag with p parameter and insertion at the end - self.process( - "flowers", - "normal quality, worse quality, ", - "flowers", - "normal quality, worse quality, red", - ) - - def test_tag_insertion_end_no_sep(self): # negtag with p parameter and insertion at the end without separator - self.process( - "flowers", - "normal quality, worse quality", - "flowers", - "normal quality, worse quality, red", - ) - - def test_complex(self): # complex negtags - self.process( - " (), flowers , , ", + " (()), flowers , , ", "normal quality, , bad quality, worse quality", "flowers", - "red, (pink), normal quality, yellow, bad quality, green, worse quality, purple, blue", + "red, (pink:1.21), normal quality, mauve, yellow, bad quality, green, worse quality, purple, blue", ) - def test_complex_no_cleanup(self): # complex negtags with no cleanup + def test_nt_complex_nocleanup(self): # complex negtags with no cleanup self.process( - " (), flowers , , ", + " (()), flowers , , ", "normal quality, , bad quality, worse quality", - " (), flowers , , ", - "red, (pink), normal quality, yellow, bad quality, green, worse quality, purple, blue", - PromptPostProcessor( - self, - DictToObj( - { - **self.__defopts.__dict__, - "ppp_cup_emptyconstructs": False, - "ppp_cup_extraseparators": False, - "ppp_cup_extraspaces": False, - "ppp_cup_breaks": False, - } - ), - ), + " (()), flowers , , ", + "red, (pink:1.21), normal quality, mauve, yellow, bad quality, green, worse quality, purple, blue", + self.nocupppp, ) - def test_inside_attention1(self): # negtag inside attention + def test_nt_inside_attention(self): # negtag inside attention self.process( - "[] this is a ((test) (test:2.0): 1.5 )", + "[] this is a ((test) (test:2.0): 1.5 ) (red:1.5)", "normal quality", - "this is a ((test) (test:2.0):1.5)", - "[neg1], normal quality, (neg2:1.65)", + "this is a ((test) (test:2.0):1.5) (red:1.5)", + "[neg1], ([square]:1.5), normal quality, (neg2:1.65)", ) - def test_inside_attention2(self): # negtag inside attention - self.process( - "(red:1.5)", - "", - "(red:1.5)", - "([square]:1.5)", - ) - - def test_inside_alternation1(self): # negtag inside alternation - self.process( - "this is a (([complex|simple|regular] test)(test:2.0):1.5)", - "normal quality", - "this is a (([complex|simple|regular] test)(test:2.0):1.5)", - "([|neg1|]:1.65), normal quality", - ) - - def test_inside_alternation2(self): # negtag inside alternation + def test_nt_inside_alternation(self): # negtag inside alternation self.process( "this is a (([complex|simple|regular] test)(test:2.0):1.5)", "normal quality", @@ -220,7 +146,7 @@ def test_inside_alternation2(self): # negtag inside alternation "([neg1||]:1.65), ([|neg2|]:1.65), ([||neg3]:1.65), normal quality", ) - def test_inside_alternation3(self): # negtag inside alternation (recursive alternation) + def test_nt_inside_alternation_recursive(self): # negtag inside alternation (recursive alternation) self.process( "this is a (([complex[one|two||three|four()]|simple|regular] test)(test:2.0):1.5)", "normal quality", @@ -228,7 +154,7 @@ def test_inside_alternation3(self): # negtag inside alternation (recursive alte "([neg1||]:1.65), ([[|neg12|||]||]:1.65), ([[||||(neg14)]||]:1.65), ([|neg2|]:1.65), ([||neg3]:1.65), normal quality", ) - def test_inside_scheduling(self): # negtag inside scheduling + def test_nt_inside_scheduling(self): # negtag inside scheduling self.process( "this is [abc:def: 5 ]", "normal quality", @@ -236,17 +162,17 @@ def test_inside_scheduling(self): # negtag inside scheduling "[neg1::5], normal quality, [neg2:5]", ) - def test_complex_features(self): # complex negtags with features + def test_nt_complex_features(self): # complex negtags with AND, BREAK and other features self.process( - "[] this is: a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK, BREAK with [abc:def:5] ", + "[] this \\(is\\): a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK, BREAK with [abc:def:5]:0.5 AND loraword AND AND hypernetword :0.3", "normal quality, ", - "this is: a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK with [abc:def:5] ", + "this \\(is\\): a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK with [abc:def:5]:0.5 AND loraword AND hypernetword :0.3", "[neg5], ([|neg6|]:1.65), (neg1:1.65), [neg4::5], normal quality, [neg2(neg3:1.6):5]", ) # Wildcard tests - def test_wildcards_ignore(self): # wildcards with ignore option + def test_wc_ignore(self): # wildcards with ignore option self.process( "__bad_wildcard__", "{option1|option2}", @@ -263,11 +189,11 @@ def test_wildcards_ignore(self): # wildcards with ignore option ), ) - def test_wildcards_remove(self): # wildcards with remove option + def test_wc_remove(self): # wildcards with remove option self.process( "[] this is: __bad_wildcard__ a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK, BREAK with [abc:def:5] ", "normal quality, {option1|option2}", - "this is: a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK with [abc:def:5] ", + "this is: a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK with [abc:def:5]", "[neg5], ([|neg6|]:1.65), (neg1:1.65), [neg4::5], normal quality, [neg2(neg3:1.6):5]", PromptPostProcessor( self, @@ -280,7 +206,7 @@ def test_wildcards_remove(self): # wildcards with remove option ), ) - def test_wildcards_warn(self): # wildcards with warn option + def test_wc_warn(self): # wildcards with warn option self.process( "__bad_wildcard__", "{option1|option2}", @@ -297,7 +223,7 @@ def test_wildcards_warn(self): # wildcards with warn option ), ) - def test_wildcards_stop(self): # wildcards with stop option + def test_wc_stop(self): # wildcards with stop option self.process( "__bad_wildcard__", "{option1|option2}", @@ -314,6 +240,24 @@ def test_wildcards_stop(self): # wildcards with stop option ), ) + # Cleanup tests + + def test_cl_simple(self): # simple cleanup + self.process( + " this is a ((test ), , () [] ( , test ,:2.0):1.5) (red:1.5) ", + " normal quality ", + "this is a ((test), (test,:2.0):1.5) (red:1.5)", + "normal quality", + ) + + def test_cl_complex(self): # complex cleanup + self.process( + " this is BREAKABLE a ((test), ,AND AND() [] ANDERSON (test:2.0):1.5) :o BREAK \n BREAK (red:1.5) ", + " [:hands, feet, :0.15]normal quality ", + "this is BREAKABLE a ((test) AND ANDERSON (test:2.0):1.5) :o BREAK (red:1.5)", + "[:hands, feet, :0.15]normal quality", + ) + if __name__ == "__main__": unittest.main()