Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add AbortSignal support across library components #1193

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ export class PretrainedConfig {
cache_dir = null,
local_files_only = false,
revision = 'main',
abort_signal = undefined,
} = {}) {
if (config && !(config instanceof PretrainedConfig)) {
config = new PretrainedConfig(config);
Expand All @@ -378,6 +379,7 @@ export class PretrainedConfig {
cache_dir,
local_files_only,
revision,
abort_signal,
})
return new this(data);
}
Expand Down
4 changes: 4 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ export class PreTrainedModel extends Callable {
local_files_only = false,
revision = 'main',
model_file_name = null,
abort_signal = undefined,
subfolder = 'onnx',
device = null,
dtype = null,
Expand All @@ -999,6 +1000,7 @@ export class PreTrainedModel extends Callable {
dtype,
use_external_data_format,
session_options,
abort_signal,
}

const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
Expand Down Expand Up @@ -6999,6 +7001,7 @@ export class PretrainedMixin {
dtype = null,
use_external_data_format = null,
session_options = {},
abort_signal = undefined,
} = {}) {

const options = {
Expand All @@ -7013,6 +7016,7 @@ export class PretrainedMixin {
dtype,
use_external_data_format,
session_options,
abort_signal,
}
options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);

Expand Down
3 changes: 2 additions & 1 deletion src/models/janus/processing_janus.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export class VLChatProcessor extends Processor {
this.image_start_tag = this.config.image_start_tag;
this.image_end_tag = this.config.image_end_tag;
this.num_image_tokens = this.config.num_image_tokens;
this.abort_signal = this.config.abort_signal;
}

/**
Expand Down Expand Up @@ -50,7 +51,7 @@ export class VLChatProcessor extends Processor {
conversation
.filter((msg) => msg.images)
.flatMap((msg) => msg.images)
.map((img) => RawImage.read(img))
.map((img) => RawImage.read(img, this.abort_signal))
);
} else if (!Array.isArray(images)) {
images = [images];
Expand Down
8 changes: 4 additions & 4 deletions src/models/mgp_str/processing_mgp_str.js
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ export class MgpstrProcessor extends Processor {
}
}
/** @type {typeof Processor.from_pretrained} */
static async from_pretrained(...args) {
const base = await super.from_pretrained(...args);
static async from_pretrained(pretrained_model_name_or_path, options) {
const base = await super.from_pretrained(pretrained_model_name_or_path, options);

// Load Transformers.js-compatible versions of the BPE and WordPiece tokenizers
const bpe_tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2") // openai-community/gpt2
const wp_tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased") // google-bert/bert-base-uncased
const bpe_tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2", { abort_signal: options?.abort_signal }) // openai-community/gpt2
const wp_tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased", { abort_signal: options?.abort_signal }) // google-bert/bert-base-uncased

// Update components
base.components = {
Expand Down
53 changes: 31 additions & 22 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,17 @@ import { RawImage } from './utils/image.js';
/**
* Prepare images for further tasks.
* @param {ImagePipelineInputs} images images to prepare.
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
* @returns {Promise<RawImage[]>} returns processed images.
* @private
*/
async function prepareImages(images) {
async function prepareImages(images, abort_signal) {
if (!Array.isArray(images)) {
images = [images];
}

// Possibly convert any non-images to images
return await Promise.all(images.map(x => RawImage.read(x)));
return await Promise.all(images.map(x => RawImage.read(x, abort_signal)));
}

/**
Expand All @@ -106,17 +107,18 @@ async function prepareImages(images) {
* Prepare audios for further tasks.
* @param {AudioPipelineInputs} audios audios to prepare.
* @param {number} sampling_rate sampling rate of the audios.
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
* @returns {Promise<Float32Array[]>} The preprocessed audio data.
* @private
*/
async function prepareAudios(audios, sampling_rate) {
async function prepareAudios(audios, sampling_rate, abort_signal) {
if (!Array.isArray(audios)) {
audios = [audios];
}

return await Promise.all(audios.map(x => {
if (typeof x === 'string' || x instanceof URL) {
return read_audio(x, sampling_rate);
return read_audio(x, sampling_rate, abort_signal);
} else if (x instanceof Float64Array) {
return new Float32Array(x);
}
Expand Down Expand Up @@ -169,13 +171,15 @@ export class Pipeline extends Callable {
* @param {PreTrainedModel} [options.model] The model used by the pipeline.
* @param {PreTrainedTokenizer} [options.tokenizer=null] The tokenizer used by the pipeline (if any).
* @param {Processor} [options.processor=null] The processor used by the pipeline (if any).
* @param {AbortSignal} [options.abort_signal=undefined] An optional AbortSignal to cancel the request.
*/
constructor({ task, model, tokenizer = null, processor = null }) {
constructor({ task, model, tokenizer = null, processor = null, abort_signal = undefined }) {
super();
this.task = task;
this.model = model;
this.tokenizer = tokenizer;
this.processor = processor;
this.abort_signal = abort_signal;
}

/** @type {DisposeType} */
Expand All @@ -198,6 +202,7 @@ export class Pipeline extends Callable {
* @property {string} task The task of the pipeline. Useful for specifying subtasks.
* @property {PreTrainedModel} model The model used by the pipeline.
* @property {Processor} processor The processor used by the pipeline.
* @property {AbortSignal} [abort_signal=undefined] An optional AbortSignal to cancel the request.
*
* @typedef {ModelProcessorConstructorArgs} AudioPipelineConstructorArgs An object used to instantiate an audio-based pipeline.
* @typedef {ModelProcessorConstructorArgs} ImagePipelineConstructorArgs An object used to instantiate an image-based pipeline.
Expand All @@ -210,6 +215,7 @@ export class Pipeline extends Callable {
* @property {PreTrainedModel} model The model used by the pipeline.
* @property {PreTrainedTokenizer} tokenizer The tokenizer used by the pipeline.
* @property {Processor} processor The processor used by the pipeline.
* @property {AbortSignal} [abort_signal=undefined] An optional AbortSignal to cancel the request.
*
* @typedef {ModelTokenizerProcessorConstructorArgs} TextAudioPipelineConstructorArgs An object used to instantiate a text- and audio-based pipeline.
* @typedef {ModelTokenizerProcessorConstructorArgs} TextImagePipelineConstructorArgs An object used to instantiate a text- and image-based pipeline.
Expand Down Expand Up @@ -1401,7 +1407,7 @@ export class ImageFeatureExtractionPipeline extends (/** @type {new (options: Im
pool = null,
} = {}) {

const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);
const { pixel_values } = await this.processor(preparedImages);
const outputs = await this.model({ pixel_values });

Expand Down Expand Up @@ -1491,7 +1497,7 @@ export class AudioClassificationPipeline extends (/** @type {new (options: Audio
} = {}) {

const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
const preparedAudios = await prepareAudios(audio, sampling_rate);
const preparedAudios = await prepareAudios(audio, sampling_rate, this.abort_signal);

// @ts-expect-error TS2339
const id2label = this.model.config.id2label;
Expand Down Expand Up @@ -1593,7 +1599,7 @@ export class ZeroShotAudioClassificationPipeline extends (/** @type {new (option
});

const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
const preparedAudios = await prepareAudios(audio, sampling_rate);
const preparedAudios = await prepareAudios(audio, sampling_rate, this.abort_signal);

const toReturn = [];
for (const aud of preparedAudios) {
Expand Down Expand Up @@ -1764,7 +1770,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
}

const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
const preparedAudios = await prepareAudios(audio, sampling_rate);
const preparedAudios = await prepareAudios(audio, sampling_rate, this.abort_signal);

const toReturn = [];
for (const aud of preparedAudios) {
Expand Down Expand Up @@ -1809,7 +1815,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
const hop_length = this.processor.feature_extractor.config.hop_length;

const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
const preparedAudios = await prepareAudios(audio, sampling_rate);
const preparedAudios = await prepareAudios(audio, sampling_rate, this.abort_signal);

const toReturn = [];
for (const aud of preparedAudios) {
Expand Down Expand Up @@ -1906,7 +1912,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options
audio = [/** @type {AudioInput} */ (audio)];
}
const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
const preparedAudios = await prepareAudios(audio, sampling_rate);
const preparedAudios = await prepareAudios(audio, sampling_rate, this.abort_signal);
const toReturn = [];
for (const aud of preparedAudios) {
const inputs = await this.processor(aud);
Expand Down Expand Up @@ -1971,7 +1977,7 @@ export class ImageToTextPipeline extends (/** @type {new (options: TextImagePipe
async _call(images, generate_kwargs = {}) {

const isBatched = Array.isArray(images);
const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);

const { pixel_values } = await this.processor(preparedImages);

Expand Down Expand Up @@ -2061,7 +2067,7 @@ export class ImageClassificationPipeline extends (/** @type {new (options: Image
top_k = 5
} = {}) {

const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);

const { pixel_values } = await this.processor(preparedImages);
const output = await this.model({ pixel_values });
Expand Down Expand Up @@ -2162,7 +2168,7 @@ export class ImageSegmentationPipeline extends (/** @type {new (options: ImagePi
throw Error("Image segmentation pipeline currently only supports a batch size of 1.");
}

const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);
const imageSizes = preparedImages.map(x => [x.height, x.width]);

const { pixel_values, pixel_mask } = await this.processor(preparedImages);
Expand Down Expand Up @@ -2292,7 +2298,7 @@ export class ZeroShotImageClassificationPipeline extends (/** @type {new (option
} = {}) {

const isBatched = Array.isArray(images);
const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);

// Insert label into hypothesis template
const texts = candidate_labels.map(
Expand Down Expand Up @@ -2397,7 +2403,7 @@ export class ObjectDetectionPipeline extends (/** @type {new (options: ImagePipe
if (isBatched && images.length !== 1) {
throw Error("Object detection pipeline currently only supports a batch size of 1.");
}
const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);

const imageSizes = percentage ? null : preparedImages.map(x => [x.height, x.width]);

Expand Down Expand Up @@ -2530,7 +2536,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
} = {}) {

const isBatched = Array.isArray(images);
const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);

// Run tokenization
const text_inputs = this.tokenizer(candidate_labels, {
Expand Down Expand Up @@ -2636,7 +2642,7 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:
// NOTE: For now, we only support a batch size of 1

// Preprocess image
const preparedImage = (await prepareImages(image))[0];
const preparedImage = (await prepareImages(image, this.abort_signal ))[0];
const { pixel_values } = await this.processor(preparedImage);

// Run tokenization
Expand Down Expand Up @@ -2776,17 +2782,18 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi

async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) {


// Load vocoder, if not provided
if (!this.vocoder) {
console.log('No vocoder specified, using default HifiGan vocoder.');
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' });
this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' , abort_signal : this.abort_signal });
}

// Load speaker embeddings as Float32Array from path/URL
if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) {
// Load from URL with fetch
speaker_embeddings = new Float32Array(
await (await fetch(speaker_embeddings)).arrayBuffer()
await (await fetch(speaker_embeddings, { signal: this.abort_signal })).arrayBuffer()
);
}

Expand Down Expand Up @@ -2854,7 +2861,7 @@ export class ImageToImagePipeline extends (/** @type {new (options: ImagePipelin
/** @type {ImageToImagePipelineCallback} */
async _call(images) {

const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);
const inputs = await this.processor(preparedImages);
const outputs = await this.model(inputs);

Expand Down Expand Up @@ -2917,7 +2924,7 @@ export class DepthEstimationPipeline extends (/** @type {new (options: ImagePipe
/** @type {DepthEstimationPipelineCallback} */
async _call(images) {

const preparedImages = await prepareImages(images);
const preparedImages = await prepareImages(images, this.abort_signal);

const inputs = await this.processor(preparedImages);
const { predicted_depth } = await this.model(inputs);
Expand Down Expand Up @@ -3301,6 +3308,7 @@ export async function pipeline(
dtype = null,
model_file_name = null,
session_options = {},
abort_signal = undefined,
} = {}
) {
// Helper method to construct pipeline
Expand Down Expand Up @@ -3331,6 +3339,7 @@ export async function pipeline(
dtype,
model_file_name,
session_options,
abort_signal,
}

const classes = new Map([
Expand Down
4 changes: 4 additions & 0 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2682,6 +2682,7 @@ export class PreTrainedTokenizer extends Callable {
local_files_only = false,
revision = 'main',
legacy = null,
abort_signal = undefined,
} = {}) {

const info = await loadTokenizer(pretrained_model_name_or_path, {
Expand All @@ -2691,6 +2692,7 @@ export class PreTrainedTokenizer extends Callable {
local_files_only,
revision,
legacy,
abort_signal,
})

// @ts-ignore
Expand Down Expand Up @@ -4351,6 +4353,7 @@ export class AutoTokenizer {
local_files_only = false,
revision = 'main',
legacy = null,
abort_signal = undefined,
} = {}) {

const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, {
Expand All @@ -4360,6 +4363,7 @@ export class AutoTokenizer {
local_files_only,
revision,
legacy,
abort_signal,
})

// Some tokenizers are saved with the "Fast" suffix, so we remove that if present.
Expand Down
5 changes: 3 additions & 2 deletions src/utils/audio.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import { Tensor, matmul } from './tensor.js';
* Helper function to read audio from a path/URL.
* @param {string|URL} url The path/URL to load the audio from.
* @param {number} sampling_rate The sampling rate to use when decoding the audio.
* @param {AbortSignal} abort_signal An optional AbortSignal to cancel the request.
* @returns {Promise<Float32Array>} The decoded audio as a `Float32Array`.
*/
export async function read_audio(url, sampling_rate) {
export async function read_audio(url, sampling_rate, abort_signal) {
if (typeof AudioContext === 'undefined') {
// Running in node or an environment without AudioContext
throw Error(
Expand All @@ -35,7 +36,7 @@ export async function read_audio(url, sampling_rate) {
)
}

const response = await (await getFile(url)).arrayBuffer();
const response = await (await getFile(url, abort_signal)).arrayBuffer();
const audioCTX = new AudioContext({ sampleRate: sampling_rate });
if (typeof sampling_rate === 'undefined') {
console.warn(`No sampling rate provided, using default of ${audioCTX.sampleRate}Hz.`)
Expand Down
Loading