diff --git a/allms/utils/io_utils.py b/allms/utils/io_utils.py index bc36edf..b0289ae 100644 --- a/allms/utils/io_utils.py +++ b/allms/utils/io_utils.py @@ -1,9 +1,9 @@ +import csv import logging from pathlib import Path -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union, OrderedDict import fsspec -import pandas as pd from allms.constants.input_data import IODataConstants from allms.domain.input_data import InputData @@ -11,25 +11,34 @@ logger = logging.getLogger(__name__) -def load_data( +def load_csv( path: str, limit: Optional[int] = None -) -> List[InputData]: +) -> List[OrderedDict[Any, Any]]: logger.info(f"Loading test data from {path}") - input_df = pd.read_csv(path) - input_df = input_df.head(limit) if limit else input_df - return load_input_data(input_df) + with open(path, mode='r') as csv_file: + csv_reader = csv.DictReader(csv_file) + data = list(csv_reader) + return data[:limit] if limit else data -def load_input_data(input_df: pd.DataFrame) -> List[InputData]: +def load_csv_to_input_data(path: str, limit: Optional[int] = None) -> List[InputData]: + csv_data = load_csv(path, limit=limit) return list( map( - lambda row: InputData(input_mappings=row[1].drop(IODataConstants.ID).to_dict(), id=str(row[1].id)), - input_df.iterrows() + lambda row: InputData(input_mappings=drop_dict_key(row, IODataConstants.ID), + id=str(row[IODataConstants.ID])), + csv_data ) ) +def drop_dict_key(dictionary: Dict[Any, Any], key: Any) -> Dict[Any, Any]: + dict_copy = dictionary.copy() + dict_copy.pop(key) + return dict_copy + + def load_credentials(path: Union[str, Path]) -> str: with fsspec.open(path, "r") as credentials_file: return credentials_file.readline() diff --git a/poetry.lock b/poetry.lock index c138ddc..bfb9a58 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.0 and should not be changed by hand. [[package]] name = "aiohttp" @@ -706,11 +706,11 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -1891,78 +1891,6 @@ files = [ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] -[[package]] -name = "pandas" -version = "2.2.0" -description = "Powerful data structures for data analysis, time series, and statistics" -optional = false -python-versions = ">=3.9" -files = [ - {file = "pandas-2.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8108ee1712bb4fa2c16981fba7e68b3f6ea330277f5ca34fa8d557e986a11670"}, - {file = "pandas-2.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:736da9ad4033aeab51d067fc3bd69a0ba36f5a60f66a527b3d72e2030e63280a"}, - {file = "pandas-2.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38e0b4fc3ddceb56ec8a287313bc22abe17ab0eb184069f08fc6a9352a769b18"}, - {file = "pandas-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20404d2adefe92aed3b38da41d0847a143a09be982a31b85bc7dd565bdba0f4e"}, - {file = "pandas-2.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7ea3ee3f125032bfcade3a4cf85131ed064b4f8dd23e5ce6fa16473e48ebcaf5"}, - {file = "pandas-2.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f9670b3ac00a387620489dfc1bca66db47a787f4e55911f1293063a78b108df1"}, - {file = "pandas-2.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:5a946f210383c7e6d16312d30b238fd508d80d927014f3b33fb5b15c2f895430"}, - {file = "pandas-2.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a1b438fa26b208005c997e78672f1aa8138f67002e833312e6230f3e57fa87d5"}, - {file = "pandas-2.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8ce2fbc8d9bf303ce54a476116165220a1fedf15985b09656b4b4275300e920b"}, - {file = "pandas-2.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2707514a7bec41a4ab81f2ccce8b382961a29fbe9492eab1305bb075b2b1ff4f"}, - {file = "pandas-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85793cbdc2d5bc32620dc8ffa715423f0c680dacacf55056ba13454a5be5de88"}, - {file = "pandas-2.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:cfd6c2491dc821b10c716ad6776e7ab311f7df5d16038d0b7458bc0b67dc10f3"}, - {file = "pandas-2.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a146b9dcacc3123aa2b399df1a284de5f46287a4ab4fbfc237eac98a92ebcb71"}, - {file = "pandas-2.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbc1b53c0e1fdf16388c33c3cca160f798d38aea2978004dd3f4d3dec56454c9"}, - {file = "pandas-2.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a41d06f308a024981dcaa6c41f2f2be46a6b186b902c94c2674e8cb5c42985bc"}, - {file = "pandas-2.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:159205c99d7a5ce89ecfc37cb08ed179de7783737cea403b295b5eda8e9c56d1"}, - {file = "pandas-2.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb1e1f3861ea9132b32f2133788f3b14911b68102d562715d71bd0013bc45440"}, - {file = "pandas-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:761cb99b42a69005dec2b08854fb1d4888fdf7b05db23a8c5a099e4b886a2106"}, - {file = "pandas-2.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a20628faaf444da122b2a64b1e5360cde100ee6283ae8effa0d8745153809a2e"}, - {file = "pandas-2.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f5be5d03ea2073627e7111f61b9f1f0d9625dc3c4d8dda72cc827b0c58a1d042"}, - {file = "pandas-2.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:a626795722d893ed6aacb64d2401d017ddc8a2341b49e0384ab9bf7112bdec30"}, - {file = "pandas-2.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9f66419d4a41132eb7e9a73dcec9486cf5019f52d90dd35547af11bc58f8637d"}, - {file = "pandas-2.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:57abcaeda83fb80d447f28ab0cc7b32b13978f6f733875ebd1ed14f8fbc0f4ab"}, - {file = "pandas-2.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e60f1f7dba3c2d5ca159e18c46a34e7ca7247a73b5dd1a22b6d59707ed6b899a"}, - {file = "pandas-2.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb61dc8567b798b969bcc1fc964788f5a68214d333cade8319c7ab33e2b5d88a"}, - {file = "pandas-2.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:52826b5f4ed658fa2b729264d63f6732b8b29949c7fd234510d57c61dbeadfcd"}, - {file = "pandas-2.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bde2bc699dbd80d7bc7f9cab1e23a95c4375de615860ca089f34e7c64f4a8de7"}, - {file = "pandas-2.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:3de918a754bbf2da2381e8a3dcc45eede8cd7775b047b923f9006d5f876802ae"}, - {file = "pandas-2.2.0.tar.gz", hash = "sha256:30b83f7c3eb217fb4d1b494a57a2fda5444f17834f5df2de6b2ffff68dc3c8e2"}, -] - -[package.dependencies] -numpy = [ - {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, -] -python-dateutil = ">=2.8.2" -pytz = ">=2020.1" -tzdata = ">=2022.7" - -[package.extras] -all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] -aws = ["s3fs (>=2022.11.0)"] -clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] -compression = ["zstandard (>=0.19.0)"] -computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] -consortium-standard = ["dataframe-api-compat (>=0.1.7)"] -excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] -feather = ["pyarrow (>=10.0.1)"] -fss = ["fsspec (>=2022.11.0)"] -gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] -hdf5 = ["tables (>=3.8.0)"] -html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] -mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] -output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] -parquet = ["pyarrow (>=10.0.1)"] -performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] -plot = ["matplotlib (>=3.6.3)"] -postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] -spss = ["pyreadstat (>=1.2.0)"] -sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] -test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] -xml = ["lxml (>=4.9.2)"] - [[package]] name = "pathspec" version = "0.12.1" @@ -2253,17 +2181,6 @@ files = [ [package.dependencies] six = ">=1.5" -[[package]] -name = "pytz" -version = "2023.4" -description = "World timezone definitions, modern and historical" -optional = false -python-versions = "*" -files = [ - {file = "pytz-2023.4-py2.py3-none-any.whl", hash = "sha256:f90ef520d95e7c46951105338d918664ebfd6f1d995bd7d153127ce90efafa6a"}, - {file = "pytz-2023.4.tar.gz", hash = "sha256:31d4583c4ed539cd037956140d695e42c033a19e984bfce9964a3f7d59bc2b40"}, -] - [[package]] name = "pywin32-ctypes" version = "0.2.2" @@ -2300,7 +2217,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2852,7 +2768,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} typing-extensions = ">=4.6.0" [package.extras] @@ -2896,40 +2812,47 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] [[package]] name = "tiktoken" -version = "0.4.0" +version = "0.6.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" optional = false python-versions = ">=3.8" files = [ - {file = "tiktoken-0.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:176cad7f053d2cc82ce7e2a7c883ccc6971840a4b5276740d0b732a2b2011f8a"}, - {file = "tiktoken-0.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:450d504892b3ac80207700266ee87c932df8efea54e05cefe8613edc963c1285"}, - {file = "tiktoken-0.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00d662de1e7986d129139faf15e6a6ee7665ee103440769b8dedf3e7ba6ac37f"}, - {file = "tiktoken-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5727d852ead18b7927b8adf558a6f913a15c7766725b23dbe21d22e243041b28"}, - {file = "tiktoken-0.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c06cd92b09eb0404cedce3702fa866bf0d00e399439dad3f10288ddc31045422"}, - {file = "tiktoken-0.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9ec161e40ed44e4210d3b31e2ff426b4a55e8254f1023e5d2595cb60044f8ea6"}, - {file = "tiktoken-0.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:1e8fa13cf9889d2c928b9e258e9dbbbf88ab02016e4236aae76e3b4f82dd8288"}, - {file = "tiktoken-0.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bb2341836b725c60d0ab3c84970b9b5f68d4b733a7bcb80fb25967e5addb9920"}, - {file = "tiktoken-0.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2ca30367ad750ee7d42fe80079d3092bd35bb266be7882b79c3bd159b39a17b0"}, - {file = "tiktoken-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3dc3df19ddec79435bb2a94ee46f4b9560d0299c23520803d851008445671197"}, - {file = "tiktoken-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d980fa066e962ef0f4dad0222e63a484c0c993c7a47c7dafda844ca5aded1f3"}, - {file = "tiktoken-0.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:329f548a821a2f339adc9fbcfd9fc12602e4b3f8598df5593cfc09839e9ae5e4"}, - {file = "tiktoken-0.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b1a038cee487931a5caaef0a2e8520e645508cde21717eacc9af3fbda097d8bb"}, - {file = "tiktoken-0.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:08efa59468dbe23ed038c28893e2a7158d8c211c3dd07f2bbc9a30e012512f1d"}, - {file = "tiktoken-0.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f3020350685e009053829c1168703c346fb32c70c57d828ca3742558e94827a9"}, - {file = "tiktoken-0.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba16698c42aad8190e746cd82f6a06769ac7edd415d62ba027ea1d99d958ed93"}, - {file = "tiktoken-0.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c15d9955cc18d0d7ffcc9c03dc51167aedae98542238b54a2e659bd25fe77ed"}, - {file = "tiktoken-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64e1091c7103100d5e2c6ea706f0ec9cd6dc313e6fe7775ef777f40d8c20811e"}, - {file = "tiktoken-0.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e87751b54eb7bca580126353a9cf17a8a8eaadd44edaac0e01123e1513a33281"}, - {file = "tiktoken-0.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e063b988b8ba8b66d6cc2026d937557437e79258095f52eaecfafb18a0a10c03"}, - {file = "tiktoken-0.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:9c6dd439e878172dc163fced3bc7b19b9ab549c271b257599f55afc3a6a5edef"}, - {file = "tiktoken-0.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8d1d97f83697ff44466c6bef5d35b6bcdb51e0125829a9c0ed1e6e39fb9a08fb"}, - {file = "tiktoken-0.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b6bce7c68aa765f666474c7c11a7aebda3816b58ecafb209afa59c799b0dd2d"}, - {file = "tiktoken-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a73286c35899ca51d8d764bc0b4d60838627ce193acb60cc88aea60bddec4fd"}, - {file = "tiktoken-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0394967d2236a60fd0aacef26646b53636423cc9c70c32f7c5124ebe86f3093"}, - {file = "tiktoken-0.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:dae2af6f03ecba5f679449fa66ed96585b2fa6accb7fd57d9649e9e398a94f44"}, - {file = "tiktoken-0.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:55e251b1da3c293432179cf7c452cfa35562da286786be5a8b1ee3405c2b0dd2"}, - {file = "tiktoken-0.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:c835d0ee1f84a5aa04921717754eadbc0f0a56cf613f78dfc1cf9ad35f6c3fea"}, - {file = "tiktoken-0.4.0.tar.gz", hash = "sha256:59b20a819969735b48161ced9b92f05dc4519c17be4015cfb73b65270a243620"}, + {file = "tiktoken-0.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:277de84ccd8fa12730a6b4067456e5cf72fef6300bea61d506c09e45658d41ac"}, + {file = "tiktoken-0.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9c44433f658064463650d61387623735641dcc4b6c999ca30bc0f8ba3fccaf5c"}, + {file = "tiktoken-0.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afb9a2a866ae6eef1995ab656744287a5ac95acc7e0491c33fad54d053288ad3"}, + {file = "tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c62c05b3109fefca26fedb2820452a050074ad8e5ad9803f4652977778177d9f"}, + {file = "tiktoken-0.6.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0ef917fad0bccda07bfbad835525bbed5f3ab97a8a3e66526e48cdc3e7beacf7"}, + {file = "tiktoken-0.6.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e095131ab6092d0769a2fda85aa260c7c383072daec599ba9d8b149d2a3f4d8b"}, + {file = "tiktoken-0.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:05b344c61779f815038292a19a0c6eb7098b63c8f865ff205abb9ea1b656030e"}, + {file = "tiktoken-0.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cefb9870fb55dca9e450e54dbf61f904aab9180ff6fe568b61f4db9564e78871"}, + {file = "tiktoken-0.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:702950d33d8cabc039845674107d2e6dcabbbb0990ef350f640661368df481bb"}, + {file = "tiktoken-0.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8d49d076058f23254f2aff9af603863c5c5f9ab095bc896bceed04f8f0b013a"}, + {file = "tiktoken-0.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:430bc4e650a2d23a789dc2cdca3b9e5e7eb3cd3935168d97d43518cbb1f9a911"}, + {file = "tiktoken-0.6.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:293cb8669757301a3019a12d6770bd55bec38a4d3ee9978ddbe599d68976aca7"}, + {file = "tiktoken-0.6.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7bd1a288b7903aadc054b0e16ea78e3171f70b670e7372432298c686ebf9dd47"}, + {file = "tiktoken-0.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:ac76e000183e3b749634968a45c7169b351e99936ef46f0d2353cd0d46c3118d"}, + {file = "tiktoken-0.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:17cc8a4a3245ab7d935c83a2db6bb71619099d7284b884f4b2aea4c74f2f83e3"}, + {file = "tiktoken-0.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:284aebcccffe1bba0d6571651317df6a5b376ff6cfed5aeb800c55df44c78177"}, + {file = "tiktoken-0.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c1a3a5d33846f8cd9dd3b7897c1d45722f48625a587f8e6f3d3e85080559be8"}, + {file = "tiktoken-0.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6318b2bb2337f38ee954fd5efa82632c6e5ced1d52a671370fa4b2eff1355e91"}, + {file = "tiktoken-0.6.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f5f0f2ed67ba16373f9a6013b68da298096b27cd4e1cf276d2d3868b5c7efd1"}, + {file = "tiktoken-0.6.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:75af4c0b16609c2ad02581f3cdcd1fb698c7565091370bf6c0cf8624ffaba6dc"}, + {file = "tiktoken-0.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:45577faf9a9d383b8fd683e313cf6df88b6076c034f0a16da243bb1c139340c3"}, + {file = "tiktoken-0.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7c1492ab90c21ca4d11cef3a236ee31a3e279bb21b3fc5b0e2210588c4209e68"}, + {file = "tiktoken-0.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e2b380c5b7751272015400b26144a2bab4066ebb8daae9c3cd2a92c3b508fe5a"}, + {file = "tiktoken-0.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9f497598b9f58c99cbc0eb764b4a92272c14d5203fc713dd650b896a03a50ad"}, + {file = "tiktoken-0.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e65e8bd6f3f279d80f1e1fbd5f588f036b9a5fa27690b7f0cc07021f1dfa0839"}, + {file = "tiktoken-0.6.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5f1495450a54e564d236769d25bfefbf77727e232d7a8a378f97acddee08c1ae"}, + {file = "tiktoken-0.6.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6c4e4857d99f6fb4670e928250835b21b68c59250520a1941618b5b4194e20c3"}, + {file = "tiktoken-0.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:168d718f07a39b013032741867e789971346df8e89983fe3c0ef3fbd5a0b1cb9"}, + {file = "tiktoken-0.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:47fdcfe11bd55376785a6aea8ad1db967db7f66ea81aed5c43fad497521819a4"}, + {file = "tiktoken-0.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fb7d2ccbf1a7784810aff6b80b4012fb42c6fc37eaa68cb3b553801a5cc2d1fc"}, + {file = "tiktoken-0.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ccb7a111ee76af5d876a729a347f8747d5ad548e1487eeea90eaf58894b3138"}, + {file = "tiktoken-0.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2048e1086b48e3c8c6e2ceeac866561374cd57a84622fa49a6b245ffecb7744"}, + {file = "tiktoken-0.6.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:07f229a5eb250b6403a61200199cecf0aac4aa23c3ecc1c11c1ca002cbb8f159"}, + {file = "tiktoken-0.6.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:432aa3be8436177b0db5a2b3e7cc28fd6c693f783b2f8722539ba16a867d0c6a"}, + {file = "tiktoken-0.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:8bfe8a19c8b5c40d121ee7938cd9c6a278e5b97dc035fd61714b4f0399d2f7a1"}, + {file = "tiktoken-0.6.0.tar.gz", hash = "sha256:ace62a4ede83c75b0374a2ddfa4b76903cf483e9cb06247f566be3bf14e6beed"}, ] [package.dependencies] @@ -3224,17 +3147,6 @@ files = [ mypy-extensions = ">=0.3.0" typing-extensions = ">=3.7.4" -[[package]] -name = "tzdata" -version = "2023.4" -description = "Provider of IANA time zone data" -optional = false -python-versions = ">=2" -files = [ - {file = "tzdata-2023.4-py2.py3-none-any.whl", hash = "sha256:aa3ace4329eeacda5b7beb7ea08ece826c28d761cda36e747cfbf97996d39bf3"}, - {file = "tzdata-2023.4.tar.gz", hash = "sha256:dd54c94f294765522c77399649b4fefd95522479a664a0cec87f41bebc6148c9"}, -] - [[package]] name = "urllib3" version = "2.2.0" @@ -3493,4 +3405,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "f8cce433ad6004c08f2e70f70c534acf05693f6d374986b48be129058833b48e" +content-hash = "2a63aeb94ee8c2072bdc003d9acc29b224e28d3589908645a0d8e88ca703b8ba" diff --git a/pyproject.toml b/pyproject.toml index 01c9565..7cef706 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,15 +9,14 @@ packages = [{include = "allms"}] [tool.poetry.dependencies] python = "^3.10" fsspec = "^2023.6.0" -pandas = "^2.0.3" -openai = "^0.27.8" google-cloud-aiplatform = "1.38.0" -tiktoken = "^0.4.0" pydash = "^7.0.6" transformers = "^4.34.1" pydantic = "1.10.13" langchain = "^0.0.351" aioresponses = "^0.7.6" +tiktoken = "^0.6.0" +openai = "^0.27.8" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 73294ca..1950282 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -1,7 +1,6 @@ import re -import pandas as pd - +from allms.constants.input_data import IODataConstants from allms.domain.prompt_dto import KeywordsOutputClass from allms.utils import io_utils from tests.conftest import AzureOpenAIEnv @@ -29,7 +28,7 @@ def test_model_is_queried_successfully( repeat=True ) - input_data = io_utils.load_data( + input_data = io_utils.load_csv_to_input_data( limit=5, path="./tests/resources/test_input_data.csv" ) @@ -48,24 +47,24 @@ def test_model_is_queried_successfully( parsed_responses = sorted(parsed_responses, key=lambda key: key.input_data.id) # THEN - expected_output = pd.read_csv("./tests/resources/test_end_to_end_expected_output.csv") - expected_output = expected_output.astype({"id": "str", "text": "str"}) - expected_output = expected_output.sort_values(by="id").reset_index(drop=True) - expected_output["response"] = expected_output["response"].apply(lambda x: eval(x)) + expected_output = io_utils.load_csv("./tests/resources/test_end_to_end_expected_output.csv") + expected_output = sorted(expected_output, key=lambda example: example[IODataConstants.ID]) + for idx in range(len(expected_output)): + expected_output[idx]["response"] = eval(expected_output[idx]["response"]) - assert expected_output["id"].values.tolist() == list( + assert list(map(lambda output: output[IODataConstants.ID], expected_output)) == list( map(lambda example: example.input_data.id, parsed_responses)) - assert expected_output["text"].values.tolist() == list( + assert list(map(lambda output: output[IODataConstants.TEXT], expected_output)) == list( map(lambda example: example.input_data.input_mappings["text"], parsed_responses)) - assert expected_output["response"].values.tolist() == list( + assert list(map(lambda output: output[IODataConstants.RESPONSE_STR_NAME], expected_output)) == list( map(lambda example: example.response.keywords, parsed_responses)) - assert expected_output["number_of_prompt_tokens"].values.tolist() == list( + assert list(map(lambda output: int(output[IODataConstants.PROMPT_TOKENS_NUMBER]), expected_output)) == list( map(lambda example: example.number_of_prompt_tokens, parsed_responses)) - assert expected_output["number_of_generated_tokens"].values.tolist() == list( + assert list(map(lambda output: int(output[IODataConstants.GENERATED_TOKENS_NUMBER]), expected_output)) == list( map(lambda example: example.number_of_generated_tokens, parsed_responses)) def test_model_times_out( diff --git a/tests/test_utf_characters_data.py b/tests/test_utf_characters_data.py index f69c7cf..f7606f7 100644 --- a/tests/test_utf_characters_data.py +++ b/tests/test_utf_characters_data.py @@ -10,7 +10,6 @@ class TestModelBehaviorForSpecialCharacters: @pytest.mark.parametrize("input_character", list(html.entities.entitydefs.values())) def test_model_is_not_broken_by_special_characters(self, tokens_mock, arun_mock, input_character, models): # GIVEN - print(tokens_mock) arun_mock.return_value = f"{input_character}" tokens_mock.return_value = 1