Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generate tools #156

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ you. For example:

```python
>>> import vision_agent as va
>>> llm = va.llm.OpenAILMM()
>>> detector = llm.generate_detector("Can you build a jar detector for me?")
>>> lmm = va.lmm.OpenAILMM()
>>> detector = lmm.generate_detector("Can you build a jar detector for me?")
>>> detector(va.tools.load_image("jar.jpg"))
[{"labels": ["jar",],
"scores": [0.99],
Expand Down
12 changes: 8 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export OPENAI_API_KEY="your-api-key"
```

### Important Note on API Usage
Please be aware that using the API in this project requires you to have API credits (minimum of five US dollars). This is different from the OpenAI subscription used in this chatbot. If you don't have credit, further information can be found [here](https://github.com/moutasemalakkad/vision-agent/blob/f491252f3477103b7f517c45e6dea2f9d9f7abc4/docs/index.md#L207)
Please be aware that using the API in this project requires you to have API credits (minimum of five US dollars). This is different from the OpenAI subscription used in this chatbot. If you don't have credit, further information can be found [here](https://github.com/landing-ai/vision-agent?tab=readme-ov-file#how-to-get-started-with-openai-api-credits)

### Vision Agent
#### Basic Usage
Expand Down Expand Up @@ -137,8 +137,8 @@ you. For example:

```python
>>> import vision_agent as va
>>> llm = va.llm.OpenAILMM()
>>> detector = llm.generate_detector("Can you build a jar detector for me?")
>>> lmm = va.lmm.OpenAILMM()
>>> detector = lmm.generate_detector("Can you build a jar detector for me?")
>>> detector(va.tools.load_image("jar.jpg"))
[{"labels": ["jar",],
"scores": [0.99],
Expand Down Expand Up @@ -203,8 +203,12 @@ You can then run Vision Agent using the Azure OpenAI models:
import vision_agent as va
agent = va.agent.AzureVisionAgent()
```

******************************************************************************************************************************
#### To get started with API credits:

### Q&A

#### How to get started with OpenAI API credits

1. Visit the[OpenAI API platform](https://beta.openai.com/signup/) to sign up for an API key.
2. Follow the instructions to purchase and manage your API credits.
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ def test_generate_classifier(openai_lmm_mock): # noqa: F811
indirect=["openai_lmm_mock"],
)
def test_generate_detector(openai_lmm_mock): # noqa: F811
with patch("vision_agent.tools.grounding_dino") as grounding_dino_mock:
grounding_dino_mock.return_value = "test"
grounding_dino_mock.__name__ = "grounding_dino"
grounding_dino_mock.__doc__ = "grounding_dino"
with patch("vision_agent.tools.owl_v2") as owl_v2_mock:
owl_v2_mock.return_value = "test"
owl_v2_mock.__name__ = "owl_v2"
owl_v2_mock.__doc__ = "owl_v2"

lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
detector = lmm.generate_detector(prompt)
dummy_image = np.zeros((10, 10, 3)).astype(np.uint8)
detector(dummy_image)
assert grounding_dino_mock.call_args[0][0] == "cat"
assert owl_v2_mock.call_args[0][0] == "cat"


@pytest.mark.parametrize(
Expand Down
7 changes: 5 additions & 2 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def generate_classifier(self, question: str) -> Callable:
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
)

try:
Expand All @@ -179,14 +180,15 @@ def generate_classifier(self, question: str) -> Callable:
return lambda x: T.clip(x, params["prompt"])

def generate_detector(self, question: str) -> Callable:
api_doc = T.get_tool_documentation([T.grounding_dino])
api_doc = T.get_tool_documentation([T.owl_v2])
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
)

try:
Expand All @@ -199,7 +201,7 @@ def generate_detector(self, question: str) -> Callable:
)
raise ValueError("Failed to decode response")

return lambda x: T.grounding_dino(params["prompt"], x)
return lambda x: T.owl_v2(params["prompt"], x)

def generate_segmentor(self, question: str) -> Callable:
api_doc = T.get_tool_documentation([T.grounding_sam])
Expand All @@ -210,6 +212,7 @@ def generate_segmentor(self, question: str) -> Callable:
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
)

try:
Expand Down
Loading