diff --git a/daras_ai/mdit_wa_plugin.py b/daras_ai/mdit_wa_plugin.py new file mode 100644 index 000000000..a815d0df0 --- /dev/null +++ b/daras_ai/mdit_wa_plugin.py @@ -0,0 +1,86 @@ +from mdformat.renderer._context import ( + RenderContext, + make_render_children, + _render_inline_as_text, + WRAP_POINT, +) +from mdformat.renderer._tree import RenderTreeNode +from mdformat.plugins import ParserExtensionInterface +from mdformat.renderer._util import maybe_add_link_brackets + + +def wa_heading_renderer(node: RenderTreeNode, context: RenderContext) -> str: + text = make_render_children(separator="")(node, context) + text = text.lstrip("*") + text = text.rstrip("*") + + return "*" + text + "*" + + +def wa_em_renderer(node: RenderTreeNode, context: RenderContext) -> str: + text = make_render_children(separator="")(node, context) + return "_" + text + "_" + + +def wa_strong_renderer(node: RenderTreeNode, context: RenderContext) -> str: + text = make_render_children(separator="")(node, context) + return "*" + text + "*" + + +def wa_link_renderer(node: RenderTreeNode, context: RenderContext) -> str: + if node.info == "auto": + autolink_url = node.attrs["href"] + # Remove 'mailto:' if the URL is a mailto link and the content doesn't start with 'mailto:' + if autolink_url.startswith("mailto:") and not node.children[ + 0 + ].content.startswith("mailto:"): + autolink_url = autolink_url[7:] + return f"{autolink_url}" + + # Get the display text for the link + text = "".join(child.render(context) for child in node.children) + + uri = node.attrs["href"] + return f"{text} ({uri})" + + +def wa_image_renderer(node: RenderTreeNode, context: RenderContext) -> str: + description = _render_inline_as_text(node, context) + ref_label = node.meta.get("label") + if ref_label: + context.env["used_refs"].add(ref_label) + ref_label_repr = ref_label.lower() + if description.lower() == ref_label_repr: + return f"[{description}]" + return f" {description} [{ref_label_repr}]" + + uri = node.attrs["src"] + assert isinstance(uri, str) + uri = maybe_add_link_brackets(uri) + title = node.attrs.get("title") + if title is not None: + return f'{description} ({uri} "{title}")' + return f"{description} ({uri})" + + +def wa_hr_renderer(node: RenderTreeNode, context: RenderContext) -> str: + return "" + + +def wa_strikethrough_renderer(node: RenderTreeNode, context: RenderContext) -> str: + # Render the content inside the strikethrough element + text = make_render_children(separator="")(node, context) + return f"~{text}~" + + +class WhatsappParser(ParserExtensionInterface): + + RENDERERS = { + "heading": wa_heading_renderer, + "em": wa_em_renderer, + "strong": wa_strong_renderer, + "link": wa_link_renderer, + "hr": wa_hr_renderer, + "image": wa_image_renderer, + "s": wa_strikethrough_renderer, + } diff --git a/daras_ai/text_format.py b/daras_ai/text_format.py index a7f4eac23..7b288e725 100644 --- a/daras_ai/text_format.py +++ b/daras_ai/text_format.py @@ -1,13 +1,37 @@ import ast - +import re import parse +import requests +from typing import Mapping, Any + +from furl import furl from markdown_it import MarkdownIt +from mdformat.renderer import MDRenderer +from daras_ai.image_input import upload_file_from_bytes +from daras_ai.mdit_wa_plugin import WhatsappParser +from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.tts_markdown_renderer import RendererPlain +from daras_ai_v2.text_splitter import new_para +from loguru import logger input_spec_parse_pattern = "{" * 5 + "}" * 5 +WA_FORMATTING_OPTIONS: Mapping[str, Any] = { + "mdformat": {"number": True}, + "parser_extension": [WhatsappParser], +} + +WHATSAPP_VALID_IMAGE_FORMATS = [ + "image/jpeg", + "image/png", + "image/gif", + "image/tiff", + "image/webp", + "image/bmp", +] + def daras_ai_format_str(format_str, variables): from glom import glom @@ -48,3 +72,94 @@ def format_number_with_suffix(num: int) -> str: def unmarkdown(text: str) -> str: """markdown to plaintext""" return MarkdownIt(renderer_cls=RendererPlain).render(text) + + +def extract_image_urls(tokens) -> list[str]: + image_urls = [] + + for token in tokens: + if token.type == "inline" and token.children: + for child in token.children: + if child.type == "image" and "src" in child.attrs: + image_urls.append(child.attrs["src"]) + + return image_urls + + +def get_mimetype_from_url(url: str) -> str: + try: + r = requests.head(url) + raise_for_status(r) + return r.headers.get("content-type", "application/octet-stream") + except requests.RequestException as e: + logger.warning(f"Error fetching mimetype for {url}: {e}") + return "application/octet-stream" + + +def process_wa_image_urls(image_urls: list[str]) -> list[str]: + from wand.image import Image + + processed_images = [] + for image_url in image_urls: + + parsed_url = furl(image_url) + if parsed_url.scheme not in ["http", "https"]: + continue + + mime_type = get_mimetype_from_url(image_url) + + if mime_type in WHATSAPP_VALID_IMAGE_FORMATS: + r = requests.get(image_url) + raise_for_status(r) + filename = ( + r.headers.get("content-disposition", "") + .split("filename=")[-1] + .strip('"') + ) + image_data = r.content + + with Image(blob=image_data) as img: + if img.format.lower() not in ["png", "jpeg"]: + png_blob = img.make_blob(format="png") + processed_images.append( + upload_file_from_bytes(filename, png_blob, "image/png") + ) + else: + processed_images.append(image_url) + + return processed_images + + +def wa_markdown(text: str) -> str | tuple[list[str | Any], str]: + """commonmark to WA compatible Markdown""" + + if text is None: + return "" + + md = MarkdownIt("commonmark").enable("strikethrough") + tokens = md.parse(text) + image_urls = extract_image_urls(tokens) + processed_images = process_wa_image_urls(image_urls) + whatsapp_msg_text = MDRenderer().render( + tokens, options=WA_FORMATTING_OPTIONS, env={} + ) + return processed_images, whatsapp_msg_text + + +def is_list_item_complete(text: str) -> bool: + """Returns True if the last block is a list item, False otherwise.""" + + if text is None: + return False + blocks = re.split(new_para, text.strip()) + + if not blocks: + return False + + last_block = blocks[-1].strip() + lines = [ln for ln in last_block.split("\n") if ln.strip()] + list_item_pattern = re.compile(r"^\s*(?:[*+\-]|\d+\.)\s+") + + is_list_block = any(list_item_pattern.match(ln) for ln in lines) + + return is_list_block diff --git a/daras_ai_v2/facebook_bots.py b/daras_ai_v2/facebook_bots.py index cbda7b982..799970c0c 100644 --- a/daras_ai_v2/facebook_bots.py +++ b/daras_ai_v2/facebook_bots.py @@ -10,6 +10,7 @@ from daras_ai_v2.bots import BotInterface, ReplyButton, ButtonPressed from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.text_splitter import text_splitter +from daras_ai.text_format import wa_markdown WA_MSG_MAX_SIZE = 1024 @@ -138,6 +139,8 @@ def send_msg_to( ) -> str | None: # see https://developers.facebook.com/docs/whatsapp/api/messages/media/ + images, text = wa_markdown(text) + # split text into chunks if too long if text and len(text) > WA_MSG_MAX_SIZE: splits = text_splitter( @@ -190,6 +193,32 @@ def send_msg_to( }, }, ] + + elif images: + if buttons: + messages = _build_msg_buttons( + buttons, + { + "body": { + "text": text or "\u200b", + }, + "header": { + "type": "image", + "image": {"link": images[0]}, + }, + }, + ) + else: + messages = [ + { + "type": "image", + "image": { + "link": images[0], + "caption": text, + }, + }, + ] + elif buttons: # interactive text msg messages = _build_msg_buttons( diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 25ad7a144..827c13b16 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -25,6 +25,7 @@ ) from daras_ai.image_input import gs_url_to_uri, bytes_to_cv2_img, cv2_img_to_bytes +from daras_ai.text_format import is_list_item_complete from daras_ai_v2.asr import get_google_auth_session from daras_ai_v2.exceptions import raise_for_status, UserError from daras_ai_v2.gpu_server import call_celery_task @@ -1242,6 +1243,10 @@ def _stream_openai_chunked( if not (isinstance(last_part, str) and last_part.strip()): continue + # add regex to handle not breaking in list items + if is_list_item_complete(chunk): + continue + # iterate through the separators and find the best one that matches for sep in default_separators[:-1]: # find the last occurrence of the separator diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 04d0aaccc..9d9ed58ee 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -502,7 +502,7 @@ def do_check_document_updates( metadatas = yield from apply_parallel( doc_or_yt_url_to_file_metas, lookups.keys(), - message="Fetching latest knowlege docs...", + message="Fetching latest knowledge docs...", max_workers=100, ) diff --git a/poetry.lock b/poetry.lock index 15656eb9f..f3d6dae42 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3166,6 +3166,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3260,6 +3270,21 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "mdformat" +version = "0.7.22" +description = "CommonMark compliant Markdown formatter" +optional = false +python-versions = ">=3.9" +files = [ + {file = "mdformat-0.7.22-py3-none-any.whl", hash = "sha256:61122637c9e1d9be1329054f3fa216559f0d1f722b7919b060a8c2a4ae1850e5"}, + {file = "mdformat-0.7.22.tar.gz", hash = "sha256:eef84fa8f233d3162734683c2a8a6222227a229b9206872e6139658d99acb1ea"}, +] + +[package.dependencies] +markdown-it-py = ">=1.0.0,<4.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} + [[package]] name = "mdurl" version = "0.1.2" @@ -4799,6 +4824,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4806,8 +4832,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4824,6 +4858,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4831,6 +4866,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -6868,4 +6904,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "6e78e465edd05b409691ddb0ac8ab1deaea6d0b49573b59d24ba98eeb6faf757" +content-hash = "7d0b0b5757643ebec9c1a6488e752bc0b06e77709c58e6d492d60a7aa16d4ebf" diff --git a/pyproject.toml b/pyproject.toml index a1e16d778..ddfd20623 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ python-pptx = "^1.0.2" azure-identity = "^1.19.0" azure-keyvault-secrets = "^4.9.0" xlrd = "^2.0.1" +mdformat = "^0.7.21" [tool.poetry.group.dev.dependencies] watchdog = "^2.1.9" diff --git a/tests/test_wa_markdown.py b/tests/test_wa_markdown.py new file mode 100644 index 000000000..20975b8c9 --- /dev/null +++ b/tests/test_wa_markdown.py @@ -0,0 +1,66 @@ +import pytest +from daras_ai.text_format import wa_markdown + +WHATSAPP_MARKDOWN_TEST = [ + ( + """# This is an