From a2f0d5212482bf918d4fcc27f6e9aae0c5d85a1c Mon Sep 17 00:00:00 2001 From: Victor Torre Date: Mon, 25 Mar 2024 12:57:29 +0100 Subject: [PATCH] Fix pypi & black * Fix pypi & black Co-authored-by: vatorre --- promptmeteo/__init__.py | 2 +- promptmeteo/api_formatter.py | 3 +- promptmeteo/base.py | 2 -- promptmeteo/document_classifier.py | 1 - promptmeteo/document_qa.py | 1 - promptmeteo/models/__init__.py | 6 ++-- promptmeteo/models/azure_openai.py | 5 --- promptmeteo/models/base.py | 13 ++++---- promptmeteo/models/bedrock.py | 20 +++++------- promptmeteo/models/fake_llm.py | 4 --- promptmeteo/models/google_vertexai.py | 8 +---- promptmeteo/models/hf_hub_api.py | 4 --- promptmeteo/models/hf_pipeline.py | 4 --- promptmeteo/models/openai.py | 5 --- promptmeteo/parsers/__init__.py | 6 ++-- promptmeteo/parsers/api_parser.py | 1 - promptmeteo/parsers/classification_parser.py | 1 - promptmeteo/parsers/dummy_parser.py | 1 - promptmeteo/parsers/json_parser.py | 9 ++---- promptmeteo/prompts/__init__.py | 1 - promptmeteo/prompts/base.py | 34 +++++++++++--------- promptmeteo/selector/__init__.py | 15 +++++---- promptmeteo/selector/base.py | 2 -- promptmeteo/selector/custom_selectors.py | 17 +++++----- promptmeteo/summarizer.py | 15 +++------ promptmeteo/tasks/task.py | 1 - promptmeteo/tasks/task_builder.py | 2 -- pyproject.toml | 2 +- tests/test_base_supervised.py | 6 ++-- tests/test_base_unsupervised.py | 7 ++-- 30 files changed, 71 insertions(+), 127 deletions(-) diff --git a/promptmeteo/__init__.py b/promptmeteo/__init__.py index 8b9c6a7..a6fa8ef 100644 --- a/promptmeteo/__init__.py +++ b/promptmeteo/__init__.py @@ -4,4 +4,4 @@ from .document_classifier import DocumentClassifier from .api_generator import APIGenerator from .api_formatter import APIFormatter -from .summarizer import Summarizer \ No newline at end of file +from .summarizer import Summarizer diff --git a/promptmeteo/api_formatter.py b/promptmeteo/api_formatter.py index 4935402..0125916 100644 --- a/promptmeteo/api_formatter.py +++ b/promptmeteo/api_formatter.py @@ -43,7 +43,7 @@ class APIFormatter(BaseUnsupervised): - """ API Formatter Task. + """API Formatter Task. This class initializes the API Formatter Task to correct and format APIs. @@ -397,6 +397,7 @@ def _add_external_information(api: str, replacements: dict) -> str: str Updated API YAML string. """ + def replace_values(orig_dict, replace_dict): for k, v in replace_dict.items(): if k in orig_dict: diff --git a/promptmeteo/base.py b/promptmeteo/base.py index 74acb11..352c1b5 100644 --- a/promptmeteo/base.py +++ b/promptmeteo/base.py @@ -44,7 +44,6 @@ class Base(ABC): - """ Promptmeteo is a tool powered by LLMs, capable of solving NLP tasks such as text classification and Named Entity Recognition. Its interface resembles @@ -347,7 +346,6 @@ def _load_builder(self, **kwargs) -> TaskBuilder: class BaseSupervised(Base): - """ Base class for supervised training tasks. """ diff --git a/promptmeteo/document_classifier.py b/promptmeteo/document_classifier.py index a7fe1ed..d283cc1 100644 --- a/promptmeteo/document_classifier.py +++ b/promptmeteo/document_classifier.py @@ -33,7 +33,6 @@ class DocumentClassifier(BaseSupervised): - """ DocumentClassifier Task diff --git a/promptmeteo/document_qa.py b/promptmeteo/document_qa.py index 5acd159..9a2745d 100644 --- a/promptmeteo/document_qa.py +++ b/promptmeteo/document_qa.py @@ -26,7 +26,6 @@ class DocumentQA(BaseUnsupervised): - """ Question Answering over Documents Task diff --git a/promptmeteo/models/__init__.py b/promptmeteo/models/__init__.py index ec3fd32..59b7bac 100644 --- a/promptmeteo/models/__init__.py +++ b/promptmeteo/models/__init__.py @@ -33,7 +33,6 @@ class ModelProvider(str, Enum): - """ LLM providers currently supported by Promptmeteo """ @@ -47,7 +46,6 @@ class ModelProvider(str, Enum): class ModelFactory: - """ The ModelFactory class is used to create a BaseModel object from the given configuration. @@ -59,7 +57,7 @@ class ModelFactory: ModelProvider.PROVIDER_2: HFHubApiLLM, ModelProvider.PROVIDER_3: HFPipelineLLM, ModelProvider.PROVIDER_3: GoogleVertexAILLM, - ModelProvider.PROVIDER_5: BedrockLLM + ModelProvider.PROVIDER_5: BedrockLLM, } @classmethod @@ -90,7 +88,7 @@ def factory_method( elif model_provider_name == ModelProvider.PROVIDER_4.value: model_cls = GoogleVertexAILLM - + elif model_provider_name == ModelProvider.PROVIDER_5.value: model_cls = BedrockLLM diff --git a/promptmeteo/models/azure_openai.py b/promptmeteo/models/azure_openai.py index 1492f9f..4637da7 100644 --- a/promptmeteo/models/azure_openai.py +++ b/promptmeteo/models/azure_openai.py @@ -33,7 +33,6 @@ class ModelTypes(str, Enum): - """ Enum of available model types. """ @@ -54,13 +53,11 @@ def has_value( class ModelEnum(Enum): - """ Model types with their parameters. """ class GPT35TurboInstruct: - """ Default parameters for TextDavinci003 model. """ @@ -76,7 +73,6 @@ class GPT35TurboInstruct: } class GPT35Turbo: - """ Default parameters for GPT35Turbo model. """ @@ -93,7 +89,6 @@ class GPT35Turbo: class AzureOpenAILLM(BaseModel): - """ OpenAI LLM model. """ diff --git a/promptmeteo/models/base.py b/promptmeteo/models/base.py index db1acfa..38413c9 100644 --- a/promptmeteo/models/base.py +++ b/promptmeteo/models/base.py @@ -21,20 +21,21 @@ # THE SOFTWARE. from abc import ABC +from typing import Optional + from langchain.llms.base import BaseLLM from langchain.schema import HumanMessage from langchain.embeddings.base import Embeddings class BaseModel(ABC): - """ Model Interface. """ - def __init__(self): - self._llm: BaseLLM = None - self._embeddings: Embeddings = None + def __init__(self, **kwargs): + self._llm: Optional[BaseLLM] = kwargs.get("llm", None) + self._embeddings: Optional[Embeddings] = kwargs.get("embeddings", None) @property def llm( @@ -59,10 +60,10 @@ def run( """ try: - return self._llm(prompt=sample) + return self.llm(prompt=sample) except TypeError: - return self._llm([HumanMessage(content=sample)]).content + return self.llm([HumanMessage(content=sample)]).content except Exception as error: raise RuntimeError( diff --git a/promptmeteo/models/bedrock.py b/promptmeteo/models/bedrock.py index 6536d13..796fb97 100644 --- a/promptmeteo/models/bedrock.py +++ b/promptmeteo/models/bedrock.py @@ -32,7 +32,6 @@ class ModelTypes(str, Enum): - """ Enum of available model types. """ @@ -52,13 +51,11 @@ def has_value( class ModelEnum(Enum): - """ Model Parameters. """ class AnthropicClaudeV2: - """ Default parameters for Anthropic Claude V2 """ @@ -67,16 +64,15 @@ class AnthropicClaudeV2: embedding = HuggingFaceEmbeddings model_task: str = "text2text-generation" params: dict = { - 'max_tokens_to_sample': 2048, - 'temperature': 0.3, - 'top_k': 250, - 'top_p': 0.999, - 'stop_sequences': ['Human:'] + "max_tokens_to_sample": 2048, + "temperature": 0.3, + "top_k": 250, + "top_p": 0.999, + "stop_sequences": ["Human:"], } class BedrockLLM(BaseModel): - """ Bedrock LLM model. """ @@ -86,7 +82,7 @@ def __init__( model_name: Optional[str] = "", model_params: Optional[Dict] = None, model_provider_token: Optional[str] = "", - **kwargs + **kwargs, ) -> None: """ Make predictions using a model from OpenAI. @@ -98,7 +94,7 @@ def __init__( f"`model_name`={model_name} not in supported model names: " f"{[i.name for i in ModelTypes]}" ) - self.boto3_bedrock = boto3.client('bedrock-runtime', **kwargs) + self.boto3_bedrock = boto3.client("bedrock-runtime", **kwargs) super(BedrockLLM, self).__init__() # Model name @@ -117,7 +113,7 @@ def __init__( self._llm = ModelEnum[model].value.client( model_id=model_name, model_kwargs=self.model_params, - client = self.boto3_bedrock + client=self.boto3_bedrock, ) embedding_name = "sentence-transformers/all-MiniLM-L6-v2" diff --git a/promptmeteo/models/fake_llm.py b/promptmeteo/models/fake_llm.py index 14fb58b..2a7f8c2 100644 --- a/promptmeteo/models/fake_llm.py +++ b/promptmeteo/models/fake_llm.py @@ -35,7 +35,6 @@ class FakeStaticLLM(LLM): - """ Fake Static LLM wrapper for testing purposes. """ @@ -82,7 +81,6 @@ async def _acall( class FakePromptCopyLLM(LLM): - """ Fake Prompt Copy LLM wrapper for testing purposes. """ @@ -127,7 +125,6 @@ async def _acall( class FakeListLLM(LLM): - """ Fake LLM wrapper for testing purposes. """ @@ -178,7 +175,6 @@ async def _acall( class ModelTypes(Enum): - """ FakeLLM Model Types. """ diff --git a/promptmeteo/models/google_vertexai.py b/promptmeteo/models/google_vertexai.py index e18a941..8a2274a 100644 --- a/promptmeteo/models/google_vertexai.py +++ b/promptmeteo/models/google_vertexai.py @@ -32,7 +32,6 @@ class ModelTypes(str, Enum): - """ Enum of available model types. """ @@ -55,13 +54,11 @@ def has_value( class ModelEnum(Enum): - """ Model types with their parameters. """ class TextBison001: - """ Default parameters for text-bison model. """ @@ -70,7 +67,6 @@ class TextBison001: model_kwargs = {"temperature": 0.4, "max_tokens": 256, "max_retries": 3} class TextBison: - """ Default parameters for text-bison model in their latest version """ @@ -79,7 +75,6 @@ class TextBison: model_kwargs = {"temperature": 0.4, "max_tokens": 256, "max_retries": 3} class TextBison32k: - """ Default parameters for text-bison-32 model in their latest version """ @@ -89,7 +84,6 @@ class TextBison32k: class GoogleVertexAILLM(BaseModel): - """ Google VertexAI LLM model. """ @@ -129,4 +123,4 @@ def __init__( # Model Parameters if not model_params: model_params = ModelEnum[model].value - self.model_params = model_params \ No newline at end of file + self.model_params = model_params diff --git a/promptmeteo/models/hf_hub_api.py b/promptmeteo/models/hf_hub_api.py index 4d642be..bfc77f4 100644 --- a/promptmeteo/models/hf_hub_api.py +++ b/promptmeteo/models/hf_hub_api.py @@ -30,7 +30,6 @@ class ModelTypes(str, Enum): - """ Enum of available model types. """ @@ -50,13 +49,11 @@ def has_value( class ModelEnum(Enum): - """ Model Parameters Enum """ class FlanT5Xxl: - """ Flan-t5-xxl default params """ @@ -66,7 +63,6 @@ class FlanT5Xxl: class HFHubApiLLM(BaseModel): - """ HuggingFace API call. """ diff --git a/promptmeteo/models/hf_pipeline.py b/promptmeteo/models/hf_pipeline.py index 661a728..9da8522 100644 --- a/promptmeteo/models/hf_pipeline.py +++ b/promptmeteo/models/hf_pipeline.py @@ -31,7 +31,6 @@ class ModelTypes(str, Enum): - """ Enum of available model types. """ @@ -48,13 +47,11 @@ def has_value(cls, value): class ModelParams(Enum): - """ Model Parameters. """ class MODEL_1: - """ Parameters Model 1. """ @@ -65,7 +62,6 @@ class MODEL_1: class HFPipelineLLM(BaseModel): - """ HuggingFace Local Pipeline. """ diff --git a/promptmeteo/models/openai.py b/promptmeteo/models/openai.py index 1d61626..c87e53f 100644 --- a/promptmeteo/models/openai.py +++ b/promptmeteo/models/openai.py @@ -32,7 +32,6 @@ class ModelTypes(str, Enum): - """ Enum of available model types. """ @@ -53,13 +52,11 @@ def has_value( class ModelEnum(Enum): - """ Model Parameters. """ class GPT35TurboInstruct: - """ Default parameters for TextDavinci003 model. """ @@ -76,7 +73,6 @@ class GPT35TurboInstruct: } class GPT35Turbo: - """ Default parameters for GPT3.5Turbo model. """ @@ -92,7 +88,6 @@ class GPT35Turbo: class OpenAILLM(BaseModel): - """ OpenAI LLM model. """ diff --git a/promptmeteo/parsers/__init__.py b/promptmeteo/parsers/__init__.py index 000af85..8e6b781 100644 --- a/promptmeteo/parsers/__init__.py +++ b/promptmeteo/parsers/__init__.py @@ -31,7 +31,6 @@ class ParserTypes(str, Enum): - """ Enum of availables parsers. """ @@ -47,7 +46,6 @@ class ParserTypes(str, Enum): class ParserFactory: - """ Factory of Parsers. """ @@ -80,10 +78,10 @@ def factory_method( elif task_type == ParserTypes.PARSER_6.value: parser_cls = ApiParser - + elif task_type == ParserTypes.PARSER_7.value: parser_cls = JSONParser - + elif task_type == ParserTypes.PARSER_8.value: parser_cls = DummyParser diff --git a/promptmeteo/parsers/api_parser.py b/promptmeteo/parsers/api_parser.py index fce254d..227e181 100644 --- a/promptmeteo/parsers/api_parser.py +++ b/promptmeteo/parsers/api_parser.py @@ -26,7 +26,6 @@ class ApiParser(BaseParser): - """ Dummy parser, returns what it receives. """ diff --git a/promptmeteo/parsers/classification_parser.py b/promptmeteo/parsers/classification_parser.py index 5628bbd..fd5bb7f 100644 --- a/promptmeteo/parsers/classification_parser.py +++ b/promptmeteo/parsers/classification_parser.py @@ -26,7 +26,6 @@ class ClassificationParser(BaseParser): - """ Parser for the classification task. """ diff --git a/promptmeteo/parsers/dummy_parser.py b/promptmeteo/parsers/dummy_parser.py index ff0482d..e5a4566 100644 --- a/promptmeteo/parsers/dummy_parser.py +++ b/promptmeteo/parsers/dummy_parser.py @@ -26,7 +26,6 @@ class DummyParser(BaseParser): - """ Dummy parser, returns what it receives. """ diff --git a/promptmeteo/parsers/json_parser.py b/promptmeteo/parsers/json_parser.py index c7f3ca9..0c7159c 100644 --- a/promptmeteo/parsers/json_parser.py +++ b/promptmeteo/parsers/json_parser.py @@ -21,14 +21,12 @@ # THE SOFTWARE. from typing import List -import re from .base import BaseParser import regex import json class JSONParser(BaseParser): - """ Parser for potential JSON outputs """ @@ -48,7 +46,6 @@ def run( return json_output except: return "" - def _preprocess( self, @@ -59,9 +56,9 @@ def _preprocess( such as end-of-line presence and beginning and finishing with empty space. """ - pattern = regex.compile(r'\{(?:[^{}]|(?R))*\}') + pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}") str_json = pattern.findall(text)[0] - - str_json = str_json.replace("'",'"') + + str_json = str_json.replace("'", '"') return str_json diff --git a/promptmeteo/prompts/__init__.py b/promptmeteo/prompts/__init__.py index a5e3a75..cd120b7 100644 --- a/promptmeteo/prompts/__init__.py +++ b/promptmeteo/prompts/__init__.py @@ -92,7 +92,6 @@ def get_files_taxonomy(sep: str = "_"): class PromptFactory: - """ Factory of Prompts """ diff --git a/promptmeteo/prompts/base.py b/promptmeteo/prompts/base.py index 2f7f155..49e9e44 100644 --- a/promptmeteo/prompts/base.py +++ b/promptmeteo/prompts/base.py @@ -163,21 +163,21 @@ def run( """ prompt_variables = dict( - __PROMPT_SAMPLE__=self.PROMPT_SAMPLE - if hasattr(self, "PROMPT_SAMPLE") - else "", + __PROMPT_SAMPLE__=( + self.PROMPT_SAMPLE if hasattr(self, "PROMPT_SAMPLE") else "" + ), __PROMPT_LABELS__="", __PROMPT_DOMAIN__="", __PROMPT_DETAIL__="", - __SHOT_EXAMPLES__=self.SHOT_EXAMPLES - if hasattr(self, "SHOT_EXAMPLES") - else "", - __ANSWER_FORMAT__=self.ANSWER_FORMAT - if hasattr(self, "ANSWER_FORMAT") - else "", - __CHAIN_THOUGHT__=self.CHAIN_THOUGHT - if hasattr(self, "CHAIN_THOUGHT") - else "", + __SHOT_EXAMPLES__=( + self.SHOT_EXAMPLES if hasattr(self, "SHOT_EXAMPLES") else "" + ), + __ANSWER_FORMAT__=( + self.ANSWER_FORMAT if hasattr(self, "ANSWER_FORMAT") else "" + ), + __CHAIN_THOUGHT__=( + self.CHAIN_THOUGHT if hasattr(self, "CHAIN_THOUGHT") else "" + ), ) # Labels @@ -193,8 +193,9 @@ def run( ) # Domain - prompt_variables["__PROMPT_DOMAIN__"] = self.PROMPT_DOMAIN.format(__DOMAIN__=self._prompt_domain) - + prompt_variables["__PROMPT_DOMAIN__"] = self.PROMPT_DOMAIN.format( + __DOMAIN__=self._prompt_domain + ) # Detail prompt_detail = ( @@ -202,8 +203,9 @@ def run( if isinstance(self._prompt_detail, list) else self._prompt_detail ) - prompt_variables["__PROMPT_DETAIL__"] = self.PROMPT_DETAIL.format(__DETAIL__=prompt_detail) - + prompt_variables["__PROMPT_DETAIL__"] = self.PROMPT_DETAIL.format( + __DETAIL__=prompt_detail + ) return PromptTemplate.from_template( PromptTemplate.from_template(self.TEMPLATE).format( diff --git a/promptmeteo/selector/__init__.py b/promptmeteo/selector/__init__.py index 0f3b47c..dad91d6 100644 --- a/promptmeteo/selector/__init__.py +++ b/promptmeteo/selector/__init__.py @@ -34,7 +34,6 @@ class SelectorTypes(str, Enum): - """ Enum with the avaialable selector types. """ @@ -44,7 +43,6 @@ class SelectorTypes(str, Enum): class SelectorFactory: - """ Factory of Selectors """ @@ -67,12 +65,15 @@ def factory_method( selector_cls = BaseSelectorSupervised elif selector_type == SelectorTypes.UNSUPERVISED.value: - if selector_algorithm == SelectorAlgorithms.SIMILARITY_CLASS_BALANCED.value: + if ( + selector_algorithm + == SelectorAlgorithms.SIMILARITY_CLASS_BALANCED.value + ): raise ValueError( - f"{cls.__name__} error in class method `factory_method`. " - f"Selector algorithm {selector_algorithm} " - f"is only valid for DocumentClassifier models" - ) + f"{cls.__name__} error in class method `factory_method`. " + f"Selector algorithm {selector_algorithm} " + f"is only valid for DocumentClassifier models" + ) selector_cls = BaseSelectorUnsupervised else: diff --git a/promptmeteo/selector/base.py b/promptmeteo/selector/base.py index c0445c8..82f5c75 100644 --- a/promptmeteo/selector/base.py +++ b/promptmeteo/selector/base.py @@ -43,7 +43,6 @@ class SelectorAlgorithms(str, Enum): - """ Enum with the avaialable selector algorithms. """ @@ -54,7 +53,6 @@ class SelectorAlgorithms(str, Enum): class BaseSelector(ABC): - """ Base Selector Interface """ diff --git a/promptmeteo/selector/custom_selectors.py b/promptmeteo/selector/custom_selectors.py index f7f82b8..e80d29a 100644 --- a/promptmeteo/selector/custom_selectors.py +++ b/promptmeteo/selector/custom_selectors.py @@ -1,8 +1,8 @@ """Custom selectors""" + from typing import Any, Dict, List, Optional, Type import random from langchain_core.example_selectors.base import BaseExampleSelector -from langchain_core.pydantic_v1 import BaseModel, Extra from langchain_core.pydantic_v1 import BaseModel from langchain_core.vectorstores import VectorStore from langchain_core.embeddings import Embeddings @@ -83,8 +83,7 @@ def from_examples( f"be greater than number of classes ({len(class_list)} classes)" f"for balanced examples selection" ) - - + if input_keys: string_examples = [ " ".join(sorted_values({k: eg[k] for k in input_keys})) @@ -112,12 +111,14 @@ def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: final_examples = [] new_class_list = self.class_list.copy() random.shuffle(new_class_list) - new_class_list =(self.class_list*ceil(self.k/len(set(self.class_list))))[:self.k] - - dict_k_per_class = {i:new_class_list.count(i) for i in new_class_list} - + new_class_list = ( + self.class_list * ceil(self.k / len(set(self.class_list))) + )[: self.k] + + dict_k_per_class = {i: new_class_list.count(i) for i in new_class_list} + # Get the docs with the highest similarity. - for cl,k in dict_k_per_class.items(): + for cl, k in dict_k_per_class.items(): if self.input_keys: input_variables = { key: input_variables[key] for key in self.input_keys diff --git a/promptmeteo/summarizer.py b/promptmeteo/summarizer.py index d3768d6..5e71ba0 100644 --- a/promptmeteo/summarizer.py +++ b/promptmeteo/summarizer.py @@ -1,5 +1,5 @@ -#%% -#!/usr/bin/python3 +# %% +# !/usr/bin/python3 # Copyright (c) 2023 Paradigma Digital S.L. # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -19,26 +19,19 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -import re import tarfile import tempfile import json import os -import yaml -from copy import deepcopy -from typing import List - try: from typing import Self except ImportError: from typing_extensions import Self -from langchain.prompts import PromptTemplate from .base import BaseUnsupervised -from .tasks import TaskTypes, TaskBuilder +from .tasks import TaskTypes from .tools import add_docstring_from -from .validations import version_validation class Summarizer(BaseUnsupervised): @@ -61,7 +54,7 @@ class Summarizer(BaseUnsupervised): """ TASK_TYPE = TaskTypes.SUMMARIZATION.value - + @add_docstring_from(BaseUnsupervised.__init__) def __init__( self, diff --git a/promptmeteo/tasks/task.py b/promptmeteo/tasks/task.py index c1b6be4..59fbd65 100644 --- a/promptmeteo/tasks/task.py +++ b/promptmeteo/tasks/task.py @@ -31,7 +31,6 @@ class Task: - """ Base Task interface. """ diff --git a/promptmeteo/tasks/task_builder.py b/promptmeteo/tasks/task_builder.py index 36cec4f..0351386 100644 --- a/promptmeteo/tasks/task_builder.py +++ b/promptmeteo/tasks/task_builder.py @@ -40,7 +40,6 @@ class TaskTypes(str, Enum): - """ Enum with all the available task types """ @@ -54,7 +53,6 @@ class TaskTypes(str, Enum): class TaskBuilder: - """ Builder of Tasks. """ diff --git a/pyproject.toml b/pyproject.toml index 5b55693..079b0fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ [tool.setuptools_scm] [tool.setuptools.packages.find] -where = ["promptmeteo"] +where = ["."] [tool.black] line-length = 80 diff --git a/tests/test_base_supervised.py b/tests/test_base_supervised.py index 01fa15d..b81bbf2 100644 --- a/tests/test_base_supervised.py +++ b/tests/test_base_supervised.py @@ -71,9 +71,9 @@ def test_wrong_predict(self): ).predict([1, 2, 3]) assert error.value.args[0] == ( - f"BaseSupervised error in function `predict()`. " - f"Arguments `examples` are expected to be of type " - f"`List[str]`. Some values seem no to be of type `str`." + "BaseSupervised error in function `predict()`. " + "Arguments `examples` are expected to be of type " + "`List[str]`. Some values seem no to be of type `str`." ) def test_wrong_train(self): diff --git a/tests/test_base_unsupervised.py b/tests/test_base_unsupervised.py index 370cc1f..354e6e7 100644 --- a/tests/test_base_unsupervised.py +++ b/tests/test_base_unsupervised.py @@ -1,10 +1,7 @@ import os -import tempfile import pytest -from promptmeteo.tasks import Task -from promptmeteo.tasks import TaskBuilder from promptmeteo.base import BaseUnsupervised @@ -129,8 +126,8 @@ def test_wrong_train(self): ) assert error.value.args[0] == ( - f"TypeError: BaseUnsupervised.train() got an unexpected " - f"keyword argument 'annotations'" + "TypeError: BaseUnsupervised.train() got an unexpected " + "keyword argument 'annotations'" ) def test_load_model(self):