Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up v2 api #1657

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions vulnerabilities/api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#


from django.db.models import Prefetch
from django_filters import rest_framework as filters
from drf_spectacular.utils import OpenApiParameter
from drf_spectacular.utils import extend_schema
Expand Down Expand Up @@ -135,6 +136,13 @@ def get_queryset(self):
if aliases:
queryset = queryset.filter(aliases__alias__in=aliases).distinct()

# Prefetch related fields to reduce queries in serializers
queryset = queryset.prefetch_related(
"aliases",
"weaknesses",
"vulnerabilityreference_set",
"severities",
)
return queryset

def get_serializer_class(self):
Expand All @@ -146,7 +154,6 @@ def list(self, request, *args, **kwargs):
queryset = self.get_queryset()
vulnerability_ids = request.query_params.getlist("vulnerability_id")

# If exactly one vulnerability_id is provided, return the serialized data
if len(vulnerability_ids) == 1:
try:
vulnerability = queryset.get(vulnerability_id=vulnerability_ids[0])
Expand All @@ -155,17 +162,14 @@ def list(self, request, *args, **kwargs):
except Vulnerability.DoesNotExist:
return Response({"detail": "Not found."}, status=404)

# Otherwise, return a dictionary of vulnerabilities keyed by vulnerability_id
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
data = serializer.data
vulnerabilities = {item["vulnerability_id"]: item for item in data}
vulnerabilities = {item["vulnerability_id"]: item for item in serializer.data}
return self.get_paginated_response({"vulnerabilities": vulnerabilities})

serializer = self.get_serializer(queryset, many=True)
data = serializer.data
vulnerabilities = {item["vulnerability_id"]: item for item in data}
vulnerabilities = {item["vulnerability_id"]: item for item in serializer.data}
return Response({"vulnerabilities": vulnerabilities})


Expand All @@ -174,17 +178,17 @@ class PackageV2Serializer(serializers.ModelSerializer):
risk_score = serializers.FloatField(read_only=True)
affected_by_vulnerabilities = serializers.SerializerMethodField()
fixing_vulnerabilities = serializers.SerializerMethodField()
next_non_vulnerable_version = serializers.CharField(read_only=True)
latest_non_vulnerable_version = serializers.CharField(read_only=True)
next_non_vulnerable_package = serializers.CharField(read_only=True)
latest_non_vulnerable_package = serializers.CharField(read_only=True)

class Meta:
model = Package
fields = [
"purl",
"affected_by_vulnerabilities",
"fixing_vulnerabilities",
"next_non_vulnerable_version",
"latest_non_vulnerable_version",
"next_non_vulnerable_package",
"latest_non_vulnerable_package",
"risk_score",
]

Expand Down Expand Up @@ -245,36 +249,56 @@ def get_queryset(self):
queryset = queryset.filter(
fixing_vulnerabilities__vulnerability_id=fixing_vulnerability
)
return queryset.with_is_vulnerable()

queryset = queryset.prefetch_related(
Prefetch(
"affected_by_vulnerabilities",
queryset=Vulnerability.objects.prefetch_related(
"aliases",
"weaknesses",
"vulnerabilityreference_set",
"severities",
),
),
Prefetch(
"fixing_vulnerabilities",
queryset=Vulnerability.objects.prefetch_related(
"aliases",
"weaknesses",
"vulnerabilityreference_set",
"severities",
),
),
)
return queryset

def list(self, request, *args, **kwargs):
queryset = self.get_queryset()

# Apply pagination
page = self.paginate_queryset(queryset)
if page is not None:
# Collect only vulnerabilities for packages in the current page
# Collect vulnerabilities from prefetched data
vulnerabilities = set()
for package in page:
vulnerabilities.update(package.affected_by_vulnerabilities.all())
vulnerabilities.update(package.fixing_vulnerabilities.all())

# Serialize the vulnerabilities with vulnerability_id as keys
# Serialize vulnerabilities
vulnerability_data = {
vuln.vulnerability_id: VulnerabilityV2Serializer(vuln).data
for vuln in vulnerabilities
}

# Serialize the current page of packages
# Serialize packages
serializer = self.get_serializer(page, many=True)
data = serializer.data

# Use 'self.get_paginated_response' to include pagination data
return self.get_paginated_response(
{"vulnerabilities": vulnerability_data, "packages": data}
)

# If pagination is not applied, collect vulnerabilities for all packages
# If no pagination
vulnerabilities = set()
for package in queryset:
vulnerabilities.update(package.affected_by_vulnerabilities.all())
Expand All @@ -284,11 +308,11 @@ def list(self, request, *args, **kwargs):
vuln.vulnerability_id: VulnerabilityV2Serializer(vuln).data for vuln in vulnerabilities
}

# Serialize all packages when pagination is not applied
serializer = self.get_serializer(queryset, many=True)
data = serializer.data
return Response({"vulnerabilities": vulnerability_data, "packages": data})


