Commit d54df8c4 by Carlton Gibson Committed by Tom Christie

Refactor schema generation to allow per-view customisation (#5354)

* Initial Refactor Step

* Add descriptor class
* call from generator
* proxy back to generator for implementation.

* Move `get_link` to descriptor

* Move `get_description` to descriptor

* Remove need for generator in get_description

* Move get_path_fields to descriptor

* Move `get_serializer_fields` to descriptor

* Move `get_pagination_fields` to descriptor

* Move `get_filter_fields` to descriptor

* Move `get_encoding` to descriptor.

* Pass just `url` from SchemaGenerator to descriptor

* Make `view` a property

Encapsulates check for a view instance.

* Adjust API Reference docs

* Add `ManualSchema` class

* Refactor to `ViewInspector` plus `AutoSchema`

The interface then is **just** `get_link()`

* Add `manual_fields` kwarg to AutoSchema

* Add schema decorator for FBVs

* Adjust comments

* Docs: Provide full params in example

Ref feedback https://github.com/encode/django-rest-framework/pull/5354/files/b52e372f8f936204753b17fe7c9bfb517b93a045#r137254795

* Add docstring for ViewInstpector.__get__ descriptor method.

Ref https://github.com/encode/django-rest-framework/pull/5354#discussion_r137265022

* Make `schemas` a package.

* Split generators, inspectors, views.

* Adjust imports

* Rename to EndpointEnumerator

* Adjust ManualSchema to take `fields`

… and `description`.

Allows `url` and `action` to remain dynamic

* Add package/module docstrings
parent 5ea810d5
...@@ -10,7 +10,14 @@ API schemas are a useful tool that allow for a range of use cases, including ...@@ -10,7 +10,14 @@ API schemas are a useful tool that allow for a range of use cases, including
generating reference documentation, or driving dynamic client libraries that generating reference documentation, or driving dynamic client libraries that
can interact with your API. can interact with your API.
## Representing schemas internally ## Install Core API
You'll need to install the `coreapi` package in order to add schema support
for REST framework.
pip install coreapi
## Internal schema representation
REST framework uses [Core API][coreapi] in order to model schema information in REST framework uses [Core API][coreapi] in order to model schema information in
a format-independent representation. This information can then be rendered a format-independent representation. This information can then be rendered
...@@ -68,9 +75,34 @@ has to be rendered into the actual bytes that are used in the response. ...@@ -68,9 +75,34 @@ has to be rendered into the actual bytes that are used in the response.
REST framework includes a renderer class for handling this media type, which REST framework includes a renderer class for handling this media type, which
is available as `renderers.CoreJSONRenderer`. is available as `renderers.CoreJSONRenderer`.
### Alternate schema formats
Other schema formats such as [Open API][open-api] ("Swagger"), Other schema formats such as [Open API][open-api] ("Swagger"),
[JSON HyperSchema][json-hyperschema], or [API Blueprint][api-blueprint] can [JSON HyperSchema][json-hyperschema], or [API Blueprint][api-blueprint] can also
also be supported by implementing a custom renderer class. be supported by implementing a custom renderer class that handles converting a
`Document` instance into a bytestring representation.
If there is a Core API codec package that supports encoding into the format you
want to use then implementing the renderer class can be done by using the codec.
#### Example
For example, the `openapi_codec` package provides support for encoding or decoding
to the Open API ("Swagger") format:
from rest_framework import renderers
from openapi_codec import OpenAPICodec
class SwaggerRenderer(renderers.BaseRenderer):
media_type = 'application/openapi+json'
format = 'swagger'
def render(self, data, media_type=None, renderer_context=None):
codec = OpenAPICodec()
return codec.dump(data)
## Schemas vs Hypermedia ## Schemas vs Hypermedia
...@@ -89,18 +121,121 @@ document, detailing both the current state and the available interactions. ...@@ -89,18 +121,121 @@ document, detailing both the current state and the available interactions.
Further information and support on building Hypermedia APIs with REST framework Further information and support on building Hypermedia APIs with REST framework
is planned for a future version. is planned for a future version.
--- ---
# Adding a schema # Creating a schema
You'll need to install the `coreapi` package in order to add schema support REST framework includes functionality for auto-generating a schema,
for REST framework. or allows you to specify one explicitly.
pip install coreapi ## Manual Schema Specification
REST framework includes functionality for auto-generating a schema, To manually specify a schema you create a Core API `Document`, similar to the
or allows you to specify one explicitly. There are a few different ways to example above.
add a schema to your API, depending on exactly what you need.
schema = coreapi.Document(
title='Flight Search API',
content={
...
}
)
## Automatic Schema Generation
Automatic schema generation is provided by the `SchemaGenerator` class.
`SchemaGenerator` processes a list of routed URL pattterns and compiles the
appropriately structured Core API Document.
Basic usage is just to provide the title for your schema and call
`get_schema()`:
generator = schemas.SchemaGenerator(title='Flight Search API')
schema = generator.get_schema()
### Per-View Schema Customisation
By default, view introspection is performed by an `AutoSchema` instance
accessible via the `schema` attribute on `APIView`. This provides the
appropriate Core API `Link` object for the view, request method and path:
auto_schema = view.schema
coreapi_link = auto_schema.get_link(...)
(In compiling the schema, `SchemaGenerator` calls `view.schema.get_link()` for
each view, allowed method and path.)
To customise the `Link` generation you may:
* Instantiate `AutoSchema` on your view with the `manual_fields` kwarg:
from rest_framework.views import APIView
from rest_framework.schemas import AutoSchema
class CustomView(APIView):
...
schema = AutoSchema(
manual_fields=[
coreapi.Field("extra_field", ...),
]
)
This allows extension for the most common case without subclassing.
* Provide an `AutoSchema` subclass with more complex customisation:
from rest_framework.views import APIView
from rest_framework.schemas import AutoSchema
class CustomSchema(AutoSchema):
def get_link(...):
# Implemet custom introspection here (or in other sub-methods)
class CustomView(APIView):
...
schema = CustomSchema()
This provides complete control over view introspection.
* Instantiate `ManualSchema` on your view, providing the Core API `Fields` for
the view explicitly:
from rest_framework.views import APIView
from rest_framework.schemas import ManualSchema
class CustomView(APIView):
...
schema = ManualSchema(fields=[
coreapi.Field(
"first_field",
required=True,
location="path",
schema=coreschema.String()
),
coreapi.Field(
"second_field",
required=True,
location="path",
schema=coreschema.String()
),
])
This allows manually specifying the schema for some views whilst maintaining
automatic generation elsewhere.
---
**Note**: For full details on `SchemaGenerator` plus the `AutoSchema` and
`ManualSchema` descriptors see the [API Reference below](#api-reference).
---
# Adding a schema view
There are a few different ways to add a schema view to your API, depending on
exactly what you need.
## The get_schema_view shortcut ## The get_schema_view shortcut
...@@ -342,38 +477,12 @@ A generic viewset with sections in the class docstring, using multi-line style. ...@@ -342,38 +477,12 @@ A generic viewset with sections in the class docstring, using multi-line style.
--- ---
# Alternate schema formats
In order to support an alternate schema format, you need to implement a custom renderer
class that handles converting a `Document` instance into a bytestring representation.
If there is a Core API codec package that supports encoding into the format you
want to use then implementing the renderer class can be done by using the codec.
## Example
For example, the `openapi_codec` package provides support for encoding or decoding
to the Open API ("Swagger") format:
from rest_framework import renderers
from openapi_codec import OpenAPICodec
class SwaggerRenderer(renderers.BaseRenderer):
media_type = 'application/openapi+json'
format = 'swagger'
def render(self, data, media_type=None, renderer_context=None):
codec = OpenAPICodec()
return codec.dump(data)
---
# API Reference # API Reference
## SchemaGenerator ## SchemaGenerator
A class that deals with introspecting your API views, which can be used to A class that walks a list of routed URL patterns, requests the schema for each view,
generate a schema. and collates the resulting CoreAPI Document.
Typically you'll instantiate `SchemaGenerator` with a single argument, like so: Typically you'll instantiate `SchemaGenerator` with a single argument, like so:
...@@ -406,39 +515,108 @@ Return a nested dictionary containing all the links that should be included in t ...@@ -406,39 +515,108 @@ Return a nested dictionary containing all the links that should be included in t
This is a good point to override if you want to modify the resulting structure of the generated schema, This is a good point to override if you want to modify the resulting structure of the generated schema,
as you can build a new dictionary with a different layout. as you can build a new dictionary with a different layout.
### get_link(self, path, method, view)
## AutoSchema
A class that deals with introspection of individual views for schema generation.
`AutoSchema` is attached to `APIView` via the `schema` attribute.
The `AutoSchema` constructor takes a single keyword argument `manual_fields`.
**`manual_fields`**: a `list` of `coreapi.Field` instances that will be added to
the generated fields. Generated fields with a matching `name` will be overwritten.
class CustomView(APIView):
schema = AutoSchema(manual_fields=[
coreapi.Field(
"my_extra_field",
required=True,
location="path",
schema=coreschema.String()
),
])
For more advanced customisation subclass `AutoSchema` to customise schema generation.
class CustomViewSchema(AutoSchema):
"""
Overrides `get_link()` to provide Custom Behavior X
"""
def get_link(self, path, method, base_url):
link = super().get_link(path, method, base_url)
# Do something to customize link here...
return link
class MyView(APIView):
schema = CustomViewSchema()
The following methods are available to override.
### get_link(self, path, method, base_url)
Returns a `coreapi.Link` instance corresponding to the given view. Returns a `coreapi.Link` instance corresponding to the given view.
This is the main entry point.
You can override this if you need to provide custom behaviors for particular views. You can override this if you need to provide custom behaviors for particular views.
### get_description(self, path, method, view) ### get_description(self, path, method)
Returns a string to use as the link description. By default this is based on the Returns a string to use as the link description. By default this is based on the
view docstring as described in the "Schemas as Documentation" section above. view docstring as described in the "Schemas as Documentation" section above.
### get_encoding(self, path, method, view) ### get_encoding(self, path, method)
Returns a string to indicate the encoding for any request body, when interacting Returns a string to indicate the encoding for any request body, when interacting
with the given view. Eg. `'application/json'`. May return a blank string for views with the given view. Eg. `'application/json'`. May return a blank string for views
that do not expect a request body. that do not expect a request body.
### get_path_fields(self, path, method, view): ### get_path_fields(self, path, method):
Return a list of `coreapi.Link()` instances. One for each path parameter in the URL. Return a list of `coreapi.Link()` instances. One for each path parameter in the URL.
### get_serializer_fields(self, path, method, view) ### get_serializer_fields(self, path, method)
Return a list of `coreapi.Link()` instances. One for each field in the serializer class used by the view. Return a list of `coreapi.Link()` instances. One for each field in the serializer class used by the view.
### get_pagination_fields(self, path, method, view ### get_pagination_fields(self, path, method)
Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method on any pagination class used by the view. Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method on any pagination class used by the view.
### get_filter_fields(self, path, method, view) ### get_filter_fields(self, path, method)
Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view. Return a list of `coreapi.Link()` instances, as returned by the `get_schema_fields()` method of any filter classes used by the view.
## ManualSchema
Allows manually providing a list of `coreapi.Field` instances for the schema,
plus an optional description.
class MyView(APIView):
schema = ManualSchema(fields=[
coreapi.Field(
"first_field",
required=True,
location="path",
schema=coreschema.String()
),
coreapi.Field(
"second_field",
required=True,
location="path",
schema=coreschema.String()
),
]
)
The `ManualSchema` constructor takes two arguments:
**`fields`**: A list of `coreapi.Field` instances. Required.
**`description`**: A string description. Optional.
--- ---
## Core API ## Core API
......
...@@ -184,6 +184,28 @@ The available decorators are: ...@@ -184,6 +184,28 @@ The available decorators are:
Each of these decorators takes a single argument which must be a list or tuple of classes. Each of these decorators takes a single argument which must be a list or tuple of classes.
## View schema decorator
To override the default schema generation for function based views you may use
the `@schema` decorator. This must come *after* (below) the `@api_view`
decorator. For example:
from rest_framework.decorators import api_view, schema
from rest_framework.schemas import AutoSchema
class CustomAutoSchema(AutoSchema):
def get_link(self, path, method, base_url):
# override view introspection here...
@api_view(['GET'])
@schema(CustomAutoSchema())
def view(request):
return Response({"message": "Hello for today! See you tomorrow!"})
This decorator takes a single `AutoSchema` instance, an `AutoSchema` subclass
instance or `ManualSchema` instance as described in the [Schemas documentation][schemas],
[cite]: http://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html [cite]: http://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html [cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
[settings]: settings.md [settings]: settings.md
......
...@@ -72,6 +72,9 @@ def api_view(http_method_names=None, exclude_from_schema=False): ...@@ -72,6 +72,9 @@ def api_view(http_method_names=None, exclude_from_schema=False):
WrappedAPIView.permission_classes = getattr(func, 'permission_classes', WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
APIView.permission_classes) APIView.permission_classes)
WrappedAPIView.schema = getattr(func, 'schema',
APIView.schema)
WrappedAPIView.exclude_from_schema = exclude_from_schema WrappedAPIView.exclude_from_schema = exclude_from_schema
return WrappedAPIView.as_view() return WrappedAPIView.as_view()
return decorator return decorator
...@@ -112,6 +115,13 @@ def permission_classes(permission_classes): ...@@ -112,6 +115,13 @@ def permission_classes(permission_classes):
return decorator return decorator
def schema(view_inspector):
def decorator(func):
func.schema = view_inspector
return func
return decorator
def detail_route(methods=None, **kwargs): def detail_route(methods=None, **kwargs):
""" """
Used to mark a method on a ViewSet that should be routed for detail requests. Used to mark a method on a ViewSet that should be routed for detail requests.
......
...@@ -26,7 +26,8 @@ from rest_framework import views ...@@ -26,7 +26,8 @@ from rest_framework import views
from rest_framework.compat import NoReverseMatch from rest_framework.compat import NoReverseMatch
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator, SchemaView from rest_framework.schemas import SchemaGenerator
from rest_framework.schemas.views import SchemaView
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
......
import re
from collections import OrderedDict
from importlib import import_module
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.core.exceptions import PermissionDenied
from django.db import models
from django.http import Http404
from django.utils import six
from django.utils.encoding import force_text, smart_text
from django.utils.translation import ugettext_lazy as _
from rest_framework import exceptions, renderers, serializers
from rest_framework.compat import (
RegexURLPattern, RegexURLResolver, coreapi, coreschema, uritemplate,
urlparse
)
from rest_framework.request import clone_request
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from rest_framework.utils.model_meta import _get_pk
from rest_framework.views import APIView
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
return coreschema.Array(
items=coreschema.String(),
title=title,
description=description
)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices.keys())),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices.keys()),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
s1 = min(split_paths)
s2 = max(split_paths)
common = s1
for i, c in enumerate(s1):
if c != s2[i]:
common = s1[:i]
break
return '/' + '/'.join(common)
def get_pk_name(model):
meta = model._meta.concrete_model._meta
return _get_pk(meta).name
def is_api_view(callback):
"""
Return `True` if the given view callback is a REST framework view/viewset.
"""
cls = getattr(callback, 'cls', None)
return (cls is not None) and issubclass(cls, APIView)
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
{'a': {'b': {'c': 123}}}
"""
for key in keys[:-1]:
if key not in target:
target[key] = {}
target = target[key]
target[keys[-1]] = value
def is_custom_action(action):
return action not in set([
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
])
def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'
if method.lower() != 'get':
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True
def endpoint_ordering(endpoint):
path, method, callback = endpoint
method_priority = {
'GET': 0,
'POST': 1,
'PUT': 2,
'PATCH': 3,
'DELETE': 4
}.get(method, 5)
return (path, method_priority)
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)
class EndpointInspector(object):
"""
A class to determine the available API endpoints that a project exposes.
"""
def __init__(self, patterns=None, urlconf=None):
if patterns is None:
if urlconf is None:
# Use the default Django URL conf
urlconf = settings.ROOT_URLCONF
# Load the given URLconf module
if isinstance(urlconf, six.string_types):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
self.patterns = patterns
def get_api_endpoints(self, patterns=None, prefix=''):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + pattern.regex.pattern
if isinstance(pattern, RegexURLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
endpoint = (path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
return api_endpoints
def get_path_from_regex(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
"""
path = simplify_regex(path_regex)
path = path.replace('<', '{').replace('>', '}')
return path
def should_include_endpoint(self, path, callback):
"""
Return `True` if the given endpoint should be included.
"""
if not is_api_view(callback):
return False # Ignore anything except REST framework views.
if path.endswith('.{format}') or path.endswith('.{format}/'):
return False # Ignore .json style URLs.
return True
def get_allowed_methods(self, callback):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
if hasattr(callback, 'actions'):
actions = set(callback.actions.keys())
http_method_names = set(callback.cls.http_method_names)
return [method.upper() for method in actions & http_method_names]
return [
method for method in
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
]
class SchemaGenerator(object):
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
endpoint_inspector_cls = EndpointInspector
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
# 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas.
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
if url and not url.endswith('/'):
url += '/'
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
self.patterns = patterns
self.urlconf = urlconf
self.title = title
self.description = description
self.url = url
self.endpoints = None
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = inspector.get_api_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = OrderedDict()
# Generate (path, method, view) given (path, method, callback).
paths = []
view_endpoints = []
for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
if getattr(view, 'exclude_from_schema', False):
continue
path = self.coerce_path(path, method, view)
paths.append(path)
view_endpoints.append((path, method, view))
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = self.get_link(path, method, view)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
return links
# Methods used when we generate a view instance from the raw callback...
def determine_path_prefix(self, paths):
"""
Given a list of all paths, return the common prefix which should be
discounted when generating a schema structure.
This will be the longest common string that does not include that last
component of the URL, or the last component before a path parameter.
For example:
/api/v1/users/
/api/v1/users/{pk}/
The path prefix is '/api/v1/'
"""
prefixes = []
for path in paths:
components = path.strip('/').split('/')
initial_components = []
for component in components:
if '{' in component:
break
initial_components.append(component)
prefix = '/'.join(initial_components[:-1])
if not prefix:
# We can just break early in the case that there's at least
# one URL that doesn't have a path prefix.
return '/'
prefixes.append('/' + prefix + '/')
return common_path(prefixes)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
if request is not None:
view.request = clone_request(request, method)
return view
def has_view_permissions(self, path, method, view):
"""
Return `True` if the incoming request has the correct view permissions.
"""
if view.request is None:
return True
try:
view.check_permissions(view.request)
except (exceptions.APIException, Http404, PermissionDenied):
return False
return True
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
# Methods for generating each individual `Link` instance...
def get_link(self, path, method, view):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
fields = self.get_path_fields(path, method, view)
fields += self.get_serializer_fields(path, method, view)
fields += self.get_pagination_fields(path, method, view)
fields += self.get_filter_fields(path, method, view)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method, view)
else:
encoding = None
description = self.get_description(path, method, view)
if self.url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=urlparse.urljoin(self.url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_description(self, path, method, view):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return formatting.dedent(smart_text(method_docstring))
description = view.get_view_description()
lines = [line.strip() for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
header = getattr(view, 'action', method.lower())
if header in sections:
return sections[header].strip()
if header in self.coerce_method_names:
if self.coerce_method_names[header] in sections:
return sections[self.coerce_method_names[header]].strip()
return sections[''].strip()
def get_encoding(self, path, method, view):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
# Core API supports the following request encodings over HTTP...
supported_media_types = set((
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
))
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
def get_path_fields(self, path, method, view):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_text(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method, view):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method, view):
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def get_filter_fields(self, path, method, view):
if not is_list_view(path, method, view):
return []
if not getattr(view, 'filter_backends', None):
return []
fields = []
for filter_backend in view.filter_backends:
fields += filter_backend().get_schema_fields(view)
return fields
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]
class SchemaView(APIView):
_ignore_model_permissions = True
exclude_from_schema = True
renderer_classes = None
schema_generator = None
public = False
def __init__(self, *args, **kwargs):
super(SchemaView, self).__init__(*args, **kwargs)
if self.renderer_classes is None:
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes = [
renderers.CoreJSONRenderer,
renderers.BrowsableAPIRenderer,
]
else:
self.renderer_classes = [renderers.CoreJSONRenderer]
def get(self, request, *args, **kwargs):
schema = self.schema_generator.get_schema(request, self.public)
if schema is None:
raise exceptions.PermissionDenied()
return Response(schema)
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=SchemaGenerator):
"""
Return a schema view.
"""
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
)
return SchemaView.as_view(
renderer_classes=renderer_classes,
schema_generator=generator,
public=public,
)
"""
rest_framework.schemas
schemas:
__init__.py
generators.py # Top-down schema generation
inspectors.py # Per-endpoint view introspection
utils.py # Shared helper functions
views.py # Houses `SchemaView`, `APIView` subclass.
We expose a minimal "public" API directly from `schemas`. This covers the
basic use-cases:
from rest_framework.schemas import (
AutoSchema,
ManualSchema,
get_schema_view,
SchemaGenerator,
)
Other access should target the submodules directly
"""
from .generators import SchemaGenerator
from .inspectors import AutoSchema, ManualSchema # noqa
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=SchemaGenerator):
"""
Return a schema view.
"""
# Avoid import cycle on APIView
from .views import SchemaView
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
)
return SchemaView.as_view(
renderer_classes=renderer_classes,
schema_generator=generator,
public=public,
)
"""
generators.py # Top-down schema generation
See schemas.__init__.py for package overview.
"""
from collections import OrderedDict
from importlib import import_module
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.utils import six
from rest_framework import exceptions
from rest_framework.compat import (
RegexURLPattern, RegexURLResolver, coreapi, coreschema
)
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
from rest_framework.utils.model_meta import _get_pk
from .utils import is_list_view
def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
s1 = min(split_paths)
s2 = max(split_paths)
common = s1
for i, c in enumerate(s1):
if c != s2[i]:
common = s1[:i]
break
return '/' + '/'.join(common)
def get_pk_name(model):
meta = model._meta.concrete_model._meta
return _get_pk(meta).name
def is_api_view(callback):
"""
Return `True` if the given view callback is a REST framework view/viewset.
"""
# Avoid import cycle on APIView
from rest_framework.views import APIView
cls = getattr(callback, 'cls', None)
return (cls is not None) and issubclass(cls, APIView)
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
{'a': {'b': {'c': 123}}}
"""
for key in keys[:-1]:
if key not in target:
target[key] = {}
target = target[key]
target[keys[-1]] = value
def is_custom_action(action):
return action not in set([
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
])
def endpoint_ordering(endpoint):
path, method, callback = endpoint
method_priority = {
'GET': 0,
'POST': 1,
'PUT': 2,
'PATCH': 3,
'DELETE': 4
}.get(method, 5)
return (path, method_priority)
class EndpointEnumerator(object):
"""
A class to determine the available API endpoints that a project exposes.
"""
def __init__(self, patterns=None, urlconf=None):
if patterns is None:
if urlconf is None:
# Use the default Django URL conf
urlconf = settings.ROOT_URLCONF
# Load the given URLconf module
if isinstance(urlconf, six.string_types):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
self.patterns = patterns
def get_api_endpoints(self, patterns=None, prefix=''):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + pattern.regex.pattern
if isinstance(pattern, RegexURLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
endpoint = (path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
return api_endpoints
def get_path_from_regex(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
"""
path = simplify_regex(path_regex)
path = path.replace('<', '{').replace('>', '}')
return path
def should_include_endpoint(self, path, callback):
"""
Return `True` if the given endpoint should be included.
"""
if not is_api_view(callback):
return False # Ignore anything except REST framework views.
if path.endswith('.{format}') or path.endswith('.{format}/'):
return False # Ignore .json style URLs.
return True
def get_allowed_methods(self, callback):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
if hasattr(callback, 'actions'):
actions = set(callback.actions.keys())
http_method_names = set(callback.cls.http_method_names)
return [method.upper() for method in actions & http_method_names]
return [
method for method in
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
]
class SchemaGenerator(object):
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
endpoint_inspector_cls = EndpointEnumerator
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
# 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas.
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
if url and not url.endswith('/'):
url += '/'
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
self.patterns = patterns
self.urlconf = urlconf
self.title = title
self.description = description
self.url = url
self.endpoints = None
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = inspector.get_api_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = OrderedDict()
# Generate (path, method, view) given (path, method, callback).
paths = []
view_endpoints = []
for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
if getattr(view, 'exclude_from_schema', False):
continue
path = self.coerce_path(path, method, view)
paths.append(path)
view_endpoints.append((path, method, view))
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
return links
# Methods used when we generate a view instance from the raw callback...
def determine_path_prefix(self, paths):
"""
Given a list of all paths, return the common prefix which should be
discounted when generating a schema structure.
This will be the longest common string that does not include that last
component of the URL, or the last component before a path parameter.
For example:
/api/v1/users/
/api/v1/users/{pk}/
The path prefix is '/api/v1/'
"""
prefixes = []
for path in paths:
components = path.strip('/').split('/')
initial_components = []
for component in components:
if '{' in component:
break
initial_components.append(component)
prefix = '/'.join(initial_components[:-1])
if not prefix:
# We can just break early in the case that there's at least
# one URL that doesn't have a path prefix.
return '/'
prefixes.append('/' + prefix + '/')
return common_path(prefixes)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
if request is not None:
view.request = clone_request(request, method)
return view
def has_view_permissions(self, path, method, view):
"""
Return `True` if the incoming request has the correct view permissions.
"""
if view.request is None:
return True
try:
view.check_permissions(view.request)
except (exceptions.APIException, Http404, PermissionDenied):
return False
return True
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1:
action = self.default_mapping[method.lower()]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]
"""
inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview.
"""
import re
from collections import OrderedDict
from django.db import models
from django.utils.encoding import force_text, smart_text
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from rest_framework.compat import coreapi, coreschema, uritemplate, urlparse
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from .utils import is_list_view
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
return coreschema.Array(
items=coreschema.String(),
title=title,
description=description
)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices.keys())),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices.keys()),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)
class ViewInspector(object):
"""
Descriptor class on APIView.
Provide subclass for per-view schema generation
"""
def __get__(self, instance, owner):
"""
Enables `ViewInspector` as a Python _Descriptor_.
This is how `view.schema` knows about `view`.
`__get__` is called when the descriptor is accessed on the owner.
(That will be when view.schema is called in our case.)
`owner` is always the owner class. (An APIView, or subclass for us.)
`instance` is the view instance or `None` if accessed from the class,
rather than an instance.
See: https://docs.python.org/3/howto/descriptor.html for info on
descriptor usage.
"""
self.view = instance
return self
@property
def view(self):
"""View property."""
assert self._view is not None, "Schema generation REQUIRES a view instance. (Hint: you accessed `schema` from the view class rather than an instance.)"
return self._view
@view.setter
def view(self, value):
self._view = value
@view.deleter
def view(self):
self._view = None
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
raise NotImplementedError(".get_link() must be overridden.")
class AutoSchema(ViewInspector):
"""
Default inspector for APIView
Responsible for per-view instrospection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)
if self._manual_fields is not None:
by_name = {f.name: f for f in fields}
for f in self._manual_fields:
by_name[f.name] = f
fields = list(by_name.values())
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=urlparse.urljoin(base_url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_description(self, path, method):
"""
Determine a link description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return formatting.dedent(smart_text(method_docstring))
description = view.get_view_description()
lines = [line.strip() for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
header = getattr(view, 'action', method.lower())
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_text(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def get_filter_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
if not getattr(view, 'filter_backends', None):
return []
fields = []
for filter_backend in view.filter_backends:
fields += filter_backend().get_schema_fields(view)
return fields
def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
view = self.view
# Core API supports the following request encodings over HTTP...
supported_media_types = set((
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
))
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
class ManualSchema(ViewInspector):
"""
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description=''):
"""
Parameters:
* `fields`: list of `coreapi.Field` instances.
* `descripton`: String description for view. Optional.
"""
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=urlparse.urljoin(base_url, path),
action=method.lower(),
encoding=None,
fields=self._fields,
description=self._description
)
return self._link
"""
utils.py # Shared helper functions
See schemas.__init__.py for package overview.
"""
def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'
if method.lower() != 'get':
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True
"""
views.py # Houses `SchemaView`, `APIView` subclass.
See schemas.__init__.py for package overview.
"""
from rest_framework import exceptions, renderers
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
class SchemaView(APIView):
_ignore_model_permissions = True
exclude_from_schema = True
renderer_classes = None
schema_generator = None
public = False
def __init__(self, *args, **kwargs):
super(SchemaView, self).__init__(*args, **kwargs)
if self.renderer_classes is None:
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes = [
renderers.CoreJSONRenderer,
renderers.BrowsableAPIRenderer,
]
else:
self.renderer_classes = [renderers.CoreJSONRenderer]
def get(self, request, *args, **kwargs):
schema = self.schema_generator.get_schema(request, self.public)
if schema is None:
raise exceptions.PermissionDenied()
return Response(schema)
...@@ -19,6 +19,7 @@ from rest_framework import exceptions, status ...@@ -19,6 +19,7 @@ from rest_framework import exceptions, status
from rest_framework.compat import set_rollback from rest_framework.compat import set_rollback
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.schemas import AutoSchema
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting from rest_framework.utils import formatting
...@@ -113,6 +114,7 @@ class APIView(View): ...@@ -113,6 +114,7 @@ class APIView(View):
# Mark the view as being included or excluded from schema generation. # Mark the view as being included or excluded from schema generation.
exclude_from_schema = False exclude_from_schema = False
schema = AutoSchema()
@classmethod @classmethod
def as_view(cls, **initkwargs): def as_view(cls, **initkwargs):
......
...@@ -6,12 +6,13 @@ from rest_framework import status ...@@ -6,12 +6,13 @@ from rest_framework import status
from rest_framework.authentication import BasicAuthentication from rest_framework.authentication import BasicAuthentication
from rest_framework.decorators import ( from rest_framework.decorators import (
api_view, authentication_classes, parser_classes, permission_classes, api_view, authentication_classes, parser_classes, permission_classes,
renderer_classes, throttle_classes renderer_classes, schema, throttle_classes
) )
from rest_framework.parsers import JSONParser from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.schemas import AutoSchema
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView from rest_framework.views import APIView
...@@ -151,3 +152,17 @@ class DecoratorTestCase(TestCase): ...@@ -151,3 +152,17 @@ class DecoratorTestCase(TestCase):
response = view(request) response = view(request)
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
def test_schema(self):
"""
Checks CustomSchema class is set on view
"""
class CustomSchema(AutoSchema):
pass
@api_view(['GET'])
@schema(CustomSchema())
def view(request):
return Response({})
assert isinstance(view.cls.schema, CustomSchema)
import unittest import unittest
import pytest
from django.conf.urls import include, url from django.conf.urls import include, url
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
...@@ -10,7 +11,9 @@ from rest_framework.compat import coreapi, coreschema ...@@ -10,7 +11,9 @@ from rest_framework.compat import coreapi, coreschema
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator, get_schema_view from rest_framework.schemas import (
AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
)
from rest_framework.test import APIClient, APIRequestFactory from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
...@@ -496,3 +499,81 @@ class Test4605Regression(TestCase): ...@@ -496,3 +499,81 @@ class Test4605Regression(TestCase):
'/auth/convert-token/' '/auth/convert-token/'
]) ])
assert prefix == '/' assert prefix == '/'
class TestDescriptor(TestCase):
def test_apiview_schema_descriptor(self):
view = APIView()
assert hasattr(view, 'schema')
assert isinstance(view.schema, AutoSchema)
def test_get_link_requires_instance(self):
descriptor = APIView.schema # Accessed from class
with pytest.raises(AssertionError):
descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
def test_manual_fields(self):
class CustomView(APIView):
schema = AutoSchema(manual_fields=[
coreapi.Field(
"my_extra_field",
required=True,
location="path",
schema=coreschema.String()
),
])
view = CustomView()
link = view.schema.get_link('/a/url/{id}/', 'GET', '')
fields = link.fields
assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields]
def test_view_with_manual_schema(self):
path = '/example'
method = 'get'
base_url = None
fields = [
coreapi.Field(
"first_field",
required=True,
location="path",
schema=coreschema.String()
),
coreapi.Field(
"second_field",
required=True,
location="path",
schema=coreschema.String()
),
coreapi.Field(
"third_field",
required=True,
location="path",
schema=coreschema.String()
),
]
description = "A test endpoint"
class CustomView(APIView):
"""
ManualSchema takes list of fields for endpoint.
- Provides url and action, which are always dynamic
"""
schema = ManualSchema(fields, description)
expected = coreapi.Link(
url=path,
action=method,
fields=fields,
description=description
)
view = CustomView()
link = view.schema.get_link(path, method, base_url)
assert link == expected
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment