From 1a5fb7d906edc0dffd8efc3528266734b8de7070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Mon, 26 Aug 2024 18:00:43 +0200 Subject: [PATCH 01/15] Stashing changes to logging --- src/lighteval/logging/evaluation_tracker.py | 332 ++++++++------------ src/lighteval/utils/io.py | 50 +++ tests/logging/test_evaluation_tracker.py | 138 ++++++++ 3 files changed, 319 insertions(+), 201 deletions(-) create mode 100644 src/lighteval/utils/io.py create mode 100644 tests/logging/test_evaluation_tracker.py diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index ebae02b6..e4395bf5 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -28,12 +28,12 @@ from dataclasses import asdict, is_dataclass from datetime import datetime from enum import Enum -from pathlib import Path +from tempfile import TemporaryDirectory import torch from datasets import Dataset, load_dataset from datasets.utils.metadata import MetadataConfigs -from huggingface_hub import DatasetCard, DatasetCardData, HfApi, HFSummaryWriter, hf_hub_url +from huggingface_hub import DatasetCard, DatasetCardData, HfApi, hf_hub_url from lighteval.logging.hierarchical_logger import hlog, hlog_warn from lighteval.logging.info_loggers import ( @@ -44,6 +44,7 @@ VersionsLogger, ) from lighteval.utils.imports import NO_TENSORBOARDX_WARN_MSG, is_nanotron_available, is_tensorboardX_available +from lighteval.utils.io import FsspecDataResource from lighteval.utils.utils import obj_to_markdown @@ -94,13 +95,10 @@ class EvaluationTracker: def __init__( self, - output_dir: str = None, - hub_results_org: str = "", - push_results_to_hub: bool = False, - push_details_to_hub: bool = False, - push_results_to_tensorboard: bool = False, - tensorboard_metric_prefix: str = "eval", - public: bool = False, + output_dir: str, + save_results: bool, + save_details: bool, + save_tensorboard: bool, token: str = "", nanotron_run_info: "GeneralArgs" = None, ) -> None: @@ -131,48 +129,31 @@ def __init__( self.api = HfApi(token=token) - self.output_dir = output_dir - - self.hub_results_org = hub_results_org # will also contain tensorboard results - if hub_results_org in ["", None] and any( - [push_details_to_hub, push_results_to_hub, push_results_to_tensorboard] - ): - raise Exception( - "You need to select which org to push to, using `--results_org`, if you want to save information to the hub." - ) - - self.hub_results_repo = f"{hub_results_org}/results" - self.hub_private_results_repo = f"{hub_results_org}/private-results" - self.push_results_to_hub = push_results_to_hub - self.push_details_to_hub = push_details_to_hub - - self.push_results_to_tensorboard = push_results_to_tensorboard - self.tensorboard_repo = f"{hub_results_org}/tensorboard_logs" - self.tensorboard_metric_prefix = tensorboard_metric_prefix + self.output_res = FsspecDataResource.from_uri(output_dir) + self.save_results = save_results + self.save_details = save_details + self.save_tensorboard = save_tensorboard self.nanotron_run_info = nanotron_run_info - self.public = public - def save(self) -> None: - """Saves the experiment information and results to files, and to the hub if requested.""" - hlog("Saving experiment tracker") - date_id = datetime.now().isoformat().replace(":", "-") - - output_dir_results = Path(self.output_dir) / "results" / self.general_config_logger.model_name - output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name - output_dir_details_sub_folder = output_dir_details / date_id - output_dir_results.mkdir(parents=True, exist_ok=True) - output_dir_details_sub_folder.mkdir(parents=True, exist_ok=True) + """Saves the experiment information and results to files, and to the hub if requested. + Note: + In case of save failure, this function will only print a warning, with the error message. - output_results_file = output_dir_results / f"results_{date_id}.json" - output_results_in_details_file = output_dir_details / f"results_{date_id}.json" - - hlog(f"Saving results to {output_results_file} and {output_results_in_details_file}") + Args: + output_dir (str): Local folder path where you want results to be saved. + save_results (bool): If True, results are saved to the specified logging URI. + save_details (bool): If True, detailed results are saved to the specified logging URI. + save_tensorboard (bool, optional): If True, tensorboard logs are saved to the specified logging URI. Defaults to False. + """ + hlog("Saving experiment tracker") + date_id = datetime.now().isoformat().replace(":", "-") config_general = copy.deepcopy(self.general_config_logger) + config_general.config = ( + config_general.config.as_dict() if is_dataclass(config_general.config) else config_general.config + ) config_general = asdict(config_general) - # We remove the config from logging, which contains context/accelerator objects - config_general.pop("config") to_dump = { "config_general": config_general, @@ -182,45 +163,44 @@ def save(self) -> None: "summary_tasks": self.details_logger.compiled_details, "summary_general": asdict(self.details_logger.compiled_details_over_all_tasks), } - dumped = json.dumps(to_dump, cls=EnhancedJSONEncoder, indent=2) - - with open(output_results_file, "w") as f: - f.write(dumped) - - with open(output_results_in_details_file, "w") as f: - f.write(dumped) - - for task_name, task_details in self.details_logger.details.items(): - output_file_details = output_dir_details_sub_folder / f"details_{task_name}_{date_id}.parquet" - # Create a dataset from the dictionary - we force cast to str to avoid formatting problems for nested objects - dataset = Dataset.from_list([{k: str(v) for k, v in asdict(detail).items()} for detail in task_details]) - - # We don't keep 'id' around if it's there - column_names = dataset.column_names - if "id" in dataset.column_names: - column_names = [t for t in dataset.column_names if t != "id"] - - # Sort column names to make it easier later - dataset = dataset.select_columns(sorted(column_names)) - # Save the dataset to a Parquet file - dataset.to_parquet(output_file_details.as_posix()) - - if self.push_results_to_hub: - self.api.upload_folder( - repo_id=self.hub_results_repo if self.public else self.hub_private_results_repo, - folder_path=output_dir_results, - path_in_repo=self.general_config_logger.model_name, - repo_type="dataset", - commit_message=f"Updating model {self.general_config_logger.model_name}", - ) + dumped = json.dumps(to_dump, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False) + + if self.save_results: + output_results_res = self.output_res / "results" / self.general_config_logger.model_name + hlog(f"Saving results to {output_results_res}") + output_results_res.fs.mkdirs(output_results_res.path, exist_ok=True) + output_results_file = output_results_res / f"results_{date_id}.json" + with output_results_file.fs.open(output_results_file.path, "w") as f: + f.write(dumped) + + if self.save_details: + output_details_res = self.output_res / "details" / self.general_config_logger.model_name + hlog(f"Saving details to {output_details_res}") + output_details_res.fs.mkdirs(output_details_res.path, exist_ok=True) + + for task_name, task_details in self.details_logger.details.items(): + output_task_details_file = output_details_res / f"details_{task_name}_{date_id}.parquet" + # Create a dataset from the dictionary + try: + dataset = Dataset.from_list([asdict(detail) for detail in task_details]) + except Exception: + # We force cast to str to avoid formatting problems for nested objects + dataset = Dataset.from_list( + [{k: str(v) for k, v in asdict(detail).items()} for detail in task_details] + ) - if self.push_details_to_hub: - self.details_to_hub( - results_file_path=output_results_in_details_file, - details_folder_path=output_dir_details_sub_folder, - ) + # We don't keep 'id' around if it's there + column_names = dataset.column_names + if "id" in dataset.column_names: + column_names = [t for t in dataset.column_names if t != "id"] + + # Sort column names to make it easier later + dataset = dataset.select_columns(sorted(column_names)) + # Save the dataset to a Parquet file + with output_task_details_file.fs.open(output_task_details_file.path, "wb") as f: + dataset.to_parquet(f) - if self.push_results_to_tensorboard: + if self.save_tensorboard: self.push_to_tensorboard( results=self.metrics_logger.metric_aggregated, details=self.details_logger.details ) @@ -246,63 +226,6 @@ def generate_final_dict(self) -> dict: return final_dict - def details_to_hub( - self, - results_file_path: Path | str, - details_folder_path: Path | str, - ) -> None: - """Pushes the experiment details (all the model predictions for every step) to the hub. - - Args: - results_file_path (str or Path): Local path of the current's experiment aggregated results individual file - details_folder_path (str or Path): Local path of the current's experiment details folder. - The details folder (created by [`EvaluationTracker.save`]) should contain one parquet file per task used during the evaluation run of the current model. - - """ - results_file_path = str(results_file_path) - details_folder_path = str(details_folder_path) - - sanitized_model_name = self.general_config_logger.model_name.replace("/", "__") - - # "Default" detail names are the public detail names (same as results vs private-results) - repo_id = f"{self.hub_results_org}/details_{sanitized_model_name}" - if not self.public: # if not public, we add `_private` - repo_id = f"{repo_id}_private" - - sub_folder_path = os.path.basename(results_file_path).replace(".json", "").replace("results_", "") - - paths_to_check = [os.path.basename(results_file_path)] - try: - checked_paths = list(self.api.get_paths_info(repo_id=repo_id, paths=paths_to_check, repo_type="dataset")) - except Exception: - checked_paths = [] - - if len(checked_paths) == 0: - hlog(f"Repo {repo_id} not found for {results_file_path}. Creating it.") - self.api.create_repo(repo_id, private=not (self.public), repo_type="dataset", exist_ok=True) - - # Create parquet version of results file as well - results = load_dataset("json", data_files=results_file_path) - parquet_name = os.path.basename(results_file_path).replace(".json", ".parquet") - parquet_local_path = os.path.join(os.path.dirname(results_file_path), parquet_name) - results["train"].to_parquet(parquet_local_path) - - # Upload results file (json and parquet) and folder - self.api.upload_file( - repo_id=repo_id, - path_or_fileobj=results_file_path, - path_in_repo=os.path.basename(results_file_path), - repo_type="dataset", - ) - self.api.upload_file( - repo_id=repo_id, path_or_fileobj=parquet_local_path, path_in_repo=parquet_name, repo_type="dataset" - ) - self.api.upload_folder( - repo_id=repo_id, folder_path=details_folder_path, path_in_repo=sub_folder_path, repo_type="dataset" - ) - - self.recreate_metadata_card(repo_id) - def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 """Fully updates the details repository metadata card for the currently evaluated model @@ -513,6 +436,9 @@ def push_to_tensorboard( # noqa: C901 if not is_nanotron_available(): hlog_warn("You cannot push results to tensorboard without having nanotron installed. Skipping") return + + from tensorboardX import SummaryWriter + prefix = self.tensorboard_metric_prefix if self.nanotron_run_info is not None: @@ -522,71 +448,75 @@ def push_to_tensorboard( # noqa: C901 global_step = 0 run = prefix - output_dir_tb = Path(self.output_dir) / "tb" / run - output_dir_tb.mkdir(parents=True, exist_ok=True) - tb_context = HFSummaryWriter( - logdir=str(output_dir_tb), - repo_id=self.tensorboard_repo, - repo_private=True, - path_in_repo="tb", - commit_every=6000, # Very long time so that we can change our files names and trigger push ourselves (see below) - ) - bench_averages = {} - for name, values in results.items(): - splited_name = name.split("|") - if len(splited_name) == 3: - _, task_name, _ = splited_name - else: - task_name = name - bench_suite = None - if ":" in task_name: - bench_suite = task_name.split(":")[0] # e.g. MMLU - hlog(f"bench_suite {bench_suite} in {task_name}") + with TemporaryDirectory() as tmp_dir: + tb_context = SummaryWriter( + logdir=tmp_dir, + ) + bench_averages = {} + for name, values in results.items(): + splited_name = name.split("|") + if len(splited_name) == 3: + _, task_name, _ = splited_name + else: + task_name = name + bench_suite = None + if ":" in task_name: + bench_suite = task_name.split(":")[0] # e.g. MMLU + hlog(f"bench_suite {bench_suite} in {task_name}") + for metric, value in values.items(): + if "stderr" in metric: + continue + if bench_suite not in bench_averages: + bench_averages[bench_suite] = {} + bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [ + float(value) + ] + hlog(f"Pushing {task_name} {values} to tensorboard") for metric, value in values.items(): if "stderr" in metric: - continue - if bench_suite not in bench_averages: - bench_averages[bench_suite] = {} - bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)] - hlog(f"Pushing {task_name} {values} to tensorboard") - for metric, value in values.items(): - if "stderr" in metric: - tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step) - elif bench_suite is not None: + tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step) + elif bench_suite is not None: + tb_context.add_scalar( + f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step + ) + else: + tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step) + # Tasks with subtasks + for name, values in bench_averages.items(): + for metric, values in values.items(): + hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard") tb_context.add_scalar( - f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step + f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step ) - else: - tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step) - # Tasks with subtasks - for name, values in bench_averages.items(): - for metric, values in values.items(): - hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard") - tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step) - - tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step) - - for task_name, task_details in details.items(): - tb_context.add_text( - f"eval_details_{task_name}", - obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}), - global_step=global_step, - ) - # We are doing parallel evaluations of multiple checkpoints and recording the steps not in order - # This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints - # See: https://github.com/tensorflow/tensorboard/issues/5958 - # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files - - tb_context.close() # flushes the unfinished write operations - time.sleep(5) - files = os.listdir(output_dir_tb) - for file in files: - os.rename(os.path.join(output_dir_tb, file), os.path.join(output_dir_tb, f"{global_step:07d}_{file}")) - - # Now we can push to the hub - tb_context.scheduler.trigger() - hlog( - f"Pushed to tensorboard at https://huggingface.co/{self.tensorboard_repo}/{output_dir_tb}/tensorboard" - f"at global_step {global_step}" - ) + tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step) + + for task_name, task_details in details.items(): + tb_context.add_text( + f"eval_details_{task_name}", + obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}), + global_step=global_step, + ) + + # We are doing parallel evaluations of multiple checkpoints and recording the steps not in order + # This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints + # See: https://github.com/tensorflow/tensorboard/issues/5958 + # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files + + tb_context.close() # flushes the unfinished write operations + time.sleep(5) + files = os.listdir(tmp_dir) + for file in files: + os.rename(os.path.join(tmp_dir, file), os.path.join(tmp_dir, f"{global_step:07d}_{file}")) + + output_dir_tb = self.output_res / "tb" / run + output_dir_tb.fs.mkdirs(output_dir_tb.path, exist_ok=True) + for root, _, files in os.walk(tmp_dir): + for file in files: + file_path = os.path.join(root, file) + with output_dir_tb.fs.open(output_dir_tb / file, "wb") as output_f, open( + file_path, "rb" + ) as input_f: + output_f.write(input_f.read()) + + hlog(f"Pushed to tensorboard at {output_dir_tb}" f"at global_step {global_step}") diff --git a/src/lighteval/utils/io.py b/src/lighteval/utils/io.py new file mode 100644 index 00000000..662f077f --- /dev/null +++ b/src/lighteval/utils/io.py @@ -0,0 +1,50 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +from dataclasses import dataclass + +from fsspec import AbstractFileSystem, url_to_fs +from huggingface_hub import HfFileSystem + + +@dataclass(frozen=True) +class FsspecDataResource: + fs: AbstractFileSystem + path: str + + @classmethod + def from_uri(cls, uri: str) -> "FsspecDataResource": + fs, path = url_to_fs(uri) + return cls(fs=fs, path=path) + + def __truediv__(self, other: str) -> "FsspecDataResource": + return FsspecDataResource(fs=self.fs, path=os.path.join(self.path, other)) + + def __str__(self) -> str: + return self.path + + +def get_hf_repo_id(resource: FsspecDataResource) -> str: + if isinstance(resource.fs, HfFileSystem): + return "/".join(resource.path.split("/")[:2]) + raise ValueError("Resource is not a Hugging Face Hub repository") diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py new file mode 100644 index 00000000..552c95ad --- /dev/null +++ b/tests/logging/test_evaluation_tracker.py @@ -0,0 +1,138 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import json +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from datasets import Dataset + +from lighteval.logging.evaluation_tracker import EvaluationTracker +from lighteval.logging.info_loggers import DetailsLogger + + +@pytest.fixture +def mock_evaluation_tracker(): + with tempfile.TemporaryDirectory() as temp_dir: + tracker = EvaluationTracker( + output_dir=temp_dir, + save_results=True, + save_details=True, + save_tensorboard=True, + ) + tracker.general_config_logger.model_name = "test_model" + yield tracker + + +def test_tensorboard_logging(mock_evaluation_tracker): + mock_evaluation_tracker.save_results = False + mock_evaluation_tracker.save_details = False + mock_evaluation_tracker.save_tensorboard = True + + mock_evaluation_tracker.metrics_logger.metric_aggregated = { + "task1": {"accuracy": 0.8, "f1": 0.75}, + "task2": {"precision": 0.9, "recall": 0.85}, + } + + mock_evaluation_tracker.save() + + with open( + Path(mock_evaluation_tracker.output_res.path) / "tensorboard" / "test_model" / "events.out.tfevents", "r" + ) as f: + content = f.read() + # Check if SummaryWriter was called + assert "SummaryWriter" in content, "SummaryWriter was not called" + + # Check if scalar values were added + assert "add_scalar" in content, "Scalar values were not added" + assert "task1/accuracy" in content, "task1/accuracy was not logged" + assert "task1/f1" in content, "task1/f1 was not logged" + assert "task2/precision" in content, "task2/precision was not logged" + assert "task2/recall" in content, "task2/recall was not logged" + + # Check if SummaryWriter was called + + # Check if scalar values were added + + +def test_results_logging(mock_evaluation_tracker: EvaluationTracker): + mock_evaluation_tracker.metrics_logger.log("task1", {"accuracy": 0.8, "f1": 0.75}) + mock_evaluation_tracker.metrics_logger.log("task2", {"precision": 0.9, "recall": 0.85}) + + mock_evaluation_tracker.save() + + results_dir = Path(mock_evaluation_tracker.output_res.path) / "results" / "test_model" + assert results_dir.exists() + + result_files = list(results_dir.glob("results_*.json")) + assert len(result_files) == 1 + + with open(result_files[0], "r") as f: + saved_results = json.load(f) + + assert "results" in saved_results + assert saved_results["results"] == mock_evaluation_tracker.metrics_logger.metric_aggregated + + +def test_details_logging(mock_evaluation_tracker): + mock_evaluation_tracker.details_logger.details = { + "task1": [DetailsLogger.CompiledDetail(task_name="task1", num_samples=100)], + "task2": [DetailsLogger.CompiledDetail(task_name="task2", num_samples=200)], + } + + mock_evaluation_tracker.save() + + details_dir = Path(mock_evaluation_tracker.output_res.path) / "details" / "test_model" + assert details_dir.exists() + + detail_files = list(details_dir.glob("details_*.parquet")) + assert len(detail_files) == 2 + + for file in detail_files: + dataset = Dataset.from_parquet(file) + assert len(dataset) == 1 + assert "task_name" in dataset.column_names + assert "num_samples" in dataset.column_names + + +@patch("lighteval.logging.evaluation_tracker.HfApi") +@patch("lighteval.logging.evaluation_tracker.DatasetCard") +def test_recreate_metadata_card(mock_dataset_card, mock_hf_api, mock_evaluation_tracker): + mock_api_instance = MagicMock() + mock_hf_api.return_value = mock_api_instance + mock_api_instance.list_repo_files.return_value = [ + "results_2023-01-01T00-00-00.json", + "details_task1_2023-01-01T00-00-00.parquet", + "details_task2_2023-01-01T00-00-00.parquet", + ] + + mock_dataset = MagicMock() + mock_dataset.__getitem__.return_value = [{"results": {"task1": {"accuracy": 0.8}, "task2": {"precision": 0.9}}}] + + with patch("lighteval.logging.evaluation_tracker.load_dataset", return_value=mock_dataset): + mock_evaluation_tracker.recreate_metadata_card("test/repo") + + mock_dataset_card.from_template.assert_called_once() + mock_card = mock_dataset_card.from_template.return_value + mock_card.push_to_hub.assert_called_once_with("test/repo", repo_type="dataset") From 0892317c413668171ea7bacac7fb1535a8f53e49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Thu, 29 Aug 2024 16:03:04 +0200 Subject: [PATCH 02/15] revert the push to hub removal --- src/lighteval/logging/evaluation_tracker.py | 345 ++++++++++++-------- src/lighteval/main_accelerate.py | 5 +- src/lighteval/main_nanotron.py | 1 + src/lighteval/parsers.py | 7 +- tests/fixtures.py | 52 +++ tests/logging/test_evaluation_tracker.py | 156 +++++---- 6 files changed, 346 insertions(+), 220 deletions(-) create mode 100644 tests/fixtures.py diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index e4395bf5..9c1ca554 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -28,12 +28,14 @@ from dataclasses import asdict, is_dataclass from datetime import datetime from enum import Enum -from tempfile import TemporaryDirectory +from io import BytesIO +from pathlib import Path import torch from datasets import Dataset, load_dataset from datasets.utils.metadata import MetadataConfigs -from huggingface_hub import DatasetCard, DatasetCardData, HfApi, hf_hub_url +from fsspec import url_to_fs +from huggingface_hub import DatasetCard, DatasetCardData, HfApi, HFSummaryWriter, hf_hub_url from lighteval.logging.hierarchical_logger import hlog, hlog_warn from lighteval.logging.info_loggers import ( @@ -44,7 +46,6 @@ VersionsLogger, ) from lighteval.utils.imports import NO_TENSORBOARDX_WARN_MSG, is_nanotron_available, is_tensorboardX_available -from lighteval.utils.io import FsspecDataResource from lighteval.utils.utils import obj_to_markdown @@ -96,10 +97,13 @@ class EvaluationTracker: def __init__( self, output_dir: str, - save_results: bool, - save_details: bool, - save_tensorboard: bool, - token: str = "", + save_details: bool = True, + push_to_hub: bool = False, + push_results_to_tensorboard: bool = False, + hub_results_org: str = "", + tensorboard_metric_prefix: str = "eval", + public: bool = False, + token: str | None = None, nanotron_run_info: "GeneralArgs" = None, ) -> None: """) @@ -107,17 +111,17 @@ def __init__( Args: output_dir (str): Local folder path where you want results to be saved - hub_results_org (str): The organisation to push the results to. See - more details about the datasets organisation in - [`EvaluationTracker.save`] - push_results_to_hub (bool): If True, results are pushed to the hub. - Results will be pushed either to `{hub_results_org}/results`, a public dataset, if `public` is True else to `{hub_results_org}/private-results`, a private dataset. - push_details_to_hub (bool): If True, details are pushed to the hub. + save_details (bool): If True, details are saved to the output_dir + push_to_hub (bool): If True, details are pushed to the hub. Results are pushed to `{hub_results_org}/details__{sanitized model_name}` for the model `model_name`, a public dataset, if `public` is True else `{hub_results_org}/details__{sanitized model_name}_private`, a private dataset. push_results_to_tensorboard (bool): If True, will create and push the results for a tensorboard folder on the hub + hub_results_org (str): The organisation to push the results to. See + more details about the datasets organisation in + [`EvaluationTracker.save`] + tensorboard_metric_prefix (str): Prefix for the metrics in the tensorboard logs public (bool): If True, results and details are pushed in private orgs - token (str): Token to use when pushing to the hub. This token should + token (str | None): Token to use when pushing to the hub. This token should have write access to `hub_results_org`. nanotron_run_info (GeneralArgs): Reference to informations about Nanotron models runs """ @@ -129,33 +133,36 @@ def __init__( self.api = HfApi(token=token) - self.output_res = FsspecDataResource.from_uri(output_dir) - self.save_results = save_results - self.save_details = save_details - self.save_tensorboard = save_tensorboard - self.nanotron_run_info = nanotron_run_info + self.fs, self.output_dir = url_to_fs(output_dir) - def save(self) -> None: - """Saves the experiment information and results to files, and to the hub if requested. - Note: - In case of save failure, this function will only print a warning, with the error message. + self.hub_results_org = hub_results_org # will also contain tensorboard results + if hub_results_org in ["", None] and any([push_to_hub, push_results_to_tensorboard]): + raise Exception( + "You need to select which org to push to, using `--results_org`, if you want to save information to the hub." + ) - Args: - output_dir (str): Local folder path where you want results to be saved. - save_results (bool): If True, results are saved to the specified logging URI. - save_details (bool): If True, detailed results are saved to the specified logging URI. - save_tensorboard (bool, optional): If True, tensorboard logs are saved to the specified logging URI. Defaults to False. - """ + self.should_push_to_hub = push_to_hub + self.should_save_details = save_details + + self.should_push_results_to_tensorboard = push_results_to_tensorboard + self.tensorboard_repo = f"{hub_results_org}/tensorboard_logs" + self.tensorboard_metric_prefix = tensorboard_metric_prefix + self.nanotron_run_info = nanotron_run_info + + self.public = public + def save(self) -> None: + """Saves the experiment information and results to files, and to the hub if requested.""" hlog("Saving experiment tracker") date_id = datetime.now().isoformat().replace(":", "-") + + # We first prepare data to save config_general = copy.deepcopy(self.general_config_logger) - config_general.config = ( - config_general.config.as_dict() if is_dataclass(config_general.config) else config_general.config - ) config_general = asdict(config_general) + # We remove the config from logging, which contains context/accelerator objects + config_general.pop("config") - to_dump = { + results_dict = { "config_general": config_general, "results": self.metrics_logger.metric_aggregated, "versions": self.versions_logger.versions, @@ -163,48 +170,58 @@ def save(self) -> None: "summary_tasks": self.details_logger.compiled_details, "summary_general": asdict(self.details_logger.compiled_details_over_all_tasks), } - dumped = json.dumps(to_dump, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False) - - if self.save_results: - output_results_res = self.output_res / "results" / self.general_config_logger.model_name - hlog(f"Saving results to {output_results_res}") - output_results_res.fs.mkdirs(output_results_res.path, exist_ok=True) - output_results_file = output_results_res / f"results_{date_id}.json" - with output_results_file.fs.open(output_results_file.path, "w") as f: - f.write(dumped) - - if self.save_details: - output_details_res = self.output_res / "details" / self.general_config_logger.model_name - hlog(f"Saving details to {output_details_res}") - output_details_res.fs.mkdirs(output_details_res.path, exist_ok=True) - - for task_name, task_details in self.details_logger.details.items(): - output_task_details_file = output_details_res / f"details_{task_name}_{date_id}.parquet" - # Create a dataset from the dictionary - try: - dataset = Dataset.from_list([asdict(detail) for detail in task_details]) - except Exception: - # We force cast to str to avoid formatting problems for nested objects - dataset = Dataset.from_list( - [{k: str(v) for k, v in asdict(detail).items()} for detail in task_details] - ) - # We don't keep 'id' around if it's there - column_names = dataset.column_names - if "id" in dataset.column_names: - column_names = [t for t in dataset.column_names if t != "id"] + # Create the details datasets for later upload + details_datasets: dict[str, Dataset] = {} + for task_name, task_details in self.details_logger.details.items(): + # Create a dataset from the dictionary - we force cast to str to avoid formatting problems for nested objects + dataset = Dataset.from_list([{k: str(v) for k, v in asdict(detail).items()} for detail in task_details]) + + # We don't keep 'id' around if it's there + column_names = dataset.column_names + if "id" in dataset.column_names: + column_names = [t for t in dataset.column_names if t != "id"] + + # Sort column names to make it easier later + dataset = dataset.select_columns(sorted(column_names)) + details_datasets[task_name] = dataset - # Sort column names to make it easier later - dataset = dataset.select_columns(sorted(column_names)) - # Save the dataset to a Parquet file - with output_task_details_file.fs.open(output_task_details_file.path, "wb") as f: - dataset.to_parquet(f) + # We save results at every case + self.save_results(date_id, results_dict) + + if self.should_save_details: + self.save_details(date_id, details_datasets) + + if self.should_push_to_hub: + self.push_to_hub( + date_id=date_id, + details=details_datasets, + results_dict=results_dict, + ) - if self.save_tensorboard: + if self.should_push_results_to_tensorboard: self.push_to_tensorboard( - results=self.metrics_logger.metric_aggregated, details=self.details_logger.details + results=self.metrics_logger.metric_aggregated, details=self.details_logger.compiled_details ) + def save_results(self, date_id: str, results_dict: dict): + output_dir_results = Path(self.output_dir) / "results" / self.general_config_logger.model_name + self.fs.mkdirs(output_dir_results, exist_ok=True) + output_results_file = output_dir_results / f"results_{date_id}.json" + hlog(f"Saving results to {output_results_file}") + with self.fs.open(output_results_file, "w") as f: + f.write(json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)) + + def save_details(self, date_id: str, details_datasets: dict[str, Dataset]): + output_dir_details = Path(self.output_dir) / "details" / self.general_config_logger.model_name + output_dir_details_sub_folder = output_dir_details / date_id + self.fs.mkdirs(output_dir_details_sub_folder, exist_ok=True) + hlog(f"Saving details to {output_dir_details_sub_folder}") + for task_name, dataset in details_datasets.items(): + output_file_details = output_dir_details_sub_folder / f"details_{task_name}_{date_id}.parquet" + with self.fs.open(str(output_file_details), "wb") as f: + dataset.to_parquet(f) + def generate_final_dict(self) -> dict: """Aggregates and returns all the logger's experiment information in a dictionary. @@ -226,6 +243,47 @@ def generate_final_dict(self) -> dict: return final_dict + def push_to_hub( + self, + date_id: str, + details: dict[str, Dataset], + results_dict: dict, + ) -> None: + """Pushes the experiment details (all the model predictions for every step) to the hub.""" + sanitized_model_name = self.general_config_logger.model_name.replace("/", "__") + + # "Default" detail names are the public detail names (same as results vs private-results) + repo_id = f"{self.hub_results_org}/details_{sanitized_model_name}" + if not self.public: # if not public, we add `_private` + repo_id = f"{repo_id}_private" + + fsspec_repo_uri = f"hf://datasets/{repo_id}" + + if not self.api.repo_exists(repo_id): + self.api.create_repo(repo_id, private=not (self.public), repo_type="dataset", exist_ok=True) + hlog(f"Repository {repo_id} not found, creating it.") + + # We upload it both as a json and a parquet file + result_file_base_name = f"results_{date_id}" + results_json = json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False) + self.api.upload_file( + repo_id=repo_id, + path_or_fileobj=BytesIO(results_json.encode("utf-8")), + path_in_repo=f"{result_file_base_name}.json", + repo_type="dataset", + ) + + results_dataset = Dataset.from_dict( + {key: [json.dumps(v, cls=EnhancedJSONEncoder, indent=2)] for key, v in results_dict.items()} + ) + results_dataset.to_parquet(f"{fsspec_repo_uri}/{result_file_base_name}.parquet") + + for task_name, dataset in details.items(): + output_file_details = Path(date_id) / f"details_{task_name}_{date_id}.parquet" + dataset.to_parquet(f"{fsspec_repo_uri}/{output_file_details}") + + self.recreate_metadata_card(repo_id) + def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 """Fully updates the details repository metadata card for the currently evaluated model @@ -236,6 +294,8 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset") results_files = [f for f in files_in_repo if ".json" in f] parquet_files = [f for f in files_in_repo if ".parquet" in f] + + details_file_regex = re.compile(r"details_(?P.*?)_(?P\d+-\d+-\d+T.*)\.parquet$") multiple_results = len(results_files) > 1 # Get last eval results date for each task (evals might be non overlapping) @@ -249,8 +309,8 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 # `2023-09-03T10-57-04.203304/details_harness|hendrycksTest-us_foreign_policy|5_2023-09-03T10-57-04.203304.parquet` # in the iso date, the `:` are replaced by `-` because windows does not allow `:` in their filenames task_name = ( - os.path.basename(sub_file).replace("details_", "").split("_202")[0] - ) # 202 for dates, 2023, 2024, ... + details_file_regex.match(os.path.basename(sub_file)).group("task_name") # type: ignore + ) # task_name is then equal to `leaderboard|mmlu:us_foreign_policy|5` # to be able to parse the filename as iso dates, we need to re-replace the `-` with `:` @@ -282,9 +342,15 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 sanitized_last_eval_date_results = re.sub(r"[^\w\.]", "_", max_last_eval_date_results) repo_file_name = os.path.basename(sub_file) else: - task_name = os.path.basename(sub_file).replace("details_", "").split("_2023")[0].split("_2024")[0] + filename = os.path.basename(sub_file) + + task_name_match = details_file_regex.match(filename) # type: ignore + if not task_name_match: + raise ValueError(f"Could not parse task name from filename: {filename}") + task_name = task_name_match.group("task_name") + eval_date = task_name_match.group("date") + sanitized_task = re.sub(r"\W", "_", task_name) - eval_date = os.path.dirname(sub_file) sanitized_last_eval_date_results = re.sub(r"[^\w\.]", "_", last_eval_date_results[task_name]) repo_file_name = os.path.join("**", os.path.basename(sub_file)) @@ -436,9 +502,6 @@ def push_to_tensorboard( # noqa: C901 if not is_nanotron_available(): hlog_warn("You cannot push results to tensorboard without having nanotron installed. Skipping") return - - from tensorboardX import SummaryWriter - prefix = self.tensorboard_metric_prefix if self.nanotron_run_info is not None: @@ -448,75 +511,71 @@ def push_to_tensorboard( # noqa: C901 global_step = 0 run = prefix - with TemporaryDirectory() as tmp_dir: - tb_context = SummaryWriter( - logdir=tmp_dir, - ) - bench_averages = {} - for name, values in results.items(): - splited_name = name.split("|") - if len(splited_name) == 3: - _, task_name, _ = splited_name - else: - task_name = name - bench_suite = None - if ":" in task_name: - bench_suite = task_name.split(":")[0] # e.g. MMLU - hlog(f"bench_suite {bench_suite} in {task_name}") - for metric, value in values.items(): - if "stderr" in metric: - continue - if bench_suite not in bench_averages: - bench_averages[bench_suite] = {} - bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [ - float(value) - ] - hlog(f"Pushing {task_name} {values} to tensorboard") + output_dir_tb = Path(self.output_dir) / "tb" / run + output_dir_tb.mkdir(parents=True, exist_ok=True) + tb_context = HFSummaryWriter( + logdir=str(output_dir_tb), + repo_id=self.tensorboard_repo, + repo_private=True, + path_in_repo="tb", + commit_every=6000, # Very long time so that we can change our files names and trigger push ourselves (see below) + ) + bench_averages = {} + for name, values in results.items(): + splited_name = name.split("|") + if len(splited_name) == 3: + _, task_name, _ = splited_name + else: + task_name = name + bench_suite = None + if ":" in task_name: + bench_suite = task_name.split(":")[0] # e.g. MMLU + hlog(f"bench_suite {bench_suite} in {task_name}") for metric, value in values.items(): if "stderr" in metric: - tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step) - elif bench_suite is not None: - tb_context.add_scalar( - f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step - ) - else: - tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step) - # Tasks with subtasks - for name, values in bench_averages.items(): - for metric, values in values.items(): - hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard") + continue + if bench_suite not in bench_averages: + bench_averages[bench_suite] = {} + bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)] + hlog(f"Pushing {task_name} {values} to tensorboard") + for metric, value in values.items(): + if "stderr" in metric: + tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step) + elif bench_suite is not None: tb_context.add_scalar( - f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step + f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step ) + else: + tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step) + # Tasks with subtasks + for name, values in bench_averages.items(): + for metric, values in values.items(): + hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard") + tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step) + + tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step) + + for task_name, task_details in details.items(): + tb_context.add_text( + f"eval_details_{task_name}", + obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}), + global_step=global_step, + ) - tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step) - - for task_name, task_details in details.items(): - tb_context.add_text( - f"eval_details_{task_name}", - obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}), - global_step=global_step, - ) - - # We are doing parallel evaluations of multiple checkpoints and recording the steps not in order - # This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints - # See: https://github.com/tensorflow/tensorboard/issues/5958 - # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files - - tb_context.close() # flushes the unfinished write operations - time.sleep(5) - files = os.listdir(tmp_dir) - for file in files: - os.rename(os.path.join(tmp_dir, file), os.path.join(tmp_dir, f"{global_step:07d}_{file}")) - - output_dir_tb = self.output_res / "tb" / run - output_dir_tb.fs.mkdirs(output_dir_tb.path, exist_ok=True) - for root, _, files in os.walk(tmp_dir): - for file in files: - file_path = os.path.join(root, file) - with output_dir_tb.fs.open(output_dir_tb / file, "wb") as output_f, open( - file_path, "rb" - ) as input_f: - output_f.write(input_f.read()) - - hlog(f"Pushed to tensorboard at {output_dir_tb}" f"at global_step {global_step}") + # We are doing parallel evaluations of multiple checkpoints and recording the steps not in order + # This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints + # See: https://github.com/tensorflow/tensorboard/issues/5958 + # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files + + tb_context.close() # flushes the unfinished write operations + time.sleep(5) + files = os.listdir(output_dir_tb) + for file in files: + os.rename(os.path.join(output_dir_tb, file), os.path.join(output_dir_tb, f"{global_step:07d}_{file}")) + + # Now we can push to the hub + tb_context.scheduler.trigger() + hlog( + f"Pushed to tensorboard at https://huggingface.co/{self.tensorboard_repo}/{output_dir_tb}/tensorboard" + f"at global_step {global_step}" + ) diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 22054c97..79346e7c 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -48,9 +48,8 @@ def main(args): env_config = EnvConfig(token=TOKEN, cache_dir=args.cache_dir) evaluation_tracker = EvaluationTracker( output_dir=args.output_dir, - hub_results_org=args.results_org, - push_results_to_hub=args.push_results_to_hub, - push_details_to_hub=args.push_details_to_hub, + save_details=args.save_details, + push_to_hub=args.push_to_hub, push_results_to_tensorboard=args.push_results_to_tensorboard, public=args.public_run, token=TOKEN, diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 6e219b30..2b0d60de 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -66,6 +66,7 @@ def main( else: lighteval_config = model_config.lighteval + # TODO: Once lighteval config is owned by Ligteval fix this evaluation_tracker = EvaluationTracker( token=os.getenv("HF_TOKEN"), output_dir=lighteval_config.logging.local_output_path, diff --git a/src/lighteval/parsers.py b/src/lighteval/parsers.py index 499d945e..46bfb93a 100644 --- a/src/lighteval/parsers.py +++ b/src/lighteval/parsers.py @@ -55,13 +55,8 @@ def parser_accelerate(parser=None): # Saving parser.add_argument("--output_dir", required=True, type=str, help="Directory to save the results") - parser.add_argument( - "--push_results_to_hub", default=False, action="store_true", help="Set to push the results to the hub" - ) parser.add_argument("--save_details", action="store_true", help="Save the details of the run in the output_dir") - parser.add_argument( - "--push_details_to_hub", default=False, action="store_true", help="Set to push the details to the hub" - ) + parser.add_argument("--push_to_hub", default=False, action="store_true", help="Set to push the details to the hub") parser.add_argument("--push_results_to_tensorboard", default=False, action="store_true") parser.add_argument( "--public_run", default=False, action="store_true", help="Push results and details to a public repo" diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 00000000..a85279fb --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,52 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import pytest +from huggingface_hub import HfApi +from huggingface_hub.hf_api import DatasetInfo + + +@pytest.fixture +def testing_empty_hf_org_id(): + org_id = "lighteval-testing" + + def list_repos(org_id: str): + return list(hf_api.list_models(author=org_id)) + list(hf_api.list_datasets(author=org_id)) + + def clean_repos(org_id: str): + repos = list_repos(org_id) + for repo in repos: + hf_api.delete_repo(repo.id, repo_type="dataset" if isinstance(repo, DatasetInfo) else "model") + + hf_api = HfApi() + # Remove all repositories in the HF org + clean_repos(org_id) + + # Verify that all repositories have been removed + remaining_repos = list_repos(org_id) + assert len(remaining_repos) == 0, f"Expected 0 repositories, but found {len(remaining_repos)}" + + yield org_id + + # Clean up: recreate any necessary default repositories after the test + # This step is optional and depends on your specific needs + clean_repos(org_id) diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py index 552c95ad..665ced35 100644 --- a/tests/logging/test_evaluation_tracker.py +++ b/tests/logging/test_evaluation_tracker.py @@ -22,11 +22,12 @@ import json import tempfile +from datetime import datetime from pathlib import Path -from unittest.mock import MagicMock, patch import pytest from datasets import Dataset +from huggingface_hub import HfApi from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.logging.info_loggers import DetailsLogger @@ -37,52 +38,41 @@ def mock_evaluation_tracker(): with tempfile.TemporaryDirectory() as temp_dir: tracker = EvaluationTracker( output_dir=temp_dir, - save_results=True, - save_details=True, - save_tensorboard=True, + save_details=False, + push_to_hub=False, + push_results_to_tensorboard=False, ) tracker.general_config_logger.model_name = "test_model" yield tracker -def test_tensorboard_logging(mock_evaluation_tracker): - mock_evaluation_tracker.save_results = False - mock_evaluation_tracker.save_details = False - mock_evaluation_tracker.save_tensorboard = True - - mock_evaluation_tracker.metrics_logger.metric_aggregated = { - "task1": {"accuracy": 0.8, "f1": 0.75}, - "task2": {"precision": 0.9, "recall": 0.85}, - } - - mock_evaluation_tracker.save() - - with open( - Path(mock_evaluation_tracker.output_res.path) / "tensorboard" / "test_model" / "events.out.tfevents", "r" - ) as f: - content = f.read() - # Check if SummaryWriter was called - assert "SummaryWriter" in content, "SummaryWriter was not called" +@pytest.fixture +def mock_datetime(monkeypatch): + mock_date = datetime(2023, 1, 1, 12, 0, 0) - # Check if scalar values were added - assert "add_scalar" in content, "Scalar values were not added" - assert "task1/accuracy" in content, "task1/accuracy was not logged" - assert "task1/f1" in content, "task1/f1 was not logged" - assert "task2/precision" in content, "task2/precision was not logged" - assert "task2/recall" in content, "task2/recall was not logged" + class MockDatetime: + @classmethod + def now(cls): + return mock_date - # Check if SummaryWriter was called + @classmethod + def fromisoformat(cls, date_string: str): + return mock_date - # Check if scalar values were added + monkeypatch.setattr("lighteval.logging.evaluation_tracker.datetime", MockDatetime) + return mock_date def test_results_logging(mock_evaluation_tracker: EvaluationTracker): - mock_evaluation_tracker.metrics_logger.log("task1", {"accuracy": 0.8, "f1": 0.75}) - mock_evaluation_tracker.metrics_logger.log("task2", {"precision": 0.9, "recall": 0.85}) + task_metrics = { + "task1": {"accuracy": 0.8, "f1": 0.75}, + "task2": {"precision": 0.9, "recall": 0.85}, + } + mock_evaluation_tracker.metrics_logger.metric_aggregated = task_metrics mock_evaluation_tracker.save() - results_dir = Path(mock_evaluation_tracker.output_res.path) / "results" / "test_model" + results_dir = Path(mock_evaluation_tracker.output_dir) / "results" / "test_model" assert results_dir.exists() result_files = list(results_dir.glob("results_*.json")) @@ -92,47 +82,77 @@ def test_results_logging(mock_evaluation_tracker: EvaluationTracker): saved_results = json.load(f) assert "results" in saved_results - assert saved_results["results"] == mock_evaluation_tracker.metrics_logger.metric_aggregated + assert saved_results["results"] == task_metrics + assert saved_results["config_general"]["model_name"] == "test_model" -def test_details_logging(mock_evaluation_tracker): - mock_evaluation_tracker.details_logger.details = { - "task1": [DetailsLogger.CompiledDetail(task_name="task1", num_samples=100)], - "task2": [DetailsLogger.CompiledDetail(task_name="task2", num_samples=200)], +def test_details_logging(mock_evaluation_tracker, mock_datetime): + mock_evaluation_tracker.should_save_details = True + task_details = { + "task1": [DetailsLogger.CompiledDetail(truncated=10, padded=5)], + "task2": [DetailsLogger.CompiledDetail(truncated=20, padded=10)], } + mock_evaluation_tracker.details_logger.details = task_details mock_evaluation_tracker.save() - details_dir = Path(mock_evaluation_tracker.output_res.path) / "details" / "test_model" + date_id = mock_datetime.isoformat().replace(":", "-") + details_dir = Path(mock_evaluation_tracker.output_dir) / "details" / "test_model" / date_id assert details_dir.exists() - detail_files = list(details_dir.glob("details_*.parquet")) - assert len(detail_files) == 2 - - for file in detail_files: - dataset = Dataset.from_parquet(file) + for task in ["task1", "task2"]: + file_path = details_dir / f"details_{task}_{date_id}.parquet" + dataset = Dataset.from_parquet(str(file_path)) assert len(dataset) == 1 - assert "task_name" in dataset.column_names - assert "num_samples" in dataset.column_names - - -@patch("lighteval.logging.evaluation_tracker.HfApi") -@patch("lighteval.logging.evaluation_tracker.DatasetCard") -def test_recreate_metadata_card(mock_dataset_card, mock_hf_api, mock_evaluation_tracker): - mock_api_instance = MagicMock() - mock_hf_api.return_value = mock_api_instance - mock_api_instance.list_repo_files.return_value = [ - "results_2023-01-01T00-00-00.json", - "details_task1_2023-01-01T00-00-00.parquet", - "details_task2_2023-01-01T00-00-00.parquet", - ] - - mock_dataset = MagicMock() - mock_dataset.__getitem__.return_value = [{"results": {"task1": {"accuracy": 0.8}, "task2": {"precision": 0.9}}}] - - with patch("lighteval.logging.evaluation_tracker.load_dataset", return_value=mock_dataset): - mock_evaluation_tracker.recreate_metadata_card("test/repo") - - mock_dataset_card.from_template.assert_called_once() - mock_card = mock_dataset_card.from_template.return_value - mock_card.push_to_hub.assert_called_once_with("test/repo", repo_type="dataset") + assert int(dataset[0]["truncated"]) == task_details[task][0].truncated + assert int(dataset[0]["padded"]) == task_details[task][0].padded + + +def test_no_details_output(mock_evaluation_tracker: EvaluationTracker): + mock_evaluation_tracker.should_save_details = False + mock_evaluation_tracker.save() + + details_dir = Path(mock_evaluation_tracker.output_dir) / "details" / "test_model" + assert not details_dir.exists() + + +def test_push_to_hub_works(testing_empty_hf_org_id, mock_evaluation_tracker: EvaluationTracker, mock_datetime): + mock_evaluation_tracker.should_push_to_hub = True + mock_evaluation_tracker.hub_results_org = testing_empty_hf_org_id + + # Prepare the dummy data + task_metrics = { + "task1": {"accuracy": 0.8, "f1": 0.75}, + "task2": {"precision": 0.9, "recall": 0.85}, + } + mock_evaluation_tracker.metrics_logger.metric_aggregated = task_metrics + + task_details = { + "task1": [DetailsLogger.CompiledDetail(truncated=10, padded=5)], + "task2": [DetailsLogger.CompiledDetail(truncated=20, padded=10)], + } + mock_evaluation_tracker.details_logger.details = task_details + + mock_evaluation_tracker.save() + + # Verify using HfApi + api = HfApi() + + # Check if repo exists and it's private + expected_repo_id = f"{testing_empty_hf_org_id}/details_test_model_private" + assert api.repo_exists(repo_id=expected_repo_id, repo_type="dataset") + assert api.repo_info(repo_id=expected_repo_id, repo_type="dataset").private + + repo_files = api.list_repo_files(repo_id=expected_repo_id, repo_type="dataset") + # Check if README.md exists + assert any(file == "README.md" for file in repo_files) + + # Check that both results files were uploaded + result_files = [file for file in repo_files if file.startswith("results_")] + assert len(result_files) == 2 + assert len([file for file in result_files if file.endswith(".json")]) == 1 + assert len([file for file in result_files if file.endswith(".parquet")]) == 1 + + # Check that the details dataset was uploaded + details_files = [file for file in repo_files if "details_" in file and file.endswith(".parquet")] + assert len(details_files) == 2 From d1ea3708c8161a476ff7c1710a6f72a1870c5219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Thu, 29 Aug 2024 16:22:13 +0200 Subject: [PATCH 03/15] add new deps to pyproject --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index e301d7af..224c56ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dependencies = [ "sentencepiece>=0.1.99", "protobuf==3.20.*", # pinned for sentencepiece compat "pycountry", + "fsspec>=2023.12.2", ] [project.optional-dependencies] @@ -94,6 +95,7 @@ extended_tasks = [ "langdetect", # ifeval "openai", # llm as a judge using openai models ] +s3 = ["s3fs"] [project.urls] Homepage = "https://github.com/huggingface/lighteval" From efb4ce8618c4807d173ce5416c72dc4089b617ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Thu, 29 Aug 2024 16:24:32 +0200 Subject: [PATCH 04/15] add better comment to output_dir --- src/lighteval/parsers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/lighteval/parsers.py b/src/lighteval/parsers.py index 46bfb93a..43a3c7c1 100644 --- a/src/lighteval/parsers.py +++ b/src/lighteval/parsers.py @@ -54,7 +54,12 @@ def parser_accelerate(parser=None): parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="") # Saving - parser.add_argument("--output_dir", required=True, type=str, help="Directory to save the results") + parser.add_argument( + "--output_dir", + required=True, + type=str, + help="Directory to save the results, fsspec compliant (e.g. s3://bucket/path)", + ) parser.add_argument("--save_details", action="store_true", help="Save the details of the run in the output_dir") parser.add_argument("--push_to_hub", default=False, action="store_true", help="Set to push the details to the hub") parser.add_argument("--push_results_to_tensorboard", default=False, action="store_true") From 8a8af5b1545f6f4064755d982b1f7bb0bd165d2c Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 2 Sep 2024 12:45:33 +0000 Subject: [PATCH 05/15] fix tensorboard and push to hub args --- src/lighteval/logging/evaluation_tracker.py | 16 +++++++++------- src/lighteval/main_accelerate.py | 3 ++- src/lighteval/parsers.py | 2 +- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 9c1ca554..2329c42e 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -99,14 +99,14 @@ def __init__( output_dir: str, save_details: bool = True, push_to_hub: bool = False, - push_results_to_tensorboard: bool = False, + push_to_tensorboard: bool = False, hub_results_org: str = "", tensorboard_metric_prefix: str = "eval", public: bool = False, token: str | None = None, nanotron_run_info: "GeneralArgs" = None, ) -> None: - """) + """ Creates all the necessary loggers for evaluation tracking. Args: @@ -136,7 +136,7 @@ def __init__( self.fs, self.output_dir = url_to_fs(output_dir) self.hub_results_org = hub_results_org # will also contain tensorboard results - if hub_results_org in ["", None] and any([push_to_hub, push_results_to_tensorboard]): + if hub_results_org in ["", None] and any([push_to_hub, push_to_tensorboard]): raise Exception( "You need to select which org to push to, using `--results_org`, if you want to save information to the hub." ) @@ -144,7 +144,7 @@ def __init__( self.should_push_to_hub = push_to_hub self.should_save_details = save_details - self.should_push_results_to_tensorboard = push_results_to_tensorboard + self.should_push_results_to_tensorboard = push_to_tensorboard self.tensorboard_repo = f"{hub_results_org}/tensorboard_logs" self.tensorboard_metric_prefix = tensorboard_metric_prefix self.nanotron_run_info = nanotron_run_info @@ -502,6 +502,7 @@ def push_to_tensorboard( # noqa: C901 if not is_nanotron_available(): hlog_warn("You cannot push results to tensorboard without having nanotron installed. Skipping") return + prefix = self.tensorboard_metric_prefix if self.nanotron_run_info is not None: @@ -513,6 +514,7 @@ def push_to_tensorboard( # noqa: C901 output_dir_tb = Path(self.output_dir) / "tb" / run output_dir_tb.mkdir(parents=True, exist_ok=True) + tb_context = HFSummaryWriter( logdir=str(output_dir_tb), repo_id=self.tensorboard_repo, @@ -558,7 +560,7 @@ def push_to_tensorboard( # noqa: C901 for task_name, task_details in details.items(): tb_context.add_text( f"eval_details_{task_name}", - obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}), + obj_to_markdown({"0": task_details}), global_step=global_step, ) @@ -567,7 +569,7 @@ def push_to_tensorboard( # noqa: C901 # See: https://github.com/tensorflow/tensorboard/issues/5958 # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files - tb_context.close() # flushes the unfinished write operations + # tb_context.close() # flushes the unfinished write operations time.sleep(5) files = os.listdir(output_dir_tb) for file in files: @@ -577,5 +579,5 @@ def push_to_tensorboard( # noqa: C901 tb_context.scheduler.trigger() hlog( f"Pushed to tensorboard at https://huggingface.co/{self.tensorboard_repo}/{output_dir_tb}/tensorboard" - f"at global_step {global_step}" + f" at global_step {global_step}" ) diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 79346e7c..51bafd24 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -50,9 +50,10 @@ def main(args): output_dir=args.output_dir, save_details=args.save_details, push_to_hub=args.push_to_hub, - push_results_to_tensorboard=args.push_results_to_tensorboard, + push_to_tensorboard=args.push_to_tensorboard, public=args.public_run, token=TOKEN, + hub_results_org=args.results_org, ) pipeline_params = PipelineParameters( launcher_type=ParallelismManager.ACCELERATE, diff --git a/src/lighteval/parsers.py b/src/lighteval/parsers.py index 43a3c7c1..aecaf7c6 100644 --- a/src/lighteval/parsers.py +++ b/src/lighteval/parsers.py @@ -62,7 +62,7 @@ def parser_accelerate(parser=None): ) parser.add_argument("--save_details", action="store_true", help="Save the details of the run in the output_dir") parser.add_argument("--push_to_hub", default=False, action="store_true", help="Set to push the details to the hub") - parser.add_argument("--push_results_to_tensorboard", default=False, action="store_true") + parser.add_argument("--push_to_tensorboard", default=False, action="store_true") parser.add_argument( "--public_run", default=False, action="store_true", help="Push results and details to a public repo" ) From 3d26bbc4cef471d9652e51fd5f5bed13c1f5b0b9 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Tue, 3 Sep 2024 13:06:54 +0000 Subject: [PATCH 06/15] fix tests --- tests/logging/test_evaluation_tracker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py index 665ced35..c4f599e0 100644 --- a/tests/logging/test_evaluation_tracker.py +++ b/tests/logging/test_evaluation_tracker.py @@ -40,7 +40,7 @@ def mock_evaluation_tracker(): output_dir=temp_dir, save_details=False, push_to_hub=False, - push_results_to_tensorboard=False, + push_to_tensorboard=False, ) tracker.general_config_logger.model_name = "test_model" yield tracker From ff84ae4910faa4583314789e96ba9e699d3285ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Tue, 3 Sep 2024 16:25:29 +0200 Subject: [PATCH 07/15] improve fixtures --- tests/fixtures.py | 7 +++--- tests/logging/test_evaluation_tracker.py | 29 ++++++++++++++---------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index a85279fb..632fb54d 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -25,10 +25,11 @@ from huggingface_hub.hf_api import DatasetInfo -@pytest.fixture -def testing_empty_hf_org_id(): - org_id = "lighteval-testing" +TESTING_EMPTY_HF_ORG_ID = "lighteval-tests" + +@pytest.fixture +def testing_empty_hf_org_id(org_id: str = TESTING_EMPTY_HF_ORG_ID): def list_repos(org_id: str): return list(hf_api.list_models(author=org_id)) + list(hf_api.list_datasets(author=org_id)) diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py index c4f599e0..0b693d84 100644 --- a/tests/logging/test_evaluation_tracker.py +++ b/tests/logging/test_evaluation_tracker.py @@ -31,17 +31,24 @@ from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.logging.info_loggers import DetailsLogger +from tests.fixtures import TESTING_EMPTY_HF_ORG_ID @pytest.fixture -def mock_evaluation_tracker(): +def mock_evaluation_tracker(request): + passed_params = {} + if request.keywords.get("evaluation_tracker"): + passed_params = request.keywords["evaluation_tracker"].kwargs + with tempfile.TemporaryDirectory() as temp_dir: - tracker = EvaluationTracker( - output_dir=temp_dir, - save_details=False, - push_to_hub=False, - push_to_tensorboard=False, - ) + kwargs = { + "output_dir": temp_dir, + "save_details": passed_params.get("save_details", False), + "push_to_hub": passed_params.get("push_to_hub", False), + "push_to_tensorboard": passed_params.get("push_to_tensorboard", False), + "hub_results_org": passed_params.get("hub_results_org", ""), + } + tracker = EvaluationTracker(**kwargs) tracker.general_config_logger.model_name = "test_model" yield tracker @@ -86,8 +93,8 @@ def test_results_logging(mock_evaluation_tracker: EvaluationTracker): assert saved_results["config_general"]["model_name"] == "test_model" +@pytest.mark.evaluation_tracker(save_details=True) def test_details_logging(mock_evaluation_tracker, mock_datetime): - mock_evaluation_tracker.should_save_details = True task_details = { "task1": [DetailsLogger.CompiledDetail(truncated=10, padded=5)], "task2": [DetailsLogger.CompiledDetail(truncated=20, padded=10)], @@ -108,18 +115,16 @@ def test_details_logging(mock_evaluation_tracker, mock_datetime): assert int(dataset[0]["padded"]) == task_details[task][0].padded +@pytest.mark.evaluation_tracker(save_details=False) def test_no_details_output(mock_evaluation_tracker: EvaluationTracker): - mock_evaluation_tracker.should_save_details = False mock_evaluation_tracker.save() details_dir = Path(mock_evaluation_tracker.output_dir) / "details" / "test_model" assert not details_dir.exists() +@pytest.mark.evaluation_tracker(push_to_hub=True, hub_results_org=TESTING_EMPTY_HF_ORG_ID) def test_push_to_hub_works(testing_empty_hf_org_id, mock_evaluation_tracker: EvaluationTracker, mock_datetime): - mock_evaluation_tracker.should_push_to_hub = True - mock_evaluation_tracker.hub_results_org = testing_empty_hf_org_id - # Prepare the dummy data task_metrics = { "task1": {"accuracy": 0.8, "f1": 0.75}, From db73c2c688da9c0283d10f875fc40f4329468e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Wed, 4 Sep 2024 00:03:31 +0200 Subject: [PATCH 08/15] =?UTF-8?q?=F0=9F=90=9B=20fix=20tets?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index cff23167..221426a9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -122,7 +122,7 @@ def fake_evaluate_task( task_name = f"{task.suite[0]}|{task.name}" task_dict = {task_name: task} - evaluation_tracker = EvaluationTracker() + evaluation_tracker = EvaluationTracker(output_dir="") evaluation_tracker.task_config_logger.log(task_dict) # Create a mock Registry class From 72630347a1db69fe3882a7ad15c11cde4293fc69 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Wed, 4 Sep 2024 10:37:09 +0000 Subject: [PATCH 09/15] fix tests by adding the HF TOKEN and importing fixture --- tests/fixtures.py | 6 +++++- tests/logging/test_evaluation_tracker.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index a85279fb..8ebe10dc 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import os + import pytest from huggingface_hub import HfApi from huggingface_hub.hf_api import DatasetInfo @@ -27,7 +29,9 @@ @pytest.fixture def testing_empty_hf_org_id(): - org_id = "lighteval-testing" + org_id = "lighteval-tests" + token = os.getenv("HF_TEST_TOKEN") + os.environ["HF_TOKEN"] = token def list_repos(org_id: str): return list(hf_api.list_models(author=org_id)) + list(hf_api.list_datasets(author=org_id)) diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py index c4f599e0..0862352c 100644 --- a/tests/logging/test_evaluation_tracker.py +++ b/tests/logging/test_evaluation_tracker.py @@ -132,7 +132,6 @@ def test_push_to_hub_works(testing_empty_hf_org_id, mock_evaluation_tracker: Eva "task2": [DetailsLogger.CompiledDetail(truncated=20, padded=10)], } mock_evaluation_tracker.details_logger.details = task_details - mock_evaluation_tracker.save() # Verify using HfApi From 476116baca85623b3a17e8fdbe9dc6150a1b7d9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Wed, 4 Sep 2024 13:59:17 +0200 Subject: [PATCH 10/15] add import so that fixture is imported + hf token handilng --- src/lighteval/config/lighteval_config.py | 5 +-- src/lighteval/logging/evaluation_tracker.py | 31 ++++++------- src/lighteval/main_accelerate.py | 2 +- src/lighteval/main_nanotron.py | 5 ++- src/lighteval/utils/io.py | 50 --------------------- tests/logging/test_evaluation_tracker.py | 13 ++++-- 6 files changed, 30 insertions(+), 76 deletions(-) delete mode 100644 src/lighteval/utils/io.py diff --git a/src/lighteval/config/lighteval_config.py b/src/lighteval/config/lighteval_config.py index 3b8a3332..5c10a9ff 100644 --- a/src/lighteval/config/lighteval_config.py +++ b/src/lighteval/config/lighteval_config.py @@ -58,9 +58,8 @@ class LightEvalLoggingArgs: output_dir: str save_details: bool = True - push_results_to_hub: bool = False - push_details_to_hub: bool = False - push_results_to_tensorboard: bool = False + push_to_hub: bool = False + push_to_tensorboard: bool = False public_run: bool = False results_org: str | None = None tensorboard_metric_prefix: str = "eval" diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 2329c42e..9d9f6d1b 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -50,7 +50,7 @@ if is_nanotron_available(): - from nanotron.config import GeneralArgs + from nanotron.config import GeneralArgs # type: ignore class EnhancedJSONEncoder(json.JSONEncoder): @@ -62,7 +62,7 @@ class EnhancedJSONEncoder(json.JSONEncoder): def default(self, o): if is_dataclass(o): try: - return asdict(o) + return asdict(o) # type: ignore except Exception: return str(o) if callable(o): @@ -87,23 +87,16 @@ class EvaluationTracker: requested. """ - details_logger: DetailsLogger - metrics_logger: MetricsLogger - versions_logger: VersionsLogger - general_config_logger: GeneralConfigLogger - task_config_logger: TaskConfigLogger - hub_results_org: str - def __init__( self, output_dir: str, save_details: bool = True, push_to_hub: bool = False, push_to_tensorboard: bool = False, - hub_results_org: str = "", + hub_results_org: str | None = "", tensorboard_metric_prefix: str = "eval", public: bool = False, - token: str | None = None, + hf_token: str | None = None, nanotron_run_info: "GeneralArgs" = None, ) -> None: """ @@ -131,9 +124,9 @@ def __init__( self.general_config_logger = GeneralConfigLogger() self.task_config_logger = TaskConfigLogger() - self.api = HfApi(token=token) - - self.fs, self.output_dir = url_to_fs(output_dir) + self.api = HfApi(token=hf_token) + self.fs, self.output_dir = url_to_fs(output_dir, token=hf_token) + self.hf_token = hf_token self.hub_results_org = hub_results_org # will also contain tensorboard results if hub_results_org in ["", None] and any([push_to_hub, push_to_tensorboard]): @@ -276,11 +269,13 @@ def push_to_hub( results_dataset = Dataset.from_dict( {key: [json.dumps(v, cls=EnhancedJSONEncoder, indent=2)] for key, v in results_dict.items()} ) - results_dataset.to_parquet(f"{fsspec_repo_uri}/{result_file_base_name}.parquet") + results_dataset.to_parquet( + f"{fsspec_repo_uri}/{result_file_base_name}.parquet", storage_options={"token": self.hf_token} + ) for task_name, dataset in details.items(): output_file_details = Path(date_id) / f"details_{task_name}_{date_id}.parquet" - dataset.to_parquet(f"{fsspec_repo_uri}/{output_file_details}") + dataset.to_parquet(f"{fsspec_repo_uri}/{output_file_details}", storage_options={"token": self.hf_token}) self.recreate_metadata_card(repo_id) @@ -445,7 +440,7 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 # Get the top results last_results_file = [f for f in results_files if max_last_eval_date_results.replace(":", "-") in f][0] last_results_file_path = hf_hub_url(repo_id=repo_id, filename=last_results_file, repo_type="dataset") - f = load_dataset("json", data_files=last_results_file_path, split="train") + f: Dataset = load_dataset("json", data_files=last_results_file_path, split="train", token=self.hf_token) # type: ignore results_dict = f["results"][0] new_dictionary = {"all": results_dict} new_dictionary.update(results_dict) @@ -490,7 +485,7 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 card_data, pretty_name=card_data.pretty_name, ) - card.push_to_hub(repo_id, repo_type="dataset") + card.push_to_hub(repo_id, repo_type="dataset", token=self.hf_token) def push_to_tensorboard( # noqa: C901 self, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail] diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 51bafd24..74702575 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -52,7 +52,7 @@ def main(args): push_to_hub=args.push_to_hub, push_to_tensorboard=args.push_to_tensorboard, public=args.public_run, - token=TOKEN, + hf_token=TOKEN, hub_results_org=args.results_org, ) pipeline_params = PipelineParameters( diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 80f2d05f..a00ef884 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -66,10 +66,13 @@ def main( lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore nanotron_config = FullNanotronConfig(lighteval_config, model_config) - # TODO: Once lighteval config is owned by Ligteval fix this evaluation_tracker = EvaluationTracker( output_dir=lighteval_config.logging.output_dir, hub_results_org=lighteval_config.logging.results_org, + public=lighteval_config.logging.public_run, + push_to_hub=lighteval_config.logging.push_to_hub, + push_to_tensorboard=lighteval_config.logging.push_to_tensorboard, + save_details=lighteval_config.logging.save_details, tensorboard_metric_prefix=lighteval_config.logging.tensorboard_metric_prefix, nanotron_run_info=nanotron_config.nanotron_config.general, ) diff --git a/src/lighteval/utils/io.py b/src/lighteval/utils/io.py deleted file mode 100644 index 662f077f..00000000 --- a/src/lighteval/utils/io.py +++ /dev/null @@ -1,50 +0,0 @@ -# MIT License - -# Copyright (c) 2024 The HuggingFace Team - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import os -from dataclasses import dataclass - -from fsspec import AbstractFileSystem, url_to_fs -from huggingface_hub import HfFileSystem - - -@dataclass(frozen=True) -class FsspecDataResource: - fs: AbstractFileSystem - path: str - - @classmethod - def from_uri(cls, uri: str) -> "FsspecDataResource": - fs, path = url_to_fs(uri) - return cls(fs=fs, path=path) - - def __truediv__(self, other: str) -> "FsspecDataResource": - return FsspecDataResource(fs=self.fs, path=os.path.join(self.path, other)) - - def __str__(self) -> str: - return self.path - - -def get_hf_repo_id(resource: FsspecDataResource) -> str: - if isinstance(resource.fs, HfFileSystem): - return "/".join(resource.path.split("/")[:2]) - raise ValueError("Resource is not a Hugging Face Hub repository") diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py index 0b693d84..90638afc 100644 --- a/tests/logging/test_evaluation_tracker.py +++ b/tests/logging/test_evaluation_tracker.py @@ -21,6 +21,7 @@ # SOFTWARE. import json +import os import tempfile from datetime import datetime from pathlib import Path @@ -31,7 +32,12 @@ from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.logging.info_loggers import DetailsLogger -from tests.fixtures import TESTING_EMPTY_HF_ORG_ID + +# ruff: noqa +from tests.fixtures import TESTING_EMPTY_HF_ORG_ID, testing_empty_hf_org_id + + +HF_TEST_TOKEN = os.getenv("HF_TEST_TOKEN") @pytest.fixture @@ -47,6 +53,7 @@ def mock_evaluation_tracker(request): "push_to_hub": passed_params.get("push_to_hub", False), "push_to_tensorboard": passed_params.get("push_to_tensorboard", False), "hub_results_org": passed_params.get("hub_results_org", ""), + "hf_token": passed_params.get("hf_token", None), } tracker = EvaluationTracker(**kwargs) tracker.general_config_logger.model_name = "test_model" @@ -123,7 +130,7 @@ def test_no_details_output(mock_evaluation_tracker: EvaluationTracker): assert not details_dir.exists() -@pytest.mark.evaluation_tracker(push_to_hub=True, hub_results_org=TESTING_EMPTY_HF_ORG_ID) +@pytest.mark.evaluation_tracker(push_to_hub=True, hub_results_org=TESTING_EMPTY_HF_ORG_ID, hf_token=HF_TEST_TOKEN) def test_push_to_hub_works(testing_empty_hf_org_id, mock_evaluation_tracker: EvaluationTracker, mock_datetime): # Prepare the dummy data task_metrics = { @@ -141,7 +148,7 @@ def test_push_to_hub_works(testing_empty_hf_org_id, mock_evaluation_tracker: Eva mock_evaluation_tracker.save() # Verify using HfApi - api = HfApi() + api = HfApi(token=HF_TEST_TOKEN) # Check if repo exists and it's private expected_repo_id = f"{testing_empty_hf_org_id}/details_test_model_private" From 2f3b7ba1795f716a8341e49839b9532444615603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Wed, 4 Sep 2024 14:02:13 +0200 Subject: [PATCH 11/15] remove the overwriting hf test token --- tests/fixtures.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index 8a569d95..70a37c36 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -20,7 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os import pytest from huggingface_hub import HfApi @@ -32,9 +31,6 @@ @pytest.fixture def testing_empty_hf_org_id(org_id: str = TESTING_EMPTY_HF_ORG_ID): - token = os.getenv("HF_TEST_TOKEN") - os.environ["HF_TOKEN"] = token - def list_repos(org_id: str): return list(hf_api.list_models(author=org_id)) + list(hf_api.list_datasets(author=org_id)) From 3867c3feac52e7d387948bc7f7b6ede446da2aee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Wed, 4 Sep 2024 14:30:51 +0200 Subject: [PATCH 12/15] remove token from evaluation tracker --- src/lighteval/logging/evaluation_tracker.py | 18 ++++++------------ src/lighteval/main_accelerate.py | 1 - tests/fixtures.py | 6 ++++++ tests/logging/test_evaluation_tracker.py | 8 ++------ 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 9d9f6d1b..444ca10b 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -96,7 +96,6 @@ def __init__( hub_results_org: str | None = "", tensorboard_metric_prefix: str = "eval", public: bool = False, - hf_token: str | None = None, nanotron_run_info: "GeneralArgs" = None, ) -> None: """ @@ -114,8 +113,6 @@ def __init__( [`EvaluationTracker.save`] tensorboard_metric_prefix (str): Prefix for the metrics in the tensorboard logs public (bool): If True, results and details are pushed in private orgs - token (str | None): Token to use when pushing to the hub. This token should - have write access to `hub_results_org`. nanotron_run_info (GeneralArgs): Reference to informations about Nanotron models runs """ self.details_logger = DetailsLogger() @@ -124,9 +121,8 @@ def __init__( self.general_config_logger = GeneralConfigLogger() self.task_config_logger = TaskConfigLogger() - self.api = HfApi(token=hf_token) - self.fs, self.output_dir = url_to_fs(output_dir, token=hf_token) - self.hf_token = hf_token + self.api = HfApi() + self.fs, self.output_dir = url_to_fs(output_dir) self.hub_results_org = hub_results_org # will also contain tensorboard results if hub_results_org in ["", None] and any([push_to_hub, push_to_tensorboard]): @@ -269,13 +265,11 @@ def push_to_hub( results_dataset = Dataset.from_dict( {key: [json.dumps(v, cls=EnhancedJSONEncoder, indent=2)] for key, v in results_dict.items()} ) - results_dataset.to_parquet( - f"{fsspec_repo_uri}/{result_file_base_name}.parquet", storage_options={"token": self.hf_token} - ) + results_dataset.to_parquet(f"{fsspec_repo_uri}/{result_file_base_name}.parquet") for task_name, dataset in details.items(): output_file_details = Path(date_id) / f"details_{task_name}_{date_id}.parquet" - dataset.to_parquet(f"{fsspec_repo_uri}/{output_file_details}", storage_options={"token": self.hf_token}) + dataset.to_parquet(f"{fsspec_repo_uri}/{output_file_details}") self.recreate_metadata_card(repo_id) @@ -440,7 +434,7 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 # Get the top results last_results_file = [f for f in results_files if max_last_eval_date_results.replace(":", "-") in f][0] last_results_file_path = hf_hub_url(repo_id=repo_id, filename=last_results_file, repo_type="dataset") - f: Dataset = load_dataset("json", data_files=last_results_file_path, split="train", token=self.hf_token) # type: ignore + f: Dataset = load_dataset("json", data_files=last_results_file_path, split="train") # type: ignore results_dict = f["results"][0] new_dictionary = {"all": results_dict} new_dictionary.update(results_dict) @@ -485,7 +479,7 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 card_data, pretty_name=card_data.pretty_name, ) - card.push_to_hub(repo_id, repo_type="dataset", token=self.hf_token) + card.push_to_hub(repo_id, repo_type="dataset") def push_to_tensorboard( # noqa: C901 self, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail] diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 74702575..95465079 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -52,7 +52,6 @@ def main(args): push_to_hub=args.push_to_hub, push_to_tensorboard=args.push_to_tensorboard, public=args.public_run, - hf_token=TOKEN, hub_results_org=args.results_org, ) pipeline_params = PipelineParameters( diff --git a/tests/fixtures.py b/tests/fixtures.py index 70a37c36..d78255f9 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -21,6 +21,8 @@ # SOFTWARE. +import os + import pytest from huggingface_hub import HfApi from huggingface_hub.hf_api import DatasetInfo @@ -31,6 +33,9 @@ @pytest.fixture def testing_empty_hf_org_id(org_id: str = TESTING_EMPTY_HF_ORG_ID): + old_token = os.getenv("HF_TOKEN") + os.environ["HF_TOKEN"] = os.getenv("HF_TEST_TOKEN") + def list_repos(org_id: str): return list(hf_api.list_models(author=org_id)) + list(hf_api.list_datasets(author=org_id)) @@ -52,3 +57,4 @@ def clean_repos(org_id: str): # Clean up: recreate any necessary default repositories after the test # This step is optional and depends on your specific needs clean_repos(org_id) + os.environ["HF_TOKEN"] = old_token diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py index a1b7184d..1712616b 100644 --- a/tests/logging/test_evaluation_tracker.py +++ b/tests/logging/test_evaluation_tracker.py @@ -37,9 +37,6 @@ from tests.fixtures import TESTING_EMPTY_HF_ORG_ID, testing_empty_hf_org_id -HF_TEST_TOKEN = os.getenv("HF_TEST_TOKEN") - - @pytest.fixture def mock_evaluation_tracker(request): passed_params = {} @@ -53,7 +50,6 @@ def mock_evaluation_tracker(request): "push_to_hub": passed_params.get("push_to_hub", False), "push_to_tensorboard": passed_params.get("push_to_tensorboard", False), "hub_results_org": passed_params.get("hub_results_org", ""), - "hf_token": passed_params.get("hf_token", None), } tracker = EvaluationTracker(**kwargs) tracker.general_config_logger.model_name = "test_model" @@ -130,7 +126,7 @@ def test_no_details_output(mock_evaluation_tracker: EvaluationTracker): assert not details_dir.exists() -@pytest.mark.evaluation_tracker(push_to_hub=True, hub_results_org=TESTING_EMPTY_HF_ORG_ID, hf_token=HF_TEST_TOKEN) +@pytest.mark.evaluation_tracker(push_to_hub=True, hub_results_org=TESTING_EMPTY_HF_ORG_ID) def test_push_to_hub_works(testing_empty_hf_org_id, mock_evaluation_tracker: EvaluationTracker, mock_datetime): # Prepare the dummy data task_metrics = { @@ -147,7 +143,7 @@ def test_push_to_hub_works(testing_empty_hf_org_id, mock_evaluation_tracker: Eva mock_evaluation_tracker.save() # Verify using HfApi - api = HfApi(token=HF_TEST_TOKEN) + api = HfApi() # Check if repo exists and it's private expected_repo_id = f"{testing_empty_hf_org_id}/details_test_model_private" From ef40c6cc87fe667e3c7696953e03c04672a07dd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Wed, 4 Sep 2024 15:15:44 +0200 Subject: [PATCH 13/15] expose secret to tests --- .github/workflows/tests.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 233c4672..68909a19 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -34,6 +34,8 @@ jobs: path: "cache" key: test-cache-HF - name: Test + env: + HF_TEST_SECRET: ${{ secrets.HF_TEST_SECRET }} run: | # PYTHONPATH="${PYTHONPATH}:src" HF_DATASETS_CACHE="cache/datasets" HF_HOME="cache/models" python -m pytest --disable-pytest-warnings - name: Write cache From 1e14b054447b9d043fdacdef588910070b4ffdf4 Mon Sep 17 00:00:00 2001 From: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:02:22 +0200 Subject: [PATCH 14/15] Update .github/workflows/tests.yaml --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 68909a19..f633e318 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -35,7 +35,7 @@ jobs: key: test-cache-HF - name: Test env: - HF_TEST_SECRET: ${{ secrets.HF_TEST_SECRET }} + HF_TEST_TOKEN: ${{ secrets.HF_TEST_TOKEN }} run: | # PYTHONPATH="${PYTHONPATH}:src" HF_DATASETS_CACHE="cache/datasets" HF_HOME="cache/models" python -m pytest --disable-pytest-warnings - name: Write cache From b6678f9631d059e446087a04272cf3ab0cd8ebf4 Mon Sep 17 00:00:00 2001 From: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:44:51 +0200 Subject: [PATCH 15/15] Update tests/fixtures.py --- tests/fixtures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index d78255f9..ac1b97fb 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -57,4 +57,4 @@ def clean_repos(org_id: str): # Clean up: recreate any necessary default repositories after the test # This step is optional and depends on your specific needs clean_repos(org_id) - os.environ["HF_TOKEN"] = old_token + os.environ["HF_TOKEN"] = old_token if old_token is not None else ""