From 6605804fa7494a09e108d7013cb4ba8fb347ea23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BC=A8=E7=BC=A8?= Date: Tue, 24 Dec 2024 17:52:53 +0800 Subject: [PATCH] refactor: handle the failed task message (#600) * refactor: handle the failed task message * chore: release petercat-utils/0.1.41 --- petercat_utils/rag_helper/git_doc_task.py | 2 + petercat_utils/rag_helper/git_issue_task.py | 66 ++++++++++++--------- petercat_utils/rag_helper/git_task.py | 12 +++- petercat_utils/rag_helper/task.py | 20 ++----- pyproject.toml | 2 +- server/aws/service.py | 29 +++++---- subscriber/handler.py | 30 ++++++---- 7 files changed, 91 insertions(+), 70 deletions(-) diff --git a/petercat_utils/rag_helper/git_doc_task.py b/petercat_utils/rag_helper/git_doc_task.py index 74d16e0d..08f03e15 100644 --- a/petercat_utils/rag_helper/git_doc_task.py +++ b/petercat_utils/rag_helper/git_doc_task.py @@ -68,6 +68,7 @@ def __init__( status=TaskStatus.NOT_STARTED, from_id=None, id=None, + retry_count=0, ): super().__init__( type=TaskType.GIT_DOC, @@ -75,6 +76,7 @@ def __init__( id=id, status=status, repo_name=repo_name, + retry_count=retry_count, ) self.commit_id = commit_id self.node_type = GitDocTaskNodeType(node_type) diff --git a/petercat_utils/rag_helper/git_issue_task.py b/petercat_utils/rag_helper/git_issue_task.py index fac2a077..a8690422 100644 --- a/petercat_utils/rag_helper/git_issue_task.py +++ b/petercat_utils/rag_helper/git_issue_task.py @@ -11,10 +11,10 @@ def add_rag_git_issue_task(config: RAGGitIssueConfig): g.get_repo(config.repo_name) issue_task = GitIssueTask( - issue_id='', + issue_id="", node_type=GitIssueTaskNodeType.REPO, bot_id=config.bot_id, - repo_name=config.repo_name + repo_name=config.repo_name, ) res = issue_task.save() issue_task.send() @@ -26,17 +26,26 @@ class GitIssueTask(GitTask): issue_id: str node_type: GitIssueTaskNodeType - def __init__(self, - issue_id, - node_type: GitIssueTaskNodeType, - bot_id, - repo_name, - status=TaskStatus.NOT_STARTED, - from_id=None, - id=None - ): - super().__init__(bot_id=bot_id, type=TaskType.GIT_ISSUE, from_id=from_id, id=id, status=status, - repo_name=repo_name) + def __init__( + self, + issue_id, + node_type: GitIssueTaskNodeType, + bot_id, + repo_name, + status=TaskStatus.NOT_STARTED, + from_id=None, + id=None, + retry_count=0, + ): + super().__init__( + bot_id=bot_id, + type=TaskType.GIT_ISSUE, + from_id=from_id, + id=id, + status=status, + repo_name=repo_name, + retry_count=retry_count, + ) self.issue_id = issue_id self.node_type = GitIssueTaskNodeType(node_type) @@ -75,27 +84,28 @@ def handle_repo_node(self): if len(task_list) > 0: result = self.get_table().insert(task_list).execute() for record in result.data: - issue_task = GitIssueTask(id=record["id"], - issue_id=record["issue_id"], - repo_name=record["repo_name"], - node_type=record["node_type"], - bot_id=record["bot_id"], - status=record["status"], - from_id=record["from_task_id"] - ) + issue_task = GitIssueTask( + id=record["id"], + issue_id=record["issue_id"], + repo_name=record["repo_name"], + node_type=record["node_type"], + bot_id=record["bot_id"], + status=record["status"], + from_id=record["from_task_id"], + ) issue_task.send() - return (self.get_table().update( - {"status": TaskStatus.COMPLETED.value}) - .eq("id", self.id) - .execute()) + return ( + self.get_table() + .update({"status": TaskStatus.COMPLETED.value}) + .eq("id", self.id) + .execute() + ) def handle_issue_node(self): issue_retrieval.add_knowledge_by_issue( RAGGitIssueConfig( - repo_name=self.repo_name, - bot_id=self.bot_id, - issue_id=self.issue_id + repo_name=self.repo_name, bot_id=self.bot_id, issue_id=self.issue_id ) ) return self.update_status(TaskStatus.COMPLETED) diff --git a/petercat_utils/rag_helper/git_task.py b/petercat_utils/rag_helper/git_task.py index 61d6cf58..8b37f152 100644 --- a/petercat_utils/rag_helper/git_task.py +++ b/petercat_utils/rag_helper/git_task.py @@ -24,12 +24,14 @@ def __init__( status=TaskStatus.NOT_STARTED, from_id=None, id=None, + retry_count=0, ): self.type = type self.id = id self.from_id = from_id self.status = status self.repo_name = repo_name + self.retry_count = retry_count @staticmethod def get_table_name(type: TaskType): @@ -82,11 +84,17 @@ def send(self): QueueUrl=SQS_QUEUE_URL, DelaySeconds=10, MessageBody=( - json.dumps({"task_id": self.id, "task_type": self.type.value}) + json.dumps( + { + "task_id": self.id, + "task_type": self.type.value, + "retry_count": self.retry_count, + } + ) ), ) message_id = response["MessageId"] print( - f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}" + f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}, retry_count={self.retry_count}" ) return message_id diff --git a/petercat_utils/rag_helper/task.py b/petercat_utils/rag_helper/task.py index 144632ab..c9e8ab67 100644 --- a/petercat_utils/rag_helper/task.py +++ b/petercat_utils/rag_helper/task.py @@ -23,15 +23,6 @@ SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL") -def send_task_message(task_id: str): - response = sqs.send_message( - QueueUrl=SQS_QUEUE_URL, - DelaySeconds=10, - MessageBody=(json.dumps({"task_id": task_id})), - ) - return response["MessageId"] - - def get_oldest_task(): supabase = get_client() @@ -54,10 +45,7 @@ def get_task_by_id(task_id): return response.data[0] if (len(response.data) > 0) else None -def get_task( - task_type: TaskType, - task_id: str, -) -> GitTask: +def get_task(task_type: TaskType, task_id: str, retry_count=0) -> GitTask: supabase = get_client() response = ( supabase.table(GitTask.get_table_name(task_type)) @@ -77,6 +65,7 @@ def get_task( path=data["path"], status=data["status"], from_id=data["from_task_id"], + retry_count=retry_count, ) if task_type == TaskType.GIT_ISSUE: return GitIssueTask( @@ -87,11 +76,12 @@ def get_task( bot_id=data["bot_id"], status=data["status"], from_id=data["from_task_id"], + retry_count=retry_count, ) -def trigger_task(task_type: TaskType, task_id: Optional[str]): - task = get_task(task_type, task_id) if task_id else get_oldest_task() +def trigger_task(task_type: TaskType, task_id: Optional[str], retry_count: int = 0): + task = get_task(task_type, task_id, retry_count) if task_id else get_oldest_task() if task is None: return task return task.handle() diff --git a/pyproject.toml b/pyproject.toml index 6790e442..a855c8f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "petercat_utils" -version = "0.1.40" +version = "0.1.41" description = "" authors = ["raoha.rh "] readme = "README.md" diff --git a/server/aws/service.py b/server/aws/service.py index d451b317..fdc82133 100644 --- a/server/aws/service.py +++ b/server/aws/service.py @@ -14,30 +14,32 @@ STATIC_SECRET_NAME = get_env_variable("STATIC_SECRET_NAME") STATIC_KEYPAIR_ID = get_env_variable("STATIC_KEYPAIR_ID") + def rsa_signer(message): private_key_str = get_private_key(STATIC_SECRET_NAME) - private_key = rsa.PrivateKey.load_pkcs1(private_key_str.encode('utf-8')) - return rsa.sign(message, private_key, 'SHA-1') + private_key = rsa.PrivateKey.load_pkcs1(private_key_str.encode("utf-8")) + return rsa.sign(message, private_key, "SHA-1") + def create_signed_url(url, expire_minutes=60) -> str: cloudfront_signer = CloudFrontSigner(STATIC_KEYPAIR_ID, rsa_signer) - + # 设置过期时间 expire_date = datetime.now() + timedelta(minutes=expire_minutes) - + # 创建签名 URL signed_url = cloudfront_signer.generate_presigned_url( - url=url, - date_less_than=expire_date + url=url, date_less_than=expire_date ) - + return signed_url + def upload_image_to_s3(file, metadata: ImageMetaData, s3_client): try: file_content = file.file.read() md5_hash = hashlib.md5() - md5_hash.update(file.filename.encode('utf-8')) + md5_hash.update(file.filename.encode("utf-8")) s3_key = md5_hash.hexdigest() encoded_filename = ( base64.b64encode(metadata.title.encode("utf-8")).decode("utf-8") @@ -62,11 +64,12 @@ def upload_image_to_s3(file, metadata: ImageMetaData, s3_client): ContentType=file.content_type, Metadata=custom_metadata, ) - # you need to redirect your static domain to your s3 bucket domain s3_url = f"{STATIC_URL}/{s3_key}" - signed_url = create_signed_url(url=s3_url, expire_minutes=60) \ - if (STATIC_SECRET_NAME and STATIC_KEYPAIR_ID) \ - else s3_url - return {"message": "File uploaded successfully", "url": signed_url } + signed_url = ( + create_signed_url(url=s3_url, expire_minutes=60) + if (STATIC_SECRET_NAME and STATIC_KEYPAIR_ID) + else s3_url + ) + return {"message": "File uploaded successfully", "url": signed_url} except Exception as e: raise UploadError(detail=str(e)) diff --git a/subscriber/handler.py b/subscriber/handler.py index 84ba74e7..b4132d25 100644 --- a/subscriber/handler.py +++ b/subscriber/handler.py @@ -3,6 +3,8 @@ from petercat_utils import task as task_helper from petercat_utils.data_class import TaskType +MAX_RETRY_COUNT = 5 + def lambda_handler(event, context): if event: @@ -10,23 +12,29 @@ def lambda_handler(event, context): sqs_batch_response = {} for record in event["Records"]: - try: - body = record["body"] - print(f"receive message here: {body}") + body = record["body"] + print(f"receive message here: {body}") - message_dict = json.loads(body) - task_id = message_dict["task_id"] - task_type = message_dict["task_type"] - task = task_helper.get_task(TaskType(task_type), task_id) + message_dict = json.loads(body) + task_id = message_dict["task_id"] + task_type = message_dict["task_type"] + retry_count = message_dict["retry_count"] + task = task_helper.get_task(TaskType(task_type), task_id) + try: if task is None: return task task.handle() - # process message - print(f"message content: message={message_dict}, task_id={task_id}, task={task}") + print( + f"message content: message={message_dict}, task_id={task_id}, task={task}, retry_count={retry_count}" + ) except Exception as e: - print(f"message handle error: ${e}") - batch_item_failures.append({"itemIdentifier": record['messageId']}) + if retry_count < MAX_RETRY_COUNT: + retry_count += 1 + task_helper.trigger_task(task_type, task_id, retry_count) + else: + print(f"message handle error: ${e}") + batch_item_failures.append({"itemIdentifier": record["messageId"]}) sqs_batch_response["batchItemFailures"] = batch_item_failures return sqs_batch_response