Skip to content

Commit

Permalink
general fixes and clean up (#284)
Browse files Browse the repository at this point in the history
* refactor reranker

* refactor reranker logs

* copy native assets for udx-native and sodium-native

* 3.13.18
  • Loading branch information
rjmacarthy authored Aug 9, 2024
1 parent 784b967 commit f39bcef
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 50 deletions.
10 changes: 0 additions & 10 deletions .vscodeignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,3 @@ scripts
docs
target
.vscode-test.json
!out/node_modules/**

!node_modules/udx-native/**
!node_modules/sodium-native/**
!node_modules/b4a/**
!node_modules/node-gyp-build/**
!node_modules/streamx/**
!node_modules/queue-tick/**
!node_modules/text-decoder/**
!node_modules/fast-fifo/**
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "twinny",
"displayName": "twinny - AI Code Completion and Chat",
"description": "Locally hosted AI code completion plugin for vscode",
"version": "3.13.17",
"version": "3.13.18",
"icon": "assets/icon.png",
"keywords": [
"code-inference",
Expand Down
10 changes: 9 additions & 1 deletion scripts/build.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import esbuild from 'esbuild'
import { copy } from 'esbuild-plugin-copy';


(async () => {
const extensionConfig = {
bundle: true,
Expand Down Expand Up @@ -39,6 +38,15 @@ import { copy } from 'esbuild-plugin-copy';
from: './node_modules/web-tree-sitter/tree-sitter.wasm',
to: './out/tree-sitter.wasm'
}
,
{
from: `./node_modules/sodium-native/prebuilds/${process.platform}-${process.arch}/sodium-native.node`,
to: './out/sodium-native.node'
},
{
from: `./node_modules/udx-native/prebuilds/${process.platform}-${process.arch}/udx-native.node`,
to: './out/udx-native.node'
}
],
watch: true,
}),
Expand Down
90 changes: 54 additions & 36 deletions src/extension/reranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import * as path from 'path'
import { Toxe } from 'toxe'
import { Logger } from '../common/logger'

// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
ort.env.wasm.numThreads = 1

const logger = new Logger()
Expand All @@ -30,50 +28,69 @@ export class Reranker {
}
}

public sigmoid(value: number) {
return 1 / (1 + Math.exp(-value))
}

public async rerank(sample: string, samples: string[]): Promise<number[] | undefined> {
const ids = await this._tokenizer?.encode(sample, samples);
if (!ids?.length) return undefined;

const inputIds = ids.map((id) => BigInt(id));
const inputTensor = new ort.Tensor('int64', BigInt64Array.from(inputIds), [
samples.length,
inputIds.length / samples.length
]);
public async rerank(
sample: string,
samples: string[]
): Promise<number[] | undefined> {
const ids = await this._tokenizer?.encode(sample, samples)
if (!ids?.length) return undefined

const attentionMaskTensor = new ort.Tensor(
'int64',
new BigInt64Array(inputIds.length).fill(1n),
[samples.length, inputIds.length / samples.length]
);
const inputTensor = this.getInputTensor(ids, samples.length)
const attentionMaskTensor = this.getOutputTensor(
ids.length,
samples.length
)

const output = await this._session?.run({
input_ids: inputTensor,
attention_mask: attentionMaskTensor
});
})

if (!output) return undefined;
if (!output) return undefined

const data = await output.logits.getData();
const logits = Array.prototype.slice.call(data);

const normalizedProbabilities = this.softmax(logits);
const logits = await this.getLogits(output)
const normalizedProbabilities = this.softmax(logits)

logger.log(
`Reranked samples: \n${this.formatResults(samples, normalizedProbabilities)}`
);
`Reranked samples: \n${this.formatResults(
samples,
normalizedProbabilities
)}`
)
return normalizedProbabilities
}

return normalizedProbabilities;
private getInputTensor(ids: number[], sampleCount: number): ort.Tensor {
const inputIds = ids.map(BigInt)
return new ort.Tensor('int64', BigInt64Array.from(inputIds), [
sampleCount,
inputIds.length / sampleCount
])
}

private getOutputTensor(
inputLength: number,
sampleCount: number
): ort.Tensor {
return new ort.Tensor('int64', new BigInt64Array(inputLength).fill(1n), [
sampleCount,
inputLength / sampleCount
])
}

private async getLogits(
output: ort.InferenceSession.OnnxValueMapType
): Promise<number[]> {
const data = await output.logits.getData()
const logits = Array.prototype.slice.call(data)
return logits
}

private softmax(logits: number[]): number[] {
const maxLogit = Math.max(...logits);
const scores = logits.map(l => Math.exp(l - maxLogit));
const sum = scores.reduce((a, b) => a + b, 0);
return scores.map(s => s / sum);
const maxLogit = Math.max(...logits)
const scores = logits.map((l) => Math.exp(l - maxLogit))
const sum = scores.reduce((a, b) => a + b, 0)
return scores.map((s) => s / sum)
}

private formatResults(samples: string[], probabilities: number[]): string {
Expand All @@ -84,12 +101,13 @@ export class Reranker {

private async loadModel(): Promise<void> {
try {
logger.log('Loading reranker model...')
this._session = await ort.InferenceSession.create(this._modelPath, {
executionProviders: ['wasm']
})
logger.log(`Model loaded from ${this._modelPath}`)
logger.log('Reranker model loaded')
} catch (error) {
console.error('Error loading model:', error)
console.error(error)
throw error
}
}
Expand All @@ -100,7 +118,7 @@ export class Reranker {
this._tokenizer = new Toxe(this._tokenizerPath)
logger.log('Tokenizer loaded')
} catch (error) {
console.error('Error loading tokenizer:', error)
console.error(error)
throw error
}
}
Expand Down

0 comments on commit f39bcef

Please sign in to comment.