Skip to content

Pali gemma modeling #1895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5fd72ed
feat: load and query model
drbh May 8, 2024
b07b53e
feat: improve config and refactor
drbh May 9, 2024
2329434
fix: debugging
drbh May 9, 2024
e13c08f
fix: adjust siglip attention
drbh May 9, 2024
d503007
fix: debug avoid scaling embed
drbh May 10, 2024
36fb4b5
fix: adjust image and text merge logic
drbh May 10, 2024
4df1b25
fix: typo and lint
drbh May 10, 2024
6e8a211
fix: adjust inputs_embeds passed to language model and debug
drbh May 10, 2024
5b3b8fd
fix: prefer gemma rotary embed and split attention weight
drbh May 14, 2024
9b9614c
fix: small test tweak
drbh May 14, 2024
ebbe7ed
Don't break what's not broken.
Narsil May 14, 2024
67e833c
Back functional gemma.
Narsil May 14, 2024
c119ac4
Fixed PaliGemma.
Narsil May 14, 2024
d6e306c
fix: apply paligemma template conditionally
drbh May 14, 2024
70713fc
fix: improve pali test and add snapshot
drbh May 14, 2024
17ac93e
fix: default add special tokens to avoid vlm regressions
drbh May 15, 2024
65bc0aa
Working integration-tests.
Narsil May 15, 2024
1bcaf8f
Fixed.
Narsil May 15, 2024
e8d0218
Small updates.
Narsil May 15, 2024
79b15fe
Installing git.
Narsil May 15, 2024
ec92601
Revert "Installing git."
Narsil May 15, 2024
81e7aac
Revert "Revert "Installing git.""
Narsil May 15, 2024
368c057
Trying to understand the weird failure.
Narsil May 15, 2024
dc0b8d7
Change the dockerfile. It builds locally, something might be up in AWS
Narsil May 15, 2024
f3f7140
DEbugging this nightmare.
Narsil May 15, 2024
f8337a9
Using updated runner.
Narsil May 15, 2024
fcb62c7
Another attempt.
Narsil May 15, 2024
9005970
Sshing a cuda 12.4
Narsil May 15, 2024
7f97fda
Upgrade mamba.
Narsil May 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
EC2_AMI_ID: ami-0789b6925c11b1fb2
EC2_INSTANCE_TYPE: g5.12xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
Expand Down
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1
ARG MAMBA_VERSION=23.3.1-1
ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx
Expand Down Expand Up @@ -181,6 +181,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
ca-certificates \
make \
curl \
git \
&& rm -rf /var/lib/apt/lists/*

# Copy conda with PyTorch installed
Expand Down
Binary file added integration-tests/images/cow_beach.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 2,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 54901,
"logprob": -0.72753906,
"special": false,
"text": "beach"
},
{
"id": 1,
"logprob": -0.011009216,
"special": true,
"text": "<eos>"
}
],
"top_tokens": null
},
"generated_text": "beach"
}
39 changes: 39 additions & 0 deletions integration-tests/models/test_flash_pali_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import requests
import io
import base64


@pytest.fixture(scope="module")
def flash_pali_gemma_handle(launcher):
with launcher(
"google/paligemma-3b-pt-224",
num_shard=1,
revision="float16",
max_input_length=4000,
max_total_tokens=4096,
) as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_pali_gemma(flash_pali_gemma_handle):
await flash_pali_gemma_handle.health(300)
return flash_pali_gemma_handle.client


def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
cow = get_cow_beach()
inputs = f"![]({cow})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)

assert response.generated_text == "beach"
assert response == response_snapshot
21 changes: 19 additions & 2 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,13 @@ impl LlavaNext {
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct ClipVisionModel {
image_size: usize,
patch_size: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}

Expand All @@ -118,6 +116,24 @@ impl Idefics2 {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PaliTextConfig {
num_image_tokens: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
text_config: PaliTextConfig,
}

impl Paligemma {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
self.text_config.num_image_tokens
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
Expand All @@ -139,6 +155,7 @@ pub enum Config {
Phi3,
Llama,
Baichuan,
Paligemma(Paligemma),
Gemma,
Cohere,
Drbx,
Expand Down
24 changes: 24 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,30 @@ fn prepare_input(
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Paligemma(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
Expand Down
Loading