From c0dde366ad5348f2332eae6ccee2732b88135ae7 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 23 Apr 2024 17:24:23 -0700 Subject: [PATCH] Add Custom Tools (#61) * added custom tools * updated readme * register tool returns tool' * Add a new tool: determine if a bbox is contained within another bbox (#59) * Add a new bounding box contains tool * Fix format * [skip ci] chore(release): vision-agent 0.1.5 * Add Count tools (#56) * Adding counting tools to vision agent * fixed heatmap overlay and addressesessed PR comments * adding the counting tool to take both absolute coordinate and normalized coordinates, refactoring code, adding llm generate counter tool * fix linting * Remove torch and cuda dependencies (#60) Resolve merge conflicts * [skip ci] chore(release): vision-agent 0.2.1 * make it easier to use custom tools * ran isort * fix linting error * added OCR * added example template matching use case * formatting and typing fix * round scores * fix readme typo --------- Co-authored-by: Asia <92344512+AsiaCao@users.noreply.github.com> Co-authored-by: GitHub Actions Bot Co-authored-by: Shankar <90070882+shankar-landing-ai@users.noreply.github.com> --- README.md | 30 ++++++- examples/custom_tools/pid.png | Bin 0 -> 9070 bytes examples/custom_tools/pid_template.png | Bin 0 -> 4274 bytes examples/custom_tools/run_custom_tool.py | 49 ++++++++++++ examples/custom_tools/template_match.py | 96 +++++++++++++++++++++++ tests/tools/test_tools.py | 70 +++++++++++++++++ vision_agent/agent/vision_agent.py | 23 +++--- vision_agent/tools/__init__.py | 6 +- vision_agent/tools/tools.py | 87 +++++++++++++++++++- 9 files changed, 343 insertions(+), 18 deletions(-) create mode 100644 examples/custom_tools/pid.png create mode 100644 examples/custom_tools/pid_template.png create mode 100644 examples/custom_tools/run_custom_tool.py create mode 100644 examples/custom_tools/template_match.py diff --git a/README.md b/README.md index 48675938..835b8b99 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ pip install vision-agent ``` Ensure you have an OpenAI API key and set it as an environment variable (if you are -using Azure OpenAI please see the additional setup section): +using Azure OpenAI please see the Azure setup section): ```bash export OPENAI_API_KEY="your-api-key" @@ -96,6 +96,31 @@ you. For example: }] ``` +#### Custom Tools +You can also add your own custom tools for your vision agent to use: + +```python +>>> from vision_agent.tools import Tool, register_tool +>>> @register_tool +>>> class NumItems(Tool): +>>> name = "num_items_" +>>> description = "Returns the number of items in a list." +>>> usage = { +>>> "required_parameters": [{"name": "prompt", "type": "list"}], +>>> "examples": [ +>>> { +>>> "scenario": "How many items are in this list? ['a', 'b', 'c']", +>>> "parameters": {"prompt": "['a', 'b', 'c']"}, +>>> } +>>> ], +>>> } +>>> def __call__(self, prompt: list[str]) -> int: +>>> return len(prompt) +``` +This will register it with the list of tools Vision Agent has access to. It will be able +to pick it based on the tool description and use it based on the usage provided. + +#### Tool List | Tool | Description | | --- | --- | | CLIP | CLIP is a tool that can classify or tag any image given a set of input classes or tags. | @@ -114,11 +139,12 @@ you. For example: | ExtractFrames | ExtractFrames extracts frames with motion from a video. | | ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image | | VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt | +| OCR | OCR returns the text detected in an image along with the location. | It also has a basic set of calculate tools such as add, subtract, multiply and divide. -### Additional Setup +### Azure Setup If you want to use Azure OpenAI models, you can set the environment variable: ```bash diff --git a/examples/custom_tools/pid.png b/examples/custom_tools/pid.png new file mode 100644 index 0000000000000000000000000000000000000000..713b731716a42104aaad5d17de26db26c053da67 GIT binary patch literal 9070 zcmch5Ra70_*5x@!2<{LdcyM=jcXx;2esBv1x8MYqgS)#+aM$1#+(~e6zI*R~yZfPE zd)KH@bIm>1TzgeL)Yv;hSy2iZ0UrSX03gdqi>m?v!0Go}@8dfG5EcOsy+2?qMHEB; zfZABZXA{_W9gMY@n6j;yl$fKPBS_83#LQgE+`$}VZ6*aW@scvJw>M!mVg&%SLkhLv z^b;3xx;KXb2{pyvYvRp?0|2aKY~WMUmT(!805sDNP`^bZu`BdGbJNmB>gr<$@Tx>< z9uHQc+LS`(O&J&~x-wwU%Flc*?z*1OJ>E724%S>I1rD-(puy*oG2&#RK!BHND?3^I z4?!A%AjLMg;hs-c&>oP9?a?n4_xGQMq52<81@CvgoS{iZU(+){RZ2&1_h_EJc(}Jz zCUn0?TD$-xJ3x{QX9`sw0Dj_D%Nc#Ec!V4Y=dgkzQh+DVZZ6U&EphqBu2d#$&=I<- zBDz3eV#-jiPn*x|PT_-&=!|8>tA7rEys0*q9hQ!RD*Ner?u$9UVhP5Rh;Q5&yj)-^ zyt88fGY8IZV}LSWgs&)BzY5wFRx1sCxB+R^R6=O6@aK-6J3!`D*vkASQ}c`HPwUdhav$j;pT?Xa`M;*H(?f)Ua`ES(gj&p%sh z2YV|>hp0-JMy^QBx_5C)@Cc2g&qBvVIsu4xgYYgVU`?mt4~0JjLxe->fN;CLl3_S4 zeK^3@{u&w5Fr?iw?gEHys^FL0Bxh(0=`lE+I4u~8u)!lRoTaiHt>s9V!b4$Zs~*Q7 z;rtVng z$TVq1Z3+=?`-JrqeL2ptg=Mm5sBb@ZXbZ?KMnMtsS)xaQL(q_ev*0G^Kz7qRv~e=H zDdzgyZX|N$ZyIWfuuZbVTu&#zuhZ$x%9I|;k4)bfxzxC2zG%o{F3s62kQdG#iJljl z`&Cj%DF9vu6A}-Ae$|R|ttJFVXP!Grwn@Jd{KJonoBTd6w8VSmkAB&p@`83lufkx& zBtol#<~{YNY`|>6eoAdtZOWg_KU-qYpSbY4Sy%A%NcZ5I5cs|m{!Fj2s6d|H)?#x< zuvN|NTvEB`yX4{{dt&1gO*2{RC4>}=!nru@8sZ(X*^09#J&W0ie!RzVc*F-b z>3bCcWxY(4Qn55jO$~CCU3HMthnJV~*z4F1b#KbgX#^j!_d}h;kstKW^86bXRUs8< zI)(j35mW*o2(iOf#kzqHIRhN2xejrrIE;azfFrFQAlF^G`YzAOcg^aDRNrA z+sELHnu%Pl_ZhN7zlMlP)PEewrGBNZG7v<bV$qg}65w zB_iEbAWb`tGkiqv$@Y70*|e&iZar}us3g88Am``jnGI>BfcJCjw#&N*{^%8=%VJ^0P`r}ZAg7rG14v1e*)@lvq^)n5vN zF%HiNBLtPf91*9GDnU)9j&1`V25)Jya8A}if)OFf%;am_ReR}_I{AfEdRPibaDc0-D5pzDu$32%vCWP0%eBip-3b*cKYRb&^G#8gHU zLS(2`hgM6r_tyQ^m+X=3$<}s@R=-*bW(rv#rR7H@c#yP;wER|Kmn@6UeVqxP^r^;% z2D=8|MvR7F-K)l5)@O^n#S5i3>3LGlMZwBN%2-K-Nm)sWNvlZ*>a+_1`0^8w(Neim zxB0nwxOpvhzjVH6g`JuGnKEnShVq8u75tUk75X#lm4OwFm9PdT>$_vCL+-=iLnqh$ zWAzjEqqX^?#f8P2>FBwMIfJ}*y}J&FW+Is>5hnTAVy&WPg^Pv_F+cfDO(F4|@@{4S z_E()3M(DhNeBdg{5M*|XASmjCU#NHJ4w4!&B62x$ES~!(6G;?FJM#D$*8CtzuJ0S4 z^gh*pe;AC4K8dcz?`Fp{$FS?;DrH`_oOSp+nB;(Ok#5`8cO8D@$F?@UOy6slylhcl zZ_$~LB>52gJErWDo6|g@%lz6h|BP?b1Lqaz)i*@BsE>#d1Kv?{IGxFo$-T)u$>j{O zdbip~da%0im7024HQw@E1@(tykGzj~-Z_n}ZUv2f-sLwECqM$J46F>*3_s30dk*`w zCY`2odp*0S(~48dl^;uIwe4orCQ}W#OVy2zjs8blhjIHg!`tCv&G`}JtF=qcu{#1( zWomJA^2PYYC;1WiY56Wvj`)7dcy6~K9BwT>Banzehe0r7hJIBGhQ6P{Kov(7*E+QW zeKSMxdyPlib7{F~?AhhE}dN#5)}&*w*{BR3Vd z9fHZcoW7?Wwb5od(JpvH#LZr73`~C25r$D6x*f0Aw||no7k%#Tmma=8+FiGfup1Au zr7%xDVE9tqPhJTeKdnryDxG^wUvT+zz3IQLzpn2P4qXj762CF`P=YaQI3VP4M5%c9 zZpKn^ZfH~`l!wK}tHy>-_LtNCs;L8zRji z%E7Oq-IA+qRVipFuqjrQsg(0e!%1&qtu&iWm2sgokmi!ZE#k6>V^w9dVSA(s zQ!r#}wT)a%`b)!A#H+}V>HRHznSG|gY_74`rHFX+%hw|F^GV;ah@E+-N6Q*ZyUdv( z1`IvdlT|r>FVNT3J^j#Vh>a7)9R7ndJ;g z{EaAKDPk<*)?2PqT91N#j1!5t8kre|gDXIb!nDCckTa3Bu4-yf2Ptqtv7P+YI=qJG zqPM?1Ntdz6A!Pbm1)Git!hO~rr8dxc9O`QbnJ~vOzZxU@hW=+D8>HLX@VrmTVnQwb=`)*h@f0ypv!c(D+924V}VUrh+bF1e=G`0wx zd*!@=PRHNvvWRkg%{eWtidVCVYmS5RK@EGI@>b2-CNqQaw!s_slay-zc9ELgPH)kh zuIPONTu0y5h~}f?{jG@`J^Chf2d}Hf%M(34@+QuC?s;{4atFGl+p*jEvr4{HzGIsS z8#>#!HJXj(S~};&G~S%Hw)Oq?lrxb#>f7tnth&3uT)CUfeVYH7e8k|Pq{Qp@BHx!5Wtl&f{UrGOG*?;i3-bVK zl2G8nqw5H6A61Ic)5yx|%7*Vbe=#smc&L>KnixXywREWgzx$z<=I4j zfR3Y#RvVSRf-|;^*6B)_BB(AmsWZDYKBw~&ra}rKQW=3W*K?D*K5`NHE z3tm-m$^WFkM*?J4uC7kJOiUgg9*iDrjE*2nCKetZ9wug1CRSF4cLjrsmxHT`Cxe3v z`M*v6$4A`U#SCQagl|a_!@0I`Y7i8i8m-7GO{!@pa>7U~N$7257>A%?bRs|9Gng09Q z1QGBIx8VQ)EOr@j5p_@CSyq($DQ%U+R9F7nTZ+qo9=L};7%^zpPxVu9+^>cjgVL{W zPfto+imF?0in5ZRhYKN~_mU2YFzlfT!sp}|^%8)9e8A7ycm!-n7>$!QuWD`>@MY8F zp(ATjC>eHQ)on|TYwaY?wse|w`8MkdHt$HNFVotB{ikuC={5&6H)Jl@!y4DZh!0A} zrDtn~2tgy#BNN8mcCA5u*v0Yu(Kk6ibu;;eNU(XUNpB&}$+l>cuf1=hAC?xC^}r}2 znyCd&g1wJ=QR-lvrSGoIm+Hg4=fo2Rxc9@e-(vK|InrFUwB zpr)iOYY-C)h^aJ@V5}rEZ*#AqmbuxoZ!N3q&l_iC>?noJ1SDgM%8-=Yeb-N}16yQYNQHxNf+{2EN_3q;-uMjw1?-3WTEGi;54$fgown{U z)ytAWl14ac_7dq+T;mk}pFA8${hFuOxFoqhthsKl7hFYgu=*i?TqXa<=}M<@IC5bs zj<4%XRAwr3u+%}}e64Y)5el=$e&p_~I2;DFNh1AO`qCkavs*2pAk9|w!!Oh8MTpZf zWSI*bc#+wR!aVA0!(o6I0i<7KDH&v%kkIH(l|bmZy8gA7x!s}BW1w8fuNyBSietK? zGGbQ$VMY_O}4@0jKH3cYl4VY$(o z+qZ?+Us!Y|Kwd<~7yJB<4;%_=sShM~Vx{_GmH~afuGp16YVB1oQ^|$ArY{Ky=0DV~ zOxHzyequs8Bh7DnnC9^sX!(D}*)L4g7vl4N9kWL(T@Qu{cK@$-KwITsmzTy+RLA`sSmTUPhmCnS{U~v$1LP z6d?OsVO#&_l~5~!}M35|gg1h}Q zun}}aBGqZ)rS@YlekZIImr~K}TycNK!2rcnE{_sZ-G<(TO5FtG1FmP9Vp3iP-#N%7 z`w~)mnFf3@U{MHP+0OHvf*vh>lqJmAz}|FR1|hDH=BmsQh~AnQc%c?}JiG4e-vwJJ$>D$m*(j0~NVzY~zm_+47G{VutN!(n20n<)bq zU>}cbob1^oU%lM%%*G8wV1|5XJl!mlc%9E}ghgzw zCV2hD+dIUvkOnP~+0!GlNGhxV;q@(vF_? zT?)(zf1kxyeiiWT;P0ZVjR0O2>a|S3&%tV+6D8b=Rgp72@2YS&kFS>kD=EOnLnP@T zbhRj?XD}bocc?3RE?__BV_nXV*mz^p0uj3nmtcdtae{`z^B%M&X z=ycR(g8Y7x@ihcJo>T&`Xao-M)qU~wvX zb`RSX5+2uJSq*v97o`9e7W)MAoS<5oVRVwdjf`*z8_ixFjIBX8xGcWxenTvt6A*59 zUynU+23#i}C^us_35Z`C*k?KZ*5LDQTw?MAl{_4Vg(V&RpXnp-nW2Eg+QKo z%1_W)q*hisZBltho%ca{Y*v(I%ubXvXqA zUy_vbm-~=DcSg|#co!Ld*i6N0S5&K>HU6ECXSH^{6u6h$Tid}Q-hLY;gyhrSJ6MTr zvp2h!w?fa_d|q7XPA<;`hma7^M$g{P|A!%6wA*b=!OfWaJOHLRj&=_A{f8(wPn7Mg zrY)upW(A(@w8oSG_JAv6E_gGm8<_{aDPE+W`+PQ_i^hGU`Xq#jEhJYoeu57AR<;BL zd*yG#;tToy)Ys9ZS^|k^%6IvqH#V!@tJCpmmW{P+wN%+kE4+K&L0UQba<|)txM_a& zV~sh9V@w~?Qrh{7VA9j&4-KIa_Wh6770x*#m<2aD^euj3Mg=Ye{H@kP+HHqUJb2o@ zR_Ag>RtP(leaLZ;MV?{Na^2)jq|7^0#f3Mp)1;e+8U4#vXbNaX{pt^8sv0L7662N4 zt+yViZ+`w<=||W(fUHuxaU>~hEBk9u)b6g)eXg)=ZzJpg4v)>wn)|p!H2UH`Q%$PO zD0V<>X==dUf-F&1HbfC&m?f)Tu}s*_?c5CtbU1Vrln5)QuK}T_3r~cD^`xV03mv3e z%~`bfKp9ktA=+*b^$}iM!4ldh$lvCi@+q12|4tlf8zX~6g;kRC&fa&#J zUI8=1Kg#$E0vh)zniYTPN#Ur@MX=O1IAt;Gc}s}W?*4T zTL(@_7sS|l2m#AyDbJ5dIZnvLQ#8QvM(X`L@CF(s%o+H@Qpgdd99BCxUqSZ09X3GD z5tQv0nI+ws(L{Q7iku351*^vg)#uMIVcBUlMBVZM)@Qx0+Wx!~aR^X0op|n{RVJlo zoST=FA(~PL9#$x{!E03xuMe>Up(U1!uEm14W*y!6QZVXF@E5m7Q>{I~A4RnhLHVan zq6E!1US+?d*zlvZX}070fVaoVvTfFBQqa_!du7TkLOtxC_sSx)9S1oQF!_1D>W4dF zYm^XaX^o%-@>N2rGX>;j;DX>xXR_K{b-kX?NKFmC6KC)xll*{dA98Q;uP|sEtCSF~ zOp;nS3HaOAF$I$bx7oq`qtFF^P!thY?utwxlma3k*zZmb3xeFcN?2;%ct@VfJ1mVz zF5`%eU{LFFqKfAzHNX8rl`l%#))#<67%SW4rh6zht$o;QkyA{E^%_!!H5@5t`v}c) zSM2S3p=(YCp?vNO3~|V*L)C_R`+8RV(gh4ql8D08S}qtWQ+mP4c|9BVWxi0s4DAL% ze(a_$5}~BKm{J&YnH%c#2NMW1eLZQy6nBK5J)q9vea^1y;7bn2(7Fe0k71F?>JA;NgqW7Ik8gcBd*`6UwN**x!nYz9)bDuv-K%sG;e0|T z7W*)vr7e<>UcID^YAnZC89Sj%!IgVs4asgPaail|$$*d3Nv0uv`85V|oh{}nxS-z|uAJXK)oDm))la|G z|5EF*4-gR@M7MVc8J{!Eg)LL&O^-1&@E5m`(aP2y!aps_oGIvfu5~=T+#%op9*NCc z7vuG%ydUwPto#Z6?bKzhG?`z{y&9T^Pm9pSL{_34MQSOjUi632F^&hl!RNiUCN0-v zS_I>*w4Ns$1g$KfA6xobou<`Xq#izlauewkcKKf@19X!h=6p&O739-2PXp=Bv7ZSq zI8dT`RZ6I^MU0H;7{_{bYR2{(HanU7#3gO+>l)l6Mj<_3wDxm5ksE|X*i0A5wHMUU zQ0(c^9|LIFv$GFuXgMxo_qRP<^8;Q`Of~SJv^36}TV=bvx8Vrs=1#;8@$t{s69;~e z)4j@bPV8eVT_8p8>qIIFB^;qrCecHTVyk^T#f0fZ)|iu~3FUP7O0#F=#xMh80?HP` zox&MuHy3JEZDa1xGtA7g8#(Obaro1S?W*aN!jX)`*0T&jRiYjzYRzfOHU^j$588~a za{eIeX@IZoowjr{(Rqe`L}pBYI2k9RcKiwcjd%rqtTqw3XJ z-pw%hCejVhd5(vwi{?0!e*cLE4H)7QNk5Rh%F6CM0pFeDt|y5FT=1Dv$jIC3LjzgiJGf?*g8pn1tV4RZ^yIPAH%BEXo)qi_@!X)W z1h4XEdloEa>!5SS|CB-I7#8!8#0}^OdRWc(mTDw+9%e1nNJ^(ox?_Bt;zUryW_62& zb#uk&fPB4N1)tx(R3Yre%JZYp^r+^VbWdEkhvg;6{6vrrL6K`d=qC&qC7qd+wF@c1SsWs5#euoVbV;V0CKE9dHvs z06?ZF0ygD9u|k}NCi#s8^Z$o7$v3*Ecnhub&BmOwv62hn48NvG>=d-h`bDnOV3ct8 zqTO7jyuIu%Y6pWg%0`}=Z08@?JhAlA>K2g!Z)0`q`))g?a=4HQHSEfSD;~c3+Ynio^?h3x&YdR>iztj;L<9inWA9 z4Ra2drcC;VG_ z4V`0R*$ukKEbZs{c0%PaIG2w?4H|Lt(U{j(MZXXT__&Q_-P*U+jP6!J9MzF)T&iWm z@hE3-SY#F0DZd}6l+|KhJ;fmSX{gFvugfPUhZDSg%2qF92O+&;dIV6F)a$`4>6=32 z>93tZXAY6n@axeS97Ru9oTdbvDQ2dS!^0~@)PdK!w1E!;`KGeW4Ezl zUq{_fixS?6kM4+LYi<%%^cddjCxY8HU(@XpxOX#)^uvqV=&vtRWG8sA-Ol$~z}_F@kD>6-&?r42^!@TE2UvE-kxRVw>d2?4mS(7Q zcr?t(&_Z%8V=4%ksf{8M($@wrh)v1+l&HTYrkV*z02KlR=aX z$ri^1p+lgqnPE}nq-fQ_L`6L&N75k((?7Ae=?P1ff!?W5VTI(_w8E%rTQ#l)qW|yH hy#Fz|W`?|>rjrp8w(Qm~{c{v6BcUi>C2Ab_zW~(r0Y3l$ literal 0 HcmV?d00001 diff --git a/examples/custom_tools/pid_template.png b/examples/custom_tools/pid_template.png new file mode 100644 index 0000000000000000000000000000000000000000..c736c6cb43dae2643c328e2ad5f691adbe7bfb8b GIT binary patch literal 4274 zcmZ`+cQ{;I*B>Q%iy9@2(GqQxVHifQSB)!5Mi*hCw-^(mL>nRq(Gw$x=v@qgo9Lo< zAqde+MDO3oz2Cj}dEWQk=Q(Guv)1po>i%bA4D~gsDcC3g006bNmYNZMXU4z6WcUw2 zO@T!izY#ks>nQ^O6>*ejw%71DVx+37p|h%{s=JH3*Y6&-_70j3ZVp~ZdrdD}e@$Cg zS6c~d2>`$>JkONWB5|C#dA1YqqCEdud4hvd2tWcP8N9(=7o`mhp|c|){H7AkRE*$| z5fqHRXTcN#yAxxw)?PyUumC$^Cn8$hltD~awE1}Zu<2~e?{XHtJnh{FUw-|VD0oZb zxf)1?0N`(jgMboVD%^%Y)vqV*Y`Nh?BxRu|9EI2@7P>r(6RU;C`&%U&J60pXzpeN*_g>bcaYbA zM}Mju5Z^3rL}y+EAxjO>?UP zM(d@cho06J2huDIt)~@kq;3~Ir8)GdNk&~(@$>Z|qWEYJ0N3JE)*rps>;}kJlt{v` zO5v3Rq>F!OL^9WXWG2A1mTPlIQY{wBJ6V*=zPUEIi<*(i^|`vd9k4f*{HPry$^p})&$^WN2A zm>7OI%9A&=sE`_^;7{axy_8;*fs?M32(}T#U(Ha@U9VF|8uUJL^O%lke~RjR*p#SfDD{_?&~$ad zhny}dOmQwUoD@;tf4s_j@#s`X7^JY{#0w8Hdo1m898_Z~W`87jp4ng8e_5%jQG7oc zg-cVDD5q`y#j2fTp(nxn3GcRsv>14b7xNV>;{yzSz?M6!Tp`QmO}2? z%}?p&4~nS==1nCZQ8*iAf0(#)B=-X<2RfCMQ%SR({^KS#x0}@4W6_FrRnl1vl60o( zoVp7=hD$^yIUmoOL%d$c36WiX63v?&l6T@wm+Xycf5Q!NKHAN_rg~qeD!Xmq8MoAn zB`wOlShfVewI~dUq}I_LxFI5wra(1~-{dY2&7!fF2g&()QQTf!qcNJFBaNMmX(_@( zjr!PP^G#6oxDmFP?z2%>kMXE-@r%TOq*;l(qoF*V_wQ0E+OrSFt!Pp4z@7}KQt@ff zf3n@q<`8(!t)aY8wD?iPlQxsO2*H6}5T2%_ z%xK1fsck4HXx5bAGGxz#bIJvSXhnm2us5f}1q@kGgeNVOV!=Q%@@&va)NQ-)2_DfX zfsbm031Yk?+G;}yy1eYyc_?y;&7+J#LJ8EJCb=}S!4U5J_}xyXMd=~T9rlHqg$XxAU#Gy3>;w7l*`~9 zRfW-s%7{9R62O>7(PGYd>AR8^RVaXM2{OF01)0^at5bYJeO`{Jqt(Yihy-}o{32Cl zs>N@mplyCG0c$=*5UtaS(}eR8vK9FQ5)Da4x{N!G)#VQ6iDL_jR^PE=(~8q_a7x~< z9(-6b?|q!!UsGM}QXNo3UmbRDw`L5vIS$JoE!aL^25mQKDx?HT@!r;&E4qzJ0H^UHw%x zPUU~N>LThm)<-vNHE1+k zcoGIBa-M2HpMAZ7xKZ`&xIHFzJ+_Rk8N%v7@A45^AU5eZ z%pa09lK)7SO)e6NLmZf`BCg#dFEK&9Do5)=bE{TBC$JM%bXE<{H@D^^x@cd0oq%04 zgE7M>BT)L#RmwH3*1WdJ72$HaQM|!F^>SjfqQSn*w!fNXqO8WfCTMkTC4Q;Ab3RJ- zeGZ1_OT~m|+yY#n@b~y(-F&wE^&Cu2T8_7-J6qr+tM7pqvy7>nwU@GGqh*+AhDB)| zy+xp9Td7njbmo?u@O#4daJN@m8e8I9d0U#(S{;(E2FE&;lu0MFjoJ?q(!RxEU?Zm5LpBStnj>12a z%d$p?YlNAESy2sB>X3h-I{;4e%gY5Apsd48TQ{N-6T&)II++rnlguun4bNv-zB5P( ztm#7dV7Jk#dj{ft;zOxRn3dLdm?HkKF*R{a9h?{<&yBjfo`(xJ7aZGzQj~O~4BqJ< z@eT4+@K#7^DE~_eTbRB`Ci+SGBxJDKez+#z`z=>D@BOz9TYUjN zn1vCK6UTB#m(0Nq>iN>|3iH=?O?J^o1ZOYLG|dUjg{tbSh$>AT4#`4^?}l30MtMB_ zVXl9_FF7*>=ZTZ^o{bq&JN?<|pXRS;2lp&J(rt2hVtc`aiOIDX5+8P7N>2+Mazd6+ zCFs4B>ZYE{8j=oA{6?wnsBEL``$v&^K@|snJPnprUq7?dkg}5j9!;Z7?a_r(uY~((zDFjHZP93O8tO;WqW-o)NOVvP5 zpXvpY7q$%;vsM&u%T7KPJRSLzyei^nU;qm|*Zr6l?tzk3};d46qDpZchZ~0S9U|Zm;)4Q~v zxkk>M_nHrUQGOijRGTiRw^#EUXmPYzXTqn)uBMaV5$i%ub9mwU!G^*V!&mR)oy87I z16~6XbWhOex$na7`T4UK%x4J1Y;IGjR1Bm* z0R#f&z3xAN8L4UfMaTcaK~6qC9xyQ2&(BZPPg2y~%MmOtD=P~YlK@Ldh~N<--u`Yr zwkQ!dZ=OG${J$SH2XA{Xq=ygE-3@r<*VfM6*9Q&)T?P7k{h6l&3i(eYH}Aj7!WRg> zvVg@!#lZiK=7W6jKWJB$KWV?}`ZJvTl{1(j66N4xu7-4VaP!8erXUG{O3VL>^B<;v zGW{3T>>nybR$TVq)PI@&MZL-eX5fW%z_;nD84BX^;Q!G5g_j3kb?RT8`*SS6wD`Fw zP{@P-p0NT&qG~Mu6ajj))s&5IG36yFGq(cCR7X!|CkY6-r}|tX)K96h;-@M)E~O4R z-uH9Zr&WC@of2&xL7TOf51!6femeDuBeVW6qfSxrV(giIxG?HK**Vpwj^1pB*kSVC zSE^6FdaB)siMPa1LDy6PfRP?Kj?tA!@RQd~31Oj|FPn+@-|MM&5gdv!-jeIQyF z(^J(+N!nG*gd|ZAnY$4kYd7;2RFD03n2YOy(cpwV=CMrA#rU16GGcf3+AFfweGVez iU?qK}*4)5z^b^2(KA+xOiBOlT dict: + image_size = get_image_size(target_image) + matches = template_matching_with_rotation(target_image, template_image) + matches["bboxes"] = [ + normalize_bbox(box, image_size) for box in matches["bboxes"] + ] + return matches + + +if __name__ == "__main__": + agent = va.agent.VisionAgent(verbose=True) + resp, tools = agent.chat_with_workflow( + [ + { + "role": "user", + "content": "Can you find the locations of the pid_template.png in pid.png and tell me if any are nearby 'NOTE 5'?", + } + ], + image="pid.png", + reference_data={"image": "pid_template.png"}, + visualize_output=True, + ) diff --git a/examples/custom_tools/template_match.py b/examples/custom_tools/template_match.py new file mode 100644 index 00000000..1dd9fbe0 --- /dev/null +++ b/examples/custom_tools/template_match.py @@ -0,0 +1,96 @@ +import cv2 +import numpy as np +import torch +from torchvision.ops import nms + + +def rotate_image(mat, angle): + """ + Rotates an image (angle in degrees) and expands image to avoid cropping + """ + + height, width = mat.shape[:2] # image shape has 3 dimensions + image_center = ( + width / 2, + height / 2, + ) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape + + rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + + # rotation calculates the cos and sin, taking absolutes of those. + abs_cos = abs(rotation_mat[0, 0]) + abs_sin = abs(rotation_mat[0, 1]) + + # find the new width and height bounds + bound_w = int(height * abs_sin + width * abs_cos) + bound_h = int(height * abs_cos + width * abs_sin) + + # subtract old image center (bringing image back to origo) and adding the new image center coordinates + rotation_mat[0, 2] += bound_w / 2 - image_center[0] + rotation_mat[1, 2] += bound_h / 2 - image_center[1] + + # rotate image with the new bounds and translated rotation matrix + rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h)) + return rotated_mat + + +def template_matching_with_rotation( + main_image_path: str, + template_path: str, + max_rotation: int = 360, + step: int = 90, + threshold: float = 0.75, + visualize: bool = False, +) -> dict: + main_image = cv2.imread(main_image_path) + template = cv2.imread(template_path) + template_height, template_width = template.shape[:2] + + # Convert images to grayscale + main_image_gray = cv2.cvtColor(main_image, cv2.COLOR_BGR2GRAY) + template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY) + + boxes = [] + scores = [] + + for angle in range(0, max_rotation, step): + # Rotate the template + rotated_template = rotate_image(template_gray, angle) + + # Perform template matching + result = cv2.matchTemplate( + main_image_gray, + rotated_template, + cv2.TM_CCOEFF_NORMED, + ) + + y_coords, x_coords = np.where(result >= threshold) + for x, y in zip(x_coords, y_coords): + boxes.append( + (x, y, x + rotated_template.shape[1], y + rotated_template.shape[0]) + ) + scores.append(result[y, x]) + + indices = ( + nms( + torch.tensor(boxes).float(), + torch.tensor(scores).float(), + 0.2, + ) + .numpy() + .tolist() + ) + boxes = [boxes[i] for i in indices] + scores = [scores[i] for i in indices] + + if visualize: + # Draw a rectangle around the best match + for box in boxes: + cv2.rectangle(main_image, (box[0], box[1]), (box[2], box[3]), 255, 2) + + # Display the result + cv2.imshow("Best Match", main_image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + return {"bboxes": boxes, "scores": scores} diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index 12c21347..6de8d6c8 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -2,8 +2,10 @@ import tempfile import numpy as np +import pytest from PIL import Image +from vision_agent.tools import TOOLS, Tool, register_tool from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU @@ -65,3 +67,71 @@ def test_box_distance(): box1 = [0, 0, 2, 2] box2 = [1, 1, 3, 3] assert box_dist(box1, box2) == 0.0 + + +def test_register_tool(): + assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_" + + @register_tool + class TestTool(Tool): + name = "test_tool_" + description = "Test Tool" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + def __call__(self, prompt: str) -> str: + return prompt + + assert TOOLS[len(TOOLS) - 1]["name"] == "test_tool_" + + +def test_register_tool_incorrect(): + with pytest.raises(ValueError): + + @register_tool + class NoAttributes(Tool): + pass + + with pytest.raises(ValueError): + + @register_tool + class NoName(Tool): + description = "Test Tool" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + with pytest.raises(ValueError): + + @register_tool + class NoDescription(Tool): + name = "test_tool_" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + with pytest.raises(ValueError): + + @register_tool + class NoUsage(Tool): + name = "test_tool_" + description = "Test Tool" diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 6854ce43..93218e6c 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -377,6 +377,7 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]] "dinov_", "zero_shot_counting_", "visual_prompt_counting_", + "ocr_", ]: continue @@ -523,20 +524,20 @@ def chat_with_workflow( if image: question += f" Image name: {image}" if reference_data: - if not ( - "image" in reference_data - and ("mask" in reference_data or "bbox" in reference_data) - ): - raise ValueError( - f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}" - ) - visual_prompt_data = ( - f"Reference mask: {reference_data['mask']}" + question += ( + f" Reference image: {reference_data['image']}" + if "image" in reference_data + else "" + ) + question += ( + f" Reference mask: {reference_data['mask']}" if "mask" in reference_data - else f"Reference bbox: {reference_data['bbox']}" + else "" ) question += ( - f" Reference image: {reference_data['image']}, {visual_prompt_data}" + f" Reference bbox: {reference_data['bbox']}" + if "bbox" in reference_data + else "" ) reflections = "" diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 38bb08d4..67248156 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,6 +1,7 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT from .tools import ( # Counter, CLIP, + OCR, TOOLS, BboxArea, BboxIoU, @@ -11,9 +12,10 @@ GroundingDINO, GroundingSAM, ImageCaption, - ZeroShotCounting, - VisualPromptCounting, SegArea, SegIoU, Tool, + VisualPromptCounting, + ZeroShotCounting, + register_tool, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 6528c795..a661aeb0 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,8 +1,9 @@ +import io import logging import tempfile from abc import ABC from pathlib import Path -from typing import Any, Dict, List, Tuple, Union, cast +from typing import Any, Dict, List, Tuple, Type, Union, cast import numpy as np import requests @@ -11,10 +12,10 @@ from vision_agent.image_utils import ( convert_to_b64, + denormalize_bbox, get_image_size, - rle_decode, normalize_bbox, - denormalize_bbox, + rle_decode, ) from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey @@ -29,6 +30,9 @@ class Tool(ABC): description: str usage: Dict + def __call__(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + class NoOp(Tool): name = "noop_" @@ -858,6 +862,57 @@ def __call__(self, video_uri: str) -> List[Tuple[str, float]]: return result +class OCR(Tool): + name = "ocr_" + description = "'ocr_' extracts text from an image." + usage = { + "required_parameters": [ + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you extract the text from this image? Image name: image.png", + "parameters": {"image": "image.png"}, + }, + ], + } + _API_KEY = "land_sk_WVYwP00xA3iXely2vuar6YUDZ3MJT9yLX6oW5noUkwICzYLiDV" + _URL = "https://app.landing.ai/ocr/v1/detect-text" + + def __call__(self, image: str) -> dict: + pil_image = Image.open(image).convert("RGB") + image_size = pil_image.size[::-1] + image_buffer = io.BytesIO() + pil_image.save(image_buffer, format="PNG") + buffer_bytes = image_buffer.getvalue() + image_buffer.close() + + res = requests.post( + self._URL, + files={"images": buffer_bytes}, + data={"language": "en"}, + headers={"contentType": "multipart/form-data", "apikey": self._API_KEY}, + ) + if res.status_code != 200: + _LOGGER.error(f"Request failed: {res.text}") + raise ValueError(f"Request failed: {res.text}") + + data = res.json() + output: Dict[str, List] = {"labels": [], "bboxes": [], "scores": []} + for det in data[0]: + output["labels"].append(det["text"]) + box = [ + det["location"][0]["x"], + det["location"][0]["y"], + det["location"][2]["x"], + det["location"][2]["y"], + ] + box = normalize_bbox(box, image_size) + output["bboxes"].append(box) + output["scores"].append(round(det["score"], 2)) + return output + + class Calculator(Tool): r"""Calculator is a tool that can perform basic arithmetic operations.""" @@ -903,6 +958,7 @@ def __call__(self, equation: str) -> float: SegIoU, BboxContains, BoxDistance, + OCR, Calculator, ] ) @@ -910,6 +966,31 @@ def __call__(self, equation: str) -> float: } +def register_tool(tool: Type[Tool]) -> Type[Tool]: + r"""Add a tool to the list of available tools. + + Parameters: + tool: The tool to add. + """ + + if ( + not hasattr(tool, "name") + or not hasattr(tool, "description") + or not hasattr(tool, "usage") + ): + raise ValueError( + "The tool must have 'name', 'description' and 'usage' attributes." + ) + + TOOLS[len(TOOLS)] = { + "name": tool.name, + "description": tool.description, + "usage": tool.usage, + "class": tool, + } + return tool + + def _send_inference_request( payload: Dict[str, Any], endpoint_name: str ) -> Dict[str, Any]: