Skip to content

Commit e093db9

Browse files
authored
sample: temporarily use grammars for constrained generation in new engine (ollama#9586)
1 parent a1cda80 commit e093db9

File tree

10 files changed

+298
-210
lines changed

10 files changed

+298
-210
lines changed

llama/llama.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,20 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
245245
return &m, nil
246246
}
247247

248+
func LoadVocabFromFile(path string) (*Vocab, error) {
249+
mp := C.CString(path)
250+
defer C.free(unsafe.Pointer(mp))
251+
v := Vocab{c: C.llama_load_vocab_from_file(mp)}
252+
if v.c == nil {
253+
return nil, fmt.Errorf("unable to load vocab: %s", path)
254+
}
255+
return &v, nil
256+
}
257+
258+
func FreeVocab(vocab *Vocab) {
259+
C.llama_free_vocab(vocab.c)
260+
}
261+
248262
func FreeModel(model *Model) {
249263
C.llama_model_free(model.c)
250264
}
@@ -293,6 +307,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
293307
return nil
294308
}
295309

310+
type Vocab struct {
311+
c *C.struct_llama_vocab
312+
}
313+
296314
func (m *Model) Vocab() *C.struct_llama_vocab {
297315
return C.llama_model_get_vocab(m.c)
298316
}
@@ -669,3 +687,53 @@ func SchemaToGrammar(schema []byte) []byte {
669687
}
670688
return buf[:n]
671689
}
690+
691+
type Sampler struct {
692+
c *C.struct_llama_sampler
693+
}
694+
695+
func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
696+
cGrammar := C.CString(grammar)
697+
cRoot := C.CString("root")
698+
defer C.free(unsafe.Pointer(cGrammar))
699+
defer C.free(unsafe.Pointer(cRoot))
700+
701+
sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
702+
703+
return sampler
704+
}
705+
706+
func (s *Sampler) Accept(token int32) {
707+
C.llama_sampler_accept(s.c, C.llama_token(token))
708+
}
709+
710+
type TokenData struct {
711+
Id int32
712+
Logit float32
713+
}
714+
715+
func (s *Sampler) Apply(tokens []TokenData) {
716+
tds := make([]C.struct_llama_token_data, len(tokens))
717+
for i, token := range tokens {
718+
tds[i] = C.struct_llama_token_data{
719+
id: C.int32_t(token.Id),
720+
logit: C.float(token.Logit),
721+
p: C.float(0.0),
722+
}
723+
}
724+
tda := &C.llama_token_data_array{
725+
data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
726+
size: C.size_t(len(tokens)),
727+
selected: C.int64_t(-1),
728+
sorted: C.bool(false),
729+
}
730+
731+
var pinner runtime.Pinner
732+
pinner.Pin(&tds[0])
733+
defer pinner.Unpin()
734+
735+
C.llama_sampler_apply(s.c, tda)
736+
for i := range tokens {
737+
tokens[i].Logit = float32(tds[i].logit)
738+
}
739+
}

llama/sampling_ext.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include "sampling.h"
33
#include "sampling_ext.h"
44
#include "json-schema-to-grammar.h"
5+
#include "llama.h"
6+
#include "llama-model.h"
7+
#include "llama-model-loader.h"
58

69
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
710
try {
@@ -64,3 +67,22 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
6467
return 0;
6568
}
6669
}
70+
71+
struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
72+
llama_vocab * vocab = new llama_vocab();
73+
try {
74+
const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
75+
std::vector<std::string> splits = {};
76+
llama_model_loader ml(std::string(fname), splits, false, false, nullptr);
77+
vocab->load(ml, kv);
78+
} catch (const std::exception & err) {
79+
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
80+
return nullptr;
81+
}
82+
83+
return vocab;
84+
}
85+
86+
void llama_free_vocab(struct llama_vocab * vocab) {
87+
delete vocab;
88+
}

llama/sampling_ext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ extern "C"
3535

3636
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
3737

38+
struct llama_vocab * llama_load_vocab_from_file(const char * fname);
39+
void llama_free_vocab(struct llama_vocab * vocab);
40+
3841
#ifdef __cplusplus
3942
}
4043
#endif

llm/server.go

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -729,29 +729,24 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
729729
}
730730

731731
if len(req.Format) > 0 {
732-
format := string(req.Format)
733-
if format != `null` && format != `""` {
734-
if s.textProcessor != nil {
735-
// New engine handles this on the backend
736-
request["format"] = req.Format
737-
} else {
738-
// old engine
739-
switch format {
740-
case `"json"`:
741-
request["grammar"] = grammarJSON
742-
default:
743-
if req.Format[0] != '{' {
744-
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
745-
}
732+
switch string(req.Format) {
733+
case `null`, `""`:
734+
// Field was set, but "missing" a value. We accept
735+
// these as "not set".
736+
break
737+
case `"json"`:
738+
request["grammar"] = grammarJSON
739+
default:
740+
if req.Format[0] != '{' {
741+
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
742+
}
746743

747-
// User provided a JSON schema
748-
g := llama.SchemaToGrammar(req.Format)
749-
if g == nil {
750-
return fmt.Errorf("invalid JSON schema in format")
751-
}
752-
request["grammar"] = string(g)
753-
}
744+
// User provided a JSON schema
745+
g := llama.SchemaToGrammar(req.Format)
746+
if g == nil {
747+
return fmt.Errorf("invalid JSON schema in format")
754748
}
749+
request["grammar"] = string(g)
755750
}
756751
}
757752

runner/ollamarunner/runner.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,12 @@ type Server struct {
254254
// multimodalHash generates hashes for comparing equality
255255
// of non-text data
256256
multimodalHash maphash.Hash
257+
258+
// vocab is a llama.cpp vocab required for gammar-based
259+
// constrained generation (json mode, structured outputs)
260+
// TODO: this is temporary until Ollama sampling supports
261+
// constrained generation
262+
vocab *sample.Vocab
257263
}
258264

259265
func (s *Server) allNil() bool {
@@ -574,18 +580,25 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
574580
return
575581
}
576582

583+
var grammar *sample.Grammar
584+
var err error
585+
if req.Grammar != "" {
586+
grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
587+
if err != nil {
588+
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
589+
return
590+
}
591+
}
592+
577593
sampler := sample.NewSampler(
578594
req.Temperature,
579595
req.TopK,
580596
req.TopP,
581597
req.MinP,
582598
req.Seed,
599+
grammar,
583600
)
584601

585-
if req.Grammar != "" {
586-
panic("grammars are not yet supported")
587-
}
588-
589602
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
590603
numPredict: req.NumPredict,
591604
stop: req.Stop,
@@ -797,6 +810,8 @@ func (s *Server) loadModel(
797810
panic(err)
798811
}
799812

813+
s.vocab = sample.NewVocab(mpath)
814+
800815
// TODO(jessegross): LoRA loading
801816
if lpath.String() != "" {
802817
panic("loras are not yet implemented")

0 commit comments

Comments
 (0)