diff --git a/examples/va_example.ipynb b/examples/va_example.ipynb index d43173f3..2bddca05 100644 --- a/examples/va_example.ipynb +++ b/examples/va_example.ipynb @@ -10,7 +10,8 @@ "import pandas as pd\n", "import textwrap\n", "import json\n", - "from IPython.display import Image" + "from IPython.display import Image\n", + "from PIL import Image" ] }, { @@ -358,6 +359,243 @@ "ds_ct = ds_ct.build_index(\"analysis\")\n", "ds_ct.search(\"Presence of a Tumor\", top_k=1)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool usage" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "t = va.tools.GroundingDINO()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "14" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans = t(prompt=\"shoes\", image=\"/home/shankar/workspace/img/shoes.jpg\", box_threshold=0.30, iou_threshold=0.2)\n", + "len(ans[\"labels\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['labels', 'scores', 'bboxes', 'size'])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.41,\n", + " 0.36,\n", + " 0.36,\n", + " 0.36,\n", + " 0.35,\n", + " 0.34,\n", + " 0.34,\n", + " 0.32,\n", + " 0.32,\n", + " 0.32,\n", + " 0.32,\n", + " 0.31,\n", + " 0.31,\n", + " 0.3]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans[\"scores\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "t1 = va.tools.GroundingSAM()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "ans = t1(prompt=\"bird\", image=\"/home/shankar/workspace/img/birds.jpg\", box_threshold=0.40, iou_threshold=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[0.05, 0.03, 0.22, 0.33],\n", + " [0.71, 0.38, 0.94, 0.95],\n", + " [0.45, 0.26, 0.6, 0.56],\n", + " [0.2, 0.21, 0.37, 0.63]]" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans[\"bboxes\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['labels', 'bboxes', 'masks', 'scores'])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "img = Image.fromarray(ans[\"masks\"][2] * 255)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAH0AlgBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+iiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivSPAvwZ8QeMohe3B/snTDgrcXMTF5QV3Bo043LyvzEgfNwTgivU4P2b/Cq28S3Gq6zJOEAkeOSJFZsckKUJAz2yceprL1n9muzfe+h+IJ4sRHbDfQiTfJzjMibdqngfdYjk89K8k8Z/DnxH4FdG1e2ja0lfZFeW774nbaDjOAVPXhgM7WxkDNcnRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRXonwg+H8PjvxLMdQ8z+ydPRZbkIwBlZj8kec5AbDEkdlIyCQR9d0UVHPBDdW8tvcRRzQSoUkjkUMrqRggg8EEcYr50+MPwe/sr7R4m8M23/Ev5kvbGNf+Pb1kQf8APP1X+HqPl+54fRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRViwsbjU9RtrCzj8y6upUhhTcBudiAoyeBkkda+2/CPhex8HeGrTR7COMCJAZpVTaZ5cDdI3JOSR0ycDAHAFblFFFFeD/ABT+CNvJZya34OsvKuI9z3OmxZIlBJJaIdmGfuDggAKARhvniiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivSPgZo39r/FCylZIJIdPikvJEmGc4GxSowfmDujDpjbnOQK+t6KKKKKK8P+MPwe/tX7R4m8M23/Ew5kvbGNf+Pn1kQf8APT1X+LqPm+/84UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV9P/AHwb/YvheTxDeQ7b7Vf9TvXDR24Py4yoI3nLcEhlEZr2Ciiiiiiivn/wCOnwx+/wCLtAsf7z6rFEfoRMEx/vbyD6Nj77V4BRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRX3npOmw6No1jpdu0jQWVvHbxtIQWKooUE4AGcD0FXKKKKKKKKjnghureW3uIo5oJUKSRyKGV1IwQQeCCOMV8WfEPww/hHxzqel+R5NqJTLZgbipgY5TDNy2B8pPPzKwycVy9FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFdZ8MdNm1X4m+HbeBo1dL1LglyQNsR81hwDztQge+OnWvtOiiiiiiiiivn/wDaU0T/AJAevxW/9+yuJ9//AAOJduf+uxyB9T0rwCiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiivXP2dYIZviNdPLFG7w6ZK8TMoJRvMjXK+h2swyOxI719R0UUUUUUUUV5n8edNhvvhXeXErSB7C4huIgpGCxcRYbjptkY8Y5A+h+TKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK9c/Z1nhh+I10kssaPNpkqRKzAF28yNsL6narHA7AntX1HRRRRRRRRRXB/GeNZfhJryvNHCAkTbnDEEiZCF+UE5JGB2yRkgZI+PKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK3PB2vt4W8Y6VrStIEtbhWl8tVZmiPyyKA3GShYduvUda+46KKKKKKKKK8/8Ajb/ySHXf+3f/ANKI6+QKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK+r/gX4u/4SPwMmm3DZvtG22z8fehIPlNwoA4BTGSf3eT96vUKKKKKKKKK8/8Ajb/ySHXf+3f/ANKI6+QKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK6TwH4rm8GeMbDWUMhgR9l1Gmf3kLcOMZAJx8wBONyqT0r7XgnhureK4t5Y5oJUDxyRsGV1IyCCOCCOc1JRRRRRRRXmfx51KGx+Fd5byrIXv7iG3iKgYDBxLluem2Nhxnkj6j5Moooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor3T4BfEFbG4Pg/VJ40t53MlhLLI3EpIzCM8ANyw6fNkclxj6Looooooor5I+NXjG48T+Obmw+5YaPLJa26FAG3ggSsTk5yycf7IXgHOfN6KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK+i/hN8aIbu3TQfF19HDcxJ/o2pXEgVZlA+7Kx4DgdGP3u/zcv7pRRRRRXlfxc+K1z4CuNP07Sbe0uNQuEM832pXKxxZKrgKVyWYN/FxsOR8wNfLl/fXGp6jc395J5l1dSvNM+0Dc7EljgcDJJ6VXoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor2D4b/ABxuPCunRaNr9tPqGmxcQTRMDNAgBwmGwHXO0DJG0Z5IAUfS9hfW+p6dbX9nJ5lrdRJNC+0jcjAFTg8jII61Yooor5o/aRsbiPxlpN+0eLWbT/JjfcPmdJHLDHXgSJ+fsa8XoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooorU0DxHq/hbVF1LRb6S0uwhTeoDBlPUMrAhh0OCDyAeoFe9+Ef2idOvNtt4rs/wCz5uf9MtFZ4T948py68BRxvyST8or2SHVtNudLOqQahaS6eEZzdpMrRBVzuO8HGBg5OeMGsP4d+JLzxd4E03XL+OCO6uvN3pApCDbK6DAJJ6KO9dRXk/7Qejf2h8Ol1FEg8zTbuOVpHHziN/3ZVTjuzRkjgHb6gV8sUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV6J8L/AB9aeGHv9C16KSfw3rCGK7Cu+YMqVLgA9Cpw2PmwFIOVAP03o2i+GvAenJYacsGmWt3dgIk1yx82dwAFUyMSWIQYUenSugqnq2mw6zo19pdw0iwXtvJbyNGQGCupUkZBGcH0NfCl/Y3Gmajc2F5H5d1ayvDMm4Ha6khhkcHBB6VXoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooor2D4QfE7+zdRTw/wCKr7ztEl8v7K94PMW0mQr5fzMfkjG0eoVgpG0bjX0/RXzB8dPh7/wj2sP4ns33WOq3bebFjHkTFQ3UsS28iVuAAuMeleP0UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUV658Jvi+vg23TQdZgkl0mS43pcq7FrQMPm+Tncm7DYXBGXPzE4r6H0Xxj4c8Q6XNqWl6xaT2kCF53L7DAo3cyK2Cg+VjlgMgZ6V4p8a/irpetaOfC/h66+1RyShr66QfuyEY4jUkfN8wVty8YUYLBjjweiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiv//Z", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAH0CAAAAADEOwwIAAAJeUlEQVR4Ae3d63ajNhSA0UxX3/+V00wuNhdJIHFIDWf3zxgJMN76Rk6na3Xe3vxDgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECCwU+DPzvOc9v4gYPagqL+AVLeZzDyr+h7kNtEpvQRUUlmMrbL6nEe3YJod0plxFA/KXX2cCq/o9TnIpm7zNVPN6mOaXlUPTZVmuytp1fH+qU+Z2SHw3trQdlx/21OE1V7a7W62z2i/w01nhdVc2D3V7Dmn+Sa3nBRWa1n3NbPvrNb73HBOWAGLqqw1orDWJo+R3cHsPvFx69u/EFbIEitrySispcjzuKeWnnOf73DjV8KqLm5fK31nV9/0NhP+5L22lP2lsJxY2rEmGAdf9qd48A1f+XJhVVZnpJKRaypvf/lhYUUuobIemsJ6UMxeDCYyeNnsre9xIKzYdVTWt6ewYsN6U9YXqLCKYR3I48ClxUe56KCwwhdOWX9JhRUelm9DYcVH9XlHe5Ydq5zW0TKOXl9+qkuN+io8ZbmUJaxTwvJzlrDOCSt9WcIqhBXyRRZyk8LDXWRIWKctVO6yhLUOK6qIqPusn/ACI8K6wCJd8RGFtVq1uI0m7k6rh3z5AWGduUSJyxLWMqzQGEJvtnzSlz4W1ksvz3UfTljnrl3aLUtY54aV9k/ghbUIK+0Ws3A4eiiso4KuLwoIq8gSOJh0CxRWYENu9RQQ1tPipFc5tyxhzXPKWcHcIORIWCGMbrIUENZSJP445S4orPiQ3PFDQFgyOEVAWDPWlN9aM4GoA2FFSTbukzFXYTWCMDUuIKxxO1c2BIQ1xcn4nTX9/IGvhRWI6VZPAWE9Lc57lXAnFNZ5OaW+s7BSL/95H15Y59mmvrOwUi//eR9eWOfZpr6zsFIv/3kfXljn2aa+s7BSL/95H15Y59mmvrOwfmX58/3Ru7B+Jax8f120sKZh5Vv/6acPfS2sUE43+xEQ1o+EX0MFhDXj9F044zhwIKwDeC6tCwirbmPmgICw5ni+C+cew0fCGqZzYUtAWC0dc8MCwlrQ+S5cgAweCmsQzmVtAWEtfWxZS5GhY2ENsbloS0BYKyFb1opkYEBYA2gu2RYQ1trIlrU26R4RVjeZC/YICKugZMsqoHQOCasEpqySSteYsLq4nLxXQFhFKVtWkaVjUFgdWE7dLyCsslXslhV7t/ITv9iosF5sQe7yOMKqrGTCTaYiMTYsrJpbYFmBt6o97cuNC6u6JBlzqGJ0Twirm6z7gpSFCqveSVAQQbepP+dLzgirsSw5k2iAdEwJqwNr6NSkdQqrVUvSKFoke+eE1ZRSVpOnMSmsBs7HlLLaPtVZYVVpviaUtQFUmRZWBeYxfLCsg5c/HuNqL7J+7p51Gv9fHifWtWNtJzacx/CF28/08mcIa8cSZQ5kB0/xFGEVWRaDY2WNXbV466seCmvXyv1JHckuosVJwlqA1A6VVZMpjwur7LIeVdbapDEirAbOfEpZc4/2kbDaPtNZZU01Nl4LawNoOu1H+KlG+7Ww2j6LWWktQKqHwqrSlCc6vg/H/1NQ+a0vNSqsSy3XdR624zfgdT7U6U+6cy/KjGvHGqlwZzE7+xt5gpe/RlhDS+SH+C02YW0JVeaXaS2PK5elGRbW8FJLqUUnrJbOxtxmWol/yBLWRjzt6c202pffeFZYBxf3O62d/5548M0udLmwDi/WT1M/v05vmPe7UFjTDsZel4oau9ONrvr3Rp/l//woebemirodqwLTM9zYstIGJ6yegmrnNsqqXXL3cWGFrLCylozCWoqMHVfLyvpdKKyxkMpX+fPSh4uwHhTHXnxuWVm3pwJddQsvnGuoLVDOKqmwHasdS89sOaFybj33veS5wgpctnJZgW9woVsJ6/TFek+5ZwkrMixb1kNTWA+KiBflsjJuWcKK6Ol5j3JZz/k0r0CEL3Vpf8rHbMcKDytfRCVCYZVUoscStias6Ij8NSmfogl/L8WXtL7j/OesjMh2rHUVASOzlGYHATe/xC1SfujfWZnvXSupcNKP/Wtp8f0dau9CgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIECBAgAABAgQIEJgJ/AeQClMrDRAyZwAAAABJRU5ErkJggg==", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(img)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "t2 = va.tools.CLIP()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'labels': ['crow', ' cheetah'], 'scores': [0.9999, 0.0001]}" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t2(prompt=\"crow, cheetah\", image=\"/home/shankar/workspace/img/birds.jpg\",)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'labels': [], 'bboxes': [], 'masks': [], 'scores': []}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t1(prompt=\"bird\", image=\"/home/shankar/workspace/img/shoes.jpg\",)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 51ae01ca..96aed1bd 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -63,7 +63,7 @@ class CLIP(Tool): [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}] """ - _ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws" + _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" name = "clip_" description = "'clip_' is a tool that can classify or tag any image given a set of input classes or tags." @@ -106,6 +106,7 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: data = { "prompt": prompt, "image": image_b64, + "tool": "closed_set_image_classification", } res = requests.post( self._ENDPOINT, @@ -119,10 +120,11 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") - rets = [] - for elt in resp_json["data"]: - rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]}) - return cast(Dict, rets[0]) + resp_json["data"]["scores"] = [ + round(prob, 4) for prob in resp_json["data"]["scores"] + ] + + return resp_json["data"] class GroundingDINO(Tool): @@ -148,6 +150,10 @@ class GroundingDINO(Tool): {"name": "prompt", "type": "str"}, {"name": "image", "type": "str"}, ], + "optional_parameters": [ + {"name": "box_threshold", "type": "float"}, + {"name": "iou_threshold", "type": "float"}, + ], "examples": [ { "scenario": "Can you build me a car detector?", @@ -162,18 +168,28 @@ class GroundingDINO(Tool): "parameters": { "prompt": "red shirt. green shirt", "image": "shirts.jpg", + "box_threshold": 0.20, + "iou_threshold": 0.75, }, }, ], } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: + def __call__( + self, + prompt: str, + image: Union[str, Path, ImageType], + box_threshold: float = 0.20, + iou_threshold: float = 0.75, + ) -> Dict: """Invoke the Grounding DINO model. Parameters: prompt: one or multiple class names to detect. The classes should be separated by a period if there are multiple classes. E.g. "big dog . small cat" image: the input image to run against. + box_threshold: the threshold to filter out the bounding boxes with low scores. + iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold. Returns: A list of dictionaries containing the labels, scores, and bboxes. Each dictionary contains the detection result for an image. @@ -184,6 +200,7 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: "prompt": prompt, "image": image_b64, "tool": "visual_grounding", + "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } res = requests.post( self._ENDPOINT, @@ -198,13 +215,11 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: raise ValueError(f"Request failed: {resp_json}") data: Dict[str, Any] = resp_json["data"] if "bboxes" in data: - data["bboxes"] = [ - normalize_bbox(box, image_size) for box in data["bboxes"][0] - ] + data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]] if "scores" in data: - data["scores"] = [round(score, 2) for score in data["scores"][0]] + data["scores"] = [round(score, 2) for score in data["scores"]] if "labels" in data: - data["labels"] = [label for label in data["labels"][0]] + data["labels"] = [label for label in data["labels"]] data["size"] = (image_size[1], image_size[0]) return data @@ -241,6 +256,10 @@ class GroundingSAM(Tool): {"name": "prompt", "type": "str"}, {"name": "image", "type": "str"}, ], + "optional_parameters": [ + {"name": "box_threshold", "type": "float"}, + {"name": "iou_threshold", "type": "float"}, + ], "examples": [ { "scenario": "Can you build me a car segmentor?", @@ -255,18 +274,28 @@ class GroundingSAM(Tool): "parameters": { "prompt": "red shirt, green shirt", "image": "shirts.jpg", + "box_threshold": 0.20, + "iou_threshold": 0.75, }, }, ], } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: + def __call__( + self, + prompt: str, + image: Union[str, ImageType], + box_threshold: float = 0.2, + iou_threshold: float = 0.75, + ) -> Dict: """Invoke the Grounding SAM model. Parameters: prompt: a list of classes to segment. image: the input image to segment. + box_threshold: the threshold to filter out the bounding boxes with low scores. + iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold. Returns: A list of dictionaries containing the labels, scores, bboxes and masks. Each dictionary contains the segmentation result for an image. @@ -277,6 +306,7 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: "prompt": prompt, "image": image_b64, "tool": "visual_grounding_segment", + "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } res = requests.post( self._ENDPOINT,