diff --git a/README.md b/README.md index 76b52634b..05f295639 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@ -# LightEval [kaːbiː] +# LightEval 🌤️ ## Context LightEval is an evaluation suite which gathers a selection of features from widely used benchmarks recently proposed: - from the [Eleuther AI Harness](https://github.com/EleutherAI/lm-evaluation-harness), we use the nice request management -- from [HELM](https://crfm.stanford.edu/helm/latest/), we keep the qualitative metrics -- from our previous internal evaluation suite, we keep the easy evaluation loading. +- from [HELM](https://crfm.stanford.edu/helm/latest/), we keep the qualitative and rich metrics +- from our previous internal evaluation suite, we keep the easy edition, evaluation loading and speed. -We also ported all the evaluations from HELM and BigBench. +It is still an early, internal version - it should be nice to use but don't expect 100% stability! -## How to install and use -At the moment, the core of our code relies on the evaluation harness as a dependency. This is likely to change from v0 to v1. +In case of problems or question, feel free to open an issue! +## How to install and use ### Requirements 0) Create your virtual environment using virtualenv or conda depending on your preferences. We require Python3.10 @@ -32,16 +32,16 @@ Optional: - Using data parallelism on several GPUs - If you want to use data parallelism, first configure accelerate (`accelerate config`). - `accelerate launch main.py --model_args="pretrained=" ` - for instance: `accelerate launch --multi_gpu --num_processes 8 main.py --model_args="pretrained=EleutherAI/gpt-j-6b,dtype=float16,model_parallel=True" --tasks "helm|hellaswag,harness|hellaswag" --override_batch_size 8 --num_fewshot 10 --output_dir output_dir` + for instance: `accelerate launch --multi_gpu --num_processes 8 main.py --model_args="pretrained=EleutherAI/gpt-j-6b,dtype=float16,model_parallel=True" --tasks "helm|hellaswag,lighteval|hellaswag" --override_batch_size 8 --num_fewshot 10 --output_dir output_dir` - Note: if you use model_parallel, accelerate will use 2 processes for model parallel, num_processes for data parallel The task parameters indicate which tasks you want to launch. You can select: -- one or several tasks, with `--tasks task_names`, with task_names in the [metadata table](metadata_table.json), separated by commas. You must specify which version of the task you want (= in which suite it is), by prepending the suite name (`suite|task`). You can also add the number of training few_shots prompts for the given task (`suite|task|few_shot`), and whether you want truncation for your task (`suite|task|few_shot|1 or 0 to indicate if you want few_shot truncation or not`). +- one or several tasks, with `--tasks task_names`, with task_names in the [metadata table](metadata_table.json), separated by commas. You must specify which version of the task you want (= in which suite it is), by prepending the suite name, as well as the number of training few_shots prompts for the given task, and whether you want to automatically reduce the number of few_shots if they make the prompt too long (`suite|task|few_shot|1 or 0 to automatically reduce the number of few_shots or not`). - a file path, which contains tasks following the above format. Example If you want to compare hellaswag from helm and the harness on Gpt-6j, you can do -`python run_eval.py --model hf_causal --model_args="pretrained=EleutherAI/gpt-j-6b" --tasks helm|hellaswag,harness|hellaswag` +`python run_eval.py --model hf_causal --model_args="pretrained=EleutherAI/gpt-j-6b" --tasks helm|hellaswag,lighteval|hellaswag` Other cool parameters: - `--save_queries` will print the prompts, generations and golds. @@ -54,33 +54,43 @@ Other cool parameters: ## Adding a new task To add a new task, first **add its dataset** on the hub. -Then, **find a suitable prompt function** or **create a new prompt function** in `src/prompt_formatting.py`. This function must output a dict, which should contain `query`, your prompt, and either `gold`, the gold output, or `choices` and `gold_index`, the list of choices and index or indices of correct answers. If your query contains an instruction which should not be repeated in a few shot setup, add it to an `instruction` field. +Then, **find a suitable prompt function** or **create a new prompt function** in `src/prompt_formatting.py`. This function must output a `Doc` object, which should contain `query`, your prompt, and either `gold`, the gold output, or `choices` and `gold_index`, the list of choices and index or indices of correct answers. If your query contains an instruction which should not be repeated in a few shot setup, add it to an `instruction` field. Lastly, create a **line summary** of your evaluation, in `metadata_table.json`. This summary should contain the following fields: - `name` (str), your evaluation name -- `hf_repo` (str), the path of your eval on the hub -- `hf_subset` (str), the subset you want to use (note1: when the dataset has no subset, fill this field with `"default"`, not with `None` or `""`) (note2: you cannot use a list here) +- `suite` (list), the suite(s) to which your evaluation should belong. This field allows us to compare different tasks implementation, and is used a task selection to differentiate the versions to launch. At the moment, you'll find the keywords ["helm", "bigbench", "original", "lighteval"]; you can add also add new ones (for test, we recommend using "custom"). +- `prompt_function` (str), the name of the prompt function you defined in the step above +- `hf_repo` (str), the path to your evaluation dataset on the hub +- `hf_subset` (str), the specific subset you want to use for your evaluation (note: when the dataset has no subset, fill this field with `"default"`, not with `None` or `""`) - `hf_avail_splits` (list), all the splits available for your dataset (train, valid or validation, test, other...) - `evaluation_splits` (list), the splits you want to use for evaluation +- `few_shots_split` (str, can be `null`), the specific split from which you want to select samples for your few-shot examples. It should be different from the sets included in `evaluation_splits` +- `few_shots_select` (str, can be `null`), the method that you will use to select items for your few-shot examples. Can be `null`, or one of: + - `balanced` selects examples from the `few_shots_split` with balanced labels, to avoid skewing the few shot examples (hence the model generations) towards one specific label + - `random` selects examples at random from the `few_shots_split` + - `random_sampling` selects new examples at random from the `few_shots_split` for every new item, but if a sampled item is equal to the current one, it is removed from the available samples + - `random_sampling_from_train` selects new examples at random from the `few_shots_split` for every new item, but if a sampled item is equal to the current one, it is kept! Only use this if you know what you are doing. + - `sequential` selects the first `n` examples of the `few_shots_split` - `generation_size` (int), the maximum number of tokens allowed for a generative evaluation. If your evaluation is a log likelihood evaluation (multi-choice), this value should be -1 - `stop_sequence` (list), a list of strings acting as end of sentence tokens for your generation - `metric` (list), the metrics you want to use for your evaluation (see next section for a detailed explanation) -- `suite` (list), the suites to which your evaluation should belong. At the moment, we cover ["helm", "harness", "bigbench", "original", "lighteval"], and you can add new ones (for test, we recommend using "custom"). This section is also where we'll put tags (qa, summarization, ...) and any information we might want to use to group evaluations. This field is very important if you are adding an evaluation with the same name as an already existing one, as you'll select it on the suite. -- `prompt_function` (str), the name of the prompt function you defined in the step above - `output_regex` (str), A regex string that will be used to filter your generation. (Genrative metrics will only select tokens that are between the first and the second sequence matched by the regex. For example, for a regex matching `\n` and a generation `\nModel generation output\nSome other text` the metric will only be fed with `Model generation output`) +- `frozen` (bool), for now is set to False, but we will steadily pass all stable tasks to True. ## Available metrics ### Metrics for multiple choice tasks These metrics use log-likelihood of the different possible targets. -- `loglikelihood_acc` (Harness): Fraction of instances where the choice with the best logprob was correct, -- `loglikelihood_acc_norm` (Harness): Fraction of instances where the choice with the best logprob, normalized by sequence length, was correct, -- `loglikelihood_f1` (Harness): Average F1 score of the multichoice selection, +- `loglikelihood_acc` (Harness): Fraction of instances where the choice with the best logprob was correct - also exists in a faster version for tasks where the possible choices include only one token (`loglikelihood_acc_single_token`) +- `loglikelihood_acc_norm` (Harness): Fraction of instances where the choice with the best logprob, normalized by sequence length, was correct - also exists in a faster version for tasks where the possible choices include only one token (`loglikelihood_acc_norm_single_token`) +- `loglikelihood_acc_norm_nospace` (Harness): Fraction of instances where the choice with the best logprob, normalized by sequence length, was correct, with the first space ignored +- `loglikelihood_f1` (Harness): Corpus level F1 score of the multichoice selection - also exists in a faster version for tasks where the possible choices include only one token (`loglikelihood_f1_single_token`) - `mcc` (Harness): Matthew's correlation coefficient (measure of agreement between statistical distributions), -- `recall@1` (Harness): Fraction of instances where the choice with the best logprob was correct (equivalent here to `loglikelihood_acc`), -- `recall@2` (Harness): Fraction of instances where the choice with the 2nd best logprob or better was correct, -- `mrr` (Harness): Mean reciprocal rank, measure of the quality of a ranking of choices ordered by correctness/relevance, +- `recall_at_1` (Harness): Fraction of instances where the choice with the best logprob was correct - also exists in a faster version for tasks where the possible choices include only one token per choice (`recall_at_1_single_token`) +- `recall_at_2` (Harness): Fraction of instances where the choice with the 2nd best logprob or better was correct - also exists in a faster version for tasks where the possible choices include only one token per choice (`recall_at_2_single_token`) +- `mrr` (Harness): Mean reciprocal rank, measure of the quality of a ranking of choices ordered by correctness/relevance - also exists in a faster version for tasks where the possible choices include only one token (`mrr_single_token`) - `target_perplexity` (Harness): Perplexity of the different choices available. - `acc_golds_likelihood`: (Harness): A bit different, it actually checks if the average logprob of a single target is above or below 0.5 +- `multi_f1_numeric`: Loglikelihood F1 score for multiple gold targets All these metrics also exist in a "single token" version (`loglikelihood_acc_single_token`, `loglikelihood_acc_norm_single_token`, `loglikelihood_f1_single_token`, `mcc_single_token`, `recall@2_single_token` and `mrr_single_token`). When the multichoice option compare only one token (ex: "A" vs "B" vs "C" vs "D", or "yes" vs "no"), using these metrics in the single token version will divide the time spent by the number of choices. Single token evals also include: - `multi_f1_numeric` (Harness, for CB): computes the f1 score of all possible choices and averages it. @@ -97,22 +107,21 @@ These metrics need the model to generate an output. They are therefore slower. - Base: - `perfect_exact_match` (Harness): Fraction of instances where the prediction matches the gold exactly. - `exact_match` (HELM): Fraction of instances where the prediction matches the gold at the exception of the border whitespaces (= after a `strip` has been applied to both). - - `quasi_exact_match` (HELM): Fraction of instances where the normalized prediction matches the normalized gold (normalization done on whitespace, articles, capitalization, ...) + - `quasi_exact_match` (HELM): Fraction of instances where the normalized prediction matches the normalized gold (normalization done on whitespace, articles, capitalization, ...). Other variations exist, with other normalizers, such as `quasi_exact_match_triviaqa`, which only normalizes the predictions after applying a strip to all sentences. - `prefix_exact_match` (HELM): Fraction of instances where the beginning of the prediction matches the gold at the exception of the border whitespaces (= after a `strip` has been applied to both). - `prefix_quasi_exact_match` (HELM): Fraction of instances where the normalized beginning of the prediction matches the normalized gold (normalization done on whitespace, articles, capitalization, ...) - `exact_match_indicator`: Exact match with some preceding context (before an indicator) removed - - `f1_sequence` (BigBench): Average F1 score at the sentence level. - - `f1_from_bags` (Harness): Average F1 score at the bag of word level (sentence > bag of words). - - `f1_quasi` (HELM): Average F1 score in terms of word overlap between the model output and gold, with external whitespaces removed using strip -- Reasoning: - - `iou_set_match` (HELM): Intersection over union in terms of set overlap between the model predicted set and gold set. - - `exact_set_match` (HELM): Fraction of instances that the predicted output set matches the gold set exactly. - - `f1_set_match` (HELM): Average F1 score in terms of set overlap between the model predicted set and correct reference set. + - `f1_score_quasi` (HELM): Average F1 score in terms of word overlap between the model output and gold, with both being normalized first + - `f1_score`: Average F1 score in terms of word overlap between the model output and gold without normalisation + - `f1_score_macro`: Corpus level macro F1 score + - `f1_score_macro`: Corpus level micro F1 score - Summarization: - `rouge` (Harness): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) - - `rouge_1` (HELM): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on 1-gram overlap. - - `rouge_2` (HELM): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on 2-gram overlap. - - `rouge_l` (HELM): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on longest common subsequence overlap. + - `rouge1` (HELM): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on 1-gram overlap. + - `rouge2` (HELM): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on 2-gram overlap. + - `rougeL` (HELM): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on longest common subsequence overlap. + - `rougeLsum` (HELM): Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on longest common subsequence overlap. + - `rouge_t5` (BigBench): Corpus level ROUGE score for all available ROUGE metrics - `faithfulness` (HELM): Faithfulness scores based on the SummaC method of [Laban et al. (2022)](https://aclanthology.org/2022.tacl-1.10/). - `extractiveness` (HELM): Reports, based on [(Grusky et al., 2018)](https://aclanthology.org/N18-1065/) - `summarization_coverage`: Extent to which the model-generated summaries are extractive fragments from the source document, @@ -120,9 +129,9 @@ These metrics need the model to generate an output. They are therefore slower. - `summarization_compression`: Extent to which the model-generated summaries are compressed relative to the source document. - `bert_score` (HELM): Reports the average BERTScore precision, recall, and f1 score [(Zhang et al., 2020)](https://openreview.net/pdf?id=SkeHuCVFDr) between model generation and gold summary. - Translation - - `bleu` (Harness): Average Corpus BLEU score [(Papineni et al., 2002)](https://aclanthology.org/P02-1040/) - uses the sacrebleu implementation. - - `bleu_1` (HELM): Average BLEU score [(Papineni et al., 2002)](https://aclanthology.org/P02-1040/) based on 1-gram overlap - uses the nltk implementation. - - `bleu_4` (HELM): Average BLEU score [(Papineni et al., 2002)](https://aclanthology.org/P02-1040/) based on 4-gram overlap - uses the nltk implementation. + - `bleu`: Corpus level BLEU score [(Papineni et al., 2002)](https://aclanthology.org/P02-1040/) - uses the sacrebleu implementation. + - `bleu_1` (HELM): Average sample BLEU score [(Papineni et al., 2002)](https://aclanthology.org/P02-1040/) based on 1-gram overlap - uses the nltk implementation. + - `bleu_4` (HELM): Average sample BLEU score [(Papineni et al., 2002)](https://aclanthology.org/P02-1040/) based on 4-gram overlap - uses the nltk implementation. - `chrf` (Harness): Character n-gram matches f-score. - `ter` (Harness): Translation edit/error rate. - Bias, toxicity, copyright @@ -131,27 +140,19 @@ These metrics need the model to generate an output. They are therefore slower. - `longest_common_prefix_length`: average length of longest common prefix between model generation and reference, - `edit_distance`: average Levenshtein edit distance between model generation and reference, - `edit_similarity`: average Levenshtein edit similarity (normalized by length of longer sequence) between model generation and reference. -- Math and code: - - `code_eval_HE` (HELM): Reports metrics for the HumanEval code dataset (*implies executing generated code locally!*) - - `code_eval_acc`: Fraction of instances that the model output evaluates to the correct answer. - - `pass@1`: Fraction of model outputs that pass the associated test cases. - - `pass@k`: Fraction of k model outputs that pass the associated test cases. - - `code_eval_APPS` (HELM): Reports metrics for the APPS code dataset (*implies executing generated code locally!*) - - `code_eval_test_avg`: Fraction of test cases passed. - - `code_eval_strict_acc`: Fraction of models outputs that pass all associated test cases. +- Math: - `quasi_exact_match_math` (HELM): Fraction of instances where the normalized prediction matches the normalized gold (normalization done for math, where latex symbols, units, etc are removed) + - `quasi_exact_match_gsm8k` (Harness): Fraction of instances where the normalized prediction matches the normalized gold (normalization done for gsm8k, where latex symbols, units, etc are removed) + +### Metrics for specific tasks +To keep compatibility with the Harness for some specific tasks, we ported their evaluations more or less as such. They include `drop` (for the DROP dataset) and `truthfulqa_mc_metrics` (for TruthfulQA). In general, except for tasks where the dataset has a very different formatting than usual (an other language, programming language, math, ...), we want to use standard implementations of the above metrics. It makes little sense to have 10 different versions of an exact match depending on the task. However, most of the above metrics are parametrizable so that you can change the normalization applied easily for experimental purposes. ### Not working yet These metrics need both the generation and its logprob. They are not working at the moment, as this fn is not in the AI Harness. -- `prediction_perplexity` (HELM): Measure of the logprob of a given generation. - -### Specific metrics -Metrics in the `specific` file are metrics which have been designed for one precise dataset in one evaluation suite. They are not generic and shouldn't be used outside of their specific use case. Use them as little as possible, as it's redefined metrics like these which reduce the quality and reproducibility of evaluations. +- `prediction_perplexity` (HELM): Measure of the logprob of a given input. ## Adding a new metric -If you want to add a new metric, define its function in the corresponding file in `src/metrics` (summarization for summarization metrics, code for code evaluation metrics, you get the gist), which should return a dict of `{"metric_name": score}`. You also need to add 2 mappings to the "metric_name": which aggregation method to use in `type_aggregate` (`summarization_aggregate` for a summarization metric for ex, at the end of the file), and if a higher value for your metric indicates a better score (in `type_higher_is_better`, such as `summarization_higher_is_better`). -You then need to add your metric to one of the lists in `src/metrics/__init__.py`, depending on what your metric needs (respective log likelihoods of different choices? log likelihood of prompt (for perplexity for ex)? generation? generation and log likelihood?), and lastly to edit `process_results` in `src/tasks_from_config` to indicate the mapping which exists between the function name and the score. - +If you want to add a new metric, first check if you can use one of the parametrized functions in `src.lighteval.metrics.metrics_corpus` or `metrics_sample`. If not, add it to either of these files depending on the level at which it is applied. Then, follow the example in `src.lighteval.metrics.metrics` to register your metric. ## Examples of scripts to launch lighteval on the cluster ### Evaluate a whole suite on one node, 8 GPUs diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index 44ae0cb77..df9af332a 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -13,6 +13,7 @@ from lighteval.metrics.metrics_sample import ( BLEU, BLEURT, + MRR, ROUGE, BertScore, ExactMatches, @@ -23,7 +24,6 @@ acc_golds_likelihood, extractiveness, faithfulness, - mrr, ) from lighteval.metrics.normalizations import ( bigbench_normalizer, @@ -277,7 +277,7 @@ class Metrics(Enum): ) mrr = SampleLevelMetric( metric="mrr", - sample_level_fn=mrr, + sample_level_fn=MRR().compute, category=MetricCategory.MULTICHOICE, use_case=MetricUseCase.ACCURACY, corpus_level_fn=np.mean, diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 73fcdaafa..5cdfa6d13 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -1,4 +1,4 @@ -"""This module manages all the score aggregations and computations occurring at the corpus level. +"""This module manages all the metrics occurring at the corpus level. Some metrics (such as corpus BLEU) are not computed at the individual item level, but over all the corpus. A number of these aggregations come from the EleutherAIHarness """ @@ -10,6 +10,7 @@ from lighteval.metrics.sample_preparator import ( GenerativeCorpusMetricInput, + LogprobCorpusMetricInput, PerplexityCorpusMetricInput, ) from lighteval.utils import as_list @@ -20,7 +21,7 @@ def matthews_corrcoef(items: list[GenerativeCorpusMetricInput]) -> float: """Computes the Matthews Correlation Coefficient, using scikit learn ([doc](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html)). Args: - items (list[dict]): List of the correctly formatted dictionarinput + items (list[dict]): List of GenerativeCorpusMetricInput Returns: float: Score @@ -32,13 +33,23 @@ def matthews_corrcoef(items: list[GenerativeCorpusMetricInput]) -> float: class CorpusLevelF1Score: def __init__(self, average: str, num_classes: int = 2): - # If num_classes > 2, we compute multi_f1_corpus_aggregation - self.average = average # weighted, macro, micro + """Stores the relevant parameters for the task's corpus level f1 score. + + Args: + average (str): Method to use to compute the f1 score. Can be weighted, macro, micro. + num_classes (int, optional): Num of possible choice classes. Defaults to 2. If this parameter is above 2, we'll compute multi f1 corpus score + """ + if average not in ["weighted", "macro", "micro", None]: + raise ValueError( + f"A CorpusLevelF1Score must be initialized with weighted, macro, micro, or None as an average function. {average} was used." + ) + self.average = average self.num_classes = num_classes - def compute(self, items): - golds = [i["golds"] for i in items] - preds = [i["preds"] for i in items] + def compute(self, items: list[LogprobCorpusMetricInput]): + """Computes the metric score over all the corpus generated items, by using the scikit learn implementation.""" + golds = [i.golds for i in items] + preds = [i.preds for i in items] # Single f1 if self.num_classes == 2: fscore = sklearn.metrics.f1_score(golds, preds, average=self.average) @@ -48,11 +59,16 @@ def compute(self, items): f1s = [] for i in range(self.num_classes): f1s.append(sklearn.metrics.f1_score(y_true=golds == i, y_pred=preds == i)) - return np.mean(f1s) + return float(np.mean(f1s)) class CorpusLevelTranslationMetric: def __init__(self, metric_type: str): + """Stores the relevant parameters for a corpus level translation metric. + + Args: + metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use. + """ if metric_type == "bleu": self.metric = sacrebleu.corpus_bleu elif metric_type == "chrf": @@ -63,19 +79,32 @@ def __init__(self, metric_type: str): raise ValueError(f"Unknown corpus level translation metric type : {metric_type}") def compute(self, items: list[GenerativeCorpusMetricInput]) -> float: + """Computes the metric score over all the corpus generated items, by using the sacrebleu implementation.""" golds = [i.golds for i in items] preds = [as_list(i.preds) for i in items] - return self.metric(hypotheses=preds, references=golds).score + return float(self.metric(hypotheses=preds, references=golds).score) class CorpusLevelPerplexityMetric: def __init__(self, metric_type: str): + """Stores the relevant parameter for a corpus level perplexity metric. + Perplexity metrics compute more or less the same thing, which is a variation on the + average of log-probabilities over a sequence, but the normalization and processing applied + is different depending on the metric type. + Perplexity uses an exponential and no weights for the average, weighted perplexity uses an exponential + and the number of words as weights for the log-prob average, and bits per byte uses the number of bits + for normalization and divides the results by log(2). + + Args: + metric_type (str): Can be any of `perplexity`, `weighted_perplexity` or `bits_per_byte` + """ if metric_type not in ["perplexity", "weighted_perplexity", "bits_per_byte"]: raise ValueError(f"Unknown corpus level perplexity metric type : {metric_type}") self.metric_type = metric_type def compute(self, items: list[PerplexityCorpusMetricInput]): + """Computes the metric score over all the corpus generated items.""" logprobs = [i.logprobs for i in items] weights = [i.weights for i in items] diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 9ea9b3a51..ec123741b 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -1,3 +1,6 @@ +"""This module manages all the metrics occurring at the sample level. The results of said metrics are then aggregated +using simple function (min, mean, max, ...) at the corpus level. Most metrics fall under this category. +""" import nltk import numpy as np from nltk.metrics.distance import edit_distance @@ -16,7 +19,6 @@ from lighteval.utils import as_list -# Parametrized metrics are defined as classes class ExactMatches: def __init__( self, @@ -26,6 +28,22 @@ def __init__( strip_strings: bool = False, type_exact_match: str = "full", ): + """An exact match class. + + Args: + aggregation_function (callable, optional): How to aggregate the item results. Defaults to max. + Used if there are several golds or predictions on which scores were computed. + normalize_gold (callable, optional): Function to use to normalize the reference strings. + Defaults to None if no normalization is applied. + normalize_pred (callable, optional): Function to use to normalize the predicted strings. + Defaults to None if no normalization is applied. + strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False. + type_exact_match (str, optional): Defines what type of match to apply (post normalization if present). + Can be any of `prefix`, `suffix` or `full`. Defaults to "full". + `prefix` checks if the prediction starts with the gold, + `suffix` if the prediction ends with the gold, + `full` if the prediction and gold are equal + """ if aggregation_function is None: aggregation_function = max self.aggregation_function = aggregation_function @@ -41,6 +59,15 @@ def __init__( self.type_exact_match = type_exact_match def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float: + """Computes the metric over a list of golds and predictions for one single sample. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): Predicted strings + + Returns: + float: Aggregated score over the current sample's items. + """ results = [] # We might need to flatten golds if they are a list of lists for gold in golds: @@ -53,6 +80,15 @@ def compute_one_item( gold: str, pred: str, ) -> float: + """Compares two strings only. + + Args: + gold (str): One of the possible references + pred (str): One of the possible predictions + + Returns: + float: The exact match score. Will be 1 for a match, 0 otherwise. + """ if not pred: return 0 @@ -79,8 +115,18 @@ def __init__( normalize_gold: callable = None, normalize_pred: callable = None, strip_strings: bool = False, - type_f1: str = "", ): + """An F1 score class. F1 is computed over the bag of words of the golds and predictions. + + Args: + aggregation_function (callable, optional): How to aggregate the item results. Defaults to max. + Used if there are several golds or predictions on which scores were computed. + normalize_gold (callable, optional): Function to use to normalize the reference strings. + Defaults to None if no normalization is applied. + normalize_pred (callable, optional): Function to use to normalize the predicted strings. + Defaults to None if no normalization is applied. + strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False. + """ if aggregation_function is None: aggregation_function = max self.aggregation_function = aggregation_function @@ -88,9 +134,17 @@ def __init__( self.normalize_gold = normalize_gold self.normalize_pred = normalize_pred self.strip_strings = strip_strings - self.type_f1 = type_f1 def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float: + """Computes the metric over a list of golds and predictions for one single sample. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): Predicted strings + + Returns: + float: Aggregated score over the current sample's items. + """ results = [] # We might need to flatten golds if they are a list of lists for gold in golds: @@ -99,6 +153,15 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float: return self.aggregation_function(results) def compute_one_item(self, gold: str, pred: str) -> float: + """Compares two strings only. + + Args: + gold (str): One of the possible references + pred (str): One of the possible predictions + + Returns: + float: The f1 score over the bag of words, computed using nltk. + """ if self.normalize_gold: gold = self.normalize_gold(gold) @@ -117,10 +180,32 @@ def compute_one_item(self, gold: str, pred: str) -> float: class LoglikelihoodAcc: def __init__(self, length_normalization: bool = False, ignore_first_space: bool = False) -> None: + """Log likelihood accuracy class. It tests if the highest log-probability of the possible choices + is actually in the gold ones. + + Args: + length_normalization (bool, optional): Whether log-likelihood scores should be normalized for sentence length. Defaults to False. + Should be True for most cases. + ignore_first_space (bool, optional): Whether to ignore the first token's log prob (if it's a space only). Defaults to False. + Only case when it should be True is when the possible choices (for example `A`,`B` ...) have an extra + space added in front of them to manage tokenization issues (` A`, ` B`, ...) for some models. + """ self.length_normalization = length_normalization self.ignore_first_space = ignore_first_space - def compute(self, gold_ixs: list[int], choices_logprob: list[list[float]], formatted_doc: Doc, **kwargs): + def compute(self, gold_ixs: list[int], choices_logprob: list[float], formatted_doc: Doc, **kwargs) -> int: + """Computs the log likelihood accuracy: is the choice with the highest logprob in `choices_logprob` present + in the `gold_idxs`? + + Args: + gold_ixs (list[int]): All the gold choices indices + choices_logprob (list[float]): Summed log-probabilities of all the possible choices for the model, ordered as the choices. + formatted_doc (Doc): Original document for the sample. + Used to get the original choices's length for possible normalisation + + Returns: + int: The eval score: 1 if the best log-prob choice is in gold, 0 otherwise. + """ if self.length_normalization: normalized_log_probs = [] for ix, choice in enumerate(formatted_doc.choices): @@ -139,21 +224,67 @@ def compute(self, gold_ixs: list[int], choices_logprob: list[list[float]], forma class Recall: def __init__(self, at: int) -> None: + """Recall metric class. It checks if the top `at` best choices include one of the golds or not. + + Args: + at (int): Depth level of the recall. + Recall at 1 is equivalent to a logprob accuracy without normalization. + """ self.recall_depth = at - def compute(self, choices_logprob, gold_ixs, **kwargs): - if self.at == 1: + def compute(self, choices_logprob: list[float], gold_ixs: list[int], **kwargs) -> int: + """Computes the recall at the requested depth level: looks at the `n` best predicted choices (with the + highest log probabilies) and see if there is an actual gold among them. + + Args: + gold_ixs (list[int]): All the gold choices indices + choices_logprob (list[float]): Summed log-probabilities of all the possible choices for the model, ordered as the choices. + + Returns: + int: Score: 1 if one of the top level predicted choices was correct, 0 otherwise. + """ + if self.recall_depth == 1: return int(np.argmax(choices_logprob) in gold_ixs) return (int(any(ix in gold_ixs for ix in np.array(choices_logprob).argsort()[::-1][: self.recall_depth])),) -def mrr(choices_logprob: list[float], gold_ixs: list[float], **kwargs): - ranked_choices = [sorted(choices_logprob, reverse=True).index(choices_logprob[gold]) for gold in gold_ixs] - return 1.0 / (min(ranked_choices) + 1) +class MRR: + def __init__(self, length_normalization: bool = False): + """A mean reciprocal rank class. + + Args: + length_normalization (bool, optional): Whether to use normalisation be choice length when computing the best log-probabilities. Defaults to False. + """ + self.length_normalization = length_normalization + + def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted_doc: Doc, **kwargs) -> float: + """Mean reciprocal rank. Measures the quality of a ranking of choices (ordered by correctness). + + Args: + gold_ixs (list[int]): All the gold choices indices + choices_logprob (list[float]): Summed log-probabilities of all the possible choices for the model, ordered as the choices. + formatted_doc (Doc): Original document for the sample. + Used to get the original choices's length for possible normalisation + + Returns: + float: MRR score. + """ + if self.length_normalization: + choices_logprob = [choices_logprob[ix] / len(formatted_doc.choices[ix]) for ix in len(choices_logprob)] + ranked_choices = [sorted(choices_logprob, reverse=True).index(choices_logprob[gold]) for gold in gold_ixs] + return 1.0 / (min(ranked_choices) + 1) + +def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int: + """Tests if at least one of predicted gold targets' log-likelihood is above 0.5. -def acc_golds_likelihood(results: list[int], formatted_doc: Doc, **kwargs): - results = results[: len(formatted_doc.get_golds())] # todo: check, might not be needed + Args: + results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated. + formatted_doc (Doc): _description_ + + Returns: + int: 1 if at least one of the possible golds had a log-likelihood above 0.5. + """ return max([int(acc_ppl) for _, acc_ppl in results]) @@ -169,8 +300,22 @@ def __init__( normalize_pred: callable = None, aggregation_function: callable = None, ): + """A ROUGE wrapper method. Relies on `rouge_scorer`. + + Args: + methods (str | list[str]): What type of ROUGE scoring to use. Can be one or any of `rouge1`, `rouge2`, `rougeL` or `rougeLsum`. + multiple_golds (bool, optional): Whether to compute ROUGE by allowing the comparision to several golds + at once, or to compute ROUGE on individual gold/prediction pairs and aggregate afterwards. Defaults to False. + bootstrap (bool, optional): Whether to use bootstrapping. Defaults to False. + aggregation_function (callable, optional): How to aggregate the item results. Defaults to max. + Used if there are several golds or predictions on which scores were computed. + normalize_gold (callable, optional): Function to use to normalize the reference strings. + Defaults to None if no normalization is applied. + normalize_pred (callable, optional): Function to use to normalize the predicted strings. + Defaults to None if no normalization is applied. + """ if aggregation_function and bootstrap: - hlog_warn("Can't use both bootstrapping and an aggreagation function in Rouge. Keeping bootstrap.") + hlog_warn("Can't use both bootstrapping and an aggregation function in Rouge. Keeping bootstrap.") self.aggregation_function = aggregation_function if self.aggregation_function is None: self.aggregation_function = np.mean @@ -186,7 +331,17 @@ def __init__( self.normalize_gold = normalize_gold self.normalize_pred = normalize_pred - def compute(self, golds: list[str], predictions: list[str], **kwargs): + def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float | dict: + """Computes the metric(s) over a list of golds and predictions for one single sample. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): Predicted strings + + Returns: + float or dict: Aggregated score over the current sample's items. + If several rouge functions have been selected, returns a dict which maps name and scores. + """ # Normalize if self.normalize_gold: golds = [self.normalize_gold(g) for g in golds] @@ -195,17 +350,17 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs): predictions = [self.normalize_pred(p) for p in predictions] if self.bootstrap: # For t5 style rouge score - scores = self.rouge_score_with_bootsrap(golds=golds, predictions=predictions) + scores = self._rouge_score_with_bootsrap(golds=golds, predictions=predictions) elif self.multiple_golds: - scores = self.rouge_score_multi_golds(golds=golds, preds=predictions) + scores = self._rouge_score_multi_golds(golds=golds, preds=predictions) else: - scores = self.rouge_score(golds=golds, preds=predictions) + scores = self._rouge_score(golds=golds, preds=predictions) if len(scores) == 1: return list(scores.values())[0] return scores - def rouge_score(self, golds: list[str], preds: list[str]): + def _rouge_score(self, golds: list[str], preds: list[str]): scores = {m: [] for m in self.methods} for pred in preds: for gold in golds: @@ -214,7 +369,7 @@ def rouge_score(self, golds: list[str], preds: list[str]): scores[method].append(cur_scores[method].fmeasure) return {method: self.aggregation_function(scores[method]) for method in self.methods} - def rouge_score_multi_golds(self, golds: list[str], preds: list[str]): + def _rouge_score_multi_golds(self, golds: list[str], preds: list[str]): scores = {m: [] for m in self.methods} for pred in preds: cur_scores = self.scorer.score_multi(golds, pred) @@ -222,7 +377,7 @@ def rouge_score_multi_golds(self, golds: list[str], preds: list[str]): scores[method].append(cur_scores[method].fmeasure) return {method: self.aggregation_function(scores[method]) for method in self.methods} - def rouge_score_with_bootsrap(self, golds: list[str], preds: list[str]): + def _rouge_score_with_bootsrap(self, golds: list[str], preds: list[str]): aggregator = scoring.BootstrapAggregator() for g, p in zip(golds, preds): aggregator.add_scores(self.scorer.score(g, p)) @@ -236,6 +391,15 @@ def __init__( normalize_gold: callable = None, normalize_pred: callable = None, ): + """A BERT scorer class. Relies on some called extracted from `bert-score`. By default, will use the + `microsoft/deberta-large-mnli` as scorer + + Args: + normalize_gold (callable, optional): Function to use to normalize the reference strings. + Defaults to None if no normalization is applied. + normalize_pred (callable, optional): Function to use to normalize the predicted strings. + Defaults to None if no normalization is applied. + """ self.bert_scorer = BERTScorer( model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9 ) @@ -243,7 +407,16 @@ def __init__( self.normalize_gold = normalize_gold self.normalize_pred = normalize_pred - def compute(self, golds: list[str], predictions: list[str]): + def compute(self, golds: list[str], predictions: list[str]) -> dict: + """Computes the prediction, recall and f1 score using the bert scorer. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): Predicted strings + + Returns: + dict: Scores over the current sample's items. + """ golds = as_list(golds) predictions = as_list(predictions) # Normalize @@ -257,6 +430,7 @@ def compute(self, golds: list[str], predictions: list[str]): return {"BERTScore-P": p[0].item(), "BERTScore-R": r[0].item(), "BERTScore-F": f[0].item()} +# todo: make into clean classes with call to normalizer def extractiveness(formatted_doc: Doc, predictions: list[str], **kwargs): inp = remove_braces(formatted_doc.specific["text"]) pred = remove_braces_and_strip(predictions[0]) @@ -268,6 +442,7 @@ def extractiveness(formatted_doc: Doc, predictions: list[str], **kwargs): } +# todo: make into clean classes with call to normalizer def faithfulness(formatted_doc: Doc, predictions: list[str], **kwargs): inp = remove_braces(formatted_doc.specific["text"]) pred = remove_braces_and_strip(predictions[0]) @@ -276,13 +451,24 @@ def faithfulness(formatted_doc: Doc, predictions: list[str], **kwargs): class BLEURT: - # Model chosen could also be Elron/bleurt-base-128 def __init__(self): + """Creates a BLEURT scorer using a light bleurt-tiny-512 model. + For more complex use cases, could also be Elron/bleurt-base-128 + """ self.tokenizer = AutoTokenizer.from_pretrained("Elron/bleurt-tiny-512") self.model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-tiny-512") self.model.eval() def compute(self, golds: list[str], predictions: list[str]) -> float: + """Uses the stored BLEURT scorer to compute the score on the current sample. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): Predicted strings + + Returns: + float: Score over the current sample's items. + """ if len(predictions) == 1: predictions = predictions * len(golds) scores = self.model(**self.tokenizer(golds, predictions, return_tensors="pt"))[0].squeeze() @@ -292,12 +478,36 @@ def compute(self, golds: list[str], predictions: list[str]) -> float: class BLEU: def __init__(self, n_gram: int): + """BLEU scorer class. Relies on `nltk`'s sentencebleu for scoring. + TODO: Will have to move this to sacrebleu. + + Args: + n_gram (int): Number of n_grams to use for scoring. + """ self.n_gram = n_gram def compute(self, golds: list[str], predictions: list[str], **kwargs): - return np.mean([self.bleu_score(golds, p) for p in predictions]) + """Computes the sentence level BLEU between the golds and each prediction, then takes the average. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): Predicted strings + + Returns: + float: Score over the current sample's items. + """ + return np.mean([self._bleu_score(golds, p) for p in predictions]) + + def _bleu_score(self, gold: list[str], pred: str) -> float: + """Computes the BLEU score between a list of golds and the current prediction. + + Args: + golds (list[str]): Reference targets + predictions (str): One of the predicted strings - def bleu_score(self, gold: list[str], pred: str): + Returns: + float: Score over the current prediction. + """ weights = [1 if ix == self.n_gram else 0 for ix in range(1, 5)] return sentence_bleu([word_tokenize(g) for g in gold], word_tokenize(pred), weights=weights) @@ -308,6 +518,12 @@ def __init__( metric_types: list[str] | str, strip_prediction: bool = True, ): + """Contains a number of string distance and edition metrics. Relies on nltk to compute the edit distance. + + Args: + metric_types (list[str] | str): Can be one or any of `longest_common_prefix_length`, `edit_distance` or `edit_similarity`. + strip_prediction (bool, optional): Whether to strip the prediction. Defaults to True. + """ allowed_values = ["longest_common_prefix_length", "edit_distance", "edit_similarity"] metric_types = as_list(metric_types) if any(metric_type not in allowed_values for metric_type in metric_types): @@ -318,7 +534,16 @@ def __init__( self.strip_prediction = strip_prediction self.sample_aggregations = {"longest_common_prefix_length": max, "edit_distance": min, "edit_similarity": max} - def compute(self, gold: list[str], predictions: list[str], **kwargs): + def compute(self, gold: list[str], predictions: list[str], **kwargs) -> dict: + """Computes all the requested metrics on the golds and prediction. + + Args: + gold (list[str]): A list of possible golds. If it contains more than one item, only the first one is kept. + predictions (list[str]): Predicted strings. + + Returns: + dict: The different scores computed + """ if len(gold) > 0: hlog_warn("Provided more than one gold to compute a string distance metric. Just using the first one.") reference = gold[0] @@ -357,7 +582,7 @@ def longest_common_prefix_length(self, s1: np.ndarray, s2: np.ndarray) -> float: """Compute the length of the longest common prefix.""" min_len = min(len(s1), len(s2)) s1, s2 = s1[:min_len], s2[:min_len] - (nonzeros,) = np.cumprod(s1 == s2).nonzero() # Get indices (inclusive) up to which s1 and s2 are the same. + (nonzeros,) = np.cumprod(s1 == s2).nonzero() return int(np.max(nonzeros)) + 1 if len(nonzeros) > 0 else 0 def edit_similarity(self, s1, s2): @@ -369,7 +594,4 @@ def edit_similarity(self, s1, s2): arXiv preprint arXiv:2107.06499 (2021). """ edist = edit_distance(s1, s2) - - # Some models can return an empty completion e.g., openai/text-davinci-002 - # returns '<|endoftext|>' token immediately for a certain request. return 1.0 - edist / max(len(s1), len(s2)) if len(s1) > 0 and len(s2) > 0 else 0 diff --git a/tasks_examples/all_tasks_500.txt b/tasks_examples/all_tasks_500.txt deleted file mode 100644 index 625338706..000000000 --- a/tasks_examples/all_tasks_500.txt +++ /dev/null @@ -1,330 +0,0 @@ -lighteval|anli|0|0 -lighteval|anli:r1|0|0 -lighteval|anli:r2|0|0 -lighteval|anli:r3|0|0 -original|arc:c:options|0|0 -original|arc:c:simple|0|0 -lighteval|arc:challenge|0|0 -lighteval|arc:easy|0|0 -lighteval|arithmetic:1dc|0|0 -lighteval|arithmetic:2da|0|0 -lighteval|arithmetic:2dm|0|0 -lighteval|arithmetic:2ds|0|0 -lighteval|arithmetic:3da|0|0 -lighteval|arithmetic:3ds|0|0 -lighteval|arithmetic:4da|0|0 -lighteval|arithmetic:4ds|0|0 -lighteval|arithmetic:5da|0|0 -lighteval|arithmetic:5ds|0|0 -lighteval|asdiv|0|0 -helm|babi_qa|0|0 -helm|bbq|0|0 -helm|bbq:Age|0|0 -helm|bbq:Disability_status|0|0 -helm|bbq:Gender_identity|0|0 -helm|bbq:Nationality|0|0 -helm|bbq:Physical_appearance|0|0 -helm|bbq:Race_ethnicity|0|0 -helm|bbq:Race_x_SES|0|0 -helm|bbq:Race_x_gender|0|0 -helm|bbq:Religion|0|0 -helm|bbq:SES|0|0 -helm|bbq:Sexual_orientation|0|0 -lighteval|blimp:adjunct_island|0|0 -helm|blimp:adjunct_island|0|0 -lighteval|blimp:anaphor_gender_agreement|0|0 -helm|blimp:anaphor_gender_agreement|0|0 -lighteval|blimp:anaphor_number_agreement|0|0 -helm|blimp:anaphor_number_agreement|0|0 -lighteval|blimp:animate_subject_passive|0|0 -helm|blimp:animate_subject_passive|0|0 -lighteval|blimp:animate_subject_trans|0|0 -helm|blimp:animate_subject_trans|0|0 -lighteval|blimp:causative|0|0 -helm|blimp:causative|0|0 -lighteval|blimp:complex_NP_island|0|0 -helm|blimp:complex_NP_island|0|0 -lighteval|blimp:coordinate_structure_constraint_complex_left_branch|0|0 -helm|blimp:coordinate_structure_constraint_complex_left_branch|0|0 -lighteval|blimp:coordinate_structure_constraint_object_extraction|0|0 -helm|blimp:coordinate_structure_constraint_object_extraction|0|0 -lighteval|blimp:determiner_noun_agreement_1|0|0 -helm|blimp:determiner_noun_agreement_1|0|0 -lighteval|blimp:determiner_noun_agreement_2|0|0 -helm|blimp:determiner_noun_agreement_2|0|0 -lighteval|blimp:determiner_noun_agreement_irregular_1|0|0 -helm|blimp:determiner_noun_agreement_irregular_1|0|0 -lighteval|blimp:determiner_noun_agreement_irregular_2|0|0 -helm|blimp:determiner_noun_agreement_irregular_2|0|0 -lighteval|blimp:determiner_noun_agreement_with_adj_2|0|0 -helm|blimp:determiner_noun_agreement_with_adj_2|0|0 -lighteval|blimp:determiner_noun_agreement_with_adj_irregular_1|0|0 -helm|blimp:determiner_noun_agreement_with_adj_irregular_1|0|0 -lighteval|blimp:determiner_noun_agreement_with_adj_irregular_2|0|0 -helm|blimp:determiner_noun_agreement_with_adj_irregular_2|0|0 -lighteval|blimp:determiner_noun_agreement_with_adjective_1|0|0 -helm|blimp:determiner_noun_agreement_with_adjective_1|0|0 -lighteval|blimp:distractor_agreement_relational_noun|0|0 -helm|blimp:distractor_agreement_relational_noun|0|0 -lighteval|blimp:distractor_agreement_relative_clause|0|0 -helm|blimp:distractor_agreement_relative_clause|0|0 -lighteval|blimp:drop_argument|0|0 -helm|blimp:drop_argument|0|0 -lighteval|blimp:ellipsis_n_bar_1|0|0 -helm|blimp:ellipsis_n_bar_1|0|0 -lighteval|blimp:ellipsis_n_bar_2|0|0 -helm|blimp:ellipsis_n_bar_2|0|0 -lighteval|blimp:existential_there_object_raising|0|0 -helm|blimp:existential_there_object_raising|0|0 -lighteval|blimp:existential_there_quantifiers_1|0|0 -helm|blimp:existential_there_quantifiers_1|0|0 -lighteval|blimp:existential_there_quantifiers_2|0|0 -helm|blimp:existential_there_quantifiers_2|0|0 -lighteval|blimp:existential_there_subject_raising|0|0 -helm|blimp:existential_there_subject_raising|0|0 -lighteval|blimp:expletive_it_object_raising|0|0 -helm|blimp:expletive_it_object_raising|0|0 -lighteval|blimp:inchoative|0|0 -helm|blimp:inchoative|0|0 -lighteval|blimp:intransitive|0|0 -helm|blimp:intransitive|0|0 -lighteval|blimp:irregular_past_participle_adjectives|0|0 -helm|blimp:irregular_past_participle_adjectives|0|0 -lighteval|blimp:irregular_past_participle_verbs|0|0 -helm|blimp:irregular_past_participle_verbs|0|0 -lighteval|blimp:irregular_plural_subject_verb_agreement_1|0|0 -helm|blimp:irregular_plural_subject_verb_agreement_1|0|0 -lighteval|blimp:irregular_plural_subject_verb_agreement_2|0|0 -helm|blimp:irregular_plural_subject_verb_agreement_2|0|0 -lighteval|blimp:left_branch_island_echo_question|0|0 -helm|blimp:left_branch_island_echo_question|0|0 -lighteval|blimp:left_branch_island_simple_question|0|0 -helm|blimp:left_branch_island_simple_question|0|0 -lighteval|blimp:matrix_question_npi_licensor_present|0|0 -helm|blimp:matrix_question_npi_licensor_present|0|0 -lighteval|blimp:npi_present_1|0|0 -helm|blimp:npi_present_1|0|0 -lighteval|blimp:npi_present_2|0|0 -helm|blimp:npi_present_2|0|0 -lighteval|blimp:only_npi_licensor_present|0|0 -helm|blimp:only_npi_licensor_present|0|0 -lighteval|blimp:only_npi_scope|0|0 -helm|blimp:only_npi_scope|0|0 -lighteval|blimp:passive_1|0|0 -helm|blimp:passive_1|0|0 -lighteval|blimp:passive_2|0|0 -helm|blimp:passive_2|0|0 -lighteval|blimp:principle_A_c_command|0|0 -helm|blimp:principle_A_c_command|0|0 -lighteval|blimp:principle_A_case_1|0|0 -helm|blimp:principle_A_case_1|0|0 -lighteval|blimp:principle_A_case_2|0|0 -helm|blimp:principle_A_case_2|0|0 -lighteval|blimp:principle_A_domain_1|0|0 -helm|blimp:principle_A_domain_1|0|0 -lighteval|blimp:principle_A_domain_2|0|0 -helm|blimp:principle_A_domain_2|0|0 -lighteval|blimp:principle_A_domain_3|0|0 -helm|blimp:principle_A_domain_3|0|0 -lighteval|blimp:principle_A_reconstruction|0|0 -helm|blimp:principle_A_reconstruction|0|0 -lighteval|blimp:regular_plural_subject_verb_agreement_1|0|0 -helm|blimp:regular_plural_subject_verb_agreement_1|0|0 -lighteval|blimp:regular_plural_subject_verb_agreement_2|0|0 -helm|blimp:regular_plural_subject_verb_agreement_2|0|0 -lighteval|blimp:sentential_negation_npi_licensor_present|0|0 -helm|blimp:sentential_negation_npi_licensor_present|0|0 -lighteval|blimp:sentential_negation_npi_scope|0|0 -helm|blimp:sentential_negation_npi_scope|0|0 -lighteval|blimp:sentential_subject_island|0|0 -helm|blimp:sentential_subject_island|0|0 -lighteval|blimp:superlative_quantifiers_1|0|0 -helm|blimp:superlative_quantifiers_1|0|0 -lighteval|blimp:superlative_quantifiers_2|0|0 -helm|blimp:superlative_quantifiers_2|0|0 -lighteval|blimp:tough_vs_raising_1|0|0 -helm|blimp:tough_vs_raising_1|0|0 -lighteval|blimp:tough_vs_raising_2|0|0 -helm|blimp:tough_vs_raising_2|0|0 -lighteval|blimp:transitive|0|0 -helm|blimp:transitive|0|0 -lighteval|blimp:wh_island|0|0 -helm|blimp:wh_island|0|0 -lighteval|blimp:wh_questions_object_gap|0|0 -helm|blimp:wh_questions_object_gap|0|0 -lighteval|blimp:wh_questions_subject_gap|0|0 -helm|blimp:wh_questions_subject_gap|0|0 -lighteval|blimp:wh_questions_subject_gap_long_distance|0|0 -helm|blimp:wh_questions_subject_gap_long_distance|0|0 -lighteval|blimp:wh_vs_that_no_gap|0|0 -helm|blimp:wh_vs_that_no_gap|0|0 -lighteval|blimp:wh_vs_that_no_gap_long_distance|0|0 -helm|blimp:wh_vs_that_no_gap_long_distance|0|0 -lighteval|blimp:wh_vs_that_with_gap|0|0 -helm|blimp:wh_vs_that_with_gap|0|0 -lighteval|blimp:wh_vs_that_with_gap_long_distance|0|0 -helm|blimp:wh_vs_that_with_gap_long_distance|0|0 -helm|bold|0|0 -helm|bold:gender|0|0 -helm|bold:political_ideology|0|0 -helm|bold:profession|0|0 -helm|bold:race|0|0 -helm|bold:religious_ideology|0|0 -helm|boolq|0|0 -helm|boolq:contrastset|0|0 -helm|civil_comments|0|0 -helm|civil_comments:LGBTQ|0|0 -helm|civil_comments:black|0|0 -helm|civil_comments:christian|0|0 -helm|civil_comments:female|0|0 -helm|civil_comments:male|0|0 -helm|civil_comments:muslim|0|0 -helm|civil_comments:other_religions|0|0 -helm|civil_comments:white|0|0 -helm|commonsenseqa|0|0 -helm|copyright:n_books_1000-extractions_per_book_1-prefix_length_125|0|0 -helm|copyright:n_books_1000-extractions_per_book_1-prefix_length_25|0|0 -helm|copyright:n_books_1000-extractions_per_book_1-prefix_length_5|0|0 -helm|copyright:n_books_1000-extractions_per_book_3-prefix_length_125|0|0 -helm|copyright:n_books_1000-extractions_per_book_3-prefix_length_25|0|0 -helm|copyright:n_books_1000-extractions_per_book_3-prefix_length_5|0|0 -helm|copyright:oh_the_places|0|0 -helm|copyright:pilot|0|0 -helm|copyright:popular_books-prefix_length_10|0|0 -helm|copyright:popular_books-prefix_length_125|0|0 -helm|copyright:popular_books-prefix_length_25|0|0 -helm|copyright:popular_books-prefix_length_250|0|0 -helm|copyright:popular_books-prefix_length_5|0|0 -helm|copyright:popular_books-prefix_length_50|0|0 -helm|copyright:prompt_num_line_1-min_lines_20|0|0 -helm|copyright:prompt_num_line_10-min_lines_20|0|0 -helm|copyright:prompt_num_line_5-min_lines_20|0|0 -lighteval|coqa|0|0 -lighteval|coqa_bb|0|0 -helm|covid_dialogue|0|0 -lighteval|drop|0|0 -helm|dyck_language:2|0|0 -helm|dyck_language:3|0|0 -helm|dyck_language:4|0|0 -helm|entity_data_imputation:Buy|0|0 -helm|entity_data_imputation:Restaurant|0|0 -helm|entity_matching:Abt_Buy|0|0 -helm|entity_matching:Amazon_Google|0|0 -helm|entity_matching:Beer|0|0 -helm|entity_matching:Company|0|0 -helm|entity_matching:DBLP_ACM|0|0 -helm|entity_matching:DBLP_GoogleScholar|0|0 -helm|entity_matching:Dirty_DBLP_ACM|0|0 -helm|entity_matching:Dirty_DBLP_GoogleScholar|0|0 -helm|entity_matching:Dirty_Walmart_Amazon|0|0 -helm|entity_matching:Dirty_iTunes_Amazon|0|0 -helm|entity_matching:Fodors_Zagats|0|0 -helm|entity_matching:Walmart_Amazon|0|0 -helm|entity_matching:iTunes_Amazon|0|0 -lighteval|ethics:commonsense|0|0 -lighteval|ethics:deontology|0|0 -lighteval|ethics:justice|0|0 -lighteval|ethics:utilitarianism|0|0 -lighteval|ethics:virtue|0|0 -lighteval|glue:cola|0|0 -lighteval|glue:mnli|0|0 -lighteval|glue:mnli_mismatched|0|0 -lighteval|glue:mrpc|0|0 -lighteval|glue:qnli|0|0 -lighteval|glue:qqp|0|0 -lighteval|glue:rte|0|0 -lighteval|glue:sst2|0|0 -lighteval|glue:stsb|0|0 -lighteval|glue:wnli|0|0 -lighteval|gsm8k|0|0 -lighteval|hellaswag|0|0 -helm|hellaswag|0|0 -helm|humaneval|0|0 -helm|imdb|0|0 -helm|imdb:contrastset|0|0 -helm|interactive_qa_mmlu:abstract_algebra|0|0 -helm|interactive_qa_mmlu:college_chemistry|0|0 -helm|interactive_qa_mmlu:global_facts|0|0 -helm|interactive_qa_mmlu:miscellaneous|0|0 -helm|interactive_qa_mmlu:nutrition|0|0 -helm|interactive_qa_mmlu:us_foreign_policy|0|0 -lighteval|iwslt17:ar-en|0|0 -lighteval|iwslt17:de-en|0|0 -lighteval|iwslt17:en-ar|0|0 -lighteval|iwslt17:en-de|0|0 -lighteval|iwslt17:en-fr|0|0 -lighteval|iwslt17:en-ja|0|0 -lighteval|iwslt17:en-ko|0|0 -lighteval|iwslt17:en-zh|0|0 -lighteval|iwslt17:fr-en|0|0 -lighteval|iwslt17:ja-en|0|0 -lighteval|iwslt17:ko-en|0|0 -lighteval|iwslt17:zh-en|0|0 -lighteval|lambada:standard|0|0 -lighteval|lambada:standard_cloze|0|0 -lighteval|lambada:openai|0|0 -lighteval|lambada:openai:de|0|0 -lighteval|lambada:openai:en|0|0 -lighteval|lambada:openai:es|0|0 -lighteval|lambada:openai:fr|0|0 -lighteval|lambada:openai:it|0|0 -lighteval|lambada:openai_cloze|0|0 -helm|legal_summarization:billsum|0|0 -helm|legal_summarization:eurlexsum|0|0 -helm|legal_summarization:multilexsum|0|0 -helm|lexglue:case_hold|0|0 -helm|lexglue:ecthr_a|0|0 -helm|lexglue:ecthr_b|0|0 -helm|lexglue:eurlex|0|0 -helm|lexglue:ledgar|0|0 -helm|lexglue:scotus|0|0 -helm|lexglue:unfair_tos|0|0 -helm|lextreme:brazilian_court_decisions_judgment|0|0 -helm|lextreme:brazilian_court_decisions_unanimity|0|0 -helm|lextreme:covid19_emergency_event|0|0 -helm|lextreme:german_argument_mining|0|0 -helm|lextreme:greek_legal_code_chapter|0|0 -helm|lextreme:greek_legal_code_subject|0|0 -helm|lextreme:greek_legal_code_volume|0|0 -helm|lextreme:greek_legal_ner|0|0 -helm|lextreme:legalnero|0|0 -helm|lextreme:lener_br|0|0 -helm|lextreme:mapa_coarse|0|0 -helm|lextreme:mapa_fine|0|0 -helm|lextreme:multi_eurlex_level_1|0|0 -helm|lextreme:multi_eurlex_level_2|0|0 -helm|lextreme:multi_eurlex_level_3|0|0 -helm|lextreme:online_terms_of_service_clause_topics|0|0 -helm|lextreme:online_terms_of_service_unfairness_levels|0|0 -helm|lextreme:swiss_judgment_prediction|0|0 -lighteval|logiqa|0|0 -helm|lsat_qa|0|0 -helm|lsat_qa:assignment|0|0 -helm|lsat_qa:grouping|0|0 -helm|lsat_qa:miscellaneous|0|0 -helm|lsat_qa:ordering|0|0 -lighteval|math:algebra|0|0 -lighteval|math:counting_and_probability|0|0 -lighteval|math:geometry|0|0 -lighteval|math:intermediate_algebra|0|0 -lighteval|math:number_theory|0|0 -lighteval|math:prealgebra|0|0 -lighteval|math:precalculus|0|0 -lighteval|mathqa|0|0 -helm|me_q_sum|0|0 -helm|med_dialog:healthcaremagic|0|0 -helm|med_dialog:icliniq|0|0 -helm|med_paragraph_simplification|0|0 -lighteval|mgsm:en|0|0 -lighteval|mgsm:es|0|0 -lighteval|mgsm:fr|0|0 -lighteval|mgsm:de|0|0 -lighteval|mgsm:ru|0|0 -lighteval|mgsm:zh|0|0 -lighteval|mgsm:ja|0|0 -lighteval|mgsm:th|0|0 -lighteval|mgsm:sw|0|0 -lighteval|mgsm:bn|0|0 -lighteval|mgsm:te|0|0 -helm|mmlu|0|0 -original|mmlu|0|0 diff --git a/tasks_examples/custom_evaluation_tasks.py b/tasks_examples/custom_tasks/custom_evaluation_tasks.py similarity index 100% rename from tasks_examples/custom_evaluation_tasks.py rename to tasks_examples/custom_tasks/custom_evaluation_tasks.py diff --git a/tasks_examples/custom_evaluation_utils.py b/tasks_examples/custom_tasks/custom_evaluation_utils.py similarity index 100% rename from tasks_examples/custom_evaluation_utils.py rename to tasks_examples/custom_tasks/custom_evaluation_utils.py diff --git a/tasks_examples/custom_task.py b/tasks_examples/custom_tasks/custom_task.py similarity index 100% rename from tasks_examples/custom_task.py rename to tasks_examples/custom_tasks/custom_task.py diff --git a/tasks_examples/lighteval_config_override_template.yaml b/tasks_examples/custom_tasks/lighteval_config_override_template.yaml similarity index 100% rename from tasks_examples/lighteval_config_override_template.yaml rename to tasks_examples/custom_tasks/lighteval_config_override_template.yaml diff --git a/tasks_examples/harness_mmlu_tasks.txt b/tasks_examples/harness_mmlu_tasks.txt deleted file mode 100644 index 8dde96683..000000000 --- a/tasks_examples/harness_mmlu_tasks.txt +++ /dev/null @@ -1,57 +0,0 @@ -lighteval|mmlu:abstract_algebra|5 -lighteval|mmlu:anatomy|5 -lighteval|mmlu:astronomy|5 -lighteval|mmlu:business_ethics|5 -lighteval|mmlu:clinical_knowledge|5 -lighteval|mmlu:college_biology|5 -lighteval|mmlu:college_chemistry|5 -lighteval|mmlu:college_computer_science|5 -lighteval|mmlu:college_mathematics|5 -lighteval|mmlu:college_medicine|5 -lighteval|mmlu:college_physics|5 -lighteval|mmlu:computer_security|5 -lighteval|mmlu:conceptual_physics|5 -lighteval|mmlu:econometrics|5 -lighteval|mmlu:electrical_engineering|5 -lighteval|mmlu:elementary_mathematics|5 -lighteval|mmlu:formal_logic|5 -lighteval|mmlu:global_facts|5 -lighteval|mmlu:high_school_biology|5 -lighteval|mmlu:high_school_chemistry|5 -lighteval|mmlu:high_school_computer_science|5 -lighteval|mmlu:high_school_european_history|5 -lighteval|mmlu:high_school_geography|5 -lighteval|mmlu:high_school_government_and_politics|5 -lighteval|mmlu:high_school_macroeconomics|5 -lighteval|mmlu:high_school_mathematics|5 -lighteval|mmlu:high_school_microeconomics|5 -lighteval|mmlu:high_school_physics|5 -lighteval|mmlu:high_school_psychology|5 -lighteval|mmlu:high_school_statistics|5 -lighteval|mmlu:high_school_us_history|5 -lighteval|mmlu:high_school_world_history|5 -lighteval|mmlu:human_aging|5 -lighteval|mmlu:human_sexuality|5 -lighteval|mmlu:international_law|5 -lighteval|mmlu:jurisprudence|5 -lighteval|mmlu:logical_fallacies|5 -lighteval|mmlu:machine_learning|5 -lighteval|mmlu:management|5 -lighteval|mmlu:marketing|5 -lighteval|mmlu:medical_genetics|5 -lighteval|mmlu:miscellaneous|5 -lighteval|mmlu:moral_disputes|5 -lighteval|mmlu:moral_scenarios|5 -lighteval|mmlu:nutrition|5 -lighteval|mmlu:philosophy|5 -lighteval|mmlu:prehistory|5 -lighteval|mmlu:professional_accounting|5 -lighteval|mmlu:professional_law|5 -lighteval|mmlu:professional_medicine|5 -lighteval|mmlu:professional_psychology|5 -lighteval|mmlu:public_relations|5 -lighteval|mmlu:security_studies|5 -lighteval|mmlu:sociology|5 -lighteval|mmlu:us_foreign_policy|5 -lighteval|mmlu:virology|5 -lighteval|mmlu:world_religions|5 diff --git a/tasks_examples/open_llm_leaderboard_tasks.txt b/tasks_examples/open_llm_leaderboard_tasks.txt index 16390a477..41c0ff35a 100644 --- a/tasks_examples/open_llm_leaderboard_tasks.txt +++ b/tasks_examples/open_llm_leaderboard_tasks.txt @@ -1,60 +1,60 @@ -lighteval|arc:challenge|25 -lighteval|hellaswag|10 -lighteval|truthfulqa:mc|0 -lighteval|mmlu:abstract_algebra|5 -lighteval|mmlu:anatomy|5 -lighteval|mmlu:astronomy|5 -lighteval|mmlu:business_ethics|5 -lighteval|mmlu:clinical_knowledge|5 -lighteval|mmlu:college_biology|5 -lighteval|mmlu:college_chemistry|5 -lighteval|mmlu:college_computer_science|5 -lighteval|mmlu:college_mathematics|5 -lighteval|mmlu:college_medicine|5 -lighteval|mmlu:college_physics|5 -lighteval|mmlu:computer_security|5 -lighteval|mmlu:conceptual_physics|5 -lighteval|mmlu:econometrics|5 -lighteval|mmlu:electrical_engineering|5 -lighteval|mmlu:elementary_mathematics|5 -lighteval|mmlu:formal_logic|5 -lighteval|mmlu:global_facts|5 -lighteval|mmlu:high_school_biology|5 -lighteval|mmlu:high_school_chemistry|5 -lighteval|mmlu:high_school_computer_science|5 -lighteval|mmlu:high_school_european_history|5 -lighteval|mmlu:high_school_geography|5 -lighteval|mmlu:high_school_government_and_politics|5 -lighteval|mmlu:high_school_macroeconomics|5 -lighteval|mmlu:high_school_mathematics|5 -lighteval|mmlu:high_school_microeconomics|5 -lighteval|mmlu:high_school_physics|5 -lighteval|mmlu:high_school_psychology|5 -lighteval|mmlu:high_school_statistics|5 -lighteval|mmlu:high_school_us_history|5 -lighteval|mmlu:high_school_world_history|5 -lighteval|mmlu:human_aging|5 -lighteval|mmlu:human_sexuality|5 -lighteval|mmlu:international_law|5 -lighteval|mmlu:jurisprudence|5 -lighteval|mmlu:logical_fallacies|5 -lighteval|mmlu:machine_learning|5 -lighteval|mmlu:management|5 -lighteval|mmlu:marketing|5 -lighteval|mmlu:medical_genetics|5 -lighteval|mmlu:miscellaneous|5 -lighteval|mmlu:moral_disputes|5 -lighteval|mmlu:moral_scenarios|5 -lighteval|mmlu:nutrition|5 -lighteval|mmlu:philosophy|5 -lighteval|mmlu:prehistory|5 -lighteval|mmlu:professional_accounting|5 -lighteval|mmlu:professional_law|5 -lighteval|mmlu:professional_medicine|5 -lighteval|mmlu:professional_psychology|5 -lighteval|mmlu:public_relations|5 -lighteval|mmlu:security_studies|5 -lighteval|mmlu:sociology|5 -lighteval|mmlu:us_foreign_policy|5 -lighteval|mmlu:virology|5 -lighteval|mmlu:world_religions|5 +lighteval|arc:challenge|25|0 +lighteval|hellaswag|10|0 +lighteval|truthfulqa:mc|0|0 +lighteval|mmlu:abstract_algebra|5|0 +lighteval|mmlu:anatomy|5|0 +lighteval|mmlu:astronomy|5|0 +lighteval|mmlu:business_ethics|5|0 +lighteval|mmlu:clinical_knowledge|5|0 +lighteval|mmlu:college_biology|5|0 +lighteval|mmlu:college_chemistry|5|0 +lighteval|mmlu:college_computer_science|5|0 +lighteval|mmlu:college_mathematics|5|0 +lighteval|mmlu:college_medicine|5|0 +lighteval|mmlu:college_physics|5|0 +lighteval|mmlu:computer_security|5|0 +lighteval|mmlu:conceptual_physics|5|0 +lighteval|mmlu:econometrics|5|0 +lighteval|mmlu:electrical_engineering|5|0 +lighteval|mmlu:elementary_mathematics|5|0 +lighteval|mmlu:formal_logic|5|0 +lighteval|mmlu:global_facts|5|0 +lighteval|mmlu:high_school_biology|5|0 +lighteval|mmlu:high_school_chemistry|5|0 +lighteval|mmlu:high_school_computer_science|5|0 +lighteval|mmlu:high_school_european_history|5|0 +lighteval|mmlu:high_school_geography|5|0 +lighteval|mmlu:high_school_government_and_politics|5|0 +lighteval|mmlu:high_school_macroeconomics|5|0 +lighteval|mmlu:high_school_mathematics|5|0 +lighteval|mmlu:high_school_microeconomics|5|0 +lighteval|mmlu:high_school_physics|5|0 +lighteval|mmlu:high_school_psychology|5|0 +lighteval|mmlu:high_school_statistics|5|0 +lighteval|mmlu:high_school_us_history|5|0 +lighteval|mmlu:high_school_world_history|5|0 +lighteval|mmlu:human_aging|5|0 +lighteval|mmlu:human_sexuality|5|0 +lighteval|mmlu:international_law|5|0 +lighteval|mmlu:jurisprudence|5|0 +lighteval|mmlu:logical_fallacies|5|0 +lighteval|mmlu:machine_learning|5|0 +lighteval|mmlu:management|5|0 +lighteval|mmlu:marketing|5|0 +lighteval|mmlu:medical_genetics|5|0 +lighteval|mmlu:miscellaneous|5|0 +lighteval|mmlu:moral_disputes|5|0 +lighteval|mmlu:moral_scenarios|5|0 +lighteval|mmlu:nutrition|5|0 +lighteval|mmlu:philosophy|5|0 +lighteval|mmlu:prehistory|5|0 +lighteval|mmlu:professional_accounting|5|0 +lighteval|mmlu:professional_law|5|0 +lighteval|mmlu:professional_medicine|5|0 +lighteval|mmlu:professional_psychology|5|0 +lighteval|mmlu:public_relations|5|0 +lighteval|mmlu:security_studies|5|0 +lighteval|mmlu:sociology|5|0 +lighteval|mmlu:us_foreign_policy|5|0 +lighteval|mmlu:virology|5|0 +lighteval|mmlu:world_religions|5|0 \ No newline at end of file diff --git a/tasks_examples/original_mmlu_tasks.txt b/tasks_examples/original_mmlu_tasks.txt deleted file mode 100644 index 9a0b1c3f2..000000000 --- a/tasks_examples/original_mmlu_tasks.txt +++ /dev/null @@ -1,57 +0,0 @@ -original|mmlu:abstract_algebra|5 -original|mmlu:anatomy|5 -original|mmlu:astronomy|5 -original|mmlu:business_ethics|5 -original|mmlu:clinical_knowledge|5 -original|mmlu:college_biology|5 -original|mmlu:college_chemistry|5 -original|mmlu:college_computer_science|5 -original|mmlu:college_mathematics|5 -original|mmlu:college_medicine|5 -original|mmlu:college_physics|5 -original|mmlu:computer_security|5 -original|mmlu:conceptual_physics|5 -original|mmlu:econometrics|5 -original|mmlu:electrical_engineering|5 -original|mmlu:elementary_mathematics|5 -original|mmlu:formal_logic|5 -original|mmlu:global_facts|5 -original|mmlu:high_school_biology|5 -original|mmlu:high_school_chemistry|5 -original|mmlu:high_school_computer_science|5 -original|mmlu:high_school_european_history|5 -original|mmlu:high_school_geography|5 -original|mmlu:high_school_government_and_politics|5 -original|mmlu:high_school_macroeconomics|5 -original|mmlu:high_school_mathematics|5 -original|mmlu:high_school_microeconomics|5 -original|mmlu:high_school_physics|5 -original|mmlu:high_school_psychology|5 -original|mmlu:high_school_statistics|5 -original|mmlu:high_school_us_history|5 -original|mmlu:high_school_world_history|5 -original|mmlu:human_aging|5 -original|mmlu:human_sexuality|5 -original|mmlu:international_law|5 -original|mmlu:jurisprudence|5 -original|mmlu:logical_fallacies|5 -original|mmlu:machine_learning|5 -original|mmlu:management|5 -original|mmlu:marketing|5 -original|mmlu:medical_genetics|5 -original|mmlu:miscellaneous|5 -original|mmlu:moral_disputes|5 -original|mmlu:moral_scenarios|5 -original|mmlu:nutrition|5 -original|mmlu:philosophy|5 -original|mmlu:prehistory|5 -original|mmlu:professional_accounting|5 -original|mmlu:professional_law|5 -original|mmlu:professional_medicine|5 -original|mmlu:professional_psychology|5 -original|mmlu:public_relations|5 -original|mmlu:security_studies|5 -original|mmlu:sociology|5 -original|mmlu:us_foreign_policy|5 -original|mmlu:virology|5 -original|mmlu:world_religions|5 diff --git a/tasks_examples/tasks_to_look_at.txt b/tasks_examples/tasks_to_look_at.txt deleted file mode 100644 index 341802975..000000000 --- a/tasks_examples/tasks_to_look_at.txt +++ /dev/null @@ -1,8 +0,0 @@ -lighteval|headqa:en|0|0 # problem when loading dataset, check dataset -lighteval|headqa:es|0|0 # problem when loading dataset, check dataset -bigbench -# task with both generative and logprob, possible issue... -helm|legalsupport|0|0 -helm|med_mcqa|0|0 -helm|med_qa|0|0 -original|arc:c:letters|0|0 diff --git a/tasks_examples/tested_harness_tasks.txt b/tasks_examples/tested_harness_tasks.txt deleted file mode 100644 index 7e29e6db9..000000000 --- a/tasks_examples/tested_harness_tasks.txt +++ /dev/null @@ -1,319 +0,0 @@ -lighteval|anli:r1 -lighteval|anli:r2 -lighteval|anli:r3 -lighteval|arc:challenge -lighteval|arc:easy -lighteval|arithmetic:1dc -lighteval|arithmetic:2da -lighteval|arithmetic:2dm -lighteval|arithmetic:2ds -lighteval|arithmetic:3da -lighteval|arithmetic:3ds -lighteval|arithmetic:4da -lighteval|arithmetic:4ds -lighteval|arithmetic:5da -lighteval|arithmetic:5ds -lighteval|blimp:adjunct_island -lighteval|blimp:anaphor_gender_agreement -lighteval|blimp:anaphor_number_agreement -lighteval|blimp:animate_subject_passive -lighteval|blimp:animate_subject_trans -lighteval|blimp:causative -lighteval|blimp:complex_NP_island -lighteval|blimp:coordinate_structure_constraint_complex_left_branch -lighteval|blimp:coordinate_structure_constraint_object_extraction -lighteval|blimp:determiner_noun_agreement_1 -lighteval|blimp:determiner_noun_agreement_2 -lighteval|blimp:determiner_noun_agreement_irregular_1 -lighteval|blimp:determiner_noun_agreement_irregular_2 -lighteval|blimp:determiner_noun_agreement_with_adj_2 -lighteval|blimp:determiner_noun_agreement_with_adj_irregular_1 -lighteval|blimp:determiner_noun_agreement_with_adj_irregular_2 -lighteval|blimp:determiner_noun_agreement_with_adjective_1 -lighteval|blimp:distractor_agreement_relational_noun -lighteval|blimp:distractor_agreement_relative_clause -lighteval|blimp:drop_argument -lighteval|blimp:ellipsis_n_bar_1 -lighteval|blimp:ellipsis_n_bar_2 -lighteval|blimp:existential_there_object_raising -lighteval|blimp:existential_there_quantifiers_1 -lighteval|blimp:existential_there_quantifiers_2 -lighteval|blimp:existential_there_subject_raising -lighteval|blimp:expletive_it_object_raising -lighteval|blimp:inchoative -lighteval|blimp:intransitive -lighteval|blimp:irregular_past_participle_adjectives -lighteval|blimp:irregular_past_participle_verbs -lighteval|blimp:irregular_plural_subject_verb_agreement_1 -lighteval|blimp:irregular_plural_subject_verb_agreement_2 -lighteval|blimp:left_branch_island_echo_question -lighteval|blimp:left_branch_island_simple_question -lighteval|blimp:matrix_question_npi_licensor_present -lighteval|blimp:npi_present_1 -lighteval|blimp:npi_present_2 -lighteval|blimp:only_npi_licensor_present -lighteval|blimp:only_npi_scope -lighteval|blimp:passive_1 -lighteval|blimp:passive_2 -lighteval|blimp:principle_A_c_command -lighteval|blimp:principle_A_case_1 -lighteval|blimp:principle_A_case_2 -lighteval|blimp:principle_A_domain_1 -lighteval|blimp:principle_A_domain_2 -lighteval|blimp:principle_A_domain_3 -lighteval|blimp:principle_A_reconstruction -lighteval|blimp:regular_plural_subject_verb_agreement_1 -lighteval|blimp:regular_plural_subject_verb_agreement_2 -lighteval|blimp:sentential_negation_npi_licensor_present -lighteval|blimp:sentential_negation_npi_scope -lighteval|blimp:sentential_subject_island -lighteval|blimp:superlative_quantifiers_1 -lighteval|blimp:superlative_quantifiers_2 -lighteval|blimp:tough_vs_raising_1 -lighteval|blimp:tough_vs_raising_2 -lighteval|blimp:transitive -lighteval|blimp:wh_island -lighteval|blimp:wh_questions_object_gap -lighteval|blimp:wh_questions_subject_gap -lighteval|blimp:wh_questions_subject_gap_long_distance -lighteval|blimp:wh_vs_that_no_gap -lighteval|blimp:wh_vs_that_no_gap_long_distance -lighteval|blimp:wh_vs_that_with_gap -lighteval|blimp:wh_vs_that_with_gap_long_distance -lighteval|drop -lighteval|ethics:commonsense -lighteval|ethics:deontology -lighteval|ethics:justice -lighteval|ethics:utilitarianism -lighteval|ethics:virtue -lighteval|glue:cola -lighteval|glue:mnli -lighteval|glue:mnli_mismatched -lighteval|glue:mrpc -lighteval|glue:qnli -lighteval|glue:qqp -lighteval|glue:rte -lighteval|glue:wnli -lighteval|gsm8k -lighteval|headqa:en -lighteval|headqa:es -lighteval|hellaswag -lighteval|lambada:openai -lighteval|lambada:openai_cloze -lighteval|lambada:standard -lighteval|lambada:standard_cloze -lighteval|logiqa -lighteval|math:algebra -lighteval|math:counting_and_probability -lighteval|math:geometry -lighteval|math:intermediate_algebra -lighteval|math:number_theory -lighteval|math:prealgebra -lighteval|math:precalculus -lighteval|mathqa -lighteval|mgsm:en -lighteval|mgsm:es -lighteval|mgsm:fr -lighteval|mgsm:de -lighteval|mgsm:ru -lighteval|mgsm:zh -lighteval|mgsm:ja -lighteval|mgsm:th -lighteval|mgsm:sw -lighteval|mgsm:bn -lighteval|mgsm:te -lighteval|mmlu:abstract_algebra -lighteval|mmlu:abstract_algebra -lighteval|mmlu:anatomy -lighteval|mmlu:anatomy -lighteval|mmlu:astronomy -lighteval|mmlu:astronomy -lighteval|mmlu:business_ethics -lighteval|mmlu:business_ethics -lighteval|mmlu:clinical_knowledge -lighteval|mmlu:clinical_knowledge -lighteval|mmlu:college_biology -lighteval|mmlu:college_biology -lighteval|mmlu:college_chemistry -lighteval|mmlu:college_chemistry -lighteval|mmlu:college_computer_science -lighteval|mmlu:college_computer_science -lighteval|mmlu:college_mathematics -lighteval|mmlu:college_mathematics -lighteval|mmlu:college_medicine -lighteval|mmlu:college_medicine -lighteval|mmlu:college_physics -lighteval|mmlu:college_physics -lighteval|mmlu:computer_security -lighteval|mmlu:computer_security -lighteval|mmlu:conceptual_physics -lighteval|mmlu:conceptual_physics -lighteval|mmlu:econometrics -lighteval|mmlu:econometrics -lighteval|mmlu:electrical_engineering -lighteval|mmlu:electrical_engineering -lighteval|mmlu:elementary_mathematics -lighteval|mmlu:elementary_mathematics -lighteval|mmlu:formal_logic -lighteval|mmlu:formal_logic -lighteval|mmlu:global_facts -lighteval|mmlu:global_facts -lighteval|mmlu:high_school_biology -lighteval|mmlu:high_school_biology -lighteval|mmlu:high_school_chemistry -lighteval|mmlu:high_school_chemistry -lighteval|mmlu:high_school_computer_science -lighteval|mmlu:high_school_computer_science -lighteval|mmlu:high_school_european_history -lighteval|mmlu:high_school_european_history -lighteval|mmlu:high_school_geography -lighteval|mmlu:high_school_geography -lighteval|mmlu:high_school_government_and_politics -lighteval|mmlu:high_school_government_and_politics -lighteval|mmlu:high_school_macroeconomics -lighteval|mmlu:high_school_macroeconomics -lighteval|mmlu:high_school_mathematics -lighteval|mmlu:high_school_mathematics -lighteval|mmlu:high_school_microeconomics -lighteval|mmlu:high_school_microeconomics -lighteval|mmlu:high_school_physics -lighteval|mmlu:high_school_physics -lighteval|mmlu:high_school_psychology -lighteval|mmlu:high_school_psychology -lighteval|mmlu:high_school_statistics -lighteval|mmlu:high_school_statistics -lighteval|mmlu:high_school_us_history -lighteval|mmlu:high_school_us_history -lighteval|mmlu:high_school_world_history -lighteval|mmlu:high_school_world_history -lighteval|mmlu:human_aging -lighteval|mmlu:human_aging -lighteval|mmlu:human_sexuality -lighteval|mmlu:human_sexuality -lighteval|mmlu:international_law -lighteval|mmlu:international_law -lighteval|mmlu:jurisprudence -lighteval|mmlu:jurisprudence -lighteval|mmlu:logical_fallacies -lighteval|mmlu:logical_fallacies -lighteval|mmlu:machine_learning -lighteval|mmlu:machine_learning -lighteval|mmlu:management -lighteval|mmlu:management -lighteval|mmlu:marketing -lighteval|mmlu:marketing -lighteval|mmlu:medical_genetics -lighteval|mmlu:medical_genetics -lighteval|mmlu:miscellaneous -lighteval|mmlu:miscellaneous -lighteval|mmlu:moral_disputes -lighteval|mmlu:moral_disputes -lighteval|mmlu:moral_scenarios -lighteval|mmlu:moral_scenarios -lighteval|mmlu:nutrition -lighteval|mmlu:nutrition -lighteval|mmlu:philosophy -lighteval|mmlu:philosophy -lighteval|mmlu:prehistory -lighteval|mmlu:prehistory -lighteval|mmlu:professional_accounting -lighteval|mmlu:professional_accounting -lighteval|mmlu:professional_law -lighteval|mmlu:professional_law -lighteval|mmlu:professional_medicine -lighteval|mmlu:professional_medicine -lighteval|mmlu:professional_psychology -lighteval|mmlu:professional_psychology -lighteval|mmlu:public_relations -lighteval|mmlu:public_relations -lighteval|mmlu:security_studies -lighteval|mmlu:security_studies -lighteval|mmlu:sociology -lighteval|mmlu:sociology -lighteval|mmlu:us_foreign_policy -lighteval|mmlu:us_foreign_policy -lighteval|mmlu:virology -lighteval|mmlu:virology -lighteval|mmlu:world_religions -lighteval|mmlu:world_religions -lighteval|mutual -lighteval|mutual_plus -lighteval|openbookqa -lighteval|piqa -lighteval|prost -lighteval|pubmedqa -lighteval|qa4mre:2011 -lighteval|qa4mre:2012 -lighteval|qa4mre:2013 -lighteval|race:high -lighteval|sciq -lighteval|super_glue:boolq -lighteval|super_glue:cb -lighteval|super_glue:copa -lighteval|super_glue:record -lighteval|super_glue:multirc -lighteval|super_glue:wic -lighteval|super_glue:wsc -lighteval|swag -lighteval|the_pile:arxiv -lighteval|the_pile:bookcorpus2 -lighteval|the_pile:books3 -lighteval|the_pile:dm-mathematics -lighteval|the_pile:enron -lighteval|the_pile:europarl -lighteval|the_pile:freelaw -lighteval|the_pile:github -lighteval|the_pile:gutenberg -lighteval|the_pile:hackernews -lighteval|the_pile:nih-exporter -lighteval|the_pile:opensubtitles -lighteval|the_pile:openwebtext2 -lighteval|the_pile:philpapers -lighteval|the_pile:pile-cc -lighteval|the_pile:pubmed-abstracts -lighteval|the_pile:pubmed-central -lighteval|the_pile:stackexchange -lighteval|the_pile:ubuntu-irc -lighteval|the_pile:uspto -lighteval|the_pile:wikipedia -lighteval|the_pile:youtubesubtitles -lighteval|toxigen -lighteval|triviaqa -lighteval|truthfulqa:mc -lighteval|unscramble:anagrams1 -lighteval|unscramble:anagrams2 -lighteval|unscramble:cycle_letters -lighteval|unscramble:random_insertion -lighteval|unscramble:reversed_words -lighteval|webqs -lighteval|wikitext -lighteval|winogrande -lighteval|wsc273 -lighteval|xcopa:et -lighteval|xcopa:ht -lighteval|xcopa:it -lighteval|xcopa:id -lighteval|xcopa:qu -lighteval|xcopa:sw -lighteval|xcopa:zh -lighteval|xcopa:ta -lighteval|xcopa:th -lighteval|xcopa:tr -lighteval|xcopa:vi -lighteval|xstory_cloze:en -lighteval|xstory_cloze:ru -lighteval|xstory_cloze:zh -lighteval|xstory_cloze:es -lighteval|xstory_cloze:ar -lighteval|xstory_cloze:hi -lighteval|xstory_cloze:id -lighteval|xstory_cloze:te -lighteval|xstory_cloze:sw -lighteval|xstory_cloze:eu -lighteval|xstory_cloze:my -lighteval|xwinograd:en -lighteval|xwinograd:fr -lighteval|xwinograd:jp -lighteval|xwinograd:pt -lighteval|xwinograd:ru -lighteval|xwinograd:zh