Skip to content

Commit

Permalink
Merge branch 'main' into add-new-tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird authored Aug 26, 2024
2 parents a35f1d7 + 7875feb commit a75b892
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 87 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "vision-agent"
version = "0.2.109"
version = "0.2.111"
description = "Toolset for Vision Agent"
authors = ["Landing AI <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -78,6 +78,8 @@ line_length = 88
profile = "black"

[tool.mypy]
plugins = "pydantic.mypy"

exclude = "tests"
show_error_context = true
pretty = true
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DefaultImports:
code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions, florencev2_fine_tuning",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions",
]

@staticmethod
Expand Down
18 changes: 15 additions & 3 deletions vision_agent/clients/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from requests import Session
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,9 +37,22 @@ def post(self, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
response.raise_for_status()
result: Dict[str, Any] = response.json()
_LOGGER.info(json.dumps(result))
except (ConnectionError, Timeout, RequestException) as err:
_LOGGER.warning(f"Error: {err}.")
except json.JSONDecodeError:
resp_text = response.text
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.")
raise
return result

def get(self, url: str) -> Dict[str, Any]:
formatted_url = f"{self._base_endpoint}/{url}"
_LOGGER.info(f"Sending data to {formatted_url}")
try:
response = self._session.get(url=formatted_url, timeout=self._TIMEOUT)
response.raise_for_status()
result: Dict[str, Any] = response.json()
_LOGGER.info(json.dumps(result))
except json.JSONDecodeError:
resp_text = response.text
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.")
raise
return result
14 changes: 13 additions & 1 deletion vision_agent/clients/landing_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from typing import List
from uuid import UUID

from requests.exceptions import HTTPError

from vision_agent.clients.http import BaseHTTP
from vision_agent.tools.meta_tools_types import BboxInputBase64, PromptTask
from vision_agent.utils.type_defs import LandingaiAPIKey
from vision_agent.utils.exceptions import FineTuneModelNotFound
from vision_agent.tools.tools_types import BboxInputBase64, PromptTask, JobStatus


class LandingPublicAPI(BaseHTTP):
Expand All @@ -24,3 +27,12 @@ def launch_fine_tuning_job(
}
response = self.post(url, payload=data)
return UUID(response["jobId"])

def check_fine_tuning_job(self, job_id: UUID) -> JobStatus:
url = f"v1/agent/jobs/fine-tuning/{job_id}/status"
try:
get_job = self.get(url)
except HTTPError as err:
if err.response.status_code == 404:
raise FineTuneModelNotFound()
return JobStatus(get_job["status"])
10 changes: 7 additions & 3 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Callable, List, Optional

from .meta_tools import META_TOOL_DOCSTRING, florencev2_fine_tuning
from .meta_tools import (
META_TOOL_DOCSTRING,
)
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import (
TOOL_DESCRIPTIONS,
TOOL_DOCSTRING,
TOOLS,
TOOLS_DF,
TOOLS_INFO,
UTILITIES_DOCSTRING,
blip_image_caption,
clip,
Expand Down Expand Up @@ -56,15 +59,16 @@ def register_tool(imports: Optional[List] = None) -> Callable:
def decorator(tool: Callable) -> Callable:
import inspect

from .tools import get_tool_descriptions, get_tools_df
from .tools import get_tool_descriptions, get_tools_df, get_tools_info

global TOOLS, TOOLS_DF, TOOL_DESCRIPTIONS, TOOL_DOCSTRING
global TOOLS, TOOLS_DF, TOOL_DESCRIPTIONS, TOOL_DOCSTRING, TOOLS_INFO

if tool not in TOOLS:
TOOLS.append(tool)
TOOLS_DF = get_tools_df(TOOLS) # type: ignore
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
TOOLS_INFO = get_tools_info(TOOLS) # type: ignore

globals()[tool.__name__] = tool
if imports is not None:
Expand Down
44 changes: 2 additions & 42 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.utils.image_utils import convert_to_b64


# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent

CURRENT_FILE = None
Expand Down Expand Up @@ -384,51 +385,11 @@ def edit_file(file_path: str, start: int, end: int, content: str) -> str:

def get_tool_descriptions() -> str:
"""Returns a description of all the tools that `generate_vision_code` has access to.
Helpful for answerings questions about what types of vision tasks you can do with
Helpful for answering questions about what types of vision tasks you can do with
`generate_vision_code`."""
return TOOL_DESCRIPTIONS


def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
"""'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
to detect objects in an image based on a given dataset. It returns the fine
tuning job id.
Parameters:
bboxes (List[BboxInput]): A list of BboxInput containing the
image path, labels and bounding boxes.
task (PromptTask): The florencev2 fine-tuning task. The options are
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
Returns:
UUID: The fine tuning job id, this id will used to retrieve the fine
tuned model.
Example
-------
>>> fine_tuning_job_id = florencev2_fine_tuning(
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
"OBJECT_DETECTION"
)
"""
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
task_input = PromptTask[task]
fine_tuning_request = [
BboxInputBase64(
image=convert_to_b64(bbox_input.image_path),
filename=bbox_input.image_path.split("/")[-1],
labels=bbox_input.labels,
bboxes=bbox_input.bboxes,
)
for bbox_input in bboxes_input
]
landing_api = LandingPublicAPI()
return landing_api.launch_fine_tuning_job(
"florencev2", task_input, fine_tuning_request
)


META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
Expand All @@ -442,6 +403,5 @@ def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
search_dir,
search_file,
find_file,
florencev2_fine_tuning,
]
)
30 changes: 0 additions & 30 deletions vision_agent/tools/meta_tools_types.py

This file was deleted.

33 changes: 27 additions & 6 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from vision_agent.utils.type_defs import LandingaiAPIKey

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
_LND_API_URL = "https://api.landing.ai/v1/agent/model"
_LND_API_URL_v2 = "https://api.landing.ai/v1/tools"
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
_LND_BASE_URL = os.environ.get("LANDINGAI_URL", "https://api.landing.ai")
_LND_API_URL = f"{_LND_BASE_URL}/v1/agent/model"
_LND_API_URL_v2 = f"{_LND_BASE_URL}/v1/tools"


class ToolCallTrace(BaseModel):
Expand All @@ -32,7 +33,10 @@ def send_inference_request(
endpoint_name: str,
files: Optional[List[Tuple[Any, ...]]] = None,
v2: bool = False,
metadata_payload: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# TODO: runtime_tag and function_name should be metadata_payload and now included
# in the service payload
try:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
Expand Down Expand Up @@ -69,9 +73,13 @@ def send_inference_request(
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
raise RemoteToolCallFailed(
payload["function_name"], res.status_code, res.text
)
# TODO: function_name should be in metadata_payload
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]
raise RemoteToolCallFailed(function_name, res.status_code, res.text)

resp = res.json()
tool_call_trace.response = resp
Expand Down Expand Up @@ -149,3 +157,16 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
data["doc"].append(doc)

return pd.DataFrame(data) # type: ignore


def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
data: Dict[str, str] = {}

for func in funcs:
desc = func.__doc__
if desc is None:
desc = ""

data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"

return data
Loading

0 comments on commit a75b892

Please sign in to comment.