Skip to content

Commit 14fe924

Browse files
authored
Merge pull request #6 from julien-c/add-gemma
Add `gemma`
2 parents 876a427 + c9cbedd commit 14fe924

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

src/index.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import type {
55
BaseGGUFMetadata,
66
BloomMetadata,
77
FalconMetadata,
8+
GemmaMetadata,
89
GGUFMetadata,
910
GPT2Metadata,
1011
GPTJMetadata,
@@ -16,6 +17,7 @@ import type {
1617
import {
1718
bloomMetadataSchema,
1819
falconMetadataSchema,
20+
gemmaMetadataSchema,
1921
gPT2MetadataSchema,
2022
gPTJMetadataSchema,
2123
gPTNeoXMetadataSchema,
@@ -263,6 +265,7 @@ const isValidArchitecture = (
263265
'gpt2',
264266
'bloom',
265267
'falcon',
268+
'gemma',
266269
'rwkv',
267270
].includes(architecture)
268271
}
@@ -317,6 +320,11 @@ const validateMetadata = (
317320
if (res.success === false) return { error: res.error }
318321
return { metadata: res.data }
319322
}
323+
case 'gemma': {
324+
const res = gemmaMetadataSchema.safeParse(metadata)
325+
if (res.success === false) return { error: res.error }
326+
return { metadata: res.data }
327+
}
320328
case 'rwkv': {
321329
const res = rWKVMetadataSchema.safeParse(metadata)
322330
if (res.success === false) return { error: res.error }
@@ -604,6 +612,12 @@ export const isFalconMetadata = (
604612
return metadata.general.architecture === 'falcon'
605613
}
606614

615+
export const isGemmaMetadata = (
616+
metadata: GGUFMetadata,
617+
): metadata is GemmaMetadata => {
618+
return metadata.general.architecture === 'gemma'
619+
}
620+
607621
export const isRWKVMetadata = (
608622
metadata: GGUFMetadata,
609623
): metadata is RWKVMetadata => {

src/metadataTypes.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export type ArchitectureType =
1010
| 'gpt2'
1111
| 'bloom'
1212
| 'falcon'
13+
| 'gemma'
1314
| 'rwkv'
1415

1516
export type BaseGGUFMetadata = {
@@ -369,6 +370,37 @@ export type RWKVMetadata = {
369370
}
370371
}
371372

373+
export type GemmaMetadata = {
374+
gemma: {
375+
attention: {
376+
/** Also known as n_head. Number of attention heads. */
377+
head_count: number
378+
/** The number of heads per group used in Grouped-Query-Attention. If not
379+
* present or if present and equal to [llm].attention.head_count, the model
380+
* does not use GQA. */
381+
head_count_kv?: number
382+
/** Layer RMS normalization epsilon. */
383+
layer_norm_rms_epsilon: number
384+
}
385+
block_count: number
386+
/** Length of the context used during training or fine-tuning. RWKV is able
387+
* to handle larger context than this limit, but the output quality
388+
* may suffer. */
389+
context_length: number
390+
/** Also known as n_embd. Embedding layer size. */
391+
embedding_length: number
392+
/** Also known as n_ff. The length of the feedforward layer. */
393+
feed_forward_length: number
394+
}
395+
general: BaseGGUFMetadata & {
396+
/**
397+
* describes what architecture this model implements. All lowercase ASCII,
398+
* with only [a-z0-9]+ characters allowed.
399+
**/
400+
architecture: 'gemma'
401+
}
402+
}
403+
372404
export type WhisperMetadata = {
373405
general: BaseGGUFMetadata & {
374406
/**
@@ -416,5 +448,6 @@ export type GGUFMetadata =
416448
| GPT2Metadata
417449
| BloomMetadata
418450
| FalconMetadata
451+
| GemmaMetadata
419452
| RWKVMetadata
420453
| WhisperMetadata

src/zodValidators.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export const architectureTypeSchema = z.union([
99
z.literal('gpt2'),
1010
z.literal('bloom'),
1111
z.literal('falcon'),
12+
z.literal('gemma'),
1213
z.literal('rwkv'),
1314
])
1415

@@ -218,6 +219,25 @@ export const rWKVMetadataSchema = z.object({
218219
}),
219220
})
220221

222+
export const gemmaMetadataSchema = z.object({
223+
gemma: z.object({
224+
attention: z.object({
225+
head_count: z.number(),
226+
head_count_kv: z.number().optional(),
227+
layer_norm_rms_epsilon: z.number(),
228+
}),
229+
block_count: z.number(),
230+
context_length: z.number(),
231+
embedding_length: z.number(),
232+
feed_forward_length: z.number(),
233+
}),
234+
general: baseGGUFMetadataSchema.and(
235+
z.object({
236+
architecture: z.literal('gemma'),
237+
}),
238+
),
239+
})
240+
221241
export const whisperMetadataSchema = z.object({
222242
general: baseGGUFMetadataSchema.and(
223243
z.object({
@@ -253,6 +273,7 @@ export const gGUFMetadataSchema = z.union([
253273
gPT2MetadataSchema,
254274
bloomMetadataSchema,
255275
falconMetadataSchema,
276+
gemmaMetadataSchema,
256277
rWKVMetadataSchema,
257278
whisperMetadataSchema,
258279
])

0 commit comments

Comments
 (0)