Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add AbortSignal support across library components #1193

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sroussey
Copy link
Contributor

  • Introduce abort_signal parameter to multiple methods and constructors
  • Update file retrieval and loading mechanisms to support request cancellation
  • Add AbortSignal handling in pipelines, models, tokenizers, and utility functions

- Introduce `abort_signal` parameter to multiple methods and constructors
- Update file retrieval and loading mechanisms to support request cancellation
- Add AbortSignal handling in pipelines, models, tokenizers, and utility functions
Copy link
Contributor Author

@sroussey sroussey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xenova -- thoughts?

@emojiiii
Copy link
Contributor

Looks like the same as #1190

@sroussey
Copy link
Contributor Author

Looks like the same as #1190

I will have a look

@sroussey
Copy link
Contributor Author

@xenova a few thoughts on what i did:

I added abort_signal to PretrainedOptions. Most places just pass the options around, but many places rebuild them, so you will see more changes than expected because of this.

I also added it to the ModelTokenizerProcessorConstructorArgs constructor and there are edge cases like _call_text_to_spectrogram. Not sure I like this though. :/

This PR so far only handles downloading models, etc.

Do you know where I should look as far as ONNX/ORT to cancel generation?

@sroussey
Copy link
Contributor Author

Looks like the same as #1190

I will have a look

I can see reasons to have a custom fetch() but I think those are orthogonal to the ability to abort.

But perhaps we could have both.

I want to be able to abort generation as well as downloads, in which case fetch is not enough.

@sroussey
Copy link
Contributor Author

For reference: microsoft/onnxruntime#23703

@sroussey
Copy link
Contributor Author

Also, I made abort_signal required in various places... that is to help me find the places I have missed. I will make it optional when done.

- Add optional `abort_signal` parameter to `read_audio()` function
- Update `RawImage.read()` and `RawImage.fromURL()` to support AbortSignal
- Modify `VLChatProcessor` to include `abort_signal` configuration
- Extend image and audio preparation methods in pipelines to pass AbortSignal
@joecal
Copy link

joecal commented Feb 15, 2025

Looks like the same as #1190

I will have a look

I can see reasons to have a custom fetch() but I think those are orthogonal to the ability to abort.

But perhaps we could have both.

I want to be able to abort generation as well as downloads, in which case fetch is not enough.

I’ve been trying to figure out a way to stop text generation and so far I’ve come up with the following:

Something like this abortOnSignal helper function to wrap async operations to abort:

/**
 * A helper function to wrap asynchronous operations that can be stopped/aborted by an AbortSignal.
 * @param {Promise<unknown>} promise Any promise/async operation to be aborted.
 * @param {AbortSignal | null | undefined} signal The abort signal from an AbortController that can cancel the wrapped promise/async operation.
 * @returns {Promise<unknown>} Either the param promise or a rejected promise with AbortError if aborted.
 */
export function abortOnSignal(promise, signal) {
    if (!signal) {
        return promise;
    }
    return new Promise((resolve, reject) => {
        const abortHandler = () => {
            reject(new DOMException('Aborted', 'AbortError'));
        };
        signal.addEventListener('abort', abortHandler, { once: true });
        promise.then(resolve, reject)
            .finally(() => {
                signal.removeEventListener('abort', abortHandler);
            });
    });
}

Then add an abort signal param to the sessionRun function in models.js and wrap the session.run invocation with the abortOnSignal function with the signal param passed into it:

/**
 * Executes an InferenceSession using the specified inputs.
 * NOTE: `inputs` must contain at least the input names of the model.
 *  - If additional inputs are passed, they will be ignored.
 *  - If inputs are missing, an error will be thrown.
 * 
 * @param {Object} session The InferenceSession object to run.
 * @param {Object} inputs An object that maps input names to input tensors.
 * @param {AbortSignal | null | undefined} signal Optional abort signal from an AbortController to cancel the InferenceSession.
 * @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors.
 * @private
 */
