Skip to content

feat: add AI module for LLM interaction and a heuristic for checking code–docstring consistency #1121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions src/macaron/ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module provides a client for interacting with a Large Language Model (LLM)."""

import json
import logging
import re
from typing import Any, TypeVar

from pydantic import BaseModel, ValidationError

from macaron.config.defaults import defaults
from macaron.errors import ConfigurationError, HeuristicAnalyzerValueError
from macaron.util import send_post_http_raw

logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)


class AIClient:
"""A client for interacting with a Large Language Model."""

def __init__(self, system_prompt: str):
"""
Initialize the AI client.

The LLM configuration (enabled, API key, endpoint, model) is read from defaults.
"""
self.enabled, self.api_endpoint, self.api_key, self.model, self.context_window = self._load_defaults()
self.system_prompt = system_prompt.strip() or "You are a helpful AI assistant."
logger.info("AI client is %s.", "enabled" if self.enabled else "disabled")

def _load_defaults(self) -> tuple[bool, str, str, str, int]:
"""Load the LLM configuration from the defaults."""
section_name = "llm"
enabled, api_key, api_endpoint, model, context_window = False, "", "", "", 10000

if defaults.has_section(section_name):
section = defaults[section_name]
enabled = section.get("enabled", "False").strip().lower() == "true"
api_key = section.get("api_key", "").strip()
api_endpoint = section.get("api_endpoint", "").strip()
model = section.get("model", "").strip()
context_window = section.getint("context_window", 10000)

if enabled:
if not api_key:
raise ConfigurationError("API key for the AI client is not configured.")
if not api_endpoint:
raise ConfigurationError("API endpoint for the AI client is not configured.")
if not model:
raise ConfigurationError("Model for the AI client is not configured.")

return enabled, api_endpoint, api_key, model, context_window

def _validate_response(self, response_text: str, response_model: type[T]) -> T:
"""
Validate and parse the response from the LLM.

If raw JSON parsing fails, attempts to extract a JSON object from text.

Parameters
----------
response_text: str
The response text from the LLM.
response_model: Type[T]
The Pydantic model to validate the response against.

Returns
-------
bool
The validated Pydantic model instance.

Raises
------
HeuristicAnalyzerValueError
If there is an error in parsing or validating the response.
"""
try:
data = json.loads(response_text)
except json.JSONDecodeError:
logger.debug("Full JSON parse failed; trying to extract JSON from text.")
# If the response is not a valid JSON, try to extract a JSON object from the text.
match = re.search(r"\{.*\}", response_text, re.DOTALL)
if not match:
raise HeuristicAnalyzerValueError("No JSON object found in the LLM response.") from match
try:
data = json.loads(match.group(0))
except json.JSONDecodeError as e:
logger.error("Failed to parse extracted JSON: %s", e)
raise HeuristicAnalyzerValueError("Invalid JSON extracted from response.") from e

try:
return response_model.model_validate(data)
except ValidationError as e:
logger.error("Validation failed against response model: %s", e)
raise HeuristicAnalyzerValueError("Response JSON validation failed.") from e

def invoke(
self,
user_prompt: str,
temperature: float = 0.2,
max_tokens: int = 4000,
structured_output: type[T] | None = None,
timeout: int = 30,
) -> Any:
"""
Invoke the LLM and optionally validate its response.

Parameters
----------
user_prompt: str
The user prompt to send to the LLM.
temperature: float
The temperature for the LLM response.
max_tokens: int
The maximum number of tokens for the LLM response.
structured_output: Optional[Type[T]]
The Pydantic model to validate the response against. If provided, the response will be parsed and validated.
timeout: int
The timeout for the HTTP request in seconds.

Returns
-------
Optional[T | str]
The validated Pydantic model instance if `structured_output` is provided,
or the raw string response if not.

Raises
------
HeuristicAnalyzerValueError
If there is an error in parsing or validating the response.
"""
if not self.enabled:
raise ConfigurationError("AI client is not enabled. Please check your configuration.")

if len(user_prompt.split()) > self.context_window:
logger.warning(
"User prompt exceeds context window (%s words). "
"Truncating the prompt to fit within the context window.",
self.context_window,
)
user_prompt = " ".join(user_prompt.split()[: self.context_window])

headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
payload = {
"model": self.model,
"messages": [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}],
"temperature": temperature,
"max_tokens": max_tokens,
}

