diff --git a/README.md b/README.md index 8bf2860..81a4358 100644 --- a/README.md +++ b/README.md @@ -67,14 +67,34 @@ that part to the negative prompt. ## Configuration -The extension settings allow you to change the format of the tag in case there -is some incompatibility with another extension. +Separator used when adding to the negative prompt: You can specify the separator used when adding to the negative prompt (by default it's ", "). -You can also specify the separator added to the negative prompt which by -default is ", ". +Ignore tags with repeated content: by default it ignores repeated content to avoid repetitions in the negative prompt. -By default it ignores repeated content and also tries to clean up the prompt -after removing the tags, but these can also be changed in the settings. +Join attention modifiers (weights) when possible: by default it joins attention modifiers when possible (joins into one, multipliying their values). + +Try to clean-up the prompt after processing: by default cleans up the positive prompt after processing, removing extra spaces and separators. + +## Notes + +The content of the negative tags is not processed and is copied as is to the negative prompt. Other modifiers around the tags are processed in the following way. + +### Attention modifiers (weights) + +They will be translated to the negative prompt. For example: + +* `(red:1.5)` will end up as `(square:1.5)` in the negative prompt +* `(red[]:1.5)` will end up as `(square:1.35)` in the negative prompt (weight=1.5*0.9) +* However `(red:1.5)` will end up as `([square]:1.5)` in the negative prompt. The content of the negative tag is copied as is, and not joined with the surrounding modifier. + +### Prompt editing constructs (alternation and scheduling) + +Negative tags inside such constructs will copy the construct to the negative prompt, but separating its elements. For example: + +* Alternation: `[red|blue]` will end up as `[square|], [|circle]` in the negative prompt, instead of `[square|circle]` +* Scheduling: `[red:blue:0.5]` will end up as `[square::0.5], [:circle:0.5]` instead of `[square:circle:0.5]` + +This should still work as intended, and the only negative point i see is the unnecessary separators. ## License diff --git a/scripts/extension.py b/scripts/extension.py index 9bae103..8ba3999 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -36,38 +36,6 @@ def process(self, p: StableDiffusionProcessing, *args, **kwargs): def __on_ui_settings(self): section = ("send-to-negative", SendToNegative.NAME) - shared.opts.add_option( - key="stn_tagstart", - info=shared.OptionInfo( - SendToNegative.DEFAULT_TAG_START, - label="Tag start", - section=section, - ), - ) - shared.opts.add_option( - key="stn_tagend", - info=shared.OptionInfo( - SendToNegative.DEFAULT_TAG_END, - label="Tag end", - section=section, - ), - ) - shared.opts.add_option( - key="stn_tagparamstart", - info=shared.OptionInfo( - SendToNegative.DEFAULT_TAG_PARAM_START, - label="Tag parameter start", - section=section, - ), - ) - shared.opts.add_option( - key="stn_tagparamend", - info=shared.OptionInfo( - SendToNegative.DEFAULT_TAG_PARAM_END, - label="Tag parameter end", - section=section, - ), - ) shared.opts.add_option( key="stn_separator", info=shared.OptionInfo( @@ -84,6 +52,14 @@ def __on_ui_settings(self): section=section, ), ) + shared.opts.add_option( + key="stn_joinattention", + info=shared.OptionInfo( + True, + label="Join attention modifiers (weights) when possible", + section=section, + ), + ) shared.opts.add_option( key="stn_cleanup", info=shared.OptionInfo( diff --git a/sendtonegative.py b/sendtonegative.py index b232920..70be69c 100644 --- a/sendtonegative.py +++ b/sendtonegative.py @@ -1,30 +1,25 @@ +from collections import namedtuple import logging +import lark import re class SendToNegative: # pylint: disable=too-few-public-methods NAME = "Send to Negative" - VERSION = "1.1" + VERSION = "2.0" - DEFAULT_TAG_START = "" - DEFAULT_TAG_PARAM_START = "!" - DEFAULT_TAG_PARAM_END = "!" DEFAULT_SEPARATOR = ", " def __init__( self, - tag_start=None, - tag_end=None, - tag_param_start=None, - tag_param_end=None, separator=None, ignore_repeats=None, + join_attention=None, cleanup=None, opts=None, ): """ - Default format for the tag: + Format for the tag: @@ -39,39 +34,16 @@ def __init__( iN - tags the position of insertion point N. Used only in the negative prompt and does not accept content. N can be 0 to 9. """ self.__logger = logging.getLogger(__name__) - - str_start = ( - tag_start - if tag_start is not None - else getattr(opts, "stn_tagstart", self.DEFAULT_TAG_START) - if opts is not None - else self.DEFAULT_TAG_START - ) - str_end = ( - tag_end - if tag_end is not None - else getattr(opts, "stn_tagend", self.DEFAULT_TAG_END) - if opts is not None - else self.DEFAULT_TAG_END - ) - str_param_start = ( - tag_param_start - if tag_param_start is not None - else getattr(opts, "stn_tagparamstart", self.DEFAULT_TAG_PARAM_START) - if opts is not None - else self.DEFAULT_TAG_PARAM_START - ) - str_param_end = ( - tag_param_end - if tag_param_end is not None - else getattr(opts, "stn_tagparamend", self.DEFAULT_TAG_PARAM_END) - if opts is not None - else self.DEFAULT_TAG_PARAM_END - ) - escape_sequence = r"(?" for x in range(10)] + # Process with lark (debug with https://www.lark-parser.org/ide/) + self.__schedule_parser = lark.Lark( + r""" + start: (prompt | /[\][():|<>!]/+)* + ?prompt: (emphasized | deemphasized | scheduled | alternate | modeltag | negtag | plain)* + ?nonegprompt: (emphasized | deemphasized | scheduled | alternate | modeltag | plain)* + emphasized: "(" prompt [":" numpar] ")" + deemphasized: "[" prompt "]" + scheduled: "[" [prompt ":"] prompt ":" numpar "]" + alternate: "[" alternateoption ("|" alternateoption)+ "]" + alternateoption: prompt + negtag: "" + negtagparameters: "!" /s|e|[ip]\d/ "!" + modeltag: "<" /(?!!)[^>]+/ ">" + numpar: WHITESPACE* NUMBER WHITESPACE* + WHITESPACE: /\s+/ + ?plain: /([^\\[\]():|<>!]|\\.)+/s + %import common.SIGNED_NUMBER -> NUMBER + """, + propagate_positions=True, ) + class ReadTree(lark.visitors.Interpreter): + def __init__(self, logger, ignorerepeats, joinattention, prompt, add_at): + super().__init__() + self.__logger = logger + self.__ignore_repeats = ignorerepeats + self.__join_attention = joinattention + self.__prompt = prompt + self.AccumulatedShell = namedtuple("AccumulatedShell", ["type", "info1", "info2"]) + AccumulatedShell = self.AccumulatedShell + self.__shell: list[AccumulatedShell] = [] + self.NegTag = namedtuple("NegTag", ["start", "end", "content", "parameters", "shell"]) + NegTag = self.NegTag + self.__negtags: list[NegTag] = [] + self.__already_processed = [] + self.add_at = add_at + self.remove = [] + + def scheduled(self, tree): + if len(tree.children) > 2: # before & after + before = tree.children[0] + else: + before = None + after = tree.children[-2] + numpar = tree.children[-1] + pos = float(numpar.children[0].value) + if pos >= 1: + pos = int(pos) + # self.__shell.append(self.AccumulatedShell("sc", tree.meta.start_pos, pos)) + if before is not None and hasattr(before, "data"): + self.__logger.debug( + f"Shell scheduled before at {[before.meta.start_pos,before.meta.end_pos] if hasattr(before,'meta') else '?'} : {pos}" + ) + self.__shell.append(self.AccumulatedShell("scb", pos, None)) + self.visit(before) + self.__shell.pop() + if hasattr(after, "data"): + self.__logger.debug( + f"Shell scheduled after at {[after.meta.start_pos,after.meta.end_pos] if hasattr(after,'meta') else '?'} : {pos}" + ) + self.__shell.append(self.AccumulatedShell("sca", pos, None)) + self.visit(after) + self.__shell.pop() + # self.__shell.pop() + + def alternate(self, tree): + # self.__shell.append(self.AccumulatedShell("al", tree.meta.start_pos, len(tree.children))) + for i, opt in enumerate(tree.children): + self.__logger.debug( + f"Shell alternate at {[opt.meta.start_pos,opt.meta.end_pos] if hasattr(opt,'meta') else '?'} : {i+1}" + ) + if hasattr(opt, "data"): + self.__shell.append(self.AccumulatedShell("alo", i + 1, len(tree.children))) + self.visit(opt) + self.__shell.pop() + # self.__shell.pop() + + def emphasized(self, tree): + numpar = tree.children[-1] + weight = float(numpar.children[0].value) if numpar is not None else 1.1 + self.__logger.debug( + f"Shell attention at {[tree.meta.start_pos,tree.meta.end_pos] if hasattr(tree,'meta') else '?'}: {weight}" + ) + self.__shell.append(self.AccumulatedShell("at", weight, None)) + self.visit_children(tree) + self.__shell.pop() + + def deemphasized(self, tree): + weight = 0.9 + self.__logger.debug( + f"Shell attention at {[tree.meta.start_pos,tree.meta.end_pos] if hasattr(tree,'meta') else '?'}: {weight}" + ) + self.__shell.append(self.AccumulatedShell("at", weight, None)) + self.visit_children(tree) + self.__shell.pop() + + def negtag(self, tree): + negtagparameters = tree.children[0] + parameters = negtagparameters.children[0].value if negtagparameters is not None else "" + rest = [] + for x in tree.children[1::]: + rest.append(self.__prompt[x.meta.start_pos : x.meta.end_pos] if hasattr(x, "meta") else x.value) + content = "".join(rest) + self.__negtags.append( + self.NegTag(tree.meta.start_pos, tree.meta.end_pos, content, parameters, self.__shell.copy()) + ) + self.__logger.debug( + f"Negative tag at {[tree.meta.start_pos,tree.meta.end_pos] if hasattr(tree,'meta') else '?'}: {parameters}: {content.encode('unicode_escape').decode('utf-8')}" + ) + + def start(self, tree): + self.visit_children(tree) + # process the found negtags + for nt in self.__negtags: + if self.__join_attention: + # join consecutive attention elements + for i in range(len(nt.shell) - 1, 0, -1): + if nt.shell[i].type == "at" and nt.shell[i - 1].type == "at": + nt.shell[i - 1] = self.AccumulatedShell( + "at", + (100 * nt.shell[i - 1].info1 * nt.shell[i].info1) / 100, # we limit to two decimals + None, + ) + nt.shell.pop(i) + start = "" + end = "" + for s in nt.shell: + match s.type: + case "at": + if s.info1 == 0.9: + start += "[" + end = "]" + end + elif s.info1 == 1.1: + start += "(" + end = ")" + end + else: + start += "(" + end = f":{s.info1})" + end + #case "sc": + case "scb": + start += "[" + end = f"::{s.info1}]" + end + case "sca": + start += "[" + end = f":{s.info1}]" + end + #case "al": + case "alo": + start += "[" + ("|" * int(s.info1 - 1)) + end = ("|" * int(s.info2 - s.info1)) + "]" + end + content = start + nt.content + end + position = nt.parameters or "s" + if len(content) > 0: + if content not in self.__already_processed: + if self.__ignore_repeats: + self.__already_processed.append(content) + self.__logger.debug( + f"Adding content at position {position}: {content.encode('unicode_escape').decode('utf-8')}" + ) + if position == "e": + self.add_at["end"].append(content) + elif position.startswith("p"): + n = int(position[1]) + self.add_at["insertion_point"][n].append(content) + else: # position == "s" or invalid + self.add_at["start"].append(content) + else: + self.__logger.warning( + f"Ignoring repeated content: {content.encode('unicode_escape').decode('utf-8')}" + ) + # remove from prompt + self.remove.append([nt.start, nt.end]) + def process_prompt(self, original_prompt, original_negative_prompt): """ Extract from the prompt the tagged parts and add them to the negative prompt @@ -107,54 +235,48 @@ def process_prompt(self, original_prompt, original_negative_prompt): try: prompt = original_prompt negative_prompt = original_negative_prompt - self.__logger.debug(f"Input prompt: {prompt}") - self.__logger.debug(f"Input negative_prompt: {negative_prompt}") + self.__logger.debug(f"Input prompt: {prompt.encode('unicode_escape').decode('utf-8')}") + self.__logger.debug(f"Input negative_prompt: {negative_prompt.encode('unicode_escape').decode('utf-8')}") prompt, add_at = self.__find_tags(prompt) negative_prompt = self.__add_to_insertion_points(negative_prompt, add_at["insertion_point"]) if len(add_at["start"]) > 0: negative_prompt = self.__add_to_start(negative_prompt, add_at["start"]) if len(add_at["end"]) > 0: negative_prompt = self.__add_to_end(negative_prompt, add_at["end"]) - self.__logger.debug(f"Output prompt: {prompt}") - self.__logger.debug(f"Output negative_prompt: {negative_prompt}") + self.__logger.debug(f"Output prompt: {prompt.encode('unicode_escape').decode('utf-8')}") + self.__logger.debug(f"Output negative_prompt: {negative_prompt.encode('unicode_escape').decode('utf-8')}") return prompt, negative_prompt except Exception as e: # pylint: disable=broad-exception-caught self.__logger.exception(e) return original_prompt, original_negative_prompt def __find_tags(self, prompt): - already_processed = [] add_at = {"start": [], "insertion_point": [[] for x in range(10)], "end": []} - # process tags in prompt - matches = self.__regex.findall(prompt) - for match in matches: - position = match[1] or "s" - content = match[2] - if len(content) > 0: - if content not in already_processed: - if self.__ignore_repeats: - already_processed.append(content) - self.__logger.debug(f"Processing content at position {position}: {content}") - if position == "e": - add_at["end"].append(content) - elif position.startswith("p"): - n = int(position[1]) - add_at["insertion_point"][n].append(content) - else: # position == "s" or invalid - add_at["start"].append(content) - else: - self.__logger.warning(f"Ignoring repeated content: {content}") - # clean-up - prompt = prompt.replace(match[0], "") - if self.__cleanup: - prompt = ( - prompt.replace(" ", " ") - .replace(self.__separator + self.__separator, self.__separator) - .replace(" " + self.__separator, self.__separator) - .removeprefix(self.__separator) - .removesuffix(self.__separator) - .strip() - ) + tree = self.__schedule_parser.parse(prompt) + self.__logger.debug(f"Initial tree: {tree.pretty()}") + + readtree = self.ReadTree(self.__logger, self.__ignore_repeats, self.__join_attention, prompt, add_at) + readtree.visit(tree) + + for r in readtree.remove[::-1]: + prompt = prompt[: r[0]] + prompt[r[1] :] + if self.__cleanup: + prompt = re.sub(r"\((?::[\d\.]+)?\)", "", prompt) # clean up empty attention + prompt = re.sub(r"\[\]", "", prompt) # clean up empty attention + prompt = re.sub(r"\[:?:[\d\.]+\]", "", prompt) # clean up empty scheduling + prompt = re.sub(r"\[\|+\]", "", prompt) # clean up empty alternation + # clean up whitespace and extra separators + prompt = ( + prompt.replace(" ", " ") + .replace(self.__separator + self.__separator, self.__separator) + .replace(" " + self.__separator, self.__separator) + .removeprefix(self.__separator) + .removesuffix(self.__separator) + .strip() + ) + add_at = readtree.add_at + self.__logger.debug(f"New negative additions: {add_at}") + return prompt, add_at def __add_to_insertion_points(self, negative_prompt, add_at_insertion_point): diff --git a/tests/tests.py b/tests/tests.py index 3d72ee2..b9ef4b7 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,3 +1,4 @@ +import logging import unittest import sys import os @@ -10,14 +11,12 @@ class TestSendToNegative(unittest.TestCase): def setUp(self): self.defstn = SendToNegative( - tag_start="", - tag_param_start="!", - tag_param_end="!", separator=", ", ignore_repeats=True, + join_attention=True, cleanup=True, ) + logging.basicConfig(level=logging.DEBUG) def process( self, @@ -110,29 +109,82 @@ def test_tag_insertion_end_no_sep(self): def test_complex(self): self.process( - " , flowers , , ", + " (), flowers , , ", "normal quality, , bad quality, worse quality", "flowers", - "red, pink, normal quality, yellow, bad quality, green, worse quality, purple, blue", + "red, (pink), normal quality, yellow, bad quality, green, worse quality, purple, blue", ) def test_complex_no_cleanup(self): self.process( - " , flowers , , ", + " (), flowers , , ", "normal quality, , bad quality, worse quality", - " , flowers , , ", - "red, pink, normal quality, yellow, bad quality, green, worse quality, purple, blue", + " (), flowers , , ", + "red, (pink), normal quality, yellow, bad quality, green, worse quality, purple, blue", SendToNegative( - tag_start="", - tag_param_start="!", - tag_param_end="!", separator=", ", ignore_repeats=True, + join_attention=True, cleanup=False, ), ) + def test_inside_attention1(self): + self.process( + "[] this is a ((test) (test:2.0):1.5)", + "normal quality", + "this is a ((test) (test:2.0):1.5)", + "[neg1], normal quality, (neg2:1.65)", + ) + + def test_inside_attention2(self): + self.process( + "(red:1.5)", + "", + "(red:1.5)", + "([square]:1.5)", + ) + + def test_inside_alternation1(self): + self.process( + "this is a (([complex|simple|regular] test)(test:2.0):1.5)", + "normal quality", + "this is a (([complex|simple|regular] test)(test:2.0):1.5)", + "([|neg1|]:1.65), normal quality", + ) + + def test_inside_alternation2(self): + self.process( + "this is a (([complex|simple|regular] test)(test:2.0):1.5)", + "normal quality", + "this is a (([complex|simple|regular] test)(test:2.0):1.5)", + "([neg1||]:1.65), ([|neg2|]:1.65), ([||neg3]:1.65), normal quality", + ) + + def test_inside_alternation3(self): + self.process( + "this is a (([complex[one|two|three|four()]|simple|regular] test)(test:2.0):1.5)", + "normal quality", + "this is a (([complex[one|two|three|four]|simple|regular] test)(test:2.0):1.5)", + "([neg1||]:1.65), ([[|neg12||]||]:1.65), ([[|||(neg14)]||]:1.65), ([|neg2|]:1.65), ([||neg3]:1.65), normal quality", + ) + + def test_inside_scheduling(self): + self.process( + "this is [abc:def:5]", + "normal quality", + "this is [abc:def:5]", + "[neg1::5], normal quality, [neg2:5]", + ) + + def test_complex_features(self): + self.process( + "[] this is: a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK with [abc:def:5] ", + "normal quality, ", + "this is: a (([complex|simple|regular] test)(test:2.0):1.5) \nBREAK with [abc:def:5] ", + "[neg5], ([|neg6|]:1.65), (neg1:1.65), [neg4::5], normal quality, [neg2(neg3:1.6):5]", + ) + if __name__ == "__main__": unittest.main()