Skip to content

Commit

Permalink
Merge pull request #674 from pennlabs/pcx-docs-fix
Browse files Browse the repository at this point in the history
Fix docs override AutoSchema
  • Loading branch information
shiva-menta committed Sep 13, 2024
2 parents 67c4f50 + a31607d commit 00501b3
Showing 1 changed file with 128 additions and 0 deletions.
128 changes: 128 additions & 0 deletions backend/PennCourses/docs_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from textwrap import dedent

import jsonref
from django.db import models
from django.urls import get_resolver
from rest_framework import serializers
from rest_framework.fields import _UnvalidatedField
from rest_framework.permissions import IsAuthenticated
from rest_framework.renderers import JSONOpenAPIRenderer
from rest_framework.schemas.openapi import AutoSchema
from rest_framework.schemas.utils import is_list_view
from rest_framework.settings import api_settings


"""
Expand Down Expand Up @@ -992,6 +995,7 @@ def map_serializer(self, serializer):
"""

result = super().map_serializer(serializer)

properties = result["properties"]
model = None
if hasattr(serializer, "Meta") and hasattr(serializer.Meta, "model"):
Expand All @@ -1012,6 +1016,130 @@ def map_serializer(self, serializer):

return result

# Overrides, uses overridden method
# (https://www.django-rest-framework.org/api-guide/schemas/#map_field)
def map_field(self, field):

# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {"type": "array", "items": []}
if isinstance(field, serializers.Serializer):
data = self.map_serializer(field)
data["type"] = "object"
return data

# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {"type": "array", "items": self.map_field(field.child_relation)}
if isinstance(field, serializers.PrimaryKeyRelatedField):
if getattr(field, "pk_field", False):
return self.map_field(field=field.pk_field)
model = getattr(field.queryset, "model", None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
return {"type": "integer"}

# ChoiceFields (single and multiple).
# Q:
# - Is 'type' required?
# - can we determine the TYPE of a choicefield?
if isinstance(field, serializers.MultipleChoiceField):
return {"type": "array", "items": self.map_choicefield(field)}

if isinstance(field, serializers.ChoiceField):
return self.map_choicefield(field)

# ListField.
if isinstance(field, serializers.ListField):
mapping = {
"type": "array",
"items": {},
}
if not isinstance(field.child, _UnvalidatedField):
mapping["items"] = self.map_field(field.child)
return mapping

# DateField and DateTimeField type is string
if isinstance(field, serializers.DateField):
return {
"type": "string",
"format": "date",
}

if isinstance(field, serializers.DateTimeField):
return {
"type": "string",
"format": "date-time",
}

# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this
# specification."
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
if isinstance(field, serializers.EmailField):
return {"type": "string", "format": "email"}

if isinstance(field, serializers.URLField):
return {"type": "string", "format": "uri"}

if isinstance(field, serializers.UUIDField):
return {"type": "string", "format": "uuid"}

if isinstance(field, serializers.IPAddressField):
content = {
"type": "string",
}
if field.protocol != "both":
content["format"] = field.protocol
return content

if isinstance(field, serializers.DecimalField):
if getattr(field, "coerce_to_string", api_settings.COERCE_DECIMAL_TO_STRING):
content = {
"type": "string",
"format": "decimal",
}
else:
content = {"type": "number"}

if field.decimal_places:
content["multipleOf"] = float("." + (field.decimal_places - 1) * "0" + "1")
if field.max_whole_digits:
content["maximum"] = int(field.max_whole_digits * "9") + 1
content["minimum"] = -content["maximum"]
self._map_min_max(field, content)
return content

if isinstance(field, serializers.FloatField):
content = {
"type": "number",
}
self._map_min_max(field, content)
return content

if isinstance(field, serializers.IntegerField):
content = {"type": "integer"}
self._map_min_max(field, content)
# 2147483647 is max for int32_size, so we use int64 for format
if int(content.get("maximum", 0)) > 2147483647:
content["format"] = "int64"
if int(content.get("minimum", 0)) > 2147483647:
content["format"] = "int64"
return content

if isinstance(field, serializers.FileField):
return {"type": "string", "format": "binary"}

# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: "boolean",
serializers.JSONField: "object",
serializers.DictField: "object",
serializers.HStoreField: "object",
}
return {"type": FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, "string")}

# Helper method
def get_action(self, path, method):
"""
Expand Down

0 comments on commit 00501b3

Please sign in to comment.