try:
response = send_post_http_raw(url=self.api_endpoint, json_data=payload, headers=headers, timeout=timeout)
if not response:
raise HeuristicAnalyzerValueError("No response received from the LLM.")
response_json = response.json()
usage = response_json.get("usage", {})

if usage:
usage_str = ", ".join(f"{key} = {value}" for key, value in usage.items())
logger.info("LLM call token usage: %s", usage_str)

message_content = response_json["choices"][0]["message"]["content"]

if not structured_output:
logger.debug("Returning raw message content (no structured output requested).")
return message_content
return self._validate_response(message_content, structured_output)

except Exception as e:
logger.error("Error during LLM invocation: %s", e)
raise HeuristicAnalyzerValueError(f"Failed to get or validate LLM response: {e}") from e
14 changes: 14 additions & 0 deletions src/macaron/config/defaults.ini
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,17 @@ custom_semgrep_rules_path =
# .yaml prefix. Note, this will be ignored if a path to custom semgrep rules is not provided. This list may not contain
# duplicated elements, meaning that ruleset names must be unique.
disabled_custom_rulesets =

[llm]
# The LLM configuration for Macaron.
# If enabled, the LLM will be used to analyze the results and provide insights.
enabled =
# The API key for the LLM service.
api_key =
# The API endpoint for the LLM service.
api_endpoint =
# The model to use for the LLM service.
model =
# The context window size for the LLM service.
# This is the maximum number of tokens that the LLM can process in a single request.
context_window = 10000
3 changes: 3 additions & 0 deletions src/macaron/malware_analyzer/pypi_heuristics/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class Heuristics(str, Enum):
#: Indicates that the package source code contains suspicious code patterns.
SUSPICIOUS_PATTERNS = "suspicious_patterns"

#: Indicates that the package contains some code that doesn't match the docstrings.
MATCHING_DOCSTRINGS = "matching_docstrings"


class HeuristicResult(str, Enum):
"""Result type indicating the outcome of a heuristic."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This analyzer checks the iconsistency of code with its docstrings."""

import logging
import time
from typing import Literal

from pydantic import BaseModel, Field

from macaron.ai import AIClient
from macaron.json_tools import JsonType
from macaron.malware_analyzer.pypi_heuristics.base_analyzer import BaseHeuristicAnalyzer
from macaron.malware_analyzer.pypi_heuristics.heuristics import HeuristicResult, Heuristics
from macaron.slsa_analyzer.package_registry.pypi_registry import PyPIPackageJsonAsset

logger: logging.Logger = logging.getLogger(__name__)


class Result(BaseModel):
"""The result after analysing the code with its docstrings."""

decision: Literal["consistent", "inconsistent"] = Field(
description=""" The final decision after analysing the code with its docstrings.
It can be either 'consistent' or 'inconsistent'."""
)
reason: str = Field(
description=" The reason for the decision made. It should be a short sentence explaining the decision."
)
inconsistent_code_part: str | None = Field(
default=None,
description=""" The specific part of the code that is inconsistent with the docstring.
Empty if the decision is 'consistent'.""",
)


class MatchingDocstringsAnalyzer(BaseHeuristicAnalyzer):
"""Check whether the docstrings and the code components are consistent."""

SYSTEM_PROMPT = """
You are a code master who can detect the inconsistency of the code with the docstrings that describes its components.
You will be given a python code file. Your task is to determine whether the code is consistent with the docstrings.
Wrap the output in `json` tags.
Your response must be a JSON object matching this schema:
{
"decision": "'consistent' or 'inconsistent'",
"reason": "A short explanation.", "inconsistent_code_part":
"The inconsistent code, or null."
}

/no_think
"""

REQUEST_INTERVAL = 0.5

def __init__(self) -> None:
super().__init__(
name="matching_docstrings_analyzer",
heuristic=Heuristics.MATCHING_DOCSTRINGS,
depends_on=None,
)
self.client = AIClient(system_prompt=self.SYSTEM_PROMPT.strip())

def analyze(self, pypi_package_json: PyPIPackageJsonAsset) -> tuple[HeuristicResult, dict[str, JsonType]]:
"""Analyze the package.

Parameters
----------
pypi_package_json: PyPIPackageJsonAsset
The PyPI package JSON asset object.

Returns
-------
tuple[HeuristicResult, dict[str, JsonType]]:
The result and related information collected during the analysis.
"""
if not self.client.enabled:
logger.warning("AI client is not enabled, skipping the matching docstrings analysis.")
return HeuristicResult.SKIP, {}

