Skip to content

Commit

Permalink
added ollama vision agent coder
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 26, 2024
1 parent 456fb12 commit cfea046
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
6 changes: 5 additions & 1 deletion vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .agent import Agent
from .vision_agent import VisionAgent
from .vision_agent_coder import AzureVisionAgentCoder, VisionAgentCoder
from .vision_agent_coder import (
AzureVisionAgentCoder,
OllamaVisionAgentCoder,
VisionAgentCoder,
)
59 changes: 53 additions & 6 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
TEST_PLANS,
USER_REQ,
)
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
from vision_agent.utils.sim import AzureSim, Sim
from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
from vision_agent.utils.video import play_video

logging.basicConfig(stream=sys.stdout)
Expand Down Expand Up @@ -572,8 +572,8 @@ class VisionAgentCoder(Agent):
Example
-------
>>> from vision_agent.agent import VisionAgentCoder
>>> agent = VisionAgentCoder()
>>> import vision_agent as va
>>> agent = va.agent.VisionAgentCoder()
>>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
"""

Expand Down Expand Up @@ -841,6 +841,53 @@ def log_progress(self, data: Dict[str, Any]) -> None:
self.report_progress_callback(data)


class OllamaVisionAgentCoder(VisionAgentCoder):
"""VisionAgentCoder that uses Ollama models for planning, coding, testing.
Pre-requisites:
1. Run ollama pull llava for the LMM (or any other LMM model that can consume images)
2. Run ollama pull mxbai-embed-large for the embedding similarity model
Example
-------
>>> image vision_agent as va
>>> agent = va.agent.OllamaVisionAgentCoder()
>>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
"""

def __init__(
self,
planner: Optional[LMM] = None,
coder: Optional[LMM] = None,
tester: Optional[LMM] = None,
debugger: Optional[LMM] = None,
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
) -> None:
super().__init__(
planner=(
OllamaLMM(temperature=0.0, json_mode=True)
if planner is None
else planner
),
coder=OllamaLMM(temperature=0.0) if coder is None else coder,
tester=OllamaLMM(temperature=0.0) if tester is None else tester,
debugger=(
OllamaLMM(temperature=0.0, json_mode=True)
if debugger is None
else debugger
),
tool_recommender=(
OllamaSim(T.TOOLS_DF, sim_key="desc")
if tool_recommender is None
else tool_recommender
),
verbosity=verbosity,
report_progress_callback=report_progress_callback,
)


class AzureVisionAgentCoder(VisionAgentCoder):
"""VisionAgentCoder that uses Azure OpenAI APIs for planning, coding, testing.
Expand All @@ -850,8 +897,8 @@ class AzureVisionAgentCoder(VisionAgentCoder):
Example
-------
>>> from vision_agent import AzureVisionAgentCoder
>>> agent = AzureVisionAgentCoder()
>>> import vision_agent as va
>>> agent = va.agent.AzureVisionAgentCoder()
>>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
"""

Expand Down

0 comments on commit cfea046

Please sign in to comment.