From ab6b0757a8bfb3d5f809796856d9d13f043a8f71 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Thu, 7 Nov 2024 10:02:37 +0000 Subject: [PATCH] LLAVA notebook Signed-off-by: Prabod Rathnayaka --- ...gingFace_OpenVINO_in_Spark_NLP_LLAVA.ipynb | 879 ++++++++++++++++++ 1 file changed, 879 insertions(+) create mode 100644 examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_LLAVA.ipynb diff --git a/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_LLAVA.ipynb b/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_LLAVA.ipynb new file mode 100644 index 00000000000000..a12f939f35e6f4 --- /dev/null +++ b/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_LLAVA.ipynb @@ -0,0 +1,879 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_LLAVA.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Import OpenVINO LLAVA models from HuggingFace 🤗 into Spark NLP 🚀\n", + "\n", + "This notebook provides a detailed walkthrough on optimizing and importing LLAVA models from HuggingFace for use in Spark NLP, with [Intel OpenVINO toolkit](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html). The focus is on converting the model to the OpenVINO format and applying precision optimizations (INT8 and INT4), to enhance the performance and efficiency on CPU platforms using [Optimum Intel](https://huggingface.co/docs/optimum/main/en/intel/inference).\n", + "\n", + "Let's keep in mind a few things before we start 😊\n", + "\n", + "- OpenVINO support was introduced in `Spark NLP 5.4.0`, enabling high performance CPU inference for models. So please make sure you have upgraded to the latest Spark NLP release.\n", + "- Model quantization is a computationally expensive process, so it is recommended to use a runtime with more than 32GB memory for exporting the quantized model from HuggingFace.\n", + "- You can import LLAVA models via `LLAVA`. These models are usually under `Text Generation` category and have `LLAVA` in their labels.\n", + "- Reference: [LLAVA](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LLAVA)\n", + "- Some [example models](https://huggingface.co/models?search=LLAVA)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Export and Save the HuggingFace model\n", + "\n", + "- Let's install `transformers` and `openvino` packages with other dependencies. You don't need `openvino` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- We lock `transformers` on version `4.41.2`. This doesn't mean it won't work with the future release, but we wanted you to know which versions have been tested successfully." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import requests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "%pip install -q --upgrade transformers==4.41.2\n", + "%pip install -q --upgrade openvino==2024.1\n", + "%pip install -q \"git+https://github.com/eaidova/optimum-intel.git@ea/minicpmv\"\n", + "%pip install -q \"nncf>=2.13.0\" \"sentencepiece\" \"tokenizers>=0.12.1\" \"transformers>=4.45.0\" \"gradio>=4.36\"\n", + "%pip install -q -U --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly openvino-tokenizers openvino openvino-genai\n", + "%pip install -q --upgrade huggingface_hub\n", + "%pip install -q --upgrade onnx==1.15.0\n", + "%pip install -q --upgrade torch==2.2.1\n", + "\n", + "\n", + "utility_files = [\"notebook_utils.py\", \"cmd_helper.py\"]\n", + "\n", + "for utility in utility_files:\n", + " local_path = Path(utility)\n", + " if not local_path.exists():\n", + " r = requests.get(\n", + " url=f\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/{local_path.name}\",\n", + " )\n", + " with local_path.open(\"w\") as f:\n", + " f.write(r.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.1 Convert the model to OpenVino" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from cmd_helper import optimum_cli\n", + "\n", + "model_id = \"llava-hf/llava-1.5-7b-hf\"\n", + "model_path = Path(model_id.split(\"/\")[-1]) / \"FP16\"\n", + "\n", + "if not model_path.exists():\n", + " optimum_cli(model_id, model_path, additional_args={\"weight-format\": \"fp16\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:nncf:Statistics of the bitwidth distribution:\n", + "┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑\n", + "│ Num bits (N) │ % all parameters (layers) │ % ratio-defining parameters (layers) │\n", + "┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥\n", + "│ 4 │ 100% (225 / 225) │ 100% (225 / 225) │\n", + "┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e8f9bad3e593468db17c882e77311335", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "import shutil\n",
+    "import nncf\n",
+    "import openvino as ov\n",
+    "import gc\n",
+    "\n",
+    "\n",
+    "compression_mode = \"INT4\"\n",
+    "\n",
+    "core = ov.Core()\n",
+    "\n",
+    "\n",
+    "def compress_model_weights(precision):\n",
+    "    int4_compression_config = {\"mode\": nncf.CompressWeightsMode.INT4_ASYM, \"group_size\": 128, \"ratio\": 1, \"all_layers\": True}\n",
+    "    int8_compression_config = {\"mode\": nncf.CompressWeightsMode.INT8_ASYM}\n",
+    "\n",
+    "    compressed_model_path = model_path.parent / precision\n",
+    "\n",
+    "    if not compressed_model_path.exists():\n",
+    "        ov_model = core.read_model(model_path / \"openvino_language_model.xml\")\n",
+    "        compression_config = int4_compression_config if precision == \"INT4\" else int8_compression_config\n",
+    "        compressed_ov_model = nncf.compress_weights(ov_model, **compression_config)\n",
+    "        ov.save_model(compressed_ov_model, compressed_model_path / \"openvino_language_model.xml\")\n",
+    "        del compressed_ov_model\n",
+    "        del ov_model\n",
+    "        gc.collect()\n",
+    "        for file_name in model_path.glob(\"*\"):\n",
+    "            if file_name.name in [\"openvino_language_model.xml\", \"openvino_language_model.bin\"]:\n",
+    "                continue\n",
+    "            shutil.copy(file_name, compressed_model_path)\n",
+    "\n",
+    "\n",
+    "compress_model_weights(compression_mode)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1.2 Load openvino models"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_dir = model_path.parent / compression_mode\n",
+    "language_model = core.read_model(model_dir / \"openvino_language_model.xml\")\n",
+    "vision_embedding = core.compile_model(model_dir / \"openvino_vision_embeddings_model.xml\", \"AUTO\")\n",
+    "text_embedding = core.compile_model(model_dir / \"openvino_text_embeddings_model.xml\", \"AUTO\")\n",
+    "compiled_language_model = core.compile_model(language_model, \"AUTO\")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/prabod/anaconda3/envs/pth23/lib/python3.9/site-packages/torch/cuda/__init__.py:619: UserWarning: Can't initialize NVML\n",
+      "  warnings.warn(\"Can't initialize NVML\")\n"
+     ]
+    }
+   ],
+   "source": [
+    "import requests\n",
+    "from PIL import Image\n",
+    "from io import BytesIO\n",
+    "from transformers import AutoProcessor, AutoConfig\n",
+    "\n",
+    "config = AutoConfig.from_pretrained(model_path)\n",
+    "\n",
+    "processor = AutoProcessor.from_pretrained(\n",
+    "    model_path, patch_size=config.vision_config.patch_size, vision_feature_select_strategy=config.vision_feature_select_strategy\n",
+    ")\n",
+    "\n",
+    "\n",
+    "def load_image(image_file):\n",
+    "    if image_file.startswith(\"http\") or image_file.startswith(\"https\"):\n",
+    "        response = requests.get(image_file)\n",
+    "        image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
+    "    else:\n",
+    "        image = Image.open(image_file).convert(\"RGB\")\n",
+    "    return image\n",
+    "\n",
+    "\n",
+    "image_file = \"https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11\"\n",
+    "text_message = \"What is unusual on this image?\"\n",
+    "\n",
+    "image = load_image(image_file)\n",
+    "\n",
+    "conversation = [\n",
+    "    {\n",
+    "        \"role\": \"user\",\n",
+    "        \"content\": [\n",
+    "            {\"type\": \"text\", \"text\": text_message},\n",
+    "            {\"type\": \"image\"},\n",
+    "        ],\n",
+    "    },\n",
+    "]\n",
+    "\n",
+    "prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)\n",
+    "\n",
+    "inputs_new = processor(images=image, text=prompt, return_tensors=\"pt\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "request = compiled_language_model.create_infer_request()\n",
+    "input_names = {key.get_any_name(): idx for idx, key in enumerate(language_model.inputs)}\n",
+    "inputs = {}\n",
+    "# Set the initial input_ids\n",
+    "current_input_ids = inputs_new[\"input_ids\"]\n",
+    "attention_mask = inputs_new[\"attention_mask\"]\n",
+    "position_ids = attention_mask.long().cumsum(-1) - 1\n",
+    "position_ids.masked_fill_(attention_mask == 0, 1)\n",
+    "pixel_values = inputs_new[\"pixel_values\"]\n",
+    "\n",
+    "# Set the initial input_ids\n",
+    "text_out = text_embedding(inputs_new[\"input_ids\"])[0]\n",
+    "vision_out = vision_embedding(pixel_values)[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import torch\n",
+    "\n",
+    "class MergeMultiModalInputs(torch.nn.Module):\n",
+    "    def __init__(self,image_seq_length=576,image_token_index=32000):\n",
+    "        super().__init__()\n",
+    "        self.image_seq_length = image_seq_length\n",
+    "        self.image_token_index = image_token_index\n",
+    "\n",
+    "    def forward(\n",
+    "        self,\n",
+    "        vision_embeds,\n",
+    "        inputs_embeds,\n",
+    "        input_ids,\n",
+    "    ):\n",
+    "        image_features = vision_embeds\n",
+    "        inputs_embeds = inputs_embeds\n",
+    "        special_image_mask = (input_ids == self.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)\n",
+    "        # image_features = image_features.to(inputs_embeds.dtype)\n",
+    "        final_embedding = inputs_embeds.masked_scatter(special_image_mask, image_features)\n",
+    "\n",
+    "        return {\n",
+    "            \"final_embedding\": final_embedding\n",
+    "        }"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "torch_model_merge = MergeMultiModalInputs(\n",
+    "    image_seq_length=config.image_seq_length,\n",
+    "    image_token_index=config.image_token_index\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# test the model\n",
+    "inputs_embeds = torch.from_numpy(text_out)\n",
+    "input_ids = inputs_new[\"input_ids\"]\n",
+    "vision_embeds = torch.from_numpy(vision_out)\n",
+    "\n",
+    "final_embedding = torch_model_merge(vision_embeds, inputs_embeds, input_ids)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "WARNING:nncf:NNCF provides best results with torch==2.4.*, while current torch version is 2.3.1+cu121. If you encounter issues, consider switching to torch==2.4.*\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+      "To disable this warning, you can either:\n",
+      "\t- Avoid using `tokenizers` before the fork if possible\n",
+      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
+      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+      "To disable this warning, you can either:\n",
+      "\t- Avoid using `tokenizers` before the fork if possible\n",
+      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+     ]
+    }
+   ],
+   "source": [
+    "import openvino as ov\n",
+    "\n",
+    "# convert MergeMultiModalInputs to OpenVINO IR\n",
+    "ov_model_merge = ov.convert_model(\n",
+    "    torch_model_merge,\n",
+    "    example_input={\n",
+    "        \"vision_embeds\": torch.from_numpy(vision_out),\n",
+    "        \"inputs_embeds\": torch.from_numpy(text_out),\n",
+    "        \"input_ids\": inputs_new[\"input_ids\"],\n",
+    "    }\n",
+    ")\n",
+    "ov.save_model(ov_model_merge, model_dir/\"openvino_merge_model.xml\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "⌛ Check if all models are converted\n",
+      "✅ All models are converted. You can find results in llava-1.5-7b-hf/INT4\n"
+     ]
+    }
+   ],
+   "source": [
+    "# check if all the models are converted\n",
+    "\n",
+    "print(\"⌛ Check if all models are converted\")\n",
+    "lang_model_path = model_dir / \"openvino_language_model.xml\"\n",
+    "image_embed_path = model_dir / \"openvino_vision_embeddings_model.xml\"\n",
+    "img_projection_path = model_dir / \"openvino_text_embeddings_model.xml\"\n",
+    "merge_model_path = model_dir / \"openvino_merge_model.xml\"\n",
+    "\n",
+    "\n",
+    "\n",
+    "if all(\n",
+    "    [\n",
+    "        lang_model_path.exists(),\n",
+    "        image_embed_path.exists(),\n",
+    "        img_projection_path.exists(),\n",
+    "        merge_model_path.exists(),\n",
+    "    ]\n",
+    "):\n",
+    "    print(f\"✅ All models are converted. You can find results in {model_dir}\")\n",
+    "else:\n",
+    "    print(\"❌ Not all models are converted. Please check the conversion process\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1.2 Copy assets to the assets folder"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "assets_dir = model_dir / \"assets\"\n",
+    "assets_dir.mkdir(exist_ok=True)\n",
+    "\n",
+    "# copy all the assets to the assets directory (json files, vocab files, etc.)\n",
+    "\n",
+    "import shutil\n",
+    "\n",
+    "# copy all json files\n",
+    "\n",
+    "for file in model_dir.glob(\"*.json\"):\n",
+    "    shutil.copy(file, assets_dir)\n",
+    "\n",
+    "    \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+      "To disable this warning, you can either:\n",
+      "\t- Avoid using `tokenizers` before the fork if possible\n",
+      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "total 4.1G\n",
+      "-rw-rw-r-- 1 prabod prabod   41 Nov  7 04:33 added_tokens.json\n",
+      "drwxrwxr-x 2 prabod prabod 4.0K Nov  7 04:37 assets\n",
+      "-rw-rw-r-- 1 prabod prabod  701 Nov  7 04:33 chat_template.json\n",
+      "-rw-rw-r-- 1 prabod prabod 4.7K Nov  7 04:33 config.json\n",
+      "-rw-rw-r-- 1 prabod prabod  136 Nov  7 04:33 generation_config.json\n",
+      "-rw-rw-r-- 1 prabod prabod 332K Nov  7 04:33 openvino_detokenizer.bin\n",
+      "-rw-rw-r-- 1 prabod prabod 8.8K Nov  7 04:33 openvino_detokenizer.xml\n",
+      "-rw-rw-r-- 1 prabod prabod 3.2G Nov  7 04:33 openvino_language_model.bin\n",
+      "-rw-rw-r-- 1 prabod prabod 2.9M Nov  7 04:33 openvino_language_model.xml\n",
+      "-rw-rw-r-- 1 prabod prabod   40 Nov  7 04:36 openvino_merge_model.bin\n",
+      "-rw-rw-r-- 1 prabod prabod 9.8K Nov  7 04:36 openvino_merge_model.xml\n",
+      "-rw-rw-r-- 1 prabod prabod 251M Nov  7 04:33 openvino_text_embeddings_model.bin\n",
+      "-rw-rw-r-- 1 prabod prabod 3.1K Nov  7 04:33 openvino_text_embeddings_model.xml\n",
+      "-rw-rw-r-- 1 prabod prabod 1.2M Nov  7 04:33 openvino_tokenizer.bin\n",
+      "-rw-rw-r-- 1 prabod prabod  25K Nov  7 04:33 openvino_tokenizer.xml\n",
+      "-rw-rw-r-- 1 prabod prabod 595M Nov  7 04:33 openvino_vision_embeddings_model.bin\n",
+      "-rw-rw-r-- 1 prabod prabod 929K Nov  7 04:33 openvino_vision_embeddings_model.xml\n",
+      "-rw-rw-r-- 1 prabod prabod  505 Nov  7 04:33 preprocessor_config.json\n",
+      "-rw-rw-r-- 1 prabod prabod  134 Nov  7 04:33 processor_config.json\n",
+      "-rw-rw-r-- 1 prabod prabod  552 Nov  7 04:33 special_tokens_map.json\n",
+      "-rw-rw-r-- 1 prabod prabod 1.4K Nov  7 04:33 tokenizer_config.json\n",
+      "-rw-rw-r-- 1 prabod prabod 3.5M Nov  7 04:33 tokenizer.json\n",
+      "-rw-rw-r-- 1 prabod prabod 489K Nov  7 04:33 tokenizer.model\n"
+     ]
+    }
+   ],
+   "source": [
+    "!ls -lh {model_dir}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+      "To disable this warning, you can either:\n",
+      "\t- Avoid using `tokenizers` before the fork if possible\n",
+      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "total 3.5M\n",
+      "-rw-rw-r-- 1 prabod prabod   41 Nov  7 04:37 added_tokens.json\n",
+      "-rw-rw-r-- 1 prabod prabod  701 Nov  7 04:37 chat_template.json\n",
+      "-rw-rw-r-- 1 prabod prabod 4.7K Nov  7 04:37 config.json\n",
+      "-rw-rw-r-- 1 prabod prabod  136 Nov  7 04:37 generation_config.json\n",
+      "-rw-rw-r-- 1 prabod prabod  505 Nov  7 04:37 preprocessor_config.json\n",
+      "-rw-rw-r-- 1 prabod prabod  134 Nov  7 04:37 processor_config.json\n",
+      "-rw-rw-r-- 1 prabod prabod  552 Nov  7 04:37 special_tokens_map.json\n",
+      "-rw-rw-r-- 1 prabod prabod 1.4K Nov  7 04:37 tokenizer_config.json\n",
+      "-rw-rw-r-- 1 prabod prabod 3.5M Nov  7 04:37 tokenizer.json\n"
+     ]
+    }
+   ],
+   "source": [
+    "!ls -lh {assets_dir}"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1.3 Test the openvino model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import openvino as ov\n",
+    "import torch\n",
+    "\n",
+    "core = ov.Core()\n",
+    "device = \"CPU\"\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "language_model = core.read_model(model_dir / \"openvino_language_model.xml\")\n",
+    "language_model = core.read_model(model_dir / \"openvino_language_model.xml\")\n",
+    "vision_embedding = core.compile_model(model_dir / \"openvino_vision_embeddings_model.xml\", \"AUTO\")\n",
+    "text_embedding = core.compile_model(model_dir / \"openvino_text_embeddings_model.xml\", \"AUTO\")\n",
+    "compiled_language_model = core.compile_model(language_model, \"AUTO\")\n",
+    "merge_multi_modal = core.compile_model(model_dir / \"openvino_merge_model.xml\", \"AUTO\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "generated_tokens = []\n",
+    "\n",
+    "from transformers import AutoProcessor, TextStreamer\n",
+    "\n",
+    "conversation = [\n",
+    "    {\n",
+    "        \"role\": \"user\",\n",
+    "        \"content\": [\n",
+    "            {\"type\": \"text\", \"text\": \"What is unusual on this image?\"},\n",
+    "            {\"type\": \"image\"},\n",
+    "        ],\n",
+    "    },\n",
+    "]\n",
+    "\n",
+    "prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)\n",
+    "\n",
+    "inputs_new = processor(images=image, text=prompt, return_tensors=\"pt\")\n",
+    "\n",
+    "# inputs_new = processor(prompt, [image], return_tensors=\"pt\")\n",
+    "\n",
+    "generation_args = {\"max_new_tokens\": 50, \"do_sample\": False, \"streamer\": TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)}\n",
+    "\n",
+    "\n",
+    "request = compiled_language_model.create_infer_request()\n",
+    "merge_model_request = merge_multi_modal.create_infer_request()\n",
+    "input_names = {key.get_any_name(): idx for idx, key in enumerate(language_model.inputs)}\n",
+    "inputs = {}\n",
+    "# Set the initial input_ids\n",
+    "current_input_ids = inputs_new[\"input_ids\"]\n",
+    "attention_mask = inputs_new[\"attention_mask\"]\n",
+    "position_ids = attention_mask.long().cumsum(-1) - 1\n",
+    "position_ids.masked_fill_(attention_mask == 0, 1)\n",
+    "pixel_values = inputs_new[\"pixel_values\"]\n",
+    "\n",
+    "for i in range(generation_args[\"max_new_tokens\"]):\n",
+    "    # Generate input embeds each time\n",
+    "    if current_input_ids.shape[-1] > 1:\n",
+    "        vision_embeds = torch.from_numpy(vision_embedding({\n",
+    "            \"pixel_values\": pixel_values,\n",
+    "        })[0])\n",
+    "    \n",
+    "    text_embeds = torch.from_numpy(text_embedding(current_input_ids)[0])\n",
+    "\n",
+    "    if i == 0:\n",
+    "        merge_model_request.start_async({\n",
+    "            \"vision_embeds\": vision_embeds,\n",
+    "            \"inputs_embeds\": text_embeds,\n",
+    "            \"input_ids\": current_input_ids,\n",
+    "        }, share_inputs=True)\n",
+    "        merge_model_request.wait()\n",
+    "        final_embedding = torch.from_numpy(merge_model_request.get_tensor(\"final_embedding\").data)\n",
+    "    else:\n",
+    "        final_embedding = text_embeds\n",
+    "    if i>0:\n",
+    "        inputs = {}\n",
+    "    # Prepare inputs for the model\n",
+    "    inputs[\"inputs_embeds\"] = final_embedding\n",
+    "    inputs[\"attention_mask\"] = attention_mask\n",
+    "    inputs[\"position_ids\"] = position_ids\n",
+    "    if \"beam_idx\" in input_names:\n",
+    "        inputs[\"beam_idx\"] = np.arange(attention_mask.shape[0], dtype=int)\n",
+    "    \n",
+    "    # Start inference\n",
+    "    request.start_async(inputs, share_inputs=True)\n",
+    "    request.wait()\n",
+    "    \n",
+    "    # Get the logits and find the next token\n",
+    "    logits = torch.from_numpy(request.get_tensor(\"logits\").data)\n",
+    "    next_token = logits.argmax(-1)[0][-1]\n",
+    "    \n",
+    "    # Append the generated token\n",
+    "    generated_tokens.append(next_token)\n",
+    "    \n",
+    "    # Update input_ids with the new token\n",
+    "    current_input_ids = torch.cat([next_token.unsqueeze(0).unsqueeze(0)], dim=-1)\n",
+    "    \n",
+    "    # update the attention mask\n",
+    "    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :1])], dim=-1)\n",
+    "\n",
+    "    # Update inputs for the next iteration\n",
+    "    position_ids = attention_mask.long().cumsum(-1) - 1\n",
+    "    position_ids.masked_fill_(attention_mask == 0, 1)\n",
+    "    position_ids = position_ids[:, -current_input_ids.shape[1] :]\n",
+    "    inputs[\"position_ids\"] = position_ids"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Question:\n",
+      " What is unusual on this picture?\n",
+      "Answer:\n",
+      "The unusual aspect of this image is that a cat is lying inside a cardboard box, which is not a typical place for a cat to rest. Cats are known for their curiosity and love for small, enclosed spaces, but in this case\n"
+     ]
+    }
+   ],
+   "source": [
+    "generated_text = processor.decode(generated_tokens, skip_special_tokens=True)\n",
+    "\n",
+    "image\n",
+    "print(\"Question:\\n What is unusual on this picture?\")\n",
+    "print(\"Answer:\")\n",
+    "print(generated_text)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## 2. Import and Save LLAVA in Spark NLP\n",
+    "\n",
+    "- Let's install and setup Spark NLP in Google Colab\n",
+    "- This part is pretty easy via our simple script"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let's start Spark with Spark NLP included via our simple `start()` function"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "24/11/07 09:56:55 WARN Utils: Your hostname, minotaur resolves to a loopback address: 127.0.1.1; using 192.168.1.4 instead (on interface eno1)\n",
+      "24/11/07 09:56:55 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+      "24/11/07 09:56:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Setting default log level to \"WARN\".\n",
+      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n"
+     ]
+    }
+   ],
+   "source": [
+    "import sparknlp\n",
+    "\n",
+    "# let's start Spark with Spark NLP\n",
+    "spark = sparknlp.start()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "24/11/07 09:57:34 WARN NativeLibrary: Failed to load library null: java.lang.UnsatisfiedLinkError: Can't load library: /tmp/openvino-native15331424460843812197/libtbb.so.2\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "WARNING: An illegal reflective access operation has occurred\n",
+      "WARNING: Illegal reflective access by org.apache.spark.util.SizeEstimator$ (file:/home/prabod/spark/jars/spark-core_2.12-3.3.2.jar) to field java.util.regex.Pattern.pattern\n",
+      "WARNING: Please consider reporting this to the maintainers of org.apache.spark.util.SizeEstimator$\n",
+      "WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n",
+      "WARNING: All illegal access operations will be denied in a future release\n"
+     ]
+    }
+   ],
+   "source": [
+    "imageClassifier = LLAVAForMultiModal.pretrained() \\\n",
+    "            .setInputCols(\"image_assembler\") \\\n",
+    "            .setOutputCol(\"answer\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "imageClassifier.write().overwrite().save(\"LLAVA_spark_nlp\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sparknlp\n",
+    "from sparknlp.base import *\n",
+    "from sparknlp.annotator import *\n",
+    "from pyspark.sql.functions import lit\n",
+    "from pyspark.ml import Pipeline\n",
+    "from pathlib import Path\n",
+    "import os\n",
+    "\n",
+    "# download two images to test into ./images folder\n",
+    "\n",
+    "url1 = \"https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11\"\n",
+    "url2 = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
+    "\n",
+    "Path(\"images\").mkdir(exist_ok=True)\n",
+    "\n",
+    "!wget -q -O images/image1.jpg {url1}\n",
+    "!wget -q -O images/image2.jpg {url2}\n",
+    "\n",
+    "\n",
+    "\n",
+    "images_path = \"file://\" + os.getcwd() + \"/images/\"\n",
+    "image_df = spark.read.format(\"image\").load(\n",
+    "    path=images_path\n",
+    ")\n",
+    "\n",
+    "test_df = image_df.withColumn(\"text\", lit(\"USER: \\n <|image|> \\n What's this picture about? \\n ASSISTANT:\\n\"))\n",
+    "\n",
+    "image_assembler = ImageAssembler().setInputCol(\"image\").setOutputCol(\"image_assembler\")\n",
+    "\n",
+    "imageClassifier = LLAVAForMultiModal.load(\"LLAVA_spark_nlp\")\\\n",
+    "            .setMaxOutputLength(50) \\\n",
+    "            .setInputCols(\"image_assembler\") \\\n",
+    "            .setOutputCol(\"answer\")\n",
+    "\n",
+    "pipeline = Pipeline(\n",
+    "            stages=[\n",
+    "                image_assembler,\n",
+    "                imageClassifier,\n",
+    "            ]\n",
+    "        )\n",
+    "\n",
+    "model = pipeline.fit(test_df)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "image_path: /mnt/research/Projects/ModelZoo/LLAVA/images/image1.jpg\n",
+      "[Annotation(document, 0, 363, This image features a cat comfortably laying inside a cardboard box. The cat appears to be relaxed and enjoying its cozy spot. The scene takes place on a carpeted floor, which adds to the overall warm and inviting atmosphere of the image. The cat's position inside the box creates a sense of security and contentment, making it an endearing and heartwarming scene., Map(), [])]\n"
+     ]
+    }
+   ],
+   "source": [
+    "light_pipeline = LightPipeline(model)\n",
+    "image_path = os.getcwd() + \"/images/\" + \"image1.jpg\"\n",
+    "print(\"image_path: \" + image_path)\n",
+    "annotations_result = light_pipeline.fullAnnotateImage(\n",
+    "    image_path,\n",
+    "    \"USER: \\n <|image|> \\n What's this picture about? \\n ASSISTANT:\\n\"\n",
+    ")\n",
+    "\n",
+    "for result in annotations_result:\n",
+    "    print(result[\"answer\"])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "tempspark",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.16"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}