Skip to content

Commit

Permalink
json-schema-to-grammar : fix order of props + non-str const/enum (gge…
Browse files Browse the repository at this point in the history
…rganov#6232)

* json: ordered json in server/schema converter to respect orig order

* json: ws nits

* json: support non-string const / enums
  • Loading branch information
ochafik authored and hodlen committed Apr 3, 2024
1 parent 52fcea9 commit 238322e
Show file tree
Hide file tree
Showing 8 changed files with 1,452 additions and 1,481 deletions.
16 changes: 6 additions & 10 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <unordered_set>
#include <vector>

using json = nlohmann::json;
using json = nlohmann::ordered_json;

const std::string SPACE_RULE = "\" \"?";

Expand Down Expand Up @@ -124,7 +124,7 @@ static std::string replacePattern(const std::string & input, const std::regex &
}

static std::string format_literal(const std::string & literal) {
std::string escaped = replacePattern(json(literal).dump(), GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
char c = match.str()[0];
return GRAMMAR_LITERAL_ESCAPES.at(c);
});
Expand All @@ -137,7 +137,7 @@ class SchemaConverter {
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
std::unordered_map<std::string, nlohmann::json> _refs;
std::unordered_map<std::string, json> _refs;
std::unordered_set<std::string> _refs_being_resolved;
std::vector<std::string> _errors;
std::vector<std::string> _warnings;
Expand Down Expand Up @@ -413,7 +413,7 @@ class SchemaConverter {
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
prop_kv_rule_names[prop_name] = _add_rule(
name + (name.empty() ? "" : "-") + prop_name + "-kv",
format_literal(prop_name) + " space \":\" space " + prop_rule_name
format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
);
if (required.find(prop_name) != required.end()) {
required_props.push_back(prop_name);
Expand Down Expand Up @@ -495,7 +495,7 @@ class SchemaConverter {
_rules["space"] = SPACE_RULE;
}

void resolve_refs(nlohmann::json & schema, const std::string & url) {
void resolve_refs(json & schema, const std::string & url) {
/*
* Resolves all $ref fields in the given schema, fetching any remote schemas,
* replacing each $ref with absolute reference URL and populates _refs with the
Expand Down Expand Up @@ -557,11 +557,7 @@ class SchemaConverter {
}

std::string _generate_constant_rule(const json & value) {
if (!value.is_string()) {
_errors.push_back("Only std::string constants are supported, got " + value.dump());
return "";
}
return format_literal(value.get<std::string>());
return format_literal(value.dump());
}

std::string visit(const json & schema, const std::string & name) {
Expand Down
2 changes: 1 addition & 1 deletion common/json-schema-to-grammar.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#pragma once
#include "json.hpp"

std::string json_schema_to_grammar(const nlohmann::json& schema);
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
7 changes: 3 additions & 4 deletions examples/json-schema-to-grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):

def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
)
return f'"{escaped}"'

Expand Down Expand Up @@ -308,8 +308,7 @@ def _resolve_ref(self, ref):
return ref_name

def _generate_constant_rule(self, value):
assert isinstance(value, str), f'Only string constants are supported, got {value}'
return self._format_literal(value)
return self._format_literal(json.dumps(value))

def visit(self, schema, name):
schema_type = schema.get('type')
Expand Down Expand Up @@ -428,7 +427,7 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
prop_kv_rule_names[prop_name] = self._add_rule(
f'{name}{"-" if name else ""}{prop_name}-kv',
fr'{self._format_literal(prop_name)} space ":" space {prop_rule_name}'
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
)
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]
Expand Down
2,791 changes: 1,384 additions & 1,407 deletions examples/server/json-schema-to-grammar.mjs.hpp

Large diffs are not rendered by default.

12 changes: 3 additions & 9 deletions examples/server/public/json-schema-to-grammar.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export class SchemaConverter {
}

_formatLiteral(literal) {
const escaped = JSON.stringify(literal).replace(
const escaped = literal.replace(
GRAMMAR_LITERAL_ESCAPE_RE,
m => GRAMMAR_LITERAL_ESCAPES[m]
);
Expand Down Expand Up @@ -327,10 +327,7 @@ export class SchemaConverter {
}

_generateConstantRule(value) {
if (typeof value !== 'string') {
throw new Error('Only string constants are supported, got ' + JSON.stringify(value));
}
return this._formatLiteral(value);
return this._formatLiteral(JSON.stringify(value));
}

visit(schema, name) {
Expand All @@ -346,9 +343,6 @@ export class SchemaConverter {
} else if (Array.isArray(schemaType)) {
return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t }))));
} else if ('const' in schema) {
if (typeof schema.const !== 'string') {
throw new Error('Only string constants are supported, got ' + JSON.stringify(schema.const));
}
return this._addRule(ruleName, this._generateConstantRule(schema.const));
} else if ('enum' in schema) {
const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | ');
Expand Down Expand Up @@ -457,7 +451,7 @@ export class SchemaConverter {
const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`);
propKvRuleNames[propName] = this._addRule(
`${name ?? ''}${name ? '-' : ''}${propName}-kv`,
`${this._formatLiteral(propName)} space ":" space ${propRuleName}`
`${this._formatLiteral(JSON.stringify(propName))} space ":" space ${propRuleName}`
);
}
const requiredProps = sortedProps.filter(k => required.has(k));
Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <signal.h>
#include <memory>

using json = nlohmann::json;
using json = nlohmann::ordered_json;

bool server_verbose = false;
bool server_log_json = true;
Expand Down
2 changes: 1 addition & 1 deletion examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"

using json = nlohmann::json;
using json = nlohmann::ordered_json;

// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
Expand Down
101 changes: 53 additions & 48 deletions tests/test-json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase

test({
FAILURE,
"invalid type type",
"invalid type",
R"""({
"type": 123
})""",
Expand Down Expand Up @@ -193,21 +193,27 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
});

test({
FAILURE,
SUCCESS,
"non-string const",
R"""({
"const": 123
})""",
""
R"""(
root ::= "123"
space ::= " "?
)"""
});

test({
FAILURE,
SUCCESS,
"non-string enum",
R"""({
"enum": [123]
"enum": ["red", "amber", "green", null, 42, ["foo"]]
})""",
""
R"""(
root ::= "\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]"
space ::= " "?
)"""
});

test({
Expand Down Expand Up @@ -378,28 +384,27 @@ static void test_all(const std::string & lang, std::function<void(const TestCase

test({
SUCCESS,
"required props",
"required props in original order",
R"""({
"type": "object",
"properties": {
"a": {
"type": "string"
},
"b": {
"type": "string"
}
"b": {"type": "string"},
"c": {"type": "string"},
"a": {"type": "string"}
},
"required": [
"a",
"b"
"b",
"c"
],
"additionalProperties": false,
"definitions": {}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
root ::= "{" space a-kv "," space b-kv "}" space
c-kv ::= "\"c\"" space ":" space string
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
Expand Down Expand Up @@ -458,13 +463,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase

test({
SUCCESS,
"required + optional props",
"required + optional props each in original order",
R"""({
"properties": {
"a": {"type": "string"},
"b": {"type": "string"},
"c": {"type": "string"},
"d": {"type": "string"}
"a": {"type": "string"},
"d": {"type": "string"},
"c": {"type": "string"}
},
"required": ["a", "b"],
"additionalProperties": false
Expand All @@ -473,14 +478,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
c-kv ::= "\"c\"" space ":" space string
c-rest ::= ( "," space d-kv )?
d-kv ::= "\"d\"" space ":" space string
root ::= "{" space a-kv "," space b-kv ( "," space ( c-kv c-rest | d-kv ) )? "}" space
d-rest ::= ( "," space c-kv )?
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)* "\"" space
)"""
});

Expand Down Expand Up @@ -648,16 +653,16 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"$ref": "#/definitions/MyType",
"definitions": {
"MyType": {
"type": "object",
"properties": {
"a": {
"type": "string"
}
},
"required": [
"a"
],
"additionalProperties": false
"type": "object",
"properties": {
"a": {
"type": "string"
}
},
"required": [
"a"
],
"additionalProperties": false
}
}
})""",
Expand All @@ -683,10 +688,10 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
"properties": {"a": {"type": "number"}}
},
"bar": {
"properties": {"b": {"type": "number"}}
"properties": {"b": {"type": "number"}}
}
},
"type": "object"
Expand Down Expand Up @@ -720,16 +725,16 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
"properties": {"a": {"type": "number"}}
},
"bar": {
"properties": {"b": {"type": "number"}}
"properties": {"b": {"type": "number"}}
},
"bam": {
"properties": {"c": {"type": "number"}}
"properties": {"c": {"type": "number"}}
},
"baz": {
"properties": {"d": {"type": "number"}}
"properties": {"d": {"type": "number"}}
}
},
"type": "object"
Expand Down Expand Up @@ -757,15 +762,15 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"properties": {
"number": {
"type": "object",
"properties": {
"root": {
"type": "number"
}
},
"required": [
"root"
],
"additionalProperties": false
"properties": {
"root": {
"type": "number"
}
},
"required": [
"root"
],
"additionalProperties": false
}
},
"required": [
Expand Down Expand Up @@ -796,7 +801,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
int main() {
test_all("C++", [](const TestCase & tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)));
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
tc.verify_status(SUCCESS);
} catch (const std::runtime_error & ex) {
fprintf(stderr, "Error: %s\n", ex.what());
Expand Down

0 comments on commit 238322e

Please sign in to comment.