Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLM as Juddge evaluation metric #40

Merged
merged 5 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ __pycache__/
# C extensions
*.so

.envrc

# Distribution / packaging
.Python
build/
Expand Down
2 changes: 1 addition & 1 deletion evalem/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.4-alpha"
__version__ = "0.0.5-alpha"

from ._base.evaluators import Evaluator # noqa
from ._base.pipelines import ( # noqa
Expand Down
5 changes: 3 additions & 2 deletions evalem/_base/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Iterable, List, Tuple

Check failure on line 6 in evalem/_base/metrics.py

View workflow job for this annotation

GitHub Actions / Flake8

evalem/_base/metrics.py#L6

'typing.Iterable' imported but unused (F401)

from jury import Jury
from sklearn.metrics import confusion_matrix
Expand All @@ -14,6 +14,7 @@
EvaluationPredictionInstance,
EvaluationReferenceInstance,
MetricResult,
SequenceType,
SinglePredictionInstance,
)

Expand Down Expand Up @@ -126,10 +127,10 @@
res = []
for pred, ref in zip(predictions, references):
# if multiple predictions, skip for now
if isinstance(pred, Iterable) and not isinstance(pred, str):
if isinstance(pred, SequenceType) and not isinstance(pred, str):
raise TypeError("Cannot handle multiple prediction instance")
# if multiple references
elif isinstance(ref, Iterable) and not isinstance(ref, str):
elif isinstance(ref, SequenceType) and not isinstance(ref, str):
res.extend(list(map(lambda r: (pred, r), ref)))
else:
res.append((pred, ref))
Expand Down
3 changes: 2 additions & 1 deletion evalem/_base/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union

import numpy as np
import torch
Expand Down Expand Up @@ -108,3 +108,4 @@ def __hash__(self) -> str:
MetricOutput = Union[int, float, Dict[str, Union[str, int, float]], MetricResult]

PathType = Union[str, Path]
SequenceType = Union[List, Tuple, Set]
20 changes: 20 additions & 0 deletions evalem/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# flake8: noqa
from .metrics import (
BartScore,
BertScore,
BleuMetric,
ExactMatchMetric,
LLMAsJudgeMetric,
MeteorMetric,
NLPMetric,
RougeMetric,
SacreBleuMetric,
SemanticMetric,
)
from .models import (
DefaultQAModelWrapper,
HFLMWrapper,
HFPipelineWrapper,
QuestionAnsweringHFPipelineWrapper,
TextClassificationHFPipelineWrapper,
)
1 change: 1 addition & 0 deletions evalem/nlp/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ._base import NLPMetric
from .basics import ExactMatchMetric
from .llm import LLMAsJudgeMetric
from .semantics import (
BartScore,
BertScore,
Expand Down
218 changes: 218 additions & 0 deletions evalem/nlp/metrics/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#!/usr/bin/env python3

from enum import Enum
from typing import List, Optional, Tuple
from urllib.parse import urljoin

import numpy as np
import outlines
from loguru import logger
from outlines.models.openai import OpenAIConfig

from ..._base.structures import (
EvaluationPredictionInstance,
EvaluationReferenceInstance,
MetricResult,
SequenceType,
)
from ._base import NLPMetric


class AggregationType(Enum):
MEAN = "mean"
AVERAGE = "average"
MAX = "max"


class LLMAsJudgeMetric(NLPMetric):
"""
Uses a language model to compute metrics by performing a binary
classification of prediction matching with the reference.
Uses N tries and compute the aggregate score for each prediction.

The prompt can be changed using `prompt` attribute.

Args:
```model```: ```str```
OpenaAI-api compatible model name.
Could be:
- open ai models
- ollama models
```api_base```: ```str```
Base URL for api requests.
- openai: https://api.openai.com/v1
- ollama: https://localhost:11434/v1
If `/v1` is not present, it will be appended
```api_key```: ```Optional[str]```
API key to make request for compleition
```n_tries```: ```int```
Number of times the judgement is done for scoring.
The final aggregated scores will be based on `LLMAsJudgeMetric.AggregationType`
```prompt```: ```Optional[str]```
Prompt to use for generating the scores.
If not provided, defaults to `LLMAsJudgeMetric._prompt`
```aggregation_type```: ```Optional[AggregationType]```
Decides how to aggregate scores from the multiple judgement tries.
Defaults to `AggregationType.MEAN` if not provided.
```debug```:```bool```
Boolean flag for debug-mode outputs


Usage:
.. code-block: python

from evalem.nlp import LLMAsJudgeMetric

model = "ollama/llama3.2:3b"
api_base = "http://localhost:11434/v1"
model = "gpt-4o-mini"

api_base = "https://api.openai.com/v1"

references=["This is title 1", "This has title 2"]
predictions=[
["Title 1", "title 1 absolutely"],
["this is title 3, not title 2"]
]

metric = LLMAsJudgeMetric(
model=MODEL,
api_base=API_BASE,
api_key=os.environ.get("OPENAI_API_KEY"),
# api_key=None,
n_tries=3,
prompt=PROMPT,
debug=True,
)
result = metric.compuate(references=references, predictions=predictions)
"""

_prompt = (
"You are a very good binary classifier."
+ " Classify the quality of prediction based on the provided reference.\n"
+ "Prediction: {prediction}\n"
+ "Reference: {reference}"
)

def __init__(
self,
model: str,
api_base: str,
api_key: Optional[str] = None,
n_tries: int = 1,
temperature: float = 0.0,
prompt: Optional[str] = None,
aggregation_type: Optional[List[AggregationType]] = None,
debug: bool = False,
) -> None:
super().__init__(debug=debug)
self.model = outlines.models.openai(
self.__clean_model(model),
base_url=api_base,
api_key=api_key,
config=OpenAIConfig(temperature=temperature),
)
self.api_base = self.__clean_url(api_base)
self.n_tries = n_tries or 1
self.prompt = prompt or LLMAsJudgeMetric._prompt
self.aggregation_type = aggregation_type or AggregationType.MEAN

self._sanity_check_prmopt(self.prompt)

def _sanity_check_prmopt(self, prompt: str) -> bool:
if "{prediction}" not in prompt or "{reference}" not in prompt:
raise ValueError(
"Missing '{prediction} and '{reference}' placeholders in the prmopt.",
)
return True

def __clean_model(self, model: str) -> str:
if model.startswith("ollama/"):
model = model.removeprefix("ollama/")
return model

def __clean_url(self, url: str) -> str:
if not url.endswith("/v1"):
url = urljoin(url, "/v1")
return url

@staticmethod
def _flatten_references(
predictions,
references,
) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]:
res = []
for preds, refs in zip(predictions, references):
# multiple predictions, single reference
if isinstance(preds, SequenceType) and isinstance(refs, str):
res.extend(list(map(lambda p: (p, refs), preds)))

# single prediction, multiple references
elif isinstance(preds, str) and isinstance(refs, SequenceType):
res.extend(list(map(lambda r: (preds, r), refs)))

# single prediction, single reference
else:
res.append((preds, refs))

predictions, references = zip(*res)
return predictions, references

def compute(
self,
predictions: EvaluationPredictionInstance,
references: EvaluationReferenceInstance,
**kwargs,
) -> MetricResult:
# make sure to flatten
predictions, references = self._flatten_references(predictions, references)
if self.debug:
logger.debug(f"Evaluating for {len(predictions)} predictions.")
generator = outlines.generate.choice(self.model, ["0", "1"])
res = []
individual_scores = []
for pred, ref in zip(predictions, references):
prompt = self.prompt.format(prediction=pred, reference=ref)
if self.debug:
logger.debug(f"Prompt :: {prompt}")
scores = []
score = np.nan
with outlines.caching.cache_disabled():
scores = self._compute_single(generator, prompt, self.n_tries)
score = self._aggregate_scores(scores, self.aggregation_type)
individual_scores.append(scores)
res.append(score)
if self.debug:
logger.debug(f"Scores :: {scores}")
logger.debug(f"Aggregated score :: {score}")
return MetricResult(
score=float(np.mean(res)),
total_items=len(predictions),
metric_name=self.__classname__,
extra=dict(scores=individual_scores, model=self.model),
)

@staticmethod
def _aggregate_scores(
scores: List[int],
aggregation_type: AggregationType = AggregationType.MEAN,
) -> float:
if not scores:
return 0.0
res = 0.0
if aggregation_type in [AggregationType.MEAN, AggregationType.AVERAGE]:
res = round(sum(scores) / len(scores), 4)
elif aggregation_type in [AggregationType.MAX]:
res = float(max(scores))
return res

def _compute_single(self, generator, prompt, n_tries) -> List[float]:
return [int(generator(prompt)) for n in range(n_tries)]


def main():
pass


if __name__ == "__main__":
main()
8 changes: 4 additions & 4 deletions evalem/nlp/metrics/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class BertScore(JuryBasedMetric, SemanticMetric):
Usage:
.. code-block: python

from evalem.metrics import BertScore
from evalem.nlp import BertScore

references = [
"Reference 1",
Expand Down Expand Up @@ -185,7 +185,7 @@ class BleuMetric(JuryBasedMetric, SemanticMetric):

.. code-block: python

from evalem.metrics import BleuMetric
from evalem.nlp import BleuMetric

metric = BleuMetric()
results = metric(predictions=predictions, references=references)
Expand Down Expand Up @@ -213,7 +213,7 @@ class MeteorMetric(JuryBasedMetric, SemanticMetric):

.. code-block: python

from evalem.metrics import MeteorMetric
from evalem.nlp import MeteorMetric

metric = MeteorMetric()
results = metric(predictions=predictions, references=references)
Expand All @@ -236,7 +236,7 @@ class RougeMetric(JuryBasedMetric, SemanticMetric):

.. code-block: python

from evalem.metrics import RougeMetric
from evalem.nlp import RougeMetric

metric = RougeMetric()
results = metric(predictions=predictions, references=references)
Expand Down
41 changes: 23 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dynamic = [
description = "An evaluation framework for your NLP pipelines"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.8"
requires-python = ">=3.10"
authors = [
{ email = "[email protected]" },
]
Expand All @@ -28,26 +28,26 @@ classifiers = [
"Topic :: Text Processing :: General",
]
dependencies = [
"arrow==1.2.3",
"bert-score==0.3.13",
"datasets==2.7.0",
"evaluate==0.2.2",
"jury==2.2.3",
"loguru==0.6.0",
"numpy==1.24.2",
"onnx==1.14.0",
"onnxruntime==1.15.0",
"optimum==1.8.8",
"pandas==1.5.3",
"pyarrow==11.0.0",
"pyarrow>=18.1.0",
"bert-score>=0.3.13",
"datasets==2.9.0",
"evaluate>=0.4.3",
"jury==2.3.1",
"loguru>=0.6.0",
"numpy>=2.2.0",
"onnx>=1.17.0",
"onnxruntime>=1.20.1",
"optimum>=1.23.3",
"pandas>=2.2.3",
"pytest==7.2.1",
"pytest-cov==4.0.0",
"sacrebleu==2.3.1",
"scikit-learn==1.2.1",
"sentencepiece==0.1.99",
"sacrebleu==2.4.3",
"scikit-learn>=1.6.0",
"sentencepiece==0.2.0",
"seqeval==1.2.2",
"torch==2.0.1",
"transformers==4.28.1",
"torch>=2.5.1",
"transformers>=4.47.0",
"pip>=24.3.1",
]

[project.optional-dependencies]
Expand All @@ -58,6 +58,11 @@ nlp = [
# dependencies for nlp module
]

llm = [
"outlines>=0.1.9",
"openai>=1.57.3",
]

[project.urls]
Homepage = "https://github.com/NASA-IMPACT/evalem"

Expand Down
Loading