download_result = pypi_package_json.download_sourcecode()
if not download_result:
logger.warning("No source code found for the package, skipping the matching docstrings analysis.")
return HeuristicResult.SKIP, {}

for file, content in pypi_package_json.iter_sourcecode():
if file.endswith(".py"):
time.sleep(self.REQUEST_INTERVAL) # Respect the request interval to avoid rate limiting.
code_str = content.decode("utf-8", "ignore")
analysis_result = self.client.invoke(
user_prompt=code_str,
structured_output=Result,
)
if analysis_result and analysis_result.decision == "inconsistent":
return HeuristicResult.FAIL, {
"file": file,
"reason": analysis_result.reason,
"inconsistent part": analysis_result.inconsistent_code_part or "",
}
return HeuristicResult.PASS, {}
4 changes: 2 additions & 2 deletions src/macaron/slsa_analyzer/build_tool/gradle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022 - 2024, Oracle and/or its affiliates. All rights reserved.
# Copyright (c) 2022 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module contains the Gradle class which inherits BaseBuildTool.
Expand Down Expand Up @@ -122,7 +122,7 @@ def get_dep_analyzer(self) -> CycloneDxGradle:
raise DependencyAnalyzerError("No default dependency analyzer is found.")
if not DependencyAnalyzer.tool_valid(defaults.get("dependency.resolver", "dep_tool_gradle")):
raise DependencyAnalyzerError(
f"Dependency analyzer {defaults.get('dependency.resolver','dep_tool_gradle')} is not valid.",
f"Dependency analyzer {defaults.get('dependency.resolver', 'dep_tool_gradle')} is not valid.",
)

tool_name, tool_version = tuple(
Expand Down
4 changes: 2 additions & 2 deletions src/macaron/slsa_analyzer/build_tool/maven.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022 - 2024, Oracle and/or its affiliates. All rights reserved.
# Copyright (c) 2022 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module contains the Maven class which inherits BaseBuildTool.
Expand Down Expand Up @@ -116,7 +116,7 @@ def get_dep_analyzer(self) -> CycloneDxMaven:
raise DependencyAnalyzerError("No default dependency analyzer is found.")
if not DependencyAnalyzer.tool_valid(defaults.get("dependency.resolver", "dep_tool_maven")):
raise DependencyAnalyzerError(
f"Dependency analyzer {defaults.get('dependency.resolver','dep_tool_maven')} is not valid.",
f"Dependency analyzer {defaults.get('dependency.resolver', 'dep_tool_maven')} is not valid.",
)

tool_name, tool_version = tuple(
Expand Down
4 changes: 2 additions & 2 deletions src/macaron/slsa_analyzer/build_tool/pip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 - 2024, Oracle and/or its affiliates. All rights reserved.
# Copyright (c) 2023 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module contains the Pip class which inherits BaseBuildTool.
Expand Down Expand Up @@ -88,7 +88,7 @@ def get_dep_analyzer(self) -> DependencyAnalyzer:
tool_name = "cyclonedx_py"
if not DependencyAnalyzer.tool_valid(f"{tool_name}:{cyclonedx_version}"):
raise DependencyAnalyzerError(
f"Dependency analyzer {defaults.get('dependency.resolver','dep_tool_gradle')} is not valid.",
f"Dependency analyzer {defaults.get('dependency.resolver', 'dep_tool_gradle')} is not valid.",
)
return CycloneDxPython(
resources_path=global_config.resources_path,
Expand Down
4 changes: 2 additions & 2 deletions src/macaron/slsa_analyzer/build_tool/poetry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 - 2024, Oracle and/or its affiliates. All rights reserved.
# Copyright (c) 2023 - 2025, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module contains the Poetry class which inherits BaseBuildTool.
Expand Down Expand Up @@ -126,7 +126,7 @@ def get_dep_analyzer(self) -> DependencyAnalyzer:
tool_name = "cyclonedx_py"
if not DependencyAnalyzer.tool_valid(f"{tool_name}:{cyclonedx_version}"):
raise DependencyAnalyzerError(
f"Dependency analyzer {defaults.get('dependency.resolver','dep_tool_gradle')} is not valid.",
f"Dependency analyzer {defaults.get('dependency.resolver', 'dep_tool_gradle')} is not valid.",
)
return CycloneDxPython(
resources_path=global_config.resources_path,
Expand Down
Loading
Loading