Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add AbortSignal support across library components #1193

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ export class PretrainedConfig {
cache_dir = null,
local_files_only = false,
revision = 'main',
abort_signal = undefined,
} = {}) {
if (config && !(config instanceof PretrainedConfig)) {
config = new PretrainedConfig(config);
Expand All @@ -378,6 +379,7 @@ export class PretrainedConfig {
cache_dir,
local_files_only,
revision,
abort_signal,
})
return new this(data);
}
Expand Down
4 changes: 4 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ export class PreTrainedModel extends Callable {
local_files_only = false,
revision = 'main',
model_file_name = null,
abort_signal = undefined,
subfolder = 'onnx',
device = null,
dtype = null,
Expand All @@ -999,6 +1000,7 @@ export class PreTrainedModel extends Callable {
dtype,
use_external_data_format,
session_options,
abort_signal,
}

const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
Expand Down Expand Up @@ -6999,6 +7001,7 @@ export class PretrainedMixin {
dtype = null,
use_external_data_format = null,
session_options = {},
abort_signal = undefined,
} = {}) {

const options = {
Expand All @@ -7013,6 +7016,7 @@ export class PretrainedMixin {
dtype,
use_external_data_format,
session_options,
abort_signal,
}
options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);

Expand Down
8 changes: 4 additions & 4 deletions src/models/mgp_str/processing_mgp_str.js
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ export class MgpstrProcessor extends Processor {
}
}
/** @type {typeof Processor.from_pretrained} */
static async from_pretrained(...args) {
const base = await super.from_pretrained(...args);
static async from_pretrained(pretrained_model_name_or_path, options) {
const base = await super.from_pretrained(pretrained_model_name_or_path, options);

// Load Transformers.js-compatible versions of the BPE and WordPiece tokenizers
const bpe_tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2") // openai-community/gpt2
const wp_tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased") // google-bert/bert-base-uncased
const bpe_tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2", { abort_signal: options?.abort_signal }) // openai-community/gpt2
const wp_tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased", { abort_signal: options?.abort_signal }) // google-bert/bert-base-uncased

// Update components
base.components = {
Expand Down
12 changes: 9 additions & 3 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,15 @@ export class Pipeline extends Callable {
* @param {PreTrainedModel} [options.model] The model used by the pipeline.
* @param {PreTrainedTokenizer} [options.tokenizer=null] The tokenizer used by the pipeline (if any).
* @param {Processor} [options.processor=null] The processor used by the pipeline (if any).
* @param {AbortSignal} [options.abort_signal=undefined] An optional AbortSignal to cancel the request.
*/
constructor({ task, model, tokenizer = null, processor = null }) {
constructor({ task, model, tokenizer = null, processor = null, abort_signal = undefined }) {
super();
this.task = task;
this.model = model;
this.tokenizer = tokenizer;
this.processor = processor;
this.abort_signal = abort_signal;
}

/** @type {DisposeType} */
Expand Down Expand Up @@ -210,6 +212,7 @@ export class Pipeline extends Callable {
* @property {PreTrainedModel} model The model used by the pipeline.
* @property {PreTrainedTokenizer} tokenizer The tokenizer used by the pipeline.
* @property {Processor} processor The processor used by the pipeline.
* @property {AbortSignal} [abort_signal=undefined] An optional AbortSignal to cancel the request.
*
* @typedef {ModelTokenizerProcessorConstructorArgs} TextAudioPipelineConstructorArgs An object used to instantiate a text- and audio-based pipeline.
* @typedef {ModelTokenizerProcessorConstructorArgs} TextImagePipelineConstructorArgs An object used to instantiate a text- and image-based pipeline.
Expand Down Expand Up @@ -2776,17 +2779,18 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi

async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) {


// Load vocoder, if not provided
if (!this.vocoder) {
console.log('No vocoder specified, using default HifiGan vocoder.');
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' });
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' , abort_signal : this.abort_signal });
}

// Load speaker embeddings as Float32Array from path/URL
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
// Load from URL with fetch
speaker_embeddings = new Float32Array(
await (await fetch(speaker_embeddings)).arrayBuffer()
await (await fetch(speaker_embeddings, { signal: this.abort_signal })).arrayBuffer()
);
}

Expand Down Expand Up @@ -3301,6 +3305,7 @@ export async function pipeline(
dtype = null,
model_file_name = null,
session_options = {},
abort_signal = undefined,
} = {}
) {
// Helper method to construct pipeline
Expand Down Expand Up @@ -3331,6 +3336,7 @@ export async function pipeline(
dtype,
model_file_name,
session_options,
abort_signal,
}

const classes = new Map([
Expand Down
4 changes: 4 additions & 0 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2682,6 +2682,7 @@ export class PreTrainedTokenizer extends Callable {
local_files_only = false,
revision = 'main',
legacy = null,
abort_signal = undefined,
} = {}) {

const info = await loadTokenizer(pretrained_model_name_or_path, {
Expand All @@ -2691,6 +2692,7 @@ export class PreTrainedTokenizer extends Callable {
local_files_only,
revision,
legacy,
abort_signal,
})

// @ts-ignore
Expand Down Expand Up @@ -4351,6 +4353,7 @@ export class AutoTokenizer {
local_files_only = false,
revision = 'main',
legacy = null,
abort_signal = undefined,
} = {}) {

const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, {
Expand All @@ -4360,6 +4363,7 @@ export class AutoTokenizer {
local_files_only,
revision,
legacy,
abort_signal,
})

// Some tokenizers are saved with the "Fast" suffix, so we remove that if present.
Expand Down
44 changes: 29 additions & 15 deletions src/utils/hub.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { dispatchCallback } from './core.js';
* @property {string} [revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id,
* since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
* NOTE: This setting is ignored for local requests.
* @property {AbortSignal} [abort_signal=undefined] An optional AbortSignal to cancel the request.
*/

/**
Expand Down Expand Up @@ -58,9 +59,11 @@ class FileResponse {
/**
* Creates a new `FileResponse` object.
* @param {string|URL} filePath
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
*/
constructor(filePath) {
constructor(filePath, abort_signal) {
this.filePath = filePath;
this.abort_signal = abort_signal;
this.headers = new Headers();

this.exists = fs.existsSync(filePath);
Expand All @@ -79,9 +82,16 @@ class FileResponse {
self.arrayBuffer().then(buffer => {
controller.enqueue(new Uint8Array(buffer));
controller.close();
})
}).catch(error => {
controller.error(error);
});

abort_signal?.addEventListener('abort', () => {
controller.error(new Error('Request aborted'));
});
}
});

} else {
this.status = 404;
this.statusText = 'Not Found';
Expand All @@ -105,7 +115,7 @@ class FileResponse {
* @returns {FileResponse} A new FileResponse object with the same properties as the current object.
*/
clone() {
let response = new FileResponse(this.filePath);
let response = new FileResponse(this.filePath, this.abort_signal);
response.exists = this.exists;
response.status = this.status;
response.statusText = this.statusText;
Expand Down Expand Up @@ -185,12 +195,13 @@ function isValidUrl(string, protocols = null, validHosts = null) {
* Helper function to get a file, using either the Fetch API or FileSystem API.
*
* @param {URL|string} urlOrPath The URL/path of the file to get.
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
* @returns {Promise<FileResponse|Response>} A promise that resolves to a FileResponse object (if the file is retrieved using the FileSystem API), or a Response object (if the file is retrieved using the Fetch API).
*/
export async function getFile(urlOrPath) {
export async function getFile(urlOrPath, abort_signal) {

if (env.useFS && !isValidUrl(urlOrPath, ['http:', 'https:', 'blob:'])) {
return new FileResponse(urlOrPath);
return new FileResponse(urlOrPath, abort_signal);

} else if (typeof process !== 'undefined' && process?.release?.name === 'node') {
const IS_CI = !!process.env?.TESTING_REMOTELY;
Expand All @@ -210,12 +221,12 @@ export async function getFile(urlOrPath) {
headers.set('Authorization', `Bearer ${token}`);
}
}
return fetch(urlOrPath, { headers });
return fetch(urlOrPath, { headers, signal: abort_signal });
} else {
// Running in a browser-environment, so we use default headers
// NOTE: We do not allow passing authorization headers in the browser,
// since this would require exposing the token to the client.
return fetch(urlOrPath);
return fetch(urlOrPath, { signal: abort_signal });
}
}

Expand Down Expand Up @@ -263,13 +274,15 @@ class FileCache {

/**
* Checks whether the given request is in the cache.
* @param {string} request
* @param {string} request
* @param {Object} options An object containing the following properties:
* @param {AbortSignal} [options.abort_signal] An optional AbortSignal to cancel the request.
* @returns {Promise<FileResponse | undefined>}
*/
async match(request) {
async match(request, { abort_signal = undefined } = {}) {

let filePath = path.join(this.path, request);
let file = new FileResponse(filePath);
let file = new FileResponse(filePath, abort_signal);

if (file.exists) {
return file;
Expand Down Expand Up @@ -309,13 +322,14 @@ class FileCache {
/**
*
* @param {FileCache|Cache} cache The cache to search
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
* @param {string[]} names The names of the item to search for
* @returns {Promise<FileResponse|Response|undefined>} The item from the cache, or undefined if not found.
*/
async function tryCache(cache, ...names) {
async function tryCache(cache, abort_signal, ...names) {
for (let name of names) {
try {
let result = await cache.match(name);
let result = await cache.match(name, {abort_signal});
if (result) return result;
} catch (e) {
continue;
Expand Down Expand Up @@ -433,7 +447,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
// 1. We first try to get from cache using the local path. In some environments (like deno),
// non-URL cache keys are not allowed. In these cases, `response` will be undefined.
// 2. If no response is found, we try to get from cache using the remote URL or file system cache.
response = await tryCache(cache, localPath, proposedCacheKey);
response = await tryCache(cache, options?.abort_signal, localPath, proposedCacheKey);
}

const cacheHit = response !== undefined;
Expand All @@ -447,7 +461,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
const isURL = isValidUrl(requestURL, ['http:', 'https:']);
if (!isURL) {
try {
response = await getFile(localPath);
response = await getFile(localPath, options?.abort_signal);
cacheKey = localPath; // Update the cache key to be the local path
} catch (e) {
// Something went wrong while trying to get the file locally.
Expand Down Expand Up @@ -479,7 +493,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
}

// File not found locally, so we try to download it from the remote server
response = await getFile(remoteURL);
response = await getFile(remoteURL, options?.abort_signal);

if (response.status !== 200) {
return handleError(response.status, remoteURL, fatal);
Expand Down