-
Notifications
You must be signed in to change notification settings - Fork 856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add AbortSignal support across library components #1193
base: main
Are you sure you want to change the base?
Conversation
- Introduce `abort_signal` parameter to multiple methods and constructors - Update file retrieval and loading mechanisms to support request cancellation - Add AbortSignal handling in pipelines, models, tokenizers, and utility functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xenova -- thoughts?
Looks like the same as #1190 |
I will have a look |
@xenova a few thoughts on what i did: I added abort_signal to PretrainedOptions. Most places just pass the options around, but many places rebuild them, so you will see more changes than expected because of this. I also added it to the ModelTokenizerProcessorConstructorArgs constructor and there are edge cases like _call_text_to_spectrogram. Not sure I like this though. :/ This PR so far only handles downloading models, etc. Do you know where I should look as far as ONNX/ORT to cancel generation? |
I can see reasons to have a custom fetch() but I think those are orthogonal to the ability to abort. But perhaps we could have both. I want to be able to abort generation as well as downloads, in which case fetch is not enough. |
For reference: microsoft/onnxruntime#23703 |
Also, I made |
- Add optional `abort_signal` parameter to `read_audio()` function - Update `RawImage.read()` and `RawImage.fromURL()` to support AbortSignal - Modify `VLChatProcessor` to include `abort_signal` configuration - Extend image and audio preparation methods in pipelines to pass AbortSignal
I’ve been trying to figure out a way to stop text generation and so far I’ve come up with the following: Something like this /**
* A helper function to wrap asynchronous operations that can be stopped/aborted by an AbortSignal.
* @param {Promise<unknown>} promise Any promise/async operation to be aborted.
* @param {AbortSignal | null | undefined} signal The abort signal from an AbortController that can cancel the wrapped promise/async operation.
* @returns {Promise<unknown>} Either the param promise or a rejected promise with AbortError if aborted.
*/
export function abortOnSignal(promise, signal) {
if (!signal) {
return promise;
}
return new Promise((resolve, reject) => {
const abortHandler = () => {
reject(new DOMException('Aborted', 'AbortError'));
};
signal.addEventListener('abort', abortHandler, { once: true });
promise.then(resolve, reject)
.finally(() => {
signal.removeEventListener('abort', abortHandler);
});
});
} Then add an abort /**
* Executes an InferenceSession using the specified inputs.
* NOTE: `inputs` must contain at least the input names of the model.
* - If additional inputs are passed, they will be ignored.
* - If inputs are missing, an error will be thrown.
*
* @param {Object} session The InferenceSession object to run.
* @param {Object} inputs An object that maps input names to input tensors.
* @param {AbortSignal | null | undefined} signal Optional abort signal from an AbortController to cancel the InferenceSession.
* @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors.
* @private
*/
async function sessionRun(session, inputs, signal = null) {
const checkedInputs = validateInputs(session, inputs);
try {
// pass the original ort tensor
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
let output = await abortOnSignal(session.run(ortFeed), signal);
output = replaceTensors(output);
return output;
} catch (e) {
if (e.name === 'AbortError') {
console.log("Generation aborted by user.");
}
// Error messages can be long (nested) and uninformative. For this reason,
// we apply minor formatting to show the most important information
const formatted = Object.fromEntries(Object.entries(checkedInputs)
.map(([k, { type, dims, data }]) => [k, {
// Extract these properties from the underlying ORT tensor
type, dims, data,
}]));
// This usually occurs when the inputs are of the wrong type.
console.error(`An error occurred during model execution: "${e}".`);
console.error('Inputs given to model:', formatted);
throw e;
}
} To use it you pass the signal into the generator as an argument: const prompt = “How many letter ‘r’s are in the word ‘strawberry’?”
const controller = new AbortController();
const generator = await pipeline('text-generation', MODEL, {
dtype: 'q4f16',
device: 'webgpu'
});
const generationArgs = { max_new_tokens: maxTokens, streamer: streamer, signal: controller.signal };
await generator(prompt, generationArgs); Here’s a little demo of it working (Ignore the text output nonsense. I think my laptop processor is messed up somehow): 0214.mp4It works at stopping the generation but the way I have it now feels kinda messy passing the Take I'm sure there's a cleaner way to do it. Would love to get your guys take on this approach. |
I haven't tried it myself yet, but it is definitely undocumented. I don't see signal or anything in the docs. |
Oh, I see now... you are simulating an abort. The text generation keeps going though. If you are using progress callbacks, you can just ignore them (after the user hits stop) to the same effect. |
Oh, the text generation actually does stop completely when the abort signal is triggered. I did some testing with Chrome dev tools and the CPU usage drops to 0% when aborted, showing the generation process fully terminates. The way it works is:
But yeah, I'm not super happy with having to pass the signal through so many layers. I was trying to follow the pattern used by the Fetch API where you need to get the abort signal all the way down to the actual async operation being aborted. Would be cool if in the future the But maybe there's a cleaner way to structure this. Like maybe we could store the signal at the pipeline level instead of threading it through all those args. Let me know if you have any other ideas for improving it. |
It seems like you're looking for https://github.com/huggingface/transformers.js/blob/main/src/models.js#L1573, |
I've been using the |
abort_signal
parameter to multiple methods and constructors