Skip to content

Commit

Permalink
[llm unify 6/n] extract schema and batch schema to llm map (#1188)
Browse files Browse the repository at this point in the history
* convert schema extraction to llm map

Signed-off-by: Henry Lindeman <[email protected]>

* throw out unneeded prompts

Signed-off-by: Henry Lindeman <[email protected]>

* deprecate extract_schema. do not deprecate extract_batch_schema bc that has weird behavior (call llm once and then apply everywhere)

Signed-off-by: Henry Lindeman <[email protected]>

* mypy

Signed-off-by: Henry Lindeman <[email protected]>

---------

Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 authored Feb 17, 2025
1 parent 21c3fc0 commit f751351
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 72 deletions.
8 changes: 3 additions & 5 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def extract_entity(self, entity_extractor: EntityExtractor, **kwargs) -> "DocSet
llm_map = entity_extractor.as_llm_map(self.plan, context=self.context, **kwargs)
return DocSet(self.context, llm_map)

@deprecated(version="0.1.31", reason="Use llm_map with SchemaZeroShotJinjaPrompt instead")
def extract_schema(self, schema_extractor: SchemaExtractor, **kwargs) -> "DocSet":
"""
Extracts a JSON schema of extractable properties from each document in this DocSet.
Expand Down Expand Up @@ -534,11 +535,8 @@ def extract_schema(self, schema_extractor: SchemaExtractor, **kwargs) -> "DocSet
.partition(partitioner=ArynPartitioner())
.extract_schema(schema_extractor=schema_extractor)
"""

from sycamore.transforms import ExtractSchema

schema = ExtractSchema(self.plan, schema_extractor=schema_extractor)
return DocSet(self.context, schema)
comptransform = schema_extractor.as_llm_map(self.plan, **kwargs)
return DocSet(self.context, comptransform)

def extract_batch_schema(self, schema_extractor: SchemaExtractor, **kwargs) -> "DocSet":
"""
Expand Down
4 changes: 2 additions & 2 deletions lib/sycamore/sycamore/llms/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
EntityExtractorZeroShotJinjaPrompt,
EntityExtractorFewShotJinjaPrompt,
TextSummarizerGuidancePrompt,
SchemaZeroShotGuidancePrompt,
SchemaZeroShotJinjaPrompt,
PropertiesZeroShotGuidancePrompt,
TaskIdentifierZeroShotGuidancePrompt,
GraphEntityExtractorPrompt,
Expand All @@ -29,7 +29,7 @@
"EntityExtractorZeroShotJinjaPrompt",
"EntityExtractorFewShotJinjaPrompt",
"TextSummarizerGuidancePrompt",
"SchemaZeroShotGuidancePrompt",
"SchemaZeroShotJinjaPrompt",
"PropertiesZeroShotGuidancePrompt",
"GraphEntityExtractorPrompt",
"GraphRelationshipExtractorPrompt",
Expand Down
31 changes: 12 additions & 19 deletions lib/sycamore/sycamore/llms/prompts/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
J_SET_ENTITY,
J_SET_SCHEMA,
J_ELEMENT_BATCHED_LIST,
J_ELEMENT_LIST_CAPPED,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -173,24 +174,18 @@ class _TextSummarizerGuidancePrompt(SimplePrompt):
)


class _SchemaZeroShotGuidancePrompt(SimplePrompt):
system = "You are a helpful entity extractor. You only return JSON Schema."
user = """You are given a few text elements of a document. Extract JSON Schema representing one entity of
class {entity} from the document. Using this context, FIND, FORMAT, and RETURN the JSON-LD Schema.
Return a flat schema, without nested properties. Return at most {max_num_properties} properties.
Only return JSON Schema as part of your answer.
{query}
"""


SchemaZeroShotGuidancePrompt = ElementListPrompt(
SchemaZeroShotJinjaPrompt = JinjaPrompt(
system="You are a helpful entity extractor. You only return JSON Schema.",
user="""You are given a few text elements of a document. Extract JSON Schema representing one entity of
class {entity} from the document. Using this context, FIND, FORMAT, and RETURN the JSON-LD Schema.
Return a flat schema, without nestes properties. Return at most {max_num_properties} properties.
Only return JSON Schema as part of your answer.
{elements}""",
max_num_properties=7,
user=textwrap.dedent(
"""\
You are given a few text elements of a document. Extract JSON Schema representing
one entity of class {{ entity }} from the document. Using this context, FIND, FORMAT, and
RETURN the JSON-LD Schema. Return a flat schema, without nested properties. Return at most
{{ max_num_properties }} properties. Only return JSON Schema as part of your answer.
{% if prompt_formatter is defined %}{{ prompt_formatter(doc.elements[:num_elements]) }}{% else %}"""
)
+ J_ELEMENT_LIST_CAPPED
+ "{% endif %}",
)


Expand Down Expand Up @@ -498,8 +493,6 @@ def __init__(self, field: str, groups: list[str]):
"ENTITY_EXTRACTOR_FEW_SHOT_GUIDANCE_PROMPT": _EntityExtractorFewShotGuidancePrompt,
"TEXT_SUMMARIZER_GUIDANCE_PROMPT": _TextSummarizerGuidancePrompt,
"TEXT_SUMMARIZER_GUIDANCE_PROMPT_CHAT": _TextSummarizerGuidancePrompt,
"SCHEMA_ZERO_SHOT_GUIDANCE_PROMPT": _SchemaZeroShotGuidancePrompt,
"SCHEMA_ZERO_SHOT_GUIDANCE_PROMPT_CHAT": _SchemaZeroShotGuidancePrompt,
}


Expand Down
10 changes: 5 additions & 5 deletions lib/sycamore/sycamore/tests/unit/test_docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
Map,
MapBatch,
Partition,
ExtractSchema,
ExtractBatchSchema,
Query,
)
from sycamore.transforms import Filter
from sycamore.transforms.base import get_name_from_callable, CompositeTransform
from sycamore.transforms.base_llm import LLMMap
from sycamore.transforms.extract_entity import OpenAIEntityExtractor
from sycamore.transforms.extract_schema import SchemaExtractor, LLMPropertyExtractor
from sycamore.transforms.extract_schema import SchemaExtractor, LLMPropertyExtractor, LLMSchemaExtractor
from sycamore.transforms.query import QueryExecutor
from sycamore.transforms.similarity import SimilarityScorer
from sycamore.transforms.sort import Sort
Expand Down Expand Up @@ -269,10 +268,11 @@ def test_rerank(self, mocker):

def test_extract_schema(self, mocker):
context = mocker.Mock(spec=Context)
func = mocker.Mock(spec=Callable, extract_schema=lambda d: {})
llm = mocker.Mock(spec=LLM)
extractor = LLMSchemaExtractor(entity_name="", llm=llm)
docset = DocSet(context, None)
docset = docset.extract_schema(func)
assert isinstance(docset.lineage(), ExtractSchema)
docset = docset.extract_schema(extractor)
assert isinstance(docset.lineage(), CompositeTransform)

def test_extract_batch_schema(self, mocker):
context = mocker.Mock(spec=Context)
Expand Down
30 changes: 17 additions & 13 deletions lib/sycamore/sycamore/tests/unit/transforms/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import datetime
import random
import string
from typing import Optional

from ray.util import inspect_serializability

from sycamore.llms.prompts.default_prompts import _SchemaZeroShotGuidancePrompt
from sycamore.data import Document, Element
from sycamore.llms.llms import LLM, FakeLLM
from sycamore.llms.prompts import RenderedPrompt
from sycamore.plan_nodes import Node
from sycamore.schema import Schema, SchemaField
from sycamore.transforms.base_llm import LLMMap
from sycamore.transforms.map import Map
Expand All @@ -19,6 +21,9 @@ class TrivialExtractor(SchemaExtractor):
def __init__(self):
super().__init__("foo")

def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
return child # type: ignore

def extract_schema(self, document: Document) -> Document:
return document

Expand All @@ -39,7 +44,7 @@ def test_serializable(self, mocker):

def test_extract_schema(self, mocker):
llm = mocker.Mock(spec=LLM)
generate = mocker.patch.object(llm, "generate_old")
generate = mocker.patch.object(llm, "generate")
generate.return_value = '```json {"accidentNumber": "string"}```'

num_of_elements = 10
Expand All @@ -56,28 +61,27 @@ def test_extract_schema(self, mocker):
schema_extractor = LLMSchemaExtractor(
class_name, llm, num_of_elements=num_of_elements, max_num_properties=max_num_properties
)
doc = schema_extractor.extract_schema(doc)
doc = schema_extractor.as_llm_map(None)._local_process([doc])[0]

ground_truth = {
"_schema": {
"accidentNumber": "string",
},
"_schema_class": "AircraftIncident",
}
print(doc.properties)
assert doc.properties == ground_truth
generate.assert_called_once_with(
prompt_kwargs={
"prompt": _SchemaZeroShotGuidancePrompt(),
"entity": class_name,
"max_num_properties": max_num_properties,
"query": schema_extractor._prompt_formatter(doc.elements),
}
)
generate.assert_called_once()
ca = generate.call_args
rp = ca.kwargs["prompt"]
assert isinstance(rp, RenderedPrompt)
messages = rp.messages
assert len(messages) == 2
assert f"ELEMENT None: {element1.text_representation}" in messages[1].content
assert f"ELEMENT None: {element2.text_representation}" in messages[1].content

def test_extract_batch_schema(self, mocker):
llm = mocker.Mock(spec=LLM)
generate = mocker.patch.object(llm, "generate_old")
generate = mocker.patch.object(llm, "generate")
generate.return_value = '```json {"accidentNumber": "string"}```'
schema_extractor = LLMSchemaExtractor("AircraftIncident", llm)

Expand Down
63 changes: 35 additions & 28 deletions lib/sycamore/sycamore/transforms/extract_schema.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional, Union
from typing import Callable, Optional, Union
import json

from sycamore.data import Element, Document
from sycamore.schema import Schema
from sycamore.llms import LLM
from sycamore.llms.prompts.default_prompts import (
_SchemaZeroShotGuidancePrompt,
PropertiesZeroShotJinjaPrompt,
PropertiesFromSchemaJinjaPrompt,
SchemaZeroShotJinjaPrompt,
)
from sycamore.llms.prompts import SycamorePrompt
from sycamore.plan_nodes import Node
from sycamore.transforms.base import CompositeTransform
from sycamore.transforms.map import Map
from sycamore.transforms.base_llm import LLMMap
from sycamore.utils.extract_json import extract_json
Expand All @@ -31,6 +32,10 @@ class SchemaExtractor(ABC):
def __init__(self, entity_name: str):
self._entity_name = entity_name

@abstractmethod
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
pass

@abstractmethod
def extract_schema(self, document: Document) -> Document:
pass
Expand Down Expand Up @@ -84,35 +89,37 @@ def __init__(
self._prompt_formatter = prompt_formatter
self._max_num_properties = max_num_properties

@timetrace("ExtrSchema")
def extract_schema(self, document: Document) -> Document:
entities = self._handle_zero_shot_prompting(document)

try:
payload = entities
answer = extract_json(payload)
except (json.JSONDecodeError, ValueError):
answer = entities

document.properties.update({"_schema": answer, "_schema_class": self._entity_name})

return document

def _handle_zero_shot_prompting(self, document: Document) -> Any:
sub_elements = [document.elements[i] for i in range((min(self._num_of_elements, len(document.elements))))]
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
prompt = SchemaZeroShotJinjaPrompt.set(
entity=self._entity_name,
max_num_properties=self._max_num_properties,
num_elements=self._num_of_elements,
field="text_representation",
)
if self._prompt_formatter is not element_list_formatter:
prompt = prompt.set(prompt_formatter=self._prompt_formatter)

prompt = _SchemaZeroShotGuidancePrompt()
def parse_json(doc: Document) -> Document:
schemastr = doc.properties.get("_schema", "{}")
try:
schema = extract_json(schemastr)
except (json.JSONDecodeError, AttributeError, ValueError):
schema = schemastr
doc.properties["_schema"] = schema
doc.properties["_schema_class"] = self._entity_name
return doc

entities = self._llm.generate_old(
prompt_kwargs={
"prompt": prompt,
"entity": self._entity_name,
"max_num_properties": self._max_num_properties,
"query": self._prompt_formatter(sub_elements),
}
)
llm_map = LLMMap(child, prompt=prompt, output_field="_schema", llm=self._llm)
json_map = Map(llm_map, f=parse_json)
comptransform = CompositeTransform(child, []) # type: ignore
comptransform.nodes = [llm_map, json_map]
return comptransform

return entities
@timetrace("ExtrSchema")
def extract_schema(self, document: Document) -> Document:
comptransform = self.as_llm_map(None)
assert isinstance(comptransform, CompositeTransform)
return comptransform._local_process([document])[0]


class OpenAISchemaExtractor(LLMSchemaExtractor):
Expand Down

0 comments on commit f751351

Please sign in to comment.