diff --git a/reposter/core/types.py b/reposter/core/types.py index 44dbb21..a485f8d 100644 --- a/reposter/core/types.py +++ b/reposter/core/types.py @@ -22,6 +22,20 @@ pyrogram.types.Contact, pyrogram.types.Sticker, ] +in_group_media = typing.Union[ + pyrogram.types.Audio, + pyrogram.types.Photo, + pyrogram.types.Video, + pyrogram.types.Document, + pyrogram.types.Animation, +] +input_media = typing.Union[ + pyrogram.types.InputMediaAudio, + pyrogram.types.InputMediaPhoto, + pyrogram.types.InputMediaVideo, + pyrogram.types.InputMediaDocument, + pyrogram.types.InputMediaAnimation, +] class Progress(rich.progress.Progress): diff --git a/reposter/handlers/on_msg.py b/reposter/handlers/on_msg.py index 45524f1..8c89003 100644 --- a/reposter/handlers/on_msg.py +++ b/reposter/handlers/on_msg.py @@ -42,11 +42,18 @@ async def on_new_msg( await stream_notify.notify_all() return if src_msg.has_protected_content or src_msg.chat.has_protected_content: - real_time_resend = reposter.handlers.resend_restricted.ResendRestricted( - src_msg=src_msg, - target_any=self.target_any, - ) - await real_time_resend.resend_all() + if src_msg.media_group_id: + resend_media_group = reposter.handlers.resend_restricted.ResendMediaGroup( + src_msg=src_msg, + target_any=self.target_any, + ) + await resend_media_group.all() + else: + resend_one = reposter.handlers.resend_restricted.ResendOne( + src_msg=src_msg, + target_any=self.target_any, + ) + await resend_one.all() else: real_time_forward = reposter.handlers.forward_unrestricted.ForwardUnrestricted( target_any=self.target_any, diff --git a/reposter/handlers/resend_restricted.py b/reposter/handlers/resend_restricted.py index 90311ad..17950b2 100644 --- a/reposter/handlers/resend_restricted.py +++ b/reposter/handlers/resend_restricted.py @@ -1,15 +1,17 @@ import reposter.handlers.forward_unrestricted +import reposter.handlers.other import reposter.funcs.handle import reposter.funcs.logging import reposter.funcs.other import reposter.core.config +import reposter.core.common import reposter.core.types import reposter.tg.restricted import reposter.db.models import pyrogram.types -class ResendRestricted: +class ResendOne: def __init__( self, target_any: reposter.core.types.target, @@ -19,39 +21,34 @@ def __init__( self.src_msg = src_msg assert isinstance(self.target_any, (str, int, list)) - async def resend_all(self) -> None: + async def all(self): if isinstance(self.target_any, list): - assert reposter.core.config.json.logs_chat - resent_to_log_chat = await reposter.funcs.handle.run_excepted( - self.resend_one, - src_msg=self.src_msg, - target=reposter.core.config.json.logs_chat, - ) - real_time_forward = reposter.handlers.forward_unrestricted.ForwardUnrestricted( - target_any=self.target_any, - src_to_forward=resent_to_log_chat, - src_in_db=self.src_msg, - ) - return await real_time_forward.copy_all() + await self.multiple_targets() elif isinstance(self.target_any, (str, int)): - target_msg = await self.resend_one( - src_msg=self.src_msg, + await self.one_target( target=self.target_any, - ) - await reposter.db.models.Msg.create( - hash = reposter.funcs.other.get_hash(self.src_msg), - src_msg=self.src_msg.id, - src_chat=self.src_msg.chat, - target_msg=target_msg.id, - target_chat=target_msg.chat.id + save_db=True, ) else: raise AssertionError - async def resend_one( + async def multiple_targets(self) -> None: + assert reposter.core.config.json.logs_chat + resent_to_log_chat = await self.one_target( + target=reposter.core.config.json.logs_chat, + save_db=False, + ) + real_time_forward = reposter.handlers.forward_unrestricted.ForwardUnrestricted( + target_any=self.target_any, + src_to_forward=resent_to_log_chat, + src_in_db=self.src_msg, + ) + await real_time_forward.copy_all() + + async def one_target( self, - src_msg: pyrogram.types.Message, target: str | int, + save_db: bool, ) -> pyrogram.types.Message: resender = reposter.tg.restricted.Resender( src_msg=self.src_msg, @@ -63,8 +60,95 @@ async def resend_one( assert target_msg reposter.funcs.logging.log_msg( to_log='[green]\\[success resend][/]', - src_msg=src_msg, + src_msg=self.src_msg, target_msg=target_msg, ) + if save_db: + await reposter.db.models.Msg.create( + hash = reposter.funcs.other.get_hash(self.src_msg), + src_msg=self.src_msg.id, + src_chat=self.src_msg.chat, + target_msg=target_msg.id, + target_chat=target_msg.chat.id + ) return target_msg + + +class ResendMediaGroup: + def __init__( + self, + target_any: reposter.core.types.target, + src_msg: pyrogram.types.Message, + ) -> None: + self.target_any: reposter.core.types.target = target_any + self.src_msg = src_msg + self.media: list[reposter.core.types.input_media] = [] + self.src_media_group: list[pyrogram.types.Message] + assert isinstance(self.target_any, (str, int, list)) + + async def all(self) -> None: + self.src_media_group = await self.src_msg.get_media_group() + if self.src_msg.id != self.src_media_group[0].id: + return + for src_msg in self.src_media_group: + self.media.append( + await self.get_media(src_msg) + ) + await reposter.handlers.other.parse_targets( + target_any=self.target_any, + to_call=self.one_target, + ) + + async def one_target( + self, + target: str | int, + ): + target_media_group = await reposter.core.common.tg.client.send_media_group( + chat_id=target, + media=self.media, + ) + for src_msg, target_msg in zip( + self.src_media_group, + target_media_group, + ): + await reposter.db.models.Msg.create( + hash = reposter.funcs.other.get_hash(self.src_msg), + src_msg=src_msg.id, + src_chat=src_msg.chat.id, + target_msg=target_msg.id, + target_chat=target_msg.chat.id + ) + reposter.funcs.logging.log_msg( + to_log='[green]\\[success resend media group][/]', + src_msg=src_msg, + target_msg=target_msg, + ) + + async def get_media( + self, + src_msg: pyrogram.types.Message, + ) -> reposter.core.types.input_media: + resend_one = ResendOne( + target_any=reposter.core.config.json.logs_chat, + src_msg=src_msg, + ) + resent_to_log_chat = await resend_one.one_target( + target=reposter.core.config.json.logs_chat, + save_db=False, + ) + if resent_to_log_chat.caption: + caption: str = resent_to_log_chat.caption.markdown + else: + caption: str = '' + media_value = str(resent_to_log_chat.media.value) + media = getattr(resent_to_log_chat, media_value) + assert isinstance(media, reposter.core.types.in_group_media) + input_media = getattr(pyrogram.types, 'InputMedia' + media_value.capitalize()) + inputted = input_media( + media=media.file_id, + caption=caption + ) + assert isinstance(inputted, reposter.core.types.input_media) + return inputted +