Skip to content

Commit

Permalink
🐛 Entity._length is utf-16 length
Browse files Browse the repository at this point in the history
  • Loading branch information
j1g5awi committed Jul 17, 2024
1 parent 56b20c3 commit c67fc37
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions nonebot/adapters/telegram/message.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing_extensions import override
from typing import Any, Union, Literal, Iterable, Optional
from typing import Any, Union, Literal, TypeVar, Iterable, Optional

from nonebot.adapters import Message as BaseMessage
from nonebot.adapters import MessageSegment as BaseMessageSegment

from .model import User, MessageEntity

TMS = TypeVar("TMS", bound="MessageSegment")


class MessageSegment(BaseMessageSegment):
"""
Expand Down Expand Up @@ -181,6 +183,12 @@ def markup():


class Entity(MessageSegment):
def __init__(self, type: str, data: dict[str, Any], _length: int = 0):
super().__init__(type, data)
self._length = _length
if _length == 0:
self._length = len(self.data["text"].encode("utf-16-le")) // 2

@override
def is_text(self) -> bool:
return True
Expand Down Expand Up @@ -270,14 +278,28 @@ def custom_emoji(text: str, custom_emoji_id: str) -> "Entity":
def from_telegram_entities(text, entities: list[dict[str, Any]]) -> list["Entity"]:
nb_entites = []
offset = 0
text = text.encode("utf-16-le")
for entity in entities:
if entity["offset"] > offset:
nb_entites.append(
Entity("text", {"text": text[offset : entity["offset"]]})
Entity(
"text",
{
"text": text[offset * 2 : entity["offset"] * 2].decode(
"utf-16-le"
)
},
entity["offset"] - offset,
)
)
nb_entity = Entity(
entity["type"],
{"text": text[entity["offset"] : entity["offset"] + entity["length"]]},
{
"text": text[
entity["offset"] * 2 : (entity["offset"] + entity["length"]) * 2
].decode("utf-16-le")
},
entity["length"],
)
if "language" in entity:
nb_entity.data["language"] = entity["language"]
Expand All @@ -290,21 +312,28 @@ def from_telegram_entities(text, entities: list[dict[str, Any]]) -> list["Entity
nb_entites.append(nb_entity)
offset = entity["offset"] + entity["length"]
if offset < len(text):
nb_entites.append(Entity("text", {"text": text[offset:]}))
nb_entites.append(
Entity(
"text",
{"text": text[offset * 2 :].decode("utf-16-le")},
(len(text) - offset * 2) // 2,
)
)
return nb_entites

@staticmethod
def build_telegram_entities(entities: "Message") -> list[MessageEntity]:
def build_telegram_entities(entities: list["Entity"]) -> list[MessageEntity]:
return (
(
[
MessageEntity(
type=entity.type, # type: ignore
offset=sum(map(len, entities[:i])),
length=len(entity.data["text"]),
offset=sum(map(lambda _: _._length, entities[:i])),
length=entity._length,
url=entity.data.get("url"),
user=entity.data.get("user"),
language=entity.data.get("language"),
custom_emoji_id=entity.data.get("custom_emoji_id"),
)
for i, entity in enumerate(entities)
if entity.is_text() and entity.type != "text"
Expand Down Expand Up @@ -380,7 +409,7 @@ def video_note(
return File("video_note", {"file": file, "thumbnail": thumbnail})


class Message(BaseMessage[MessageSegment]):
class Message(BaseMessage[TMS]):
def __repr__(self) -> str:
return "".join(repr(seg) for seg in self)

Expand Down

0 comments on commit c67fc37

Please sign in to comment.