Skip to content

Commit

Permalink
💥 version 0.4.0
Browse files Browse the repository at this point in the history
separate MessageSegment.entity
  • Loading branch information
RF-Tar-Railt committed Oct 13, 2023
1 parent 01f4d09 commit 15dfdba
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 65 deletions.
4 changes: 3 additions & 1 deletion nonebot/adapters/satori/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _check_reply(
except ValueError:
return

event.to_me = True
msg_seg = message[index]

event.reply = msg_seg # type: ignore
Expand Down Expand Up @@ -80,6 +81,7 @@ def _is_at_me_seg(segment: MessageSegment) -> bool:
deleted = False
if _is_at_me_seg(message[0]):
message.pop(0)
event.to_me = True
deleted = True
if message and message[0].type == "text":
message[0].data["text"] = message[0].data["text"].lstrip("\xa0").lstrip()
Expand All @@ -99,7 +101,7 @@ def _is_at_me_seg(segment: MessageSegment) -> bool:
last_msg_seg = message[i]

if _is_at_me_seg(last_msg_seg):
deleted = True
event.to_me = True
del message[i:]

if not message:
Expand Down
246 changes: 183 additions & 63 deletions nonebot/adapters/satori/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def get_message_class(cls) -> Type["Message"]:
def text(content: str) -> "Text":
return Text("text", {"text": content})

@staticmethod
def entity(content: str, style: str) -> "Entity":
return Entity("entity", {"text": content, "style": style})

@staticmethod
def at(
user_id: str,
Expand Down Expand Up @@ -68,7 +64,7 @@ def sharp(channel_id: str, name: Optional[str] = None) -> "Sharp":

@staticmethod
def link(href: str) -> "Link":
return Link("link", {"href": href})
return Link("link", {"text": href})

@staticmethod
def image(
Expand Down Expand Up @@ -114,9 +110,41 @@ def file(
data["timeout"] = timeout
return File("file", data)

@staticmethod
def bold(text: str) -> "Bold":
return Bold("bold", {"text": text})

@staticmethod
def italic(text: str) -> "Italic":
return Italic("italic", {"text": text})

@staticmethod
def underline(text: str) -> "Underline":
return Underline("underline", {"text": text})

@staticmethod
def strikethrough(text: str) -> "Strikethrough":
return Strikethrough("strikethrough", {"text": text})

@staticmethod
def spoiler(text: str) -> "Spoiler":
return Spoiler("spoiler", {"text": text})

@staticmethod
def code(text: str) -> "Code":
return Code("code", {"text": text})

@staticmethod
def superscript(text: str) -> "Superscript":
return Superscript("superscript", {"text": text})

@staticmethod
def subscript(text: str) -> "Subscript":
return Subscript("subscript", {"text": text})

@staticmethod
def br() -> "Br":
return Br("br", {})
return Br("br", {"text": "\n"})

@staticmethod
def paragraph(text: str) -> "Paragraph":
Expand Down Expand Up @@ -165,7 +193,7 @@ def author(

@override
def is_text(self) -> bool:
return self.type == "text"
return False


class TextData(TypedDict):
Expand All @@ -180,20 +208,9 @@ class Text(MessageSegment):
def __str__(self) -> str:
return escape(self.data["text"])


class EntityData(TypedDict):
text: str
style: str


@dataclass
class Entity(MessageSegment):
data: EntityData = field(default_factory=dict)

@override
def __str__(self) -> str:
style = self.data["style"]
return f'<{style}>{escape(self.data["text"])}</{style}>'
def is_text(self) -> bool:
return True


class AtData(TypedDict):
Expand All @@ -218,17 +235,17 @@ class Sharp(MessageSegment):
data: SharpData = field(default_factory=dict)


class LinkData(TypedDict):
href: str


@dataclass
class Link(MessageSegment):
data: LinkData = field(default_factory=dict)
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<a href="{escape(self.data["href"])}"/>'
return f'<a href="{escape(self.data["text"])}"/>'

@override
def is_text(self) -> bool:
return True


class ImageData(TypedDict):
Expand Down Expand Up @@ -277,25 +294,135 @@ class File(MessageSegment):
data: FileData = field(default_factory=dict)


@dataclass
class Bold(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<b>{escape(self.data["text"])}</b>'

@override
def is_text(self) -> bool:
return True


@dataclass
class Italic(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<i>{escape(self.data["text"])}</i>'

@override
def is_text(self) -> bool:
return True


@dataclass
class Underline(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<u>{escape(self.data["text"])}</u>'

@override
def is_text(self) -> bool:
return True


@dataclass
class Strikethrough(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<s>{escape(self.data["text"])}</s>'

@override
def is_text(self) -> bool:
return True


@dataclass
class Spoiler(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<spl>{escape(self.data["text"])}</spl>'

@override
def is_text(self) -> bool:
return True


@dataclass
class Code(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<code>{escape(self.data["text"])}</code>'

@override
def is_text(self) -> bool:
return True


@dataclass
class Superscript(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<sup>{escape(self.data["text"])}</sup>'

@override
def is_text(self) -> bool:
return True


@dataclass
class Subscript(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<sub>{escape(self.data["text"])}</sub>'

@override
def is_text(self) -> bool:
return True


@dataclass
class Br(MessageSegment):
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return "<br/>"


class ParagraphData(TypedDict):
text: str
@override
def is_text(self) -> bool:
return True


@dataclass
class Paragraph(MessageSegment):
data: ParagraphData = field(default_factory=dict)
data: TextData = field(default_factory=dict)

@override
def __str__(self):
return f'<p>{escape(self.data["text"])}</p>'

@override
def is_text(self) -> bool:
return True


class RenderMessageData(TypedDict):
id: NotRequired[str]
Expand Down Expand Up @@ -335,17 +462,30 @@ class Author(MessageSegment):
"text": (Text, "text"),
"at": (At, "at"),
"sharp": (Sharp, "sharp"),
"a": (Link, "link"),
"link": (Link, "link"),
"img": (Image, "img"),
"image": (Image, "img"),
"audio": (Audio, "audio"),
"video": (Video, "video"),
"file": (File, "file"),
"br": (Br, "br"),
"author": (Author, "author"),
}

STYLE_TYPE_MAP = {
"b": (Bold, "bold"),
"strong": (Bold, "bold"),
"i": (Italic, "italic"),
"em": (Italic, "italic"),
"u": (Underline, "underline"),
"ins": (Underline, "underline"),
"s": (Strikethrough, "strikethrough"),
"del": (Strikethrough, "strikethrough"),
"spl": (Spoiler, "spoiler"),
"code": (Code, "code"),
"sup": (Superscript, "superscript"),
"sub": (Subscript, "subscript"),
"p": (Paragraph, "paragraph"),
}


class Message(BaseMessage[MessageSegment]):
@classmethod
Expand Down Expand Up @@ -381,30 +521,13 @@ def from_satori_element(cls, elements: List[Element]) -> "Message":
if elem.type in ELEMENT_TYPE_MAP:
seg_cls, seg_type = ELEMENT_TYPE_MAP[elem.type]
msg.append(seg_cls(seg_type, elem.attrs.copy()))
elif elem.type in {
"b",
"strong",
"i",
"em",
"u",
"ins",
"s",
"del",
"spl",
"code",
"sup",
"sub",
}:
msg.append(
Entity(
"entity",
{"text": elem.children[0].attrs["text"], "style": elem.type},
)
)
elif elem.type in ("p", "paragraph"):
msg.append(
Paragraph("paragraph", {"text": elem.children[0].attrs["text"]})
)
elif elem.type in ("a", "link"):
msg.append(Link("link", {"text": elem.attrs["href"]}))
elif elem.type in STYLE_TYPE_MAP:
seg_cls, seg_type = STYLE_TYPE_MAP[elem.type]
msg.append(seg_cls(seg_type, {"text": elem.children[0].attrs["text"]}))
elif elem.type in ("br", "newline"):
msg.append(Br("br", {"text": "\n"}))
elif elem.type in ("message", "quote"):
data = elem.attrs.copy()
if elem.children:
Expand All @@ -414,9 +537,6 @@ def from_satori_element(cls, elements: List[Element]) -> "Message":
msg.append(Text("text", {"text": str(elem)}))
return msg

def extract_content(self) -> str:
return "".join(
str(seg)
for seg in self
if seg.type in ("text", "entity", "at", "sharp", "link")
)
@override
def extract_plain_text(self) -> str:
return "".join(seg.data["text"] for seg in self if seg.is_text())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nonebot-adapter-satori"
version = "0.3.0"
version = "0.4.0"
description = "Satori Protocol Adapter for Nonebot2"
authors = [
{name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"},
Expand Down

0 comments on commit 15dfdba

Please sign in to comment.