@extend_schema(
request=PackageurlListSerializer,
responses={200: PackageV2Serializer(many=True)},
Expand Down
78 changes: 66 additions & 12 deletions vulnerabilities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.models import UserManager
from django.core import exceptions
from django.core.cache import cache
from django.core.exceptions import ValidationError
from django.core.paginator import Paginator
from django.core.validators import MaxValueValidator
Expand Down Expand Up @@ -471,18 +472,19 @@ def get_fixed_by_package_versions(self, purl: PackageURL, fix=True):
Return a queryset of all the package versions of this `package` that fix any vulnerability.
If `fix` is False, return all package versions whether or not they fix a vulnerability.
"""
filter_dict = {
"name": purl.name,
"namespace": purl.namespace,
# TODO: Move this to Package object method
filters = {
"type": purl.type,
"namespace": purl.namespace,
"name": purl.name,
"qualifiers": purl.qualifiers,
"subpath": purl.subpath,
}

if fix:
filter_dict["fixing_vulnerabilities__isnull"] = False
filters["fixing_vulnerabilities__isnull"] = False

return Package.objects.filter(**filter_dict).distinct()
return Package.objects.filter(**filters).distinct()

def get_or_create_from_purl(self, purl: Union[PackageURL, str]):
"""
Expand Down Expand Up @@ -648,7 +650,8 @@ class Package(PackageURLMixin):
fixing_vulnerabilities = models.ManyToManyField(
to="Vulnerability",
through="FixingPackageRelatedVulnerability",
related_name="fixed_by_packages", # Unique related_name
# Unique related_name
related_name="fixed_by_packages",
)

package_url = models.CharField(
Expand Down Expand Up @@ -779,6 +782,10 @@ def version_class(self):
def current_version(self):
return self.version_class(self.version)

@property
def vulnerabilities(self):
return self.affected_by_vulnerabilities.all() | self.fixing_vulnerabilities.all()

@property
def next_non_vulnerable_version(self):
"""
Expand All @@ -787,10 +794,6 @@ def next_non_vulnerable_version(self):
next_non_vulnerable, _ = self.get_non_vulnerable_versions()
return next_non_vulnerable.version if next_non_vulnerable else None

@property
def vulnerabilities(self):
return self.affected_by_vulnerabilities.all() | self.fixing_vulnerabilities.all()

@property
def latest_non_vulnerable_version(self):
"""
Expand Down Expand Up @@ -823,6 +826,58 @@ def get_non_vulnerable_versions(self):

return None, None

@property
def next_non_vulnerable_package(self):
"""
Return the purl of the next non-vulnerable package version.
"""
next_non_vulnerable, _ = self.get_non_vulnerable_versions_v2()
return next_non_vulnerable.purl if next_non_vulnerable else None

@property
def latest_non_vulnerable_package(self):
"""
Return the purl of the latest non-vulnerable package version.
"""
_, latest_non_vulnerable = self.get_non_vulnerable_versions_v2()
return latest_non_vulnerable.purl if latest_non_vulnerable else None

def get_non_vulnerable_versions_v2(self):
"""
Return a tuple of three Package instance:
- first fixing version
- next non-vulnerable version
- latest non-vulnerable version
Return a tuple of (None, None) if there is no non-vulnerable version.
"""
cache_key = f"non_vulnerable_versions_{self.id}"
result = cache.get(cache_key)
if result is not None:
return result

non_vulnerable_versions = Package.objects.get_fixed_by_package_versions(
self, fix=False
).only_non_vulnerable()
sorted_versions = self.sort_by_version(non_vulnerable_versions)

later_non_vulnerable_versions = [
non_vuln_ver
for non_vuln_ver in sorted_versions
if self.version_class(non_vuln_ver.version) > self.current_version
]

if later_non_vulnerable_versions:
sorted_versions = self.sort_by_version(later_non_vulnerable_versions)
next_non_vulnerable = sorted_versions[0]
latest_non_vulnerable = sorted_versions[-1]
cache.set(
cache_key, (next_non_vulnerable, latest_non_vulnerable), timeout=3600
)
return next_non_vulnerable, latest_non_vulnerable

cache.set(cache_key, (None, None), timeout=3600)
return None, None

@property
def fixed_package_details(self):
"""
Expand Down Expand Up @@ -928,15 +983,14 @@ class PackageRelatedVulnerabilityBase(models.Model):
package = models.ForeignKey(
Package,
on_delete=models.CASCADE,
# related_name="%(class)s_set", # Unique related_name per subclass
)

vulnerability = models.ForeignKey(
Vulnerability,
on_delete=models.CASCADE,
# related_name="%(class)s_set", # Unique related_name per subclass
)

# TODO: Fix the help text
created_by = models.CharField(
max_length=100,
blank=True,
Expand Down
4 changes: 4 additions & 0 deletions vulnerabilities/tests/test_api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,7 @@ def test_lookup_with_invalid_purl_format(self):
self.assertEqual(response.status_code, status.HTTP_200_OK)
# No packages or vulnerabilities should be returned
self.assertEqual(len(response.data), 0)

def test_api_packages_single_with_purl_in_query_num_queries(self):
with self.assertNumQueries(13):
self.client.get(f"/api/v2/packages/?purl={self.package2.purl}", format="json")
Loading