Skip to content

Multimodal prefix caching #1209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,13 +702,6 @@ impl Engine {
warn!("Prompt for request {} was {} tokens over the model maximum length. The last {} tokens were truncated to make space for generation.", request.id, currently_over, prompt_len - prompt_tokens.len());
}
}
let prefill_cache = handle_seq_error!(
self.prefix_cacher.search_for_matching_cache(
&prompt_tokens,
images.as_ref().is_some_and(|x| !x.is_empty())
),
request.response
);

let topk = request
.sampling_params
Expand Down Expand Up @@ -926,7 +919,7 @@ impl Engine {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!");
let seq = Sequence::new_waiting(
let mut seq = Sequence::new_waiting(
prompt_tokens.clone(),
prompt_text.clone(),
self.id,
Expand Down Expand Up @@ -959,6 +952,13 @@ impl Engine {
seq_preallocated_cache,
request.return_raw_logits,
);

let prefill_cache = handle_seq_error!(
self.prefix_cacher
.search_for_matching_cache(&mut seq, &*get_mut_arcmutex!(self.pipeline)),
request.response
);

let seq = if let Some(prefill_cache) = prefill_cache.clone() {
seq.prefill_v2(
prefill_cache.normal,
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/pipeline/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ pub trait InputsProcessor {
mapper: Option<&dyn DeviceMapper>,
) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>>;

fn supports_pre_processed_images(&self) -> bool {
true
}

fn get_type(&self) -> InputsProcessorType;
}

Expand Down
52 changes: 46 additions & 6 deletions mistralrs-core/src/prefix_cacher.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, ops::Range};

use candle_core::{Device, Result};
use either::Either;
Expand All @@ -8,6 +8,7 @@
use crate::{
pipeline::{KvCache, RotatingCache, SingleCache},
sequence::Sequence,
Pipeline,
};

#[derive(PartialEq, Eq, Debug, Hash)]
Expand Down Expand Up @@ -36,6 +37,7 @@
struct CacheElement {
cache: Vec<Option<KvCache>>,
devices: Vec<Option<Device>>,
mm_positions: Option<Vec<Range<usize>>>,

Check warning on line 40 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

field `mm_positions` is never read

Check warning on line 40 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

field `mm_positions` is never read

Check failure on line 40 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Clippy

field `mm_positions` is never read

Check warning on line 40 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

field `mm_positions` is never read

Check warning on line 40 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

field `mm_positions` is never read

Check warning on line 40 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Docs

field `mm_positions` is never read

Check warning on line 40 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

field `mm_positions` is never read

Check warning on line 40 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Test Suite (windows-latest, stable)

field `mm_positions` is never read
}

pub struct PrefixCacheManagerV2 {
Expand Down Expand Up @@ -65,7 +67,7 @@

/// This always keeps the cache on the device.
pub fn add_sequence(&mut self, seq: &mut Sequence) {
if self.no_prefix_cache || seq.has_images() {
if self.no_prefix_cache {
return;
}
let cache = seq.normal_cache().to_vec();
Expand All @@ -75,7 +77,11 @@
.collect::<Vec<_>>();
self.caches.insert(
seq.get_toks().to_vec().into(),
CacheElement { cache, devices },
CacheElement {
cache,
devices,
mm_positions: seq.get_mm_positions().cloned(),
},
);
}

Expand Down Expand Up @@ -217,10 +223,44 @@
/// Search for a matching cache given some toks
pub fn search_for_matching_cache(
&mut self,
toks: &[u32],
contains_images: bool,
seq: &mut Sequence,
pipeline: &dyn Pipeline,
) -> Result<Option<MatchingCache>> {
if self.no_prefix_cache || toks.is_empty() || contains_images {
if pipeline
.get_processor()
.inputs_processor()
.supports_pre_processed_images()
&& seq.has_images()
{
pipeline

Check failure on line 235 in mistralrs-core/src/prefix_cacher.rs

View workflow job for this annotation

GitHub Actions / Clippy

useless conversion to the same type: `std::boxed::Box<dyn std::iter::Iterator<Item = std::result::Result<pipeline::inputs_processor::InputProcessorOutput, anyhow::Error>>>`
.get_processor()
.inputs_processor()
.process_inputs(
pipeline.tokenizer(),
&mut [seq],
true,
pipeline.get_metadata().is_xlora,
&pipeline.device(),
pipeline.get_metadata().no_kv_cache,
None,
false,
pipeline.get_input_processor_config(),
None,
pipeline.get_metadata().prompt_chunksize,
pipeline.device_mapper(),
)
.into_iter()
.collect::<anyhow::Result<Vec<_>>>()
.map_err(candle_core::Error::msg)?;

let toks = seq.get_toks();

dbg!(toks.len(), seq.has_images());
}

let toks = seq.get_toks();

if self.no_prefix_cache || toks.is_empty() {
return Ok(None);
}

Expand Down
49 changes: 44 additions & 5 deletions mistralrs-core/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
use candle_core::Tensor;
use std::{
fmt::Display,
ops::Range,
sync::{Arc, RwLock},
time::{SystemTime, UNIX_EPOCH},
};
Expand Down Expand Up @@ -210,6 +211,8 @@ pub struct Sequence {
pub cached_pixel_values: Option<Tensor>,
pub cached_img_thw: Option<Tensor>,
pub cached_vid_thw: Option<Tensor>,
mm_positions: Option<Vec<Range<usize>>>,
have_processed_images: bool,

// GPU things
pub prompt_tok_per_sec: f32,
Expand Down Expand Up @@ -352,6 +355,8 @@ impl Sequence {
cached_vid_thw: None,
return_raw_logits,
token_offset: 0,
mm_positions: None,
have_processed_images: false,
}
}

Expand Down Expand Up @@ -507,6 +512,44 @@ impl Sequence {
}
}

// Clears the current mm positions.
pub(crate) fn recompute_mm_positions(&mut self, image_sequence_toks: Vec<Vec<u32>>) {
let mut accum = Vec::new();

let mut last_i = 0;
for img_seq_toks in image_sequence_toks {
let mut start = last_i;
for i in last_i..self.tokens.len() {
start = i;
if img_seq_toks
== self.tokens[i..(img_seq_toks.len() + i).min(self.tokens.len() - 1)]
{
break;
}
}

if start != self.tokens.len() - 1 {
let end = start + img_seq_toks.len();
last_i = end;
assert_eq!(img_seq_toks, self.tokens[start..end]);
accum.push(start..end);
}
}

self.mm_positions = Some(accum);
}

/// Indicate that this sequence has processed images and return the previous state.
pub(crate) fn flag_have_processed_images(&mut self) -> bool {
let prev = self.have_processed_images;
self.have_processed_images = true;
prev
}

pub(crate) fn get_mm_positions(&self) -> Option<&Vec<Range<usize>>> {
self.mm_positions.as_ref()
}

pub fn completion_bytes(&self) -> &[u8] {
&self.completion_bytes
}
Expand Down Expand Up @@ -784,11 +827,7 @@ impl Sequence {
self.adapters.clone()
}

pub fn take_images(&mut self) -> Option<Vec<image::DynamicImage>> {
self.input_images.take()
}

pub fn clone_images(&mut self) -> Option<Vec<image::DynamicImage>> {
pub fn clone_images(&self) -> Option<Vec<image::DynamicImage>> {
self.input_images.clone()
}

Expand Down
67 changes: 41 additions & 26 deletions mistralrs-core/src/vision_models/gemma3/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ impl InputsProcessor for Gemma3ImageProcessor {
let config = other_config.expect("Need a PreProcessorConfig config.");
let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");

let has_images = input_seqs.iter().all(|seq| seq.has_images());
let has_images = input_seqs
.iter()
.all(|seq| seq.is_prompt() && seq.has_images());

let (new_input, pixel_values) = if has_images {
let mut pixel_values_accum = Vec::new();
Expand All @@ -184,7 +186,7 @@ impl InputsProcessor for Gemma3ImageProcessor {
num_crops,
} = self
.preprocess(
seq.take_images()
seq.clone_images()
.expect("Need to have images by this point."),
vec![],
config,
Expand All @@ -194,43 +196,56 @@ impl InputsProcessor for Gemma3ImageProcessor {
.expect("Preprocessing failed");

let num_crops = num_crops.unwrap();

// Deliberately no .unsqueeze here
pixel_values_accum.push(pixel_values.clone());

let mut prompt = tokenizer
.decode(seq.get_toks(), false)
.expect("Detokenization failed!");
if !seq.flag_have_processed_images() {
let mut prompt = tokenizer
.decode(seq.get_toks(), false)
.expect("Detokenization failed!");

let image_indexes: Vec<usize> =
re.find_iter(&prompt).map(|mat| mat.start()).collect();
let image_indexes: Vec<usize> =
re.find_iter(&prompt).map(|mat| mat.start()).collect();

assert_ne!(pixel_values.dim(0).unwrap(), image_indexes.len());
assert_ne!(pixel_values.dim(0).unwrap(), image_indexes.len());

for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
if num != 0 {
let formatted_image_text = format!(
let num_imaged_embed = num_crops.iter().filter(|x| **x != 0).count();
for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
if num != 0 {
let formatted_image_text = format!(
"Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ")
);
prompt = format!(
"{}{formatted_image_text}{}",
&prompt[..idx],
&prompt[idx + BOI_TOKEN.len()..]
);
prompt = format!(
"{}{formatted_image_text}{}",
&prompt[..idx],
&prompt[idx + BOI_TOKEN.len()..]
);
}
}
}

prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);
prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);

seq.set_initial_prompt(prompt.clone());
let toks = tokenizer
.encode(prompt, false)
.expect("Detokenization failed!");
seq.set_initial_prompt(prompt.clone());
let toks = tokenizer
.encode(prompt, false)
.expect("Detokenization failed!");

let ids = toks.get_ids().to_vec();
all_ids.push(ids.clone());
let ids = toks.get_ids().to_vec();
all_ids.push(ids.clone());

seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());

let full_image_sequence_toks = tokenizer
.encode(self.full_image_sequence.clone(), false)
.expect("Detokenization failed!")
.get_ids()
.to_vec();
let full_image_sequence_toks = vec![full_image_sequence_toks; num_imaged_embed];
seq.recompute_mm_positions(full_image_sequence_toks);
} else {
let ids = seq.get_toks().to_vec();
all_ids.push(ids.clone());
}
}

let mut all_ids_new = Vec::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ impl InputsProcessor for Idefics2ImageProcessor {
let config = other_config.expect("Need a PreProcessorConfig config.");
let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");

let has_images = input_seqs.iter().all(|seq| seq.has_images());
let has_images = input_seqs
.iter()
.all(|seq| seq.is_prompt() && seq.has_images());

let (pixel_values, pixel_attention_mask) = if has_images {
let mut pixel_values_accum = Vec::new();
Expand All @@ -233,7 +235,7 @@ impl InputsProcessor for Idefics2ImageProcessor {
num_crops: _,
} = self
.preprocess(
seq.take_images()
seq.clone_images()
.expect("Need to have images by this point."),
vec![],
config,
Expand Down
Loading
Loading