Skip to content

Commit

Permalink
IBM watsonx_llm fixes & refactor (#2464)
Browse files Browse the repository at this point in the history
* refactor code, fix config path bug

* update types to be from typing lib

* add pre-commit formatting

* specify version of ibm_watsonx_ai package

* adjust get_watsonx_credentials() function, add minor refactor to adress PR review comments

* change missing installation hint from ibm_watsonx_ai to lm_eval[ibm_watsonx_ai]
  • Loading branch information
Medokins authored Nov 15, 2024
1 parent 67db63a commit 4259a6d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 140 deletions.
243 changes: 104 additions & 139 deletions lm_eval/models/ibm_watsonx_ai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
import copy
import os
from configparser import ConfigParser
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast

from tqdm import tqdm
Expand All @@ -18,57 +16,46 @@ class LogLikelihoodResult(NamedTuple):
is_greedy: bool


@lru_cache(maxsize=None)
def get_watsonx_credentials(
env_name: str = "YP_QA",
config_path: str = "config.ini",
) -> Dict[str, str]:
def _verify_credentials(creds: Any) -> None:
"""
Retrieves Watsonx API credentials from environmental variables or from a configuration file.
Verifies that all required keys are present in the credentials dictionary.
Args:
env_name (str, optional): The name of the environment from which to retrieve credentials. Defaults to "YP_QA".
config_path (str, optional): The file path to the `config.ini` configuration file. Defaults to "config.ini".
creds (Any): A dictionary containing the credentials.
Raises:
ValueError: If any of the necessary credentials are missing, with guidance on which environment variables need to be set.
"""
required_keys = ["apikey", "url", "project_id"]
env_var_mapping = {
"apikey": "WATSONX_API_KEY",
"url": "WATSONX_URL",
"project_id": "WATSONX_PROJECT_ID",
}
missing_keys = [key for key in required_keys if key not in creds or not creds[key]]

if missing_keys:
missing_env_vars = [env_var_mapping[key] for key in missing_keys]
raise ValueError(
f"Missing required credentials: {', '.join(missing_keys)}. Please set the following environment variables: {', '.join(missing_env_vars)}"
)


@lru_cache(maxsize=None)
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
Returns:
dict[str, str]: A dictionary containing the credentials necessary for authentication, including
Dict[str, str]: A dictionary containing the credentials necessary for authentication, including
keys such as `apikey`, `url`, and `project_id`.
Raises:
FileNotFoundError: If the specified configuration file does not exist.
AssertionError: If the credentials format is invalid.
AssertionError: If the credentials format is invalid or any of the necessary credentials are missing.
"""

def _verify_credentials(creds: Any) -> None:
assert isinstance(creds, dict) and all(
key in creds.keys() for key in ["apikey", "url", "project_id"]
), "Wrong configuration for credentials."

credentials = {
"apikey": os.getenv("WATSONX_API_KEY", None),
"url": os.getenv("WATSONX_URL", None),
"project_id": os.getenv("WATSONX_PROJECT_ID", None),
}

if any(credentials.get(key) is None for key in ["apikey", "url", "project_id"]):
eval_logger.warning(
"One or more required environment variables are missing, trying to load config.ini file."
)

config_path = "config.ini" if not config_path else config_path

if not Path(config_path).is_absolute():
config_path = os.path.join(
Path(__file__).parent.parent.absolute(), config_path
)

if not os.path.exists(config_path):
raise FileNotFoundError(
f"Provided config file path {config_path} does not exist. "
"You need to specify credentials in config.ini file under specified location."
)

config = ConfigParser()
config.read(config_path)
credentials = json.loads(config.get(env_name))

_verify_credentials(credentials)
return credentials

Expand All @@ -84,7 +71,7 @@ class WatsonxLLM(LM):
def create_from_arg_string(
cls: Type["WatsonxLLM"],
arg_string: str,
config_path: Optional[str] = None,
additional_config: Optional[Dict] = None,
) -> "WatsonxLLM":
"""
Allow the user to specify model parameters (TextGenerationParameters) in CLI arguments.
Expand All @@ -97,6 +84,8 @@ def create_from_arg_string(
)

args = simple_parse_args_string(arg_string)
args.update(additional_config)

model_id = args.pop("model_id", None)
if model_id is None:
raise ValueError("'model_id' is required, please pass it in 'model_args'")
Expand All @@ -107,7 +96,7 @@ def create_from_arg_string(
args["top_k"] = None
args["seed"] = None

cls.generate_params = {
generate_params = {
GenParams.DECODING_METHOD: (
"greedy" if not args.get("do_sample", None) else "sample"
),
Expand All @@ -130,12 +119,10 @@ def create_from_arg_string(
},
}

generate_params = {
k: v for k, v in cls.generate_params.items() if v is not None
}
generate_params = {k: v for k, v in generate_params.items() if v is not None}

return cls(
watsonx_credentials=get_watsonx_credentials(config_path),
watsonx_credentials=get_watsonx_credentials(),
model_id=model_id,
generate_params=generate_params,
)
Expand All @@ -158,7 +145,7 @@ def __init__(
project_id = watsonx_credentials.get("project_id", None)
deployment_id = watsonx_credentials.get("deployment_id", None)
client.set.default_project(project_id)
self.generate_params = generate_params or {}
self.generate_params = generate_params
self.model = ModelInference(
model_id=model_id,
deployment_id=deployment_id,
Expand Down Expand Up @@ -220,9 +207,9 @@ def _get_log_likelihood(
"""
Calculates the log likelihood of the generated tokens compared to the context tokens.
Args:
input_tokens (List[dict[str, float]]): A List of token dictionaries, each containing
input_tokens (List[Dict[str, float]]): A List of token dictionaries, each containing
token information like `text` and `logprob`.
context_tokens (List[dict[str, float]]): A List of token dictionaries representing
context_tokens (List[Dict[str, float]]): A List of token dictionaries representing
the input context.
Returns:
LogLikelihoodResult: An object containing the calculated log likelihood and a boolean
Expand Down Expand Up @@ -252,27 +239,24 @@ def generate_until(self, requests: List[Instance]) -> List[str]:
Returns:
List[str]: A List of generated responses.
"""
requests = [request.args[0] for request in requests]
requests = [request.args for request in requests]
results = []
batch_size = 5

for i in tqdm(
range(0, len(requests), batch_size),
desc=f"Running generate_until function with batch size {batch_size}",
for request in tqdm(
requests,
desc="Running generate_until function ...",
):
batch = requests[i : i + batch_size]
context, continuation = request
try:
responses = self.model.generate_text(batch, self.generate_params)

response = self.model.generate_text(context, self.generate_params)
except Exception as exp:
eval_logger.error(f"Error while generating text {exp}")
continue

for response, context in zip(responses, batch):
results.append(response)
self.cache_hook.add_partial("generated_text", context, response)
eval_logger.error("Error while generating text.")
raise exp

eval_logger.info("Cached responses")
results.append(response)
self.cache_hook.add_partial(
"generate_until", (context, continuation), response
)

return results

Expand All @@ -284,7 +268,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
2. a target string on which the loglikelihood of the LM producing this target,
conditioned on the input, will be returned.
Returns:
tuple (loglikelihood, is_greedy) for each request according to the input order:
Tuple (loglikelihood, is_greedy) for each request according to the input order:
loglikelihood: probability of generating the target string conditioned on the input
is_greedy: True if and only if the target string would be generated by greedy sampling from the LM
"""
Expand All @@ -295,65 +279,59 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
self._check_model_logprobs_support()
self.generate_params[GenParams.MAX_NEW_TOKENS] = 1
generate_params = copy.copy(self.generate_params)
generate_params[GenParams.MAX_NEW_TOKENS] = 1

requests = [request.args for request in requests]
results: List[LogLikelihoodResult] = []
batch_size = 5

for i in tqdm(
range(0, len(requests), batch_size),
desc=f"Running loglikelihood function with batch size {batch_size}",
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
for request in tqdm(
requests,
desc="Running loglikelihood function ...",
):
batch = requests[i : i + batch_size]
context, continuation = request
try:
tokenized_contexts = [
self.model.tokenize(prompt=context, return_tokens=True)["result"][
"tokens"
]
for context, _ in batch
]
tokenized_context = self.model.tokenize(
prompt=context, return_tokens=True
)["result"]["tokens"]
except Exception as exp:
eval_logger.error(f"Error while model tokenize:\n {exp}")
continue
eval_logger.error("Error while model tokenize.")
raise exp

input_prompts = [context + continuation for context, continuation in batch]
input_prompt = context + continuation

try:
responses = self.model.generate_text(
prompt=input_prompts, params=self.generate_params, raw_response=True
response = self.model.generate_text(
prompt=input_prompt, params=generate_params, raw_response=True
)
except Exception as exp:
eval_logger.error(f"Error while model generate text:\n {exp}")
continue

for response, tokenized_context, (context, continuation) in zip(
responses, tokenized_contexts, batch
):
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], tokenized_context
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood",
(context, continuation),
(
log_likelihood_response.log_likelihood,
log_likelihood_response.is_greedy,
),
)
eval_logger.info("Cached batch")
eval_logger.error("Error while model generate text.")
raise exp

log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], tokenized_context
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood",
(context, continuation),
(
log_likelihood_response.log_likelihood,
log_likelihood_response.is_greedy,
),
)

return cast(List[Tuple[float, bool]], results)

def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
"""
Used to evaluate perplexity on a data distribution.
Args:
requests: Each request contains Instance.args : tuple[str] containing an input string to the model whose
requests: Each request contains Instance.args : Tuple[str] containing an input string to the model whose
entire loglikelihood, conditioned on purely the EOT token, will be calculated.
Returns:
tuple (loglikelihood,) for each request according to the input order:
Tuple (loglikelihood,) for each request according to the input order:
loglikelihood: solely the probability of producing each piece of text given no starting input.
"""
try:
Expand All @@ -363,47 +341,34 @@ def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
self._check_model_logprobs_support()
self.generate_params[GenParams.MAX_NEW_TOKENS] = 1
generate_params = copy.deepcopy(self.generate_params)
generate_params[GenParams.MAX_NEW_TOKENS] = 1

requests = [request.args[0] for request in requests]
requests = [request.args for request in requests]
results: List[LogLikelihoodResult] = []
batch_size = 5

for i in tqdm(
range(0, len(requests), batch_size),
desc=f"Running loglikelihood_rolling function with batch size {batch_size}",
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
for request in tqdm(
requests,
desc="Running loglikelihood_rolling function ...",
):
batch = requests[i : i + batch_size]

context, continuation = request
try:
responses = self.model.generate_text(
prompt=batch, params=self.generate_params, raw_response=True
response = self.model.generate_text(
prompt=context, params=generate_params, raw_response=True
)
except Exception as exp:
eval_logger.error(f"Error while model generate text:\n {exp}")
continue

for response, context in zip(responses, batch):
try:
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], []
)
results.append(log_likelihood_response)

self.cache_hook.add_partial(
"loglikelihood_rolling",
context,
(
log_likelihood_response.log_likelihood,
log_likelihood_response.is_greedy,
),
)
except Exception as exp:
eval_logger.error(
f"Error during log likelihood calculation:\n {exp}"
)
continue

eval_logger.info("Cached batch")
eval_logger.error("Error while model generate text.")
raise exp

log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], []
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood_rolling",
(context, continuation),
log_likelihood_response.log_likelihood,
)

return cast(List[Tuple[float, bool]], results)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
gptq = ["auto-gptq[triton]>=0.6.0"]
hf_transfer = ["hf_transfer"]
ibm_watsonx_ai = ["ibm_watsonx_ai"]
ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22"]
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
neuronx = ["optimum[neuronx]"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
Expand Down

0 comments on commit 4259a6d

Please sign in to comment.