Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pali gemma modeling #1895

Merged
merged 29 commits into from
May 16, 2024
Merged

Pali gemma modeling #1895

merged 29 commits into from
May 16, 2024

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented May 14, 2024

This PR adds paligemma modeling code

Blog post: https://huggingface.co/blog/paligemma
Transformers PR: huggingface/transformers#30814

install the latest changes and run with

# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf

# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf

basic example sending various requests

from huggingface_hub import InferenceClient

client = InferenceClient("http://127.0.0.1:3000")


images = [
    "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]

prompts = [
    "What animal is in this image?",
    "Name three colors in this image.",
    "What are 10 colors in this image?",
    "Where is the cow standing?",
    "answer en Where is the cow standing?",
    "Is there a bird in the image?",
    "Is ther a cow in the image?",
    "Is there a rabbit in the image?",
    "how many birds are in the image?",
    "how many rabbits are in the image?",
]

for img in images:
    print(f"\nImage: {img.split('/')[-1]}")
    for prompt in prompts:
        inputs = f"![]({img}){prompt}\n"
        json_data = {
            "inputs": inputs,
            "parameters": {
                "max_new_tokens": 30,
                "do_sample": False,
            },
        }
        generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
        print([f"{prompt}\n{generated_output}"])

Comment on lines 176 to 179
if config.model_type == "paligemma":
full_text += "<bos>" + chunk["content"] + "\n"
else:
full_text += chunk["content"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert this ? This is already taken care of by the PaliGemmaBatch.

Also we should probably raise an error when the query is not {image}, {text}. (single text, single image, image before text)

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to add the causal flag to all flash attention places.

@Narsil Narsil merged commit 40213c9 into main May 16, 2024
8 checks passed
@Narsil Narsil deleted the pali-gemma-modeling branch May 16, 2024 04:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants