Skip to content

Commit

Permalink
Fix generate tools (#156)
Browse files Browse the repository at this point in the history
* fix docs

* fixed generate det

* fixed test case
  • Loading branch information
dillonalaird authored Jun 27, 2024
1 parent 214e3eb commit 600667d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ you. For example:

```python
>>> import vision_agent as va
>>> llm = va.llm.OpenAILMM()
>>> detector = llm.generate_detector("Can you build a jar detector for me?")
>>> lmm = va.lmm.OpenAILMM()
>>> detector = lmm.generate_detector("Can you build a jar detector for me?")
>>> detector(va.tools.load_image("jar.jpg"))
[{"labels": ["jar",],
"scores": [0.99],
Expand Down
12 changes: 8 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export OPENAI_API_KEY="your-api-key"
```

### Important Note on API Usage
Please be aware that using the API in this project requires you to have API credits (minimum of five US dollars). This is different from the OpenAI subscription used in this chatbot. If you don't have credit, further information can be found [here](https://github.com/moutasemalakkad/vision-agent/blob/f491252f3477103b7f517c45e6dea2f9d9f7abc4/docs/index.md#L207)
Please be aware that using the API in this project requires you to have API credits (minimum of five US dollars). This is different from the OpenAI subscription used in this chatbot. If you don't have credit, further information can be found [here](https://github.com/landing-ai/vision-agent?tab=readme-ov-file#how-to-get-started-with-openai-api-credits)

### Vision Agent
#### Basic Usage
Expand Down Expand Up @@ -137,8 +137,8 @@ you. For example:

```python
>>> import vision_agent as va
>>> llm = va.llm.OpenAILMM()
>>> detector = llm.generate_detector("Can you build a jar detector for me?")
>>> lmm = va.lmm.OpenAILMM()
>>> detector = lmm.generate_detector("Can you build a jar detector for me?")
>>> detector(va.tools.load_image("jar.jpg"))
[{"labels": ["jar",],
"scores": [0.99],
Expand Down Expand Up @@ -203,8 +203,12 @@ You can then run Vision Agent using the Azure OpenAI models:
import vision_agent as va
agent = va.agent.AzureVisionAgent()
```

******************************************************************************************************************************
#### To get started with API credits:

### Q&A

#### How to get started with OpenAI API credits

1. Visit the[OpenAI API platform](https://beta.openai.com/signup/) to sign up for an API key.
2. Follow the instructions to purchase and manage your API credits.
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ def test_generate_classifier(openai_lmm_mock): # noqa: F811
indirect=["openai_lmm_mock"],
)
def test_generate_detector(openai_lmm_mock): # noqa: F811
with patch("vision_agent.tools.grounding_dino") as grounding_dino_mock:
grounding_dino_mock.return_value = "test"
grounding_dino_mock.__name__ = "grounding_dino"
grounding_dino_mock.__doc__ = "grounding_dino"
with patch("vision_agent.tools.owl_v2") as owl_v2_mock:
owl_v2_mock.return_value = "test"
owl_v2_mock.__name__ = "owl_v2"
owl_v2_mock.__doc__ = "owl_v2"

lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
detector = lmm.generate_detector(prompt)
dummy_image = np.zeros((10, 10, 3)).astype(np.uint8)
detector(dummy_image)
assert grounding_dino_mock.call_args[0][0] == "cat"
assert owl_v2_mock.call_args[0][0] == "cat"


@pytest.mark.parametrize(
Expand Down
7 changes: 5 additions & 2 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def generate_classifier(self, question: str) -> Callable:
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
)

try:
Expand All @@ -179,14 +180,15 @@ def generate_classifier(self, question: str) -> Callable:
return lambda x: T.clip(x, params["prompt"])

def generate_detector(self, question: str) -> Callable:
api_doc = T.get_tool_documentation([T.grounding_dino])
api_doc = T.get_tool_documentation([T.owl_v2])
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
)

try:
Expand All @@ -199,7 +201,7 @@ def generate_detector(self, question: str) -> Callable:
)
raise ValueError("Failed to decode response")

return lambda x: T.grounding_dino(params["prompt"], x)
return lambda x: T.owl_v2(params["prompt"], x)

def generate_segmentor(self, question: str) -> Callable:
api_doc = T.get_tool_documentation([T.grounding_sam])
Expand All @@ -210,6 +212,7 @@ def generate_segmentor(self, question: str) -> Callable:
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
)

try:
Expand Down

0 comments on commit 600667d

Please sign in to comment.