Commit 46a870c0 by Alexander Gaevsky Committed by Tom Christie

Fix schema generation for APIView, since it does not have get_serializer_class method. (#4285)

parent 3586c8a6
...@@ -112,7 +112,6 @@ class SchemaGenerator(object): ...@@ -112,7 +112,6 @@ class SchemaGenerator(object):
for pattern in patterns: for pattern in patterns:
path_regex = prefix + pattern.regex.pattern path_regex = prefix + pattern.regex.pattern
if isinstance(pattern, RegexURLPattern): if isinstance(pattern, RegexURLPattern):
path = self.get_path(path_regex) path = self.get_path(path_regex)
callback = pattern.callback callback = pattern.callback
...@@ -254,6 +253,9 @@ class SchemaGenerator(object): ...@@ -254,6 +253,9 @@ class SchemaGenerator(object):
fields = [] fields = []
if not (hasattr(view, 'get_serializer_class') and callable(getattr(view, 'get_serializer_class'))):
return []
serializer_class = view.get_serializer_class() serializer_class = view.get_serializer_class()
serializer = serializer_class() serializer = serializer_class()
......
...@@ -5,8 +5,11 @@ from django.test import TestCase, override_settings ...@@ -5,8 +5,11 @@ from django.test import TestCase, override_settings
from rest_framework import filters, pagination, permissions, serializers from rest_framework import filters, pagination, permissions, serializers
from rest_framework.compat import coreapi from rest_framework.compat import coreapi
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator
from rest_framework.test import APIClient from rest_framework.test import APIClient
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
...@@ -31,11 +34,24 @@ class ExampleViewSet(ModelViewSet): ...@@ -31,11 +34,24 @@ class ExampleViewSet(ModelViewSet):
serializer_class = ExampleSerializer serializer_class = ExampleSerializer
class ExampleView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
def get(self, request, *args, **kwargs):
return Response()
def post(self, request, *args, **kwargs):
return Response()
router = DefaultRouter(schema_title='Example API' if coreapi else None) router = DefaultRouter(schema_title='Example API' if coreapi else None)
router.register('example', ExampleViewSet, base_name='example') router.register('example', ExampleViewSet, base_name='example')
urlpatterns = [ urlpatterns = [
url(r'^', include(router.urls)) url(r'^', include(router.urls))
] ]
urlpatterns2 = [
url(r'^example-view/$', ExampleView.as_view(), name='example-view')
]
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
...@@ -135,3 +151,29 @@ class TestRouterGeneratedSchema(TestCase): ...@@ -135,3 +151,29 @@ class TestRouterGeneratedSchema(TestCase):
} }
) )
self.assertEqual(response.data, expected) self.assertEqual(response.data, expected)
@unittest.skipUnless(coreapi, 'coreapi is not installed')
class TestSchemaGenerator(TestCase):
def test_view(self):
schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2)
schema = schema_generator.get_schema()
expected = coreapi.Document(
url='',
title='Test View',
content={
'example-view': {
'create': coreapi.Link(
url='/example-view/',
action='post',
fields=[]
),
'read': coreapi.Link(
url='/example-view/',
action='get',
fields=[]
)
}
}
)
self.assertEquals(schema, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment