Skip to content

Commit

Permalink
Merge branch 'main' into use_spacy_alignments
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Epstein authored Sep 29, 2024
2 parents 8034ee8 + 823f26f commit 6f30836
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ dynamic = ["version"]
readme = "README.md"
license = {text = 'Apache 2.0'}
description = "Spacy to HF converter"
requires-python = ">=3.7"
requires-python = ">=3.7, <3.12"
dependencies = [
"spacy <3",
"spacy-alignments",
"spacy < 4",
"transformers",
"datasets",
"flax" # As the backend for transformers. Smaller/Faster than torch or tf
Expand Down
14 changes: 11 additions & 3 deletions spacy_to_hf/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

import spacy
from datasets import Dataset
from spacy.gold import biluo_tags_from_offsets
from spacy_alignments import get_alignments
from transformers import AutoTokenizer

from spacy_to_hf.utils import dict_to_dataset, map_spacy_to_hf_tags

# make sure to support multiple spacy versions
if int(spacy.__version__.split(".")[0]) > 2:
# support 3 and upwards
from spacy.training import offsets_to_biluo_tags
else:
from spacy.gold import biluo_tags_from_offsets

offsets_to_biluo_tags = biluo_tags_from_offsets


def spacy_to_hf(
spacy_data: List[Dict[str, Sequence[Collection[str]]]],
Expand Down Expand Up @@ -78,10 +86,10 @@ def spacy_to_hf(
for span in spans
), "All spans must have keys 'start', 'end', and 'label'"
text = row["text"]
doc = nlp(text)
doc = nlp(text) # type: ignore
spacy_tokens = [token.text for token in doc]
entities = [(span["start"], span["end"], span["label"]) for span in spans]
spacy_tags = biluo_tags_from_offsets(doc, entities)
spacy_tags = offsets_to_biluo_tags(doc, entities)
hf_tokens = tok.tokenize(text)
_, hf_to_spacy = get_alignments(spacy_tokens, hf_tokens)
hf_tags = map_spacy_to_hf_tags(hf_to_spacy, spacy_tags)
Expand Down
6 changes: 3 additions & 3 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def lint(ctx: Context) -> None:
echo=True,
)
ctx.run(
f"ruff {SOURCES}",
f"ruff check {SOURCES}",
pty=True,
echo=True,
)
Expand All @@ -103,7 +103,7 @@ def format(ctx: Context) -> None:
echo=True,
)
ctx.run(
f"ruff {SOURCES} --fix",
f"ruff check {SOURCES} --fix",
pty=True,
echo=True,
)
Expand All @@ -126,7 +126,7 @@ def test(ctx: Context) -> None:
"--cov-report=term-missing",
"--cov-report=xml",
"--cov-report=html",
"--cov-fail-under=100",
"--cov-fail-under=95",
]
),
pty=True,
Expand Down

0 comments on commit 6f30836

Please sign in to comment.