async function sessionRun(session, inputs, signal = null) {
    const checkedInputs = validateInputs(session, inputs);
    try {
        // pass the original ort tensor
        const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
        let output = await abortOnSignal(session.run(ortFeed), signal);
        output = replaceTensors(output);
        return output;
    } catch (e) {
        if (e.name === 'AbortError') {
            console.log("Generation aborted by user.");
        }
        // Error messages can be long (nested) and uninformative. For this reason,
        // we apply minor formatting to show the most important information
        const formatted = Object.fromEntries(Object.entries(checkedInputs)
            .map(([k, { type, dims, data }]) => [k, {
                // Extract these properties from the underlying ORT tensor
                type, dims, data,
            }]));

        // This usually occurs when the inputs are of the wrong type.
        console.error(`An error occurred during model execution: "${e}".`);
        console.error('Inputs given to model:', formatted);
        throw e;
    }
}

To use it you pass the signal into the generator as an argument:

const prompt = “How many letter ‘r’s are in the word ‘strawberry’? 
const controller = new AbortController();
const generator = await pipeline('text-generation', MODEL, {
    dtype: 'q4f16',
    device: 'webgpu'
});
const generationArgs = { max_new_tokens: maxTokens, streamer: streamer, signal: controller.signal };
await generator(prompt, generationArgs);

Here’s a little demo of it working (Ignore the text output nonsense. I think my laptop processor is messed up somehow):

0214.mp4

It works at stopping the generation but the way I have it now feels kinda messy passing the signal around so much.

Take TextGenerationPipeline for example. The signal gets passed in through its _call args into this.model.generate then into this.forward then into this._forward which is decoderForward then into decoderForwards invocation of sessionRun then finally into the abortOnSignal wrapped around session.run.

I'm sure there's a cleaner way to do it. Would love to get your guys take on this approach.

@sroussey
Copy link
Contributor Author

I haven't tried it myself yet, but it is definitely undocumented. I don't see signal or anything in the docs.

https://onnxruntime.ai/docs/api/js/interfaces/InferenceSession.SessionOptions.html#freeDimensionOverrides

@sroussey
Copy link
Contributor Author

sroussey commented Feb 15, 2025

Oh, I see now... you are simulating an abort. The text generation keeps going though.

If you are using progress callbacks, you can just ignore them (after the user hits stop) to the same effect.

@joecal
Copy link

joecal commented Feb 15, 2025

Oh, I see now... you are simulating an abort. The text generation keeps going though.

If you are using progress callbacks, you can just ignore them (after the user hits stop) to the same effect.

Oh, the text generation actually does stop completely when the abort signal is triggered. I did some testing with Chrome dev tools and the CPU usage drops to 0% when aborted, showing the generation process fully terminates.

The way it works is:

  1. When the abort signal triggers, the abortOnSignal wrapper rejects the ONNX session.run() Promise with an AbortError
  2. This error bubbles up through the generation loop in generate(), stopping the entire process

But yeah, I'm not super happy with having to pass the signal through so many layers. I was trying to follow the pattern used by the Fetch API where you need to get the abort signal all the way down to the actual async operation being aborted. Would be cool if in the future the InferenceSession could support the abort signal internally.

But maybe there's a cleaner way to structure this. Like maybe we could store the signal at the pipeline level instead of threading it through all those args. Let me know if you have any other ideas for improving it.

@joecal
Copy link

joecal commented Feb 17, 2025

It seems like you're looking for https://github.com/huggingface/transformers.js/blob/main/src/models.js#L1573, you can refer to https://github.com/huggingface/transformers.js-examples/blob/main/deepseek-r1-webgpu/src/worker.js#L114.

I've been using the EosTokenCriteria and found it works great at stopping text generation at known specific token outputs. However, when I encountered the funny-looking text output issue in the video I posted above and then saw these abort signal PR's, it got me thinking: it would be cool if we could use an abort signal to stop any long-running async task in the entire pipeline, whether the task is downloading with fetch, token generation, or any other time-consuming asynchronous task. What do you guys think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants