Skip to content

Commit 83b47d9

Browse files
committed
[Fix] utf-8 codec cant decode
1 parent 5eaa969 commit 83b47d9

File tree

5 files changed

+38
-8
lines changed

5 files changed

+38
-8
lines changed

utilization/dataset/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,15 +766,16 @@ def log_final_results(
766766
return log_final_results(
767767
raw_predictions=raw_predictions,
768768
processed_predictions=processed_predictions,
769+
evaluation_instances=self.evaluation_instances,
769770
score_lists=score_lists,
770771
multiple_source=(self.dataset_name == "winogrande"),
771772
model_evaluation_method=self.model_evaluation_method,
772773
use_normalization=self.use_normalization,
773774
option_nums=self.option_nums,
774775
len_evaluation_data=len(self.evaluation_data),
775-
evaluation_instances=self.evaluation_instances,
776776
sample_num=self.sample_num,
777777
references=self.references,
778+
local_model=self.model.is_local_model(),
778779
)
779780

780781
def __repr__(self):
@@ -967,7 +968,7 @@ def step(
967968
if batch_size > 0:
968969
tqdm.set_description(self.display_names[self._cur_idx])
969970
if batch_size > 0:
970-
writer.log_batch_results(batch_raw_predictions, self._lines_iter)
971+
writer.log_batch_results(batch_raw_predictions, self._datasets[0].model.is_local_model(), self._lines_iter)
971972

972973
def __repr__(self):
973974
reprs = []

utilization/model/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,10 @@ def set_generation_args(self, **extra_model_args):
248248
logger.warning(f"Unused generation arguments: {extra_model_args}")
249249
return self.generation_kwargs
250250

251-
def generation(self, batched_inputs: Union[List[str],
252-
List[Conversation]]) -> Union[List[str], List[Tuple[str, ...]]]:
251+
def generation(
252+
self,
253+
batched_inputs: Union[List[str], List[Conversation]],
254+
) -> Union[List[str], List[Tuple[str, ...]]]:
253255
multi_turn_results = self.request(
254256
prompt=batched_inputs,
255257
multi_turn=self.multi_turn,

utilization/utils/arguments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,14 +690,15 @@ def parse_argument(args: Optional[List[str]] = None,
690690
epilog=EXAMPLE_STRING,
691691
formatter_class=argparse.RawDescriptionHelpFormatter,
692692
)
693-
model_args, dataset_args, evaluation_args = parser.parse_args_into_dataclasses(args)
694693

695694
try:
696695
from dotenv import load_dotenv
697696
load_dotenv()
698697
except (ImportError, ModuleNotFoundError):
699698
pass
700699

700+
model_args, dataset_args, evaluation_args = parser.parse_args_into_dataclasses(args)
701+
701702
if model_args.bnb_config:
702703
bnb_config_dict = json.loads(model_args.bnb_config)
703704
model_args.bnb_config = BitsAndBytesConfig(**bnb_config_dict)

utilization/utils/conversation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,9 @@ def to_model_prompt(
331331
max_turns=max_turns,
332332
)[0]
333333

334+
def apply_prompt_template(self):
335+
return self.formatter.apply_prompt_template(self)
336+
334337
def add(
335338
self,
336339
other: Optional["Conversation"] = None,

utilization/utils/log_results.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pandas as pd
88

9+
from .conversation import Conversation
10+
911
logger = getLogger(__name__)
1012

1113
if typing.TYPE_CHECKING:
@@ -35,6 +37,22 @@ def wrapper(df: pd.DataFrame):
3537
return wrapper
3638

3739

40+
def dump_conversations(convs: List[Any], local: bool):
41+
if isinstance(convs, (str, Conversation)):
42+
convs = [convs]
43+
44+
if isinstance(convs[0], Conversation):
45+
if not local:
46+
convs = [[str(m['content']) for m in p.messages] for p in convs]
47+
else:
48+
convs = [p.apply_prompt_template() for p in convs]
49+
50+
if not isinstance(convs[0], str):
51+
convs = [str(p) for p in convs]
52+
53+
return convs
54+
55+
3856
class PredictionWriter:
3957

4058
def __init__(self, evaluation_path: Optional[str]):
@@ -83,14 +101,15 @@ def _write(self, data):
83101
def log_batch_results(
84102
self,
85103
raw_predictions: List[str],
104+
local_model: bool,
86105
lines_iter: Iterator[Tuple[int, str, Any]],
87106
) -> int:
107+
"""Log the batch predictions to the evaluation jsonlines file."""
88108
if not self.alive():
89109
return len(raw_predictions)
90110

91111
for raw_prediction, (idx, source, reference) in zip(raw_predictions, lines_iter):
92-
if not isinstance(source, str):
93-
source = str(source)
112+
source = dump_conversations(source, local_model)
94113
lines = {
95114
"index": idx,
96115
"source": source,
@@ -140,16 +159,20 @@ def load_continue(self) -> Iterator[typing.Any]:
140159
def log_final_results(
141160
raw_predictions: List[str],
142161
processed_predictions: List[Union[str, float]],
162+
evaluation_instances: List[tuple],
143163
score_lists: Dict[str, List[float]],
144164
multiple_source: bool,
145165
model_evaluation_method: str,
146166
use_normalization: bool,
147167
option_nums: List[int],
148168
len_evaluation_data: int,
149-
evaluation_instances: List[tuple],
150169
sample_num: int,
151170
references: List[Any],
171+
local_model: bool,
152172
) -> Optional[pd.Series]:
173+
"""Aggregate the final results and prepare for dumping to a json file."""
174+
175+
evaluation_instances = dump_conversations(evaluation_instances, local_model)
153176

154177
transposed_score_lists = [dict(zip(score_lists.keys(), values)) for values in zip(*score_lists.values())]
155178
if model_evaluation_method == "generation":

0 commit comments

Comments
 (0)