From 8915061e27cb50e6f0391432bf1b513c24e7bdf6 Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Wed, 10 Apr 2024 14:31:59 -0700 Subject: [PATCH] Switch to the tools endpoint (#40) * get endpoint ready for demo fixed tools.json Update vision_agent/tools/tools.py Bug fixes * Fix linter errors * Fix a bug in result parsing * Include scores in the G-SAM model response * Removed tools.json , need to find better format * Fixing the endpoint for CLIP and adding thresholds for grounding tools * fix mypy errors * fixed example notebook --------- Co-authored-by: Yazhou Cao Co-authored-by: shankar_ws3 --- examples/va_example.ipynb | 236 ++++++++++++++++++++++++++++- vision_agent/agent/vision_agent.py | 1 - vision_agent/tools/tools.py | 140 +++++++++++------ 3 files changed, 326 insertions(+), 51 deletions(-) diff --git a/examples/va_example.ipynb b/examples/va_example.ipynb index d43173f3..1f52caef 100644 --- a/examples/va_example.ipynb +++ b/examples/va_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -10,7 +10,9 @@ "import pandas as pd\n", "import textwrap\n", "import json\n", - "from IPython.display import Image" + "import os\n", + "from IPython.display import Image\n", + "from PIL import Image" ] }, { @@ -358,6 +360,236 @@ "ds_ct = ds_ct.build_index(\"analysis\")\n", "ds_ct.search(\"Presence of a Tumor\", top_k=1)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool usage" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "t = va.tools.GroundingDINO()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "14" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans = t(prompt=\"shoes\", image=\"shoes.jpg\", box_threshold=0.30, iou_threshold=0.2)\n", + "len(ans[\"labels\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['labels', 'scores', 'bboxes', 'size'])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.41,\n", + " 0.36,\n", + " 0.36,\n", + " 0.36,\n", + " 0.35,\n", + " 0.34,\n", + " 0.34,\n", + " 0.32,\n", + " 0.32,\n", + " 0.32,\n", + " 0.32,\n", + " 0.31,\n", + " 0.31,\n", + " 0.3]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans[\"scores\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "t1 = va.tools.GroundingSAM()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "ans = t1(prompt=\"bird\", image=\"birds.jpg\", box_threshold=0.40, iou_threshold=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[0.05, 0.03, 0.22, 0.33],\n", + " [0.71, 0.38, 0.94, 0.95],\n", + " [0.45, 0.26, 0.6, 0.56],\n", + " [0.2, 0.21, 0.37, 0.63]]" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans[\"bboxes\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['labels', 'bboxes', 'masks', 'scores'])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "img = Image.fromarray(ans[\"masks\"][2] * 255)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAH0AlgBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+iiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivSPAvwZ8QeMohe3B/snTDgrcXMTF5QV3Bo043LyvzEgfNwTgivU4P2b/Cq28S3Gq6zJOEAkeOSJFZsckKUJAz2yceprL1n9muzfe+h+IJ4sRHbDfQiTfJzjMibdqngfdYjk89K8k8Z/DnxH4FdG1e2ja0lfZFeW774nbaDjOAVPXhgM7WxkDNcnRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXonwg+H8PjvxLMdQ8z+ydPRZbkIwBlZj8kec5AbDEkdlIyCQR9d0UVHPBDdW8tvcRRzQSoUkjkUMrqRggg8EEcYr50+MPwe/sr7R4m8M23/Ev5kvbGNf+Pb1kQf8APP1X+HqPl+54fRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRViwsbjU9RtrCzj8y6upUhhTcBudiAoyeBkkda+2/CPhex8HeGrTR7COMCJAZpVTaZ5cDdI3JOSR0ycDAHAFblFFFFeD/ABT+CNvJZya34OsvKuI9z3OmxZIlBJJaIdmGfuDggAKARhvniiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivSPgZo39r/FCylZIJIdPikvJEmGc4GxSowfmDujDpjbnOQK+t6KKKKKK8P+MPwe/tX7R4m8M23/Ew5kvbGNf+Pn1kQf8APT1X+LqPm+/84UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV9P/AHwb/YvheTxDeQ7b7Vf9TvXDR24Py4yoI3nLcEhlEZr2Ciiiiiiivn/wCOnwx+/wCLtAsf7z6rFEfoRMEx/vbyD6Nj77V4BRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRX3npOmw6No1jpdu0jQWVvHbxtIQWKooUE4AGcD0FXKKKKKKKKjnghureW3uIo5oJUKSRyKGV1IwQQeCCOMV8WfEPww/hHxzqel+R5NqJTLZgbipgY5TDNy2B8pPPzKwycVy9FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFdZ8MdNm1X4m+HbeBo1dL1LglyQNsR81hwDztQge+OnWvtOiiiiiiiiivn/wDaU0T/AJAevxW/9+yuJ9//AAOJduf+uxyB9T0rwCiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivXP2dYIZviNdPLFG7w6ZK8TMoJRvMjXK+h2swyOxI719R0UUUUUUUUV5n8edNhvvhXeXErSB7C4huIgpGCxcRYbjptkY8Y5A+h+TKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK9c/Z1nhh+I10kssaPNpkqRKzAF28yNsL6narHA7AntX1HRRRRRRRRRXB/GeNZfhJryvNHCAkTbnDEEiZCF+UE5JGB2yRkgZI+PKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK3PB2vt4W8Y6VrStIEtbhWl8tVZmiPyyKA3GShYduvUda+46KKKKKKKKK8/8Ajb/ySHXf+3f/ANKI6+QKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK+r/gX4u/4SPwMmm3DZvtG22z8fehIPlNwoA4BTGSf3eT96vUKKKKKKKKK8/8Ajb/ySHXf+3f/ANKI6+QKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK6TwH4rm8GeMbDWUMhgR9l1Gmf3kLcOMZAJx8wBONyqT0r7XgnhureK4t5Y5oJUDxyRsGV1IyCCOCCOc1JRRRRRRRXmfx51KGx+Fd5byrIXv7iG3iKgYDBxLluem2Nhxnkj6j5Moooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor3T4BfEFbG4Pg/VJ40t53MlhLLI3EpIzCM8ANyw6fNkclxj6Looooooor5I+NXjG48T+Obmw+5YaPLJa26FAG3ggSsTk5yycf7IXgHOfN6KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK+i/hN8aIbu3TQfF19HDcxJ/o2pXEgVZlA+7Kx4DgdGP3u/zcv7pRRRRRXlfxc+K1z4CuNP07Sbe0uNQuEM832pXKxxZKrgKVyWYN/FxsOR8wNfLl/fXGp6jc395J5l1dSvNM+0Dc7EljgcDJJ6VXoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor2D4b/ABxuPCunRaNr9tPqGmxcQTRMDNAgBwmGwHXO0DJG0Z5IAUfS9hfW+p6dbX9nJ5lrdRJNC+0jcjAFTg8jII61Yooor5o/aRsbiPxlpN+0eLWbT/JjfcPmdJHLDHXgSJ+fsa8XoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooorU0DxHq/hbVF1LRb6S0uwhTeoDBlPUMrAhh0OCDyAeoFe9+Ef2idOvNtt4rs/wCz5uf9MtFZ4T948py68BRxvyST8or2SHVtNudLOqQahaS6eEZzdpMrRBVzuO8HGBg5OeMGsP4d+JLzxd4E03XL+OCO6uvN3pApCDbK6DAJJ6KO9dRXk/7Qejf2h8Ol1FEg8zTbuOVpHHziN/3ZVTjuzRkjgHb6gV8sUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV6J8L/AB9aeGHv9C16KSfw3rCGK7Cu+YMqVLgA9Cpw2PmwFIOVAP03o2i+GvAenJYacsGmWt3dgIk1yx82dwAFUyMSWIQYUenSugqnq2mw6zo19pdw0iwXtvJbyNGQGCupUkZBGcH0NfCl/Y3Gmajc2F5H5d1ayvDMm4Ha6khhkcHBB6VXoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor2D4QfE7+zdRTw/wCKr7ztEl8v7K94PMW0mQr5fzMfkjG0eoVgpG0bjX0/RXzB8dPh7/wj2sP4ns33WOq3bebFjHkTFQ3UsS28iVuAAuMeleP0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV658Jvi+vg23TQdZgkl0mS43pcq7FrQMPm+Tncm7DYXBGXPzE4r6H0Xxj4c8Q6XNqWl6xaT2kCF53L7DAo3cyK2Cg+VjlgMgZ6V4p8a/irpetaOfC/h66+1RyShr66QfuyEY4jUkfN8wVty8YUYLBjjweiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiv//Z", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAH0CAAAAADEOwwIAAAJeUlEQVR4Ae3d63ajNhSA0UxX3/+V00wuNhdJIHFIDWf3zxgJMN76Rk6na3Xe3vxDgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECCwU+DPzvOc9v4gYPagqL+AVLeZzDyr+h7kNtEpvQRUUlmMrbL6nEe3YJod0plxFA/KXX2cCq/o9TnIpm7zNVPN6mOaXlUPTZVmuytp1fH+qU+Z2SHw3trQdlx/21OE1V7a7W62z2i/w01nhdVc2D3V7Dmn+Sa3nBRWa1n3NbPvrNb73HBOWAGLqqw1orDWJo+R3cHsPvFx69u/EFbIEitrySispcjzuKeWnnOf73DjV8KqLm5fK31nV9/0NhP+5L22lP2lsJxY2rEmGAdf9qd48A1f+XJhVVZnpJKRaypvf/lhYUUuobIemsJ6UMxeDCYyeNnsre9xIKzYdVTWt6ewYsN6U9YXqLCKYR3I48ClxUe56KCwwhdOWX9JhRUelm9DYcVH9XlHe5Ydq5zW0TKOXl9+qkuN+io8ZbmUJaxTwvJzlrDOCSt9WcIqhBXyRRZyk8LDXWRIWKctVO6yhLUOK6qIqPusn/ACI8K6wCJd8RGFtVq1uI0m7k6rh3z5AWGduUSJyxLWMqzQGEJvtnzSlz4W1ksvz3UfTljnrl3aLUtY54aV9k/ghbUIK+0Ws3A4eiiso4KuLwoIq8gSOJh0CxRWYENu9RQQ1tPipFc5tyxhzXPKWcHcIORIWCGMbrIUENZSJP445S4orPiQ3PFDQFgyOEVAWDPWlN9aM4GoA2FFSTbukzFXYTWCMDUuIKxxO1c2BIQ1xcn4nTX9/IGvhRWI6VZPAWE9Lc57lXAnFNZ5OaW+s7BSL/95H15Y59mmvrOwUi//eR9eWOfZpr6zsFIv/3kfXljn2aa+s7BSL/95H15Y59mmvrOwfmX58/3Ru7B+Jax8f120sKZh5Vv/6acPfS2sUE43+xEQ1o+EX0MFhDXj9F044zhwIKwDeC6tCwirbmPmgICw5ni+C+cew0fCGqZzYUtAWC0dc8MCwlrQ+S5cgAweCmsQzmVtAWEtfWxZS5GhY2ENsbloS0BYKyFb1opkYEBYA2gu2RYQ1trIlrU26R4RVjeZC/YICKugZMsqoHQOCasEpqySSteYsLq4nLxXQFhFKVtWkaVjUFgdWE7dLyCsslXslhV7t/ITv9iosF5sQe7yOMKqrGTCTaYiMTYsrJpbYFmBt6o97cuNC6u6JBlzqGJ0Twirm6z7gpSFCqveSVAQQbepP+dLzgirsSw5k2iAdEwJqwNr6NSkdQqrVUvSKFoke+eE1ZRSVpOnMSmsBs7HlLLaPtVZYVVpviaUtQFUmRZWBeYxfLCsg5c/HuNqL7J+7p51Gv9fHifWtWNtJzacx/CF28/08mcIa8cSZQ5kB0/xFGEVWRaDY2WNXbV466seCmvXyv1JHckuosVJwlqA1A6VVZMpjwur7LIeVdbapDEirAbOfEpZc4/2kbDaPtNZZU01Nl4LawNoOu1H+KlG+7Ww2j6LWWktQKqHwqrSlCc6vg/H/1NQ+a0vNSqsSy3XdR624zfgdT7U6U+6cy/KjGvHGqlwZzE7+xt5gpe/RlhDS+SH+C02YW0JVeaXaS2PK5elGRbW8FJLqUUnrJbOxtxmWol/yBLWRjzt6c202pffeFZYBxf3O62d/5548M0udLmwDi/WT1M/v05vmPe7UFjTDsZel4oau9ONrvr3Rp/l//woebemirodqwLTM9zYstIGJ6yegmrnNsqqXXL3cWGFrLCylozCWoqMHVfLyvpdKKyxkMpX+fPSh4uwHhTHXnxuWVm3pwJddQsvnGuoLVDOKqmwHasdS89sOaFybj33veS5wgpctnJZgW9woVsJ6/TFek+5ZwkrMixb1kNTWA+KiBflsjJuWcKK6Ol5j3JZz/k0r0CEL3Vpf8rHbMcKDytfRCVCYZVUoscStias6Ij8NSmfogl/L8WXtL7j/OesjMh2rHUVASOzlGYHATe/xC1SfujfWZnvXSupcNKP/Wtp8f0dau9CgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIEJgJ/AeQClMrDRAyZwAAAABJRU5ErkJggg==", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(img)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'labels': [], 'bboxes': [], 'masks': [], 'scores': []}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t1(prompt=\"bird\", image=\"shoes.jpg\",)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "t2 = va.tools.CLIP()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'labels': ['crow', ' cheetah'], 'scores': [0.9999, 0.0001]}" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t2(prompt=\"crow, cheetah\", image=\"birds.jpg\",)" + ] } ], "metadata": { diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 1aed604f..a072903c 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -368,7 +368,6 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: continue for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful if not isinstance(call_result, dict): continue diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3aa9a1b1..700555cc 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -78,32 +78,32 @@ class CLIP(Tool): ------- >>> import vision_agent as va >>> clip = va.tools.CLIP() - >>> clip(["red line", "yellow dot"], "ct_scan1.jpg")) + >>> clip("red line, yellow dot", "ct_scan1.jpg")) [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}] """ - _ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws" + _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" name = "clip_" description = "'clip_' is a tool that can classify or tag any image given a set of input classes or tags." usage = { "required_parameters": [ - {"name": "prompt", "type": "List[str]"}, + {"name": "prompt", "type": "str"}, {"name": "image", "type": "str"}, ], "examples": [ { "scenario": "Can you classify this image as a cat? Image name: cat.jpg", - "parameters": {"prompt": ["cat"], "image": "cat.jpg"}, + "parameters": {"prompt": "cat", "image": "cat.jpg"}, }, { "scenario": "Can you tag this photograph with cat or dog? Image name: cat_dog.jpg", - "parameters": {"prompt": ["cat", "dog"], "image": "cat_dog.jpg"}, + "parameters": {"prompt": "cat, dog", "image": "cat_dog.jpg"}, }, { "scenario": "Can you build me a classifier that classifies red shirts, green shirts and other? Image name: shirts.jpg", "parameters": { - "prompt": ["red shirt", "green shirt", "other"], + "prompt": "red shirt, green shirt, other", "image": "shirts.jpg", }, }, @@ -111,11 +111,11 @@ class CLIP(Tool): } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: + def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: """Invoke the CLIP model. Parameters: - prompt: a list of classes or tags to classify the image. + prompt: a string includes a list of classes or tags to classify the image. image: the input image to classify. Returns: @@ -123,8 +123,9 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: """ image_b64 = convert_to_b64(image) data = { - "classes": prompt, - "images": [image_b64], + "prompt": prompt, + "image": image_b64, + "tool": "closed_set_image_classification", } res = requests.post( self._ENDPOINT, @@ -138,10 +139,11 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") - rets = [] - for elt in resp_json["data"]: - rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]}) - return cast(Dict, rets[0]) + resp_json["data"]["scores"] = [ + round(prob, 4) for prob in resp_json["data"]["scores"] + ] + + return resp_json["data"] # type: ignore class GroundingDINO(Tool): @@ -158,7 +160,7 @@ class GroundingDINO(Tool): 'scores': [0.98, 0.02]}] """ - _ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws" + _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" name = "grounding_dino_" description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." @@ -167,6 +169,10 @@ class GroundingDINO(Tool): {"name": "prompt", "type": "str"}, {"name": "image", "type": "str"}, ], + "optional_parameters": [ + {"name": "box_threshold", "type": "float"}, + {"name": "iou_threshold", "type": "float"}, + ], "examples": [ { "scenario": "Can you build me a car detector?", @@ -181,32 +187,44 @@ class GroundingDINO(Tool): "parameters": { "prompt": "red shirt. green shirt", "image": "shirts.jpg", + "box_threshold": 0.20, + "iou_threshold": 0.75, }, }, ], } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: + def __call__( + self, + prompt: str, + image: Union[str, Path, ImageType], + box_threshold: float = 0.20, + iou_threshold: float = 0.75, + ) -> Dict: """Invoke the Grounding DINO model. Parameters: prompt: one or multiple class names to detect. The classes should be separated by a period if there are multiple classes. E.g. "big dog . small cat" image: the input image to run against. + box_threshold: the threshold to filter out the bounding boxes with low scores. + iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold. Returns: A list of dictionaries containing the labels, scores, and bboxes. Each dictionary contains the detection result for an image. """ image_size = get_image_size(image) image_b64 = convert_to_b64(image) - data = { + request_data = { "prompt": prompt, - "images": [image_b64], + "image": image_b64, + "tool": "visual_grounding", + "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } res = requests.post( self._ENDPOINT, headers={"Content-Type": "application/json"}, - json=data, + json=request_data, ) resp_json: Dict[str, Any] = res.json() if ( @@ -214,16 +232,15 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: ) or "statusCode" not in resp_json: _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") - resp_data = resp_json["data"] - for elt in resp_data: - if "bboxes" in elt: - elt["bboxes"] = [ - normalize_bbox(box, image_size) for box in elt["bboxes"] - ] - if "scores" in elt: - elt["scores"] = [round(score, 2) for score in elt["scores"]] - elt["size"] = (image_size[1], image_size[0]) - return cast(Dict, resp_data) + data: Dict[str, Any] = resp_json["data"] + if "bboxes" in data: + data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]] + if "scores" in data: + data["scores"] = [round(score, 2) for score in data["scores"]] + if "labels" in data: + data["labels"] = [label for label in data["labels"]] + data["size"] = (image_size[1], image_size[0]) + return data class GroundingSAM(Tool): @@ -234,7 +251,7 @@ class GroundingSAM(Tool): ------- >>> import vision_agent as va >>> t = va.tools.GroundingSAM() - >>> t(["red line", "yellow dot"], ct_scan1.jpg"]) + >>> t("red line, yellow dot", "ct_scan1.jpg"]) [{'labels': ['yellow dot', 'red line'], 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]], 'masks': [array([[0, 0, 0, ..., 0, 0, 0], @@ -249,55 +266,71 @@ class GroundingSAM(Tool): [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}] """ - _ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws" + _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" name = "grounding_sam_" description = "'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." usage = { "required_parameters": [ - {"name": "prompt", "type": "List[str]"}, + {"name": "prompt", "type": "str"}, {"name": "image", "type": "str"}, ], + "optional_parameters": [ + {"name": "box_threshold", "type": "float"}, + {"name": "iou_threshold", "type": "float"}, + ], "examples": [ { "scenario": "Can you build me a car segmentor?", - "parameters": {"prompt": ["car"], "image": ""}, + "parameters": {"prompt": "car", "image": ""}, }, { "scenario": "Can you segment the person on the left? Image name: person.jpg", - "parameters": {"prompt": ["person on the left"], "image": "person.jpg"}, + "parameters": {"prompt": "person on the left", "image": "person.jpg"}, }, { "scenario": "Can you build me a tool that segments red shirts and green shirts? Image name: shirts.jpg", "parameters": { - "prompt": ["red shirt", "green shirt"], + "prompt": "red shirt, green shirt", "image": "shirts.jpg", + "box_threshold": 0.20, + "iou_threshold": 0.75, }, }, ], } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: + def __call__( + self, + prompt: str, + image: Union[str, ImageType], + box_threshold: float = 0.2, + iou_threshold: float = 0.75, + ) -> Dict: """Invoke the Grounding SAM model. Parameters: prompt: a list of classes to segment. image: the input image to segment. + box_threshold: the threshold to filter out the bounding boxes with low scores. + iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold. Returns: A list of dictionaries containing the labels, scores, bboxes and masks. Each dictionary contains the segmentation result for an image. """ image_size = get_image_size(image) image_b64 = convert_to_b64(image) - data = { - "classes": prompt, + request_data = { + "prompt": prompt, "image": image_b64, + "tool": "visual_grounding_segment", + "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } res = requests.post( self._ENDPOINT, headers={"Content-Type": "application/json"}, - json=data, + json=request_data, ) resp_json: Dict[str, Any] = res.json() if ( @@ -305,14 +338,19 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: ) or "statusCode" not in resp_json: _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") - resp_data = resp_json["data"] + data: Dict[str, Any] = resp_json["data"] ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} - for pred in resp_data["preds"]: - encoded_mask = pred["encoded_mask"] - mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"]) - ret_pred["labels"].append(pred["label_name"]) - ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size)) - ret_pred["masks"].append(mask) + if "bboxes" in data: + ret_pred["bboxes"] = [ + normalize_bbox(box, image_size) for box in data["bboxes"] + ] + if "masks" in data: + ret_pred["masks"] = [ + rle_decode(mask_rle=mask, shape=data["mask_shape"]) + for mask in data["masks"] + ] + ret_pred["labels"] = data["labels"] + ret_pred["scores"] = data["scores"] return ret_pred @@ -321,8 +359,14 @@ class AgentGroundingSAM(GroundingSAM): returns the file name. This makes it easier for agents to use. """ - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: - rets = super().__call__(prompt, image) + def __call__( + self, + prompt: str, + image: Union[str, ImageType], + box_threshold: float = 0.2, + iou_threshold: float = 0.75, + ) -> Dict: + rets = super().__call__(prompt, image, box_threshold, iou_threshold) mask_files = [] for mask in rets["masks"]: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: