Skip to content
This repository has been archived by the owner on Jan 11, 2021. It is now read-only.

Stable/0.3.x: Fix $ref usage and remove unreferenced models from output #688

Open
wants to merge 2 commits into
base: stable/0.3.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 26 additions & 4 deletions rest_framework_swagger/docgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class DocumentationGenerator(object):
# Response classes defined in docstrings
explicit_response_types = dict()

# Serializers referenced with $ref
ref_serializers = set()

def __init__(self, for_user=None):

# unauthenticated user is expected to be in the form 'module.submodule.Class' if a value is present
Expand Down Expand Up @@ -90,6 +93,9 @@ def get_operations(self, api, apis=None):
response_type = self._get_method_response_type(
doc_parser, serializer, introspector, method_introspector)

if response_type != 'object':
self.get_ref(response_type)

operation = {
'method': method_introspector.get_http_method(),
'summary': method_introspector.get_summary(),
Expand All @@ -107,6 +113,9 @@ def get_operations(self, api, apis=None):
inspector=method_introspector)

operation['parameters'] = parameters or []
for param in operation['parameters']:
if param['type'] not in BaseMethodIntrospector.PRIMITIVES:
self.get_ref(param['type'])

if response_messages:
operation['responseMessages'] = response_messages
Expand All @@ -123,7 +132,7 @@ def get_operations(self, api, apis=None):
# array response
if method_introspector.is_array_response:
operation['items'] = {
'$ref': operation['type']
'$ref': self.get_ref(operation['type'])
}
operation['type'] = 'array'

Expand Down Expand Up @@ -186,10 +195,21 @@ def get_models(self, apis):
# 'properties': data['fields'],
# }

models.update(self.explicit_response_types)
models.update(self.fields_serializers)

# Remove unused serializers
for name in list(models):
if name not in self.ref_serializers:
del models[name]

models.update(self.explicit_response_types)

return models

def get_ref(self, serializer):
self.ref_serializers.add(serializer)
return serializer

def _get_method_serializer(self, method_inspector):
"""
Returns serializer used in method.
Expand Down Expand Up @@ -395,15 +415,17 @@ def _get_serializer_fields(self, serializer):
if getattr(field, 'write_only', False):
field_serializer = "Write{}".format(field_serializer)

f['type'] = field_serializer
if not has_many:
f['$ref'] = self.get_ref(field_serializer)
del f['type']
else:
field_serializer = None
data_type = 'string'

if has_many:
f['type'] = 'array'
if field_serializer:
f['items'] = {'$ref': field_serializer}
f['items'] = {'$ref': self.get_ref(field_serializer)}
elif data_type in BaseMethodIntrospector.PRIMITIVES:
f['items'] = {'type': data_type}

Expand Down
18 changes: 16 additions & 2 deletions rest_framework_swagger/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ class SerializedAPI(ListCreateAPIView):
apis = urlparser.get_apis(url_patterns)

docgen = self.get_documentation_generator()
docgen.generate(apis)
models = docgen.get_models(apis)

self.assertIn('CommentSerializer', models)
Expand Down Expand Up @@ -651,7 +652,9 @@ class OtherSerializer(serializers.Serializer):
fields = docgen._get_serializer_fields(OtherSerializer)

self.assertEqual(1, len(fields['fields']))
self.assertEqual("SomeSerializer", fields['fields']['thing2']['type'])
self.assertIn("$ref", fields['fields']['thing2'])
self.assertNotIn("type", fields['fields']['thing2'])
self.assertEqual("SomeSerializer", fields['fields']['thing2']['$ref'])

def test_get_serializer_fields_api_with_nested_many(self):
class SomeSerializer(serializers.Serializer):
Expand Down Expand Up @@ -1357,6 +1360,14 @@ class HiddenSerializer(serializers.Serializer):
hidden = serializers.HiddenField(default=42)

class SerializedAPI(ListCreateAPIView):
"""
---
POST:
parameters:
- name: HiddenSerializer
type: WriteHiddenSerializer
paramType: body
"""
serializer_class = HiddenSerializer

class_introspector = self.make_introspector2(SerializedAPI)
Expand All @@ -1369,6 +1380,7 @@ class SerializedAPI(ListCreateAPIView):
urlparser = UrlParser()
generator = self.get_documentation_generator()
apis = urlparser.get_apis(url_patterns)
generator.generate(apis)
models = generator.get_models(apis)
self.assertIn("HiddenSerializer", models)
properties = models["HiddenSerializer"]['properties']
Expand Down Expand Up @@ -1398,6 +1410,7 @@ class SerializedAPI(ListCreateAPIView):
urlparser = UrlParser()
generator = self.get_documentation_generator()
apis = urlparser.get_apis(url_patterns)
generator.generate(apis)
models = generator.get_models(apis)
self.assertIn("KitchenSinkSerializer", models)
properties = models["KitchenSinkSerializer"]['properties']
Expand Down Expand Up @@ -2042,9 +2055,10 @@ def post(self, request, *args, **kwargs):
url_patterns = patterns('', url(r'my-api/', SerializedAPI.as_view()))
urlparser = UrlParser()
apis = urlparser.get_apis(url_patterns)
generator.generate(apis)
models = generator.get_models(apis)
self.assertIn('SerializedAPIPostResponse', models)
self.assertIn('WriteCommentSerializer', models)
self.assertNotIn('WriteCommentSerializer', models)
self.assertIn('CommentSerializer', models)
self.assertNotIn('QuerySerializer', models)
self.assertNotIn('WriteQuerySerializer', models)
Expand Down