Skip to content

Commit

Permalink
Gguf updates (#543)
Browse files Browse the repository at this point in the history
1. [Use length rather than
newOffset](fcab2c9)
(discussed
[here](#540 (comment)))
2. [custom fetch
fn](18f93f3)
(discussed
[here](#540 (comment)))
  • Loading branch information
mishig25 committed Mar 12, 2024
1 parent 9366d4a commit 8ec3643
Showing 1 changed file with 44 additions and 26 deletions.
70 changes: 44 additions & 26 deletions packages/gguf/src/gguf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,26 @@ const HTTP_TOTAL_MAX_SIZE = 50 * 10 ** 6; /// 50MB
class RangeView {
private chunk: number;
private buffer: ArrayBuffer;
private fetch: typeof fetch;

readonly view: DataView;

constructor(public url: string) {
constructor(
public url: string,
params?: {
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}
) {
this.chunk = 0;
/// TODO(fix typing)
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
this.buffer = new ArrayBuffer(0, { maxByteLength: HTTP_TOTAL_MAX_SIZE });
this.view = new DataView(this.buffer);
this.fetch = params?.fetch ?? fetch;
}
/**
* Fetch a new chunk from the server
Expand All @@ -78,7 +88,7 @@ class RangeView {
const range = [this.chunk * HTTP_CHUNK_SIZE, (this.chunk + 1) * HTTP_CHUNK_SIZE - 1];
const buf = new Uint8Array(
await (
await fetch(this.url, {
await this.fetch(this.url, {
headers: {
Range: `bytes=${range[0]}-${range[1]}`,
},
Expand Down Expand Up @@ -115,54 +125,54 @@ function readVersionedSize(view: DataView, byteOffset: number, version: Version)
}
}

function readString(view: DataView, offset: number): { value: string; newOffset: number } {
function readString(view: DataView, offset: number): { value: string; length: number } {
const length = view.getBigUint64(offset, true);
const value = new TextDecoder().decode(view.buffer.slice(offset + 8, offset + 8 + Number(length)));
return { value, newOffset: offset + 8 + Number(length) };
return { value, length: 8 + Number(length) };
}

function readMetadataValue(
view: DataView,
type: GGUFValueType,
offset: number
): { value: MetadataValue; newOffset: number } {
): { value: MetadataValue; length: number } {
switch (type) {
case GGUFValueType.UINT8:
return { value: view.getUint8(offset), newOffset: offset + 1 };
return { value: view.getUint8(offset), length: 1 };
case GGUFValueType.INT8:
return { value: view.getInt8(offset), newOffset: offset + 1 };
return { value: view.getInt8(offset), length: 1 };
case GGUFValueType.UINT16:
return { value: view.getUint16(offset, true), newOffset: offset + 2 };
return { value: view.getUint16(offset, true), length: 2 };
case GGUFValueType.INT16:
return { value: view.getInt16(offset, true), newOffset: offset + 2 };
return { value: view.getInt16(offset, true), length: 2 };
case GGUFValueType.UINT32:
return { value: view.getUint32(offset, true), newOffset: offset + 4 };
return { value: view.getUint32(offset, true), length: 4 };
case GGUFValueType.INT32:
return { value: view.getInt32(offset, true), newOffset: offset + 4 };
return { value: view.getInt32(offset, true), length: 4 };
case GGUFValueType.FLOAT32:
return { value: view.getFloat32(offset, true), newOffset: offset + 4 };
return { value: view.getFloat32(offset, true), length: 4 };
case GGUFValueType.BOOL:
return { value: view.getUint8(offset) !== 0, newOffset: offset + 1 };
return { value: view.getUint8(offset) !== 0, length: 1 };
case GGUFValueType.STRING:
return readString(view, offset);
case GGUFValueType.ARRAY: {
const arrayType = view.getUint32(offset, true);
const arrayLength = view.getBigUint64(offset + 4, true);
let arrayOffset = offset + 12;
let length = 12;
const arrayValues: MetadataValue[] = [];
for (let i = 0; i < arrayLength; i++) {
const { value, newOffset } = readMetadataValue(view, arrayType, arrayOffset);
const { value, length: _length } = readMetadataValue(view, arrayType, offset + length);
arrayValues.push(value);
arrayOffset = newOffset;
length += _length;
}
return { value: arrayValues, newOffset: arrayOffset };
return { value: arrayValues, length };
}
case GGUFValueType.UINT64:
return { value: view.getBigUint64(offset, true), newOffset: offset + 8 };
return { value: view.getBigUint64(offset, true), length: 8 };
case GGUFValueType.INT64:
return { value: view.getBigInt64(offset, true), newOffset: offset + 8 };
return { value: view.getBigInt64(offset, true), length: 8 };
case GGUFValueType.FLOAT64:
return { value: view.getFloat64(offset, true), newOffset: offset + 8 };
return { value: view.getFloat64(offset, true), length: 8 };
}
}

Expand All @@ -185,8 +195,16 @@ export interface GGUFParseOutput {
tensorInfos: GGUFTensorInfo[];
}

export async function gguf(url: string): Promise<GGUFParseOutput> {
const r = new RangeView(url);
export async function gguf(
url: string,
params?: {
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}
): Promise<GGUFParseOutput> {
const r = new RangeView(url, params);
await r.fetchChunk();

if (r.view.getUint32(0, true) !== new DataView(ggufMagicNumber.buffer).getUint32(0, true)) {
Expand All @@ -213,7 +231,7 @@ export async function gguf(url: string): Promise<GGUFParseOutput> {

// read key
const keyResult = readString(r.view, offset);
offset = keyResult.newOffset;
offset += keyResult.length;

// read value type
const valueType = r.view.getUint32(offset, true);
Expand All @@ -223,7 +241,7 @@ export async function gguf(url: string): Promise<GGUFParseOutput> {
throw new Error("Unsupported metadata type: " + valueType);
}

let valueResult: { value: MetadataValue; newOffset: number } | undefined;
let valueResult: ReturnType<typeof readMetadataValue> | undefined;
while (!valueResult) {
try {
// read value
Expand All @@ -236,7 +254,7 @@ export async function gguf(url: string): Promise<GGUFParseOutput> {
}
}
}
offset = valueResult.newOffset;
offset += valueResult.length;
metadata[keyResult.value] = valueResult.value;
}

Expand All @@ -247,7 +265,7 @@ export async function gguf(url: string): Promise<GGUFParseOutput> {

// read tensor name
const keyResult = readString(r.view, offset);
offset = keyResult.newOffset;
offset += keyResult.length;

const nDims = r.view.getUint32(offset, true);
offset += 4;
Expand Down

0 comments on commit 8ec3643

Please sign in to comment.