From c67fc3725499e85a113f7166233b7304be4405aa Mon Sep 17 00:00:00 2001 From: jigsaw Date: Wed, 17 Jul 2024 17:35:33 +0800 Subject: [PATCH] :bug: Entity._length is utf-16 length --- nonebot/adapters/telegram/message.py | 45 +++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/nonebot/adapters/telegram/message.py b/nonebot/adapters/telegram/message.py index 3faa190..2a24bd0 100644 --- a/nonebot/adapters/telegram/message.py +++ b/nonebot/adapters/telegram/message.py @@ -1,11 +1,13 @@ from typing_extensions import override -from typing import Any, Union, Literal, Iterable, Optional +from typing import Any, Union, Literal, TypeVar, Iterable, Optional from nonebot.adapters import Message as BaseMessage from nonebot.adapters import MessageSegment as BaseMessageSegment from .model import User, MessageEntity +TMS = TypeVar("TMS", bound="MessageSegment") + class MessageSegment(BaseMessageSegment): """ @@ -181,6 +183,12 @@ def markup(): class Entity(MessageSegment): + def __init__(self, type: str, data: dict[str, Any], _length: int = 0): + super().__init__(type, data) + self._length = _length + if _length == 0: + self._length = len(self.data["text"].encode("utf-16-le")) // 2 + @override def is_text(self) -> bool: return True @@ -270,14 +278,28 @@ def custom_emoji(text: str, custom_emoji_id: str) -> "Entity": def from_telegram_entities(text, entities: list[dict[str, Any]]) -> list["Entity"]: nb_entites = [] offset = 0 + text = text.encode("utf-16-le") for entity in entities: if entity["offset"] > offset: nb_entites.append( - Entity("text", {"text": text[offset : entity["offset"]]}) + Entity( + "text", + { + "text": text[offset * 2 : entity["offset"] * 2].decode( + "utf-16-le" + ) + }, + entity["offset"] - offset, + ) ) nb_entity = Entity( entity["type"], - {"text": text[entity["offset"] : entity["offset"] + entity["length"]]}, + { + "text": text[ + entity["offset"] * 2 : (entity["offset"] + entity["length"]) * 2 + ].decode("utf-16-le") + }, + entity["length"], ) if "language" in entity: nb_entity.data["language"] = entity["language"] @@ -290,21 +312,28 @@ def from_telegram_entities(text, entities: list[dict[str, Any]]) -> list["Entity nb_entites.append(nb_entity) offset = entity["offset"] + entity["length"] if offset < len(text): - nb_entites.append(Entity("text", {"text": text[offset:]})) + nb_entites.append( + Entity( + "text", + {"text": text[offset * 2 :].decode("utf-16-le")}, + (len(text) - offset * 2) // 2, + ) + ) return nb_entites @staticmethod - def build_telegram_entities(entities: "Message") -> list[MessageEntity]: + def build_telegram_entities(entities: list["Entity"]) -> list[MessageEntity]: return ( ( [ MessageEntity( type=entity.type, # type: ignore - offset=sum(map(len, entities[:i])), - length=len(entity.data["text"]), + offset=sum(map(lambda _: _._length, entities[:i])), + length=entity._length, url=entity.data.get("url"), user=entity.data.get("user"), language=entity.data.get("language"), + custom_emoji_id=entity.data.get("custom_emoji_id"), ) for i, entity in enumerate(entities) if entity.is_text() and entity.type != "text" @@ -380,7 +409,7 @@ def video_note( return File("video_note", {"file": file, "thumbnail": thumbnail}) -class Message(BaseMessage[MessageSegment]): +class Message(BaseMessage[TMS]): def __repr__(self) -> str: return "".join(repr(seg) for seg in self)