Skip to content

Commit

Permalink
[Fix] utf-8 codec cant decode
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed May 26, 2024
1 parent 5eaa969 commit 83b47d9
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 8 deletions.
5 changes: 3 additions & 2 deletions utilization/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,15 +766,16 @@ def log_final_results(
return log_final_results(
raw_predictions=raw_predictions,
processed_predictions=processed_predictions,
evaluation_instances=self.evaluation_instances,
score_lists=score_lists,
multiple_source=(self.dataset_name == "winogrande"),
model_evaluation_method=self.model_evaluation_method,
use_normalization=self.use_normalization,
option_nums=self.option_nums,
len_evaluation_data=len(self.evaluation_data),
evaluation_instances=self.evaluation_instances,
sample_num=self.sample_num,
references=self.references,
local_model=self.model.is_local_model(),
)

def __repr__(self):
Expand Down Expand Up @@ -967,7 +968,7 @@ def step(
if batch_size > 0:
tqdm.set_description(self.display_names[self._cur_idx])
if batch_size > 0:
writer.log_batch_results(batch_raw_predictions, self._lines_iter)
writer.log_batch_results(batch_raw_predictions, self._datasets[0].model.is_local_model(), self._lines_iter)

def __repr__(self):
reprs = []
Expand Down
6 changes: 4 additions & 2 deletions utilization/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ def set_generation_args(self, **extra_model_args):
logger.warning(f"Unused generation arguments: {extra_model_args}")
return self.generation_kwargs

def generation(self, batched_inputs: Union[List[str],
List[Conversation]]) -> Union[List[str], List[Tuple[str, ...]]]:
def generation(
self,
batched_inputs: Union[List[str], List[Conversation]],
) -> Union[List[str], List[Tuple[str, ...]]]:
multi_turn_results = self.request(
prompt=batched_inputs,
multi_turn=self.multi_turn,
Expand Down
3 changes: 2 additions & 1 deletion utilization/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,14 +690,15 @@ def parse_argument(args: Optional[List[str]] = None,
epilog=EXAMPLE_STRING,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
model_args, dataset_args, evaluation_args = parser.parse_args_into_dataclasses(args)

try:
from dotenv import load_dotenv
load_dotenv()
except (ImportError, ModuleNotFoundError):
pass

model_args, dataset_args, evaluation_args = parser.parse_args_into_dataclasses(args)

if model_args.bnb_config:
bnb_config_dict = json.loads(model_args.bnb_config)
model_args.bnb_config = BitsAndBytesConfig(**bnb_config_dict)
Expand Down
3 changes: 3 additions & 0 deletions utilization/utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def to_model_prompt(
max_turns=max_turns,
)[0]

def apply_prompt_template(self):
return self.formatter.apply_prompt_template(self)

def add(
self,
other: Optional["Conversation"] = None,
Expand Down
29 changes: 26 additions & 3 deletions utilization/utils/log_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pandas as pd

from .conversation import Conversation

logger = getLogger(__name__)

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -35,6 +37,22 @@ def wrapper(df: pd.DataFrame):
return wrapper


def dump_conversations(convs: List[Any], local: bool):
if isinstance(convs, (str, Conversation)):
convs = [convs]

if isinstance(convs[0], Conversation):
if not local:
convs = [[str(m['content']) for m in p.messages] for p in convs]
else:
convs = [p.apply_prompt_template() for p in convs]

if not isinstance(convs[0], str):
convs = [str(p) for p in convs]

return convs


class PredictionWriter:

def __init__(self, evaluation_path: Optional[str]):
Expand Down Expand Up @@ -83,14 +101,15 @@ def _write(self, data):
def log_batch_results(
self,
raw_predictions: List[str],
local_model: bool,
lines_iter: Iterator[Tuple[int, str, Any]],
) -> int:
"""Log the batch predictions to the evaluation jsonlines file."""
if not self.alive():
return len(raw_predictions)

for raw_prediction, (idx, source, reference) in zip(raw_predictions, lines_iter):
if not isinstance(source, str):
source = str(source)
source = dump_conversations(source, local_model)
lines = {
"index": idx,
"source": source,
Expand Down Expand Up @@ -140,16 +159,20 @@ def load_continue(self) -> Iterator[typing.Any]:
def log_final_results(
raw_predictions: List[str],
processed_predictions: List[Union[str, float]],
evaluation_instances: List[tuple],
score_lists: Dict[str, List[float]],
multiple_source: bool,
model_evaluation_method: str,
use_normalization: bool,
option_nums: List[int],
len_evaluation_data: int,
evaluation_instances: List[tuple],
sample_num: int,
references: List[Any],
local_model: bool,
) -> Optional[pd.Series]:
"""Aggregate the final results and prepare for dumping to a json file."""

evaluation_instances = dump_conversations(evaluation_instances, local_model)

transposed_score_lists = [dict(zip(score_lists.keys(), values)) for values in zip(*score_lists.values())]
if model_evaluation_method == "generation":
Expand Down

0 comments on commit 83b47d9

Please sign in to comment.