diff --git a/fastapi/__init__.py b/fastapi/__init__.py
index a52bbccf6..2bb1b27c2 100644
--- a/fastapi/__init__.py
+++ b/fastapi/__init__.py
@@ -1,3 +1,3 @@
"""Fast API framework, fast high performance, fast to learn, fast to code"""
-__version__ = '0.1'
+__version__ = "0.1"
diff --git a/fastapi/applications.py b/fastapi/applications.py
index 2e1875aa1..3f5a45b73 100644
--- a/fastapi/applications.py
+++ b/fastapi/applications.py
@@ -1,61 +1,19 @@
-import typing
-import inspect
+from typing import Any, Callable, Dict, List, Type
from starlette.applications import Starlette
-from starlette.middleware.lifespan import LifespanMiddleware
+from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware.errors import ServerErrorMiddleware
-from starlette.exceptions import ExceptionMiddleware
-from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse
-from starlette.requests import Request
-from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+from starlette.middleware.lifespan import LifespanMiddleware
+from starlette.responses import JSONResponse
-from pydantic import BaseModel, BaseConfig, Schema
-from pydantic.utils import lenient_issubclass
-from pydantic.fields import Field
-from pydantic.schema import (
- field_schema,
- get_flat_models_from_models,
- get_flat_models_from_fields,
- get_model_name_map,
- schema,
- model_process_schema,
-)
-from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant
-from .pydantic_utils import jsonable_encoder
+from fastapi import routing
+from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
-def docs(openapi_url):
- return HTMLResponse(
- """
-
-
-
-
-
-
-
-
-
-
-
-
-
- """,
- media_type="text/html",
- )
+async def http_exception(request, exc: HTTPException):
+ print(exc)
+ return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
class FastAPI(Starlette):
@@ -67,24 +25,26 @@ class FastAPI(Starlette):
description: str = "",
version: str = "0.1.0",
openapi_url: str = "/openapi.json",
- docs_url: str = "/docs",
- **extra: typing.Dict[str, typing.Any],
+ swagger_ui_url: str = "/docs",
+ redoc_url: str = "/redoc",
+ **extra: Dict[str, Any],
) -> None:
self._debug = debug
- self.router = APIRouter()
+ self.router = routing.APIRouter()
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
)
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
- self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
+ self.schema_generator = None
self.template_env = self.load_template_env(template_directory)
self.title = title
self.description = description
self.version = version
self.openapi_url = openapi_url
- self.docs_url = docs_url
+ self.swagger_ui_url = swagger_ui_url
+ self.redoc_url = redoc_url
self.extra = extra
self.openapi_version = "3.0.2"
@@ -93,29 +53,52 @@ class FastAPI(Starlette):
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
- if self.docs_url:
+ if self.swagger_ui_url or self.redoc_url:
assert self.openapi_url, "The openapi_url is required for the docs"
+ self.setup()
- self.add_route(
- self.openapi_url,
- lambda req: JSONResponse(self.openapi()),
- include_in_schema=False,
- )
- self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False)
+ def setup(self):
+ if self.openapi_url:
+ self.add_route(
+ self.openapi_url,
+ lambda req: JSONResponse(
+ get_openapi(
+ title=self.title,
+ version=self.version,
+ openapi_version=self.openapi_version,
+ description=self.description,
+ routes=self.routes,
+ )
+ ),
+ include_in_schema=False,
+ )
+ if self.swagger_ui_url:
+ self.add_route(
+ self.swagger_ui_url,
+ lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"),
+ include_in_schema=False,
+ )
+ if self.redoc_url:
+ self.add_route(
+ self.redoc_url,
+ lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"),
+ include_in_schema=False,
+ )
+ self.add_exception_handler(HTTPException, http_exception)
def add_api_route(
self,
path: str,
- endpoint: typing.Callable,
- methods: typing.List[str] = None,
+ endpoint: Callable,
+ methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -126,7 +109,7 @@ class FastAPI(Starlette):
methods=methods,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -140,27 +123,27 @@ class FastAPI(Starlette):
def api_route(
self,
path: str,
- methods: typing.List[str] = None,
+ methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
- ) -> typing.Callable:
- def decorator(func: typing.Callable) -> typing.Callable:
+ ) -> Callable:
+ def decorator(func: Callable) -> Callable:
self.router.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -179,12 +162,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -193,7 +176,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -209,12 +192,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -223,7 +206,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -239,12 +222,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -253,7 +236,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -269,12 +252,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -283,7 +266,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -299,12 +282,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -313,7 +296,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -329,12 +312,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -343,7 +326,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -359,12 +342,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -373,7 +356,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -389,12 +372,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -403,7 +386,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -413,169 +396,3 @@ class FastAPI(Starlette):
response_code=response_code,
response_wrapper=response_wrapper,
)
-
- def openapi(self):
- info = {"title": self.title, "version": self.version}
- if self.description:
- info["description"] = self.description
- output = {"openapi": self.openapi_version, "info": info}
- components = {}
- paths = {}
- methods_with_body = set(("POST", "PUT"))
- body_fields_from_routes = []
- responses_from_routes = []
- ref_prefix = "#/components/schemas/"
- for route in self.routes:
- route: APIRoute
- if route.include_in_schema and isinstance(route, APIRoute):
- if route.request_body:
- assert isinstance(
- route.request_body, Field
- ), "A request body must be a Pydantic BaseModel or Field"
- body_fields_from_routes.append(route.request_body)
- if route.response_field:
- responses_from_routes.append(route.response_field)
- flat_models = get_flat_models_from_fields(
- body_fields_from_routes + responses_from_routes
- )
- model_name_map = get_model_name_map(flat_models)
- definitions = {}
- for model in flat_models:
- m_schema, m_definitions = model_process_schema(
- model, model_name_map=model_name_map, ref_prefix=ref_prefix
- )
- definitions.update(m_definitions)
- model_name = model_name_map[model]
- definitions[model_name] = m_schema
- validation_error_definition = {
- "title": "ValidationError",
- "type": "object",
- "properties": {
- "loc": {
- "title": "Location",
- "type": "array",
- "items": {"type": "string"},
- },
- "msg": {"title": "Message", "type": "string"},
- "type": {"title": "Error Type", "type": "string"},
- },
- "required": ["loc", "msg", "type"],
- }
- validation_error_response_definition = {
- "title": "HTTPValidationError",
- "type": "object",
- "properties": {
- "detail": {
- "title": "Detail",
- "type": "array",
- "items": {"$ref": ref_prefix + "ValidationError"},
- }
- },
- }
- for route in self.routes:
- route: APIRoute
- if route.include_in_schema and isinstance(route, APIRoute):
- path = paths.get(route.path, {})
- for method in route.methods:
- operation = {}
- if route.tags:
- operation["tags"] = route.tags
- if route.summary:
- operation["summary"] = route.summary
- if route.description:
- operation["description"] = route.description
- if route.operation_id:
- operation["operationId"] = route.operation_id
- else:
- operation["operationId"] = route.name
- if route.deprecated:
- operation["deprecated"] = route.deprecated
- parameters = []
- flat_dependant = get_flat_dependant(route.dependant)
- security_definitions = {}
- for security_scheme in flat_dependant.security_schemes:
- security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False)
- security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__
- security_definitions[security_name] = security_definition
- if security_definitions:
- components.setdefault("securitySchemes", {}).update(security_definitions)
- operation["security"] = [{name: []} for name in security_definitions]
- all_route_params = get_openapi_params(route.dependant)
- for param in all_route_params:
- if "ValidationError" not in definitions:
- definitions["ValidationError"] = validation_error_definition
- definitions[
- "HTTPValidationError"
- ] = validation_error_response_definition
- parameter = {
- "name": param.alias,
- "in": param.schema.in_.value,
- "required": param.required,
- "schema": field_schema(param, model_name_map={})[0],
- }
- if param.schema.description:
- parameter["description"] = param.schema.description
- if param.schema.deprecated:
- parameter["deprecated"] = param.schema.deprecated
- parameters.append(parameter)
- if parameters:
- operation["parameters"] = parameters
- if method in methods_with_body:
- request_body = getattr(route, "request_body", None)
- if request_body:
- assert isinstance(request_body, Field)
- body_schema, _ = field_schema(
- request_body,
- model_name_map=model_name_map,
- ref_prefix=ref_prefix,
- )
- required = request_body.required
- request_body_oai = {}
- if required:
- request_body_oai["required"] = required
- request_body_oai["content"] = {
- "application/json": {"schema": body_schema}
- }
- operation["requestBody"] = request_body_oai
- response_code = str(route.response_code)
- response_schema = {"type": "string"}
- if lenient_issubclass(route.response_wrapper, JSONResponse):
- response_media_type = "application/json"
- if route.response_field:
- response_schema, _ = field_schema(
- route.response_field,
- model_name_map=model_name_map,
- ref_prefix=ref_prefix,
- )
- else:
- response_schema = {}
- elif lenient_issubclass(route.response_wrapper, HTMLResponse):
- response_media_type = "text/html"
- else:
- response_media_type = "text/plain"
- content = {response_media_type: {"schema": response_schema}}
- operation["responses"] = {
- response_code: {
- "description": route.response_description,
- "content": content,
- }
- }
- if all_route_params or getattr(route, "request_body", None):
- operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
- "description": "Validation Error",
- "content": {
- "application/json": {
- "schema": {
- "$ref": ref_prefix + "HTTPValidationError"
- }
- }
- },
- }
- path[method.lower()] = operation
- paths[route.path] = path
- if definitions:
- components.setdefault("schemas", {}).update(definitions)
- if components:
- output["components"] = components
- output["paths"] = paths
- return output
diff --git a/fastapi/dependencies/__init__.py b/fastapi/dependencies/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py
new file mode 100644
index 000000000..ad9419db5
--- /dev/null
+++ b/fastapi/dependencies/models.py
@@ -0,0 +1,46 @@
+from typing import Any, Callable, Dict, List, Sequence, Tuple
+
+from starlette.concurrency import run_in_threadpool
+from starlette.requests import Request
+
+from fastapi.security.base import SecurityBase
+from pydantic import BaseConfig, Schema
+from pydantic.error_wrappers import ErrorWrapper
+from pydantic.errors import MissingError
+from pydantic.fields import Field, Required
+from pydantic.schema import get_annotation_from_schema
+
+param_supported_types = (str, int, float, bool)
+
+
+class SecurityRequirement:
+ def __init__(self, security_scheme: SecurityBase, scopes: Sequence[str] = None):
+ self.security_scheme = security_scheme
+ self.scopes = scopes
+
+
+class Dependant:
+ def __init__(
+ self,
+ *,
+ path_params: List[Field] = None,
+ query_params: List[Field] = None,
+ header_params: List[Field] = None,
+ cookie_params: List[Field] = None,
+ body_params: List[Field] = None,
+ dependencies: List["Dependant"] = None,
+ security_schemes: List[SecurityRequirement] = None,
+ name: str = None,
+ call: Callable = None,
+ request_param_name: str = None,
+ ) -> None:
+ self.path_params = path_params or []
+ self.query_params = query_params or []
+ self.header_params = header_params or []
+ self.cookie_params = cookie_params or []
+ self.body_params = body_params or []
+ self.dependencies = dependencies or []
+ self.security_requirements = security_schemes or []
+ self.request_param_name = request_param_name
+ self.name = name
+ self.call = call
diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py
new file mode 100644
index 000000000..6e86de5a5
--- /dev/null
+++ b/fastapi/dependencies/utils.py
@@ -0,0 +1,327 @@
+import asyncio
+import inspect
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Tuple
+
+from starlette.concurrency import run_in_threadpool
+from starlette.requests import Request
+
+from fastapi import params
+from fastapi.dependencies.models import Dependant, SecurityRequirement
+from fastapi.security.base import SecurityBase
+from fastapi.utils import get_path_param_names
+from pydantic import BaseConfig, Schema, create_model
+from pydantic.error_wrappers import ErrorWrapper
+from pydantic.errors import MissingError
+from pydantic.fields import Field, Required
+from pydantic.schema import get_annotation_from_schema
+from pydantic.utils import lenient_issubclass
+
+param_supported_types = (str, int, float, bool)
+
+
+def get_sub_dependant(*, param: inspect.Parameter, path: str):
+ depends: params.Depends = param.default
+ if depends.dependency:
+ dependency = depends.dependency
+ else:
+ dependency = param.annotation
+ assert callable(dependency)
+ sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
+ if isinstance(depends, params.Security) and isinstance(dependency, SecurityBase):
+ security_requirement = SecurityRequirement(
+ security_scheme=dependency, scopes=depends.scopes
+ )
+ sub_dependant.security_requirements.append(security_requirement)
+ return sub_dependant
+
+
+def get_flat_dependant(dependant: Dependant):
+ flat_dependant = Dependant(
+ path_params=dependant.path_params.copy(),
+ query_params=dependant.query_params.copy(),
+ header_params=dependant.header_params.copy(),
+ cookie_params=dependant.cookie_params.copy(),
+ body_params=dependant.body_params.copy(),
+ security_schemes=dependant.security_requirements.copy(),
+ )
+ for sub_dependant in dependant.dependencies:
+ if sub_dependant is dependant:
+ raise ValueError("recursion", dependant.dependencies)
+ flat_sub = get_flat_dependant(sub_dependant)
+ flat_dependant.path_params.extend(flat_sub.path_params)
+ flat_dependant.query_params.extend(flat_sub.query_params)
+ flat_dependant.header_params.extend(flat_sub.header_params)
+ flat_dependant.cookie_params.extend(flat_sub.cookie_params)
+ flat_dependant.body_params.extend(flat_sub.body_params)
+ flat_dependant.security_requirements.extend(flat_sub.security_requirements)
+ return flat_dependant
+
+
+def get_dependant(*, path: str, call: Callable, name: str = None):
+ path_param_names = get_path_param_names(path)
+ endpoint_signature = inspect.signature(call)
+ signature_params = endpoint_signature.parameters
+ dependant = Dependant(call=call, name=name)
+ for param_name in signature_params:
+ param = signature_params[param_name]
+ if isinstance(param.default, params.Depends):
+ sub_dependant = get_sub_dependant(param=param, path=path)
+ dependant.dependencies.append(sub_dependant)
+ for param_name in signature_params:
+ param = signature_params[param_name]
+ if (
+ (param.default == param.empty) or isinstance(param.default, params.Path)
+ ) and (param_name in path_param_names):
+ assert lenient_issubclass(
+ param.annotation, param_supported_types
+ ) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}"
+ param = signature_params[param_name]
+ add_param_to_fields(
+ param=param,
+ dependant=dependant,
+ default_schema=params.Path,
+ force_type=params.ParamTypes.path,
+ )
+ elif (param.default == param.empty or param.default is None) and (
+ param.annotation == param.empty
+ or lenient_issubclass(param.annotation, param_supported_types)
+ ):
+ add_param_to_fields(
+ param=param, dependant=dependant, default_schema=params.Query
+ )
+ elif isinstance(param.default, params.Param):
+ if param.annotation != param.empty:
+ assert lenient_issubclass(
+ param.annotation, param_supported_types
+ ), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
+ add_param_to_fields(
+ param=param, dependant=dependant, default_schema=params.Query
+ )
+ elif lenient_issubclass(param.annotation, Request):
+ dependant.request_param_name = param_name
+ elif not isinstance(param.default, params.Depends):
+ add_param_to_body_fields(param=param, dependant=dependant)
+ return dependant
+
+
+def add_param_to_fields(
+ *,
+ param: inspect.Parameter,
+ dependant: Dependant,
+ default_schema=params.Param,
+ force_type: params.ParamTypes = None,
+):
+ default_value = Required
+ if not param.default == param.empty:
+ default_value = param.default
+ if isinstance(default_value, params.Param):
+ schema = default_value
+ default_value = schema.default
+ if schema.in_ is None:
+ schema.in_ = default_schema.in_
+ if force_type:
+ schema.in_ = force_type
+ else:
+ schema = default_schema(default_value)
+ required = default_value == Required
+ annotation = Any
+ if not param.annotation == param.empty:
+ annotation = param.annotation
+ annotation = get_annotation_from_schema(annotation, schema)
+ field = Field(
+ name=param.name,
+ type_=annotation,
+ default=None if required else default_value,
+ alias=schema.alias or param.name,
+ required=required,
+ model_config=BaseConfig(),
+ class_validators=[],
+ schema=schema,
+ )
+ if schema.in_ == params.ParamTypes.path:
+ dependant.path_params.append(field)
+ elif schema.in_ == params.ParamTypes.query:
+ dependant.query_params.append(field)
+ elif schema.in_ == params.ParamTypes.header:
+ dependant.header_params.append(field)
+ else:
+ assert (
+ schema.in_ == params.ParamTypes.cookie
+ ), f"non-body parameters must be in path, query, header or cookie: {param.name}"
+ dependant.cookie_params.append(field)
+
+
+def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
+ default_value = Required
+ if not param.default == param.empty:
+ default_value = param.default
+ if isinstance(default_value, Schema):
+ schema = default_value
+ default_value = schema.default
+ else:
+ schema = Schema(default_value)
+ required = default_value == Required
+ annotation = get_annotation_from_schema(param.annotation, schema)
+ field = Field(
+ name=param.name,
+ type_=annotation,
+ default=None if required else default_value,
+ alias=schema.alias or param.name,
+ required=required,
+ model_config=BaseConfig,
+ class_validators=[],
+ schema=schema,
+ )
+ dependant.body_params.append(field)
+
+
+def is_coroutine_callable(call: Callable = None):
+ if not call:
+ return False
+ if inspect.isfunction(call):
+ return asyncio.iscoroutinefunction(call)
+ if inspect.isclass(call):
+ return False
+ call = getattr(call, "__call__", None)
+ if not call:
+ return False
+ return asyncio.iscoroutinefunction(call)
+
+
+async def solve_dependencies(
+ *, request: Request, dependant: Dependant, body: Dict[str, Any] = None
+):
+ values: Dict[str, Any] = {}
+ errors: List[ErrorWrapper] = []
+ for sub_dependant in dependant.dependencies:
+ sub_values, sub_errors = await solve_dependencies(
+ request=request, dependant=sub_dependant, body=body
+ )
+ if sub_errors:
+ return {}, errors
+ if sub_dependant.call and is_coroutine_callable(sub_dependant.call):
+ solved = await sub_dependant.call(**sub_values)
+ else:
+ solved = await run_in_threadpool(sub_dependant.call, **sub_values)
+ values[
+ sub_dependant.name
+ ] = solved # type: ignore # Sub-dependants always have a name
+ path_values, path_errors = request_params_to_args(
+ dependant.path_params, request.path_params
+ )
+ query_values, query_errors = request_params_to_args(
+ dependant.query_params, request.query_params
+ )
+ header_values, header_errors = request_params_to_args(
+ dependant.header_params, request.headers
+ )
+ cookie_values, cookie_errors = request_params_to_args(
+ dependant.cookie_params, request.cookies
+ )
+ values.update(path_values)
+ values.update(query_values)
+ values.update(header_values)
+ values.update(cookie_values)
+ errors = path_errors + query_errors + header_errors + cookie_errors
+ if dependant.body_params:
+ body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above
+ dependant.body_params, body
+ )
+ values.update(body_values)
+ errors.extend(body_errors)
+ if dependant.request_param_name:
+ values[dependant.request_param_name] = request
+ return values, errors
+
+
+def request_params_to_args(
+ required_params: List[Field], received_params: Dict[str, Any]
+) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
+ values = {}
+ errors = []
+ for field in required_params:
+ value = received_params.get(field.alias)
+ if value is None:
+ if field.required:
+ errors.append(
+ ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
+ )
+ else:
+ values[field.name] = deepcopy(field.default)
+ continue
+ v_, errors_ = field.validate(
+ value, values, loc=(field.schema.in_.value, field.alias)
+ )
+ if isinstance(errors_, ErrorWrapper):
+ errors.append(errors_)
+ elif isinstance(errors_, list):
+ errors.extend(errors_)
+ else:
+ values[field.name] = v_
+ return values, errors
+
+
+async def request_body_to_args(
+ required_params: List[Field], received_body: Dict[str, Any]
+) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
+ values = {}
+ errors = []
+ if required_params:
+ field = required_params[0]
+ embed = getattr(field.schema, "embed", None)
+ if len(required_params) == 1 and not embed:
+ received_body = {field.alias: received_body}
+ for field in required_params:
+ value = received_body.get(field.alias)
+ if value is None:
+ if field.required:
+ errors.append(
+ ErrorWrapper(
+ MissingError(), loc=("body", field.alias), config=BaseConfig
+ )
+ )
+ else:
+ values[field.name] = deepcopy(field.default)
+ continue
+ v_, errors_ = field.validate(value, values, loc=("body", field.alias))
+ if isinstance(errors_, ErrorWrapper):
+ errors.append(errors_)
+ elif isinstance(errors_, list):
+ errors.extend(errors_)
+ else:
+ values[field.name] = v_
+ return values, errors
+
+
+def get_body_field(*, dependant: Dependant, name: str):
+ flat_dependant = get_flat_dependant(dependant)
+ if not flat_dependant.body_params:
+ return None
+ first_param = flat_dependant.body_params[0]
+ embed = getattr(first_param.schema, "embed", None)
+ if len(flat_dependant.body_params) == 1 and not embed:
+ return first_param
+ model_name = "Body_" + name
+ BodyModel = create_model(model_name)
+ for f in flat_dependant.body_params:
+ BodyModel.__fields__[f.name] = f
+ required = any(True for f in flat_dependant.body_params if f.required)
+ if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
+ BodySchema = params.File
+ elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
+ BodySchema = params.Form
+ else:
+ BodySchema = params.Body
+
+ field = Field(
+ name="body",
+ type_=BodyModel,
+ default=None,
+ required=required,
+ model_config=BaseConfig,
+ class_validators=[],
+ alias="body",
+ schema=BodySchema(None),
+ )
+ return field
diff --git a/fastapi/pydantic_utils.py b/fastapi/encoders.py
similarity index 56%
rename from fastapi/pydantic_utils.py
rename to fastapi/encoders.py
index 8fc6589a4..95ce4479e 100644
--- a/fastapi/pydantic_utils.py
+++ b/fastapi/encoders.py
@@ -1,33 +1,44 @@
+from enum import Enum
from types import GeneratorType
from typing import Set
+
from pydantic import BaseModel
-from enum import Enum
from pydantic.json import pydantic_encoder
def jsonable_encoder(
- obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True,
+ obj,
+ include: Set[str] = None,
+ exclude: Set[str] = set(),
+ by_alias: bool = False,
+ include_none=True,
):
if isinstance(obj, BaseModel):
return jsonable_encoder(
- obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none
+ obj.dict(include=include, exclude=exclude, by_alias=by_alias),
+ include_none=include_none,
)
- elif isinstance(obj, Enum):
+ if isinstance(obj, Enum):
return obj.value
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dict):
return {
jsonable_encoder(
- key, by_alias=by_alias, include_none=include_none,
- ): jsonable_encoder(
- value, by_alias=by_alias, include_none=include_none,
- )
- for key, value in obj.items() if value is not None or include_none
+ key, by_alias=by_alias, include_none=include_none
+ ): jsonable_encoder(value, by_alias=by_alias, include_none=include_none)
+ for key, value in obj.items()
+ if value is not None or include_none
}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [
- jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none)
+ jsonable_encoder(
+ item,
+ include=include,
+ exclude=exclude,
+ by_alias=by_alias,
+ include_none=include_none,
+ )
for item in obj
]
return pydantic_encoder(obj)
diff --git a/fastapi/openapi/__init__.py b/fastapi/openapi/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/fastapi/openapi/constants.py b/fastapi/openapi/constants.py
new file mode 100644
index 000000000..1d94a3377
--- /dev/null
+++ b/fastapi/openapi/constants.py
@@ -0,0 +1,2 @@
+METHODS_WITH_BODY = set(("POST", "PUT"))
+REF_PREFIX = "#/components/schemas/"
diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py
new file mode 100644
index 000000000..e3d96bd7f
--- /dev/null
+++ b/fastapi/openapi/models.py
@@ -0,0 +1,347 @@
+import logging
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from pydantic import BaseModel, Schema as PSchema
+from pydantic.types import UrlStr
+
+try:
+ import pydantic.types.EmailStr
+ from pydantic.types import EmailStr
+except ImportError:
+ logging.warning(
+ "email-validator not installed, email fields will be treated as str"
+ )
+
+ class EmailStr(str):
+ pass
+
+
+class Contact(BaseModel):
+ name: Optional[str] = None
+ url: Optional[UrlStr] = None
+ email: Optional[EmailStr] = None
+
+
+class License(BaseModel):
+ name: str
+ url: Optional[UrlStr] = None
+
+
+class Info(BaseModel):
+ title: str
+ description: Optional[str] = None
+ termsOfService: Optional[str] = None
+ contact: Optional[Contact] = None
+ license: Optional[License] = None
+ version: str
+
+
+class ServerVariable(BaseModel):
+ enum: Optional[List[str]] = None
+ default: str
+ description: Optional[str] = None
+
+
+class Server(BaseModel):
+ url: UrlStr
+ description: Optional[str] = None
+ variables: Optional[Dict[str, ServerVariable]] = None
+
+
+class Reference(BaseModel):
+ ref: str = PSchema(..., alias="$ref")
+
+
+class Discriminator(BaseModel):
+ propertyName: str
+ mapping: Optional[Dict[str, str]] = None
+
+
+class XML(BaseModel):
+ name: Optional[str] = None
+ namespace: Optional[str] = None
+ prefix: Optional[str] = None
+ attribute: Optional[bool] = None
+ wrapped: Optional[bool] = None
+
+
+class ExternalDocumentation(BaseModel):
+ description: Optional[str] = None
+ url: UrlStr
+
+
+class SchemaBase(BaseModel):
+ ref: Optional[str] = PSchema(None, alias="$ref")
+ title: Optional[str] = None
+ multipleOf: Optional[float] = None
+ maximum: Optional[float] = None
+ exclusiveMaximum: Optional[float] = None
+ minimum: Optional[float] = None
+ exclusiveMinimum: Optional[float] = None
+ maxLength: Optional[int] = PSchema(None, gte=0)
+ minLength: Optional[int] = PSchema(None, gte=0)
+ pattern: Optional[str] = None
+ maxItems: Optional[int] = PSchema(None, gte=0)
+ minItems: Optional[int] = PSchema(None, gte=0)
+ uniqueItems: Optional[bool] = None
+ maxProperties: Optional[int] = PSchema(None, gte=0)
+ minProperties: Optional[int] = PSchema(None, gte=0)
+ required: Optional[List[str]] = None
+ enum: Optional[List[str]] = None
+ type: Optional[str] = None
+ allOf: Optional[List[Any]] = None
+ oneOf: Optional[List[Any]] = None
+ anyOf: Optional[List[Any]] = None
+ not_: Optional[List[Any]] = PSchema(None, alias="not")
+ items: Optional[Any] = None
+ properties: Optional[Dict[str, Any]] = None
+ additionalProperties: Optional[Union[bool, Any]] = None
+ description: Optional[str] = None
+ format: Optional[str] = None
+ default: Optional[Any] = None
+ nullable: Optional[bool] = None
+ discriminator: Optional[Discriminator] = None
+ readOnly: Optional[bool] = None
+ writeOnly: Optional[bool] = None
+ xml: Optional[XML] = None
+ externalDocs: Optional[ExternalDocumentation] = None
+ example: Optional[Any] = None
+ deprecated: Optional[bool] = None
+
+
+class Schema(SchemaBase):
+ allOf: Optional[List[SchemaBase]] = None
+ oneOf: Optional[List[SchemaBase]] = None
+ anyOf: Optional[List[SchemaBase]] = None
+ not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")
+ items: Optional[SchemaBase] = None
+ properties: Optional[Dict[str, SchemaBase]] = None
+ additionalProperties: Optional[Union[bool, SchemaBase]] = None
+
+
+class Example(BaseModel):
+ summary: Optional[str] = None
+ description: Optional[str] = None
+ value: Optional[Any] = None
+ externalValue: Optional[UrlStr] = None
+
+
+class ParameterInType(Enum):
+ query = "query"
+ header = "header"
+ path = "path"
+ cookie = "cookie"
+
+
+class Encoding(BaseModel):
+ contentType: Optional[str] = None
+ # Workaround OpenAPI recursive reference, using Any
+ headers: Optional[Dict[str, Union[Any, Reference]]] = None
+ style: Optional[str] = None
+ explode: Optional[bool] = None
+ allowReserved: Optional[bool] = None
+
+
+class MediaType(BaseModel):
+ schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
+ example: Optional[Any] = None
+ examples: Optional[Dict[str, Union[Example, Reference]]] = None
+ encoding: Optional[Dict[str, Encoding]] = None
+
+
+class ParameterBase(BaseModel):
+ description: Optional[str] = None
+ required: Optional[bool] = None
+ deprecated: Optional[bool] = None
+ # Serialization rules for simple scenarios
+ style: Optional[str] = None
+ explode: Optional[bool] = None
+ allowReserved: Optional[bool] = None
+ schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
+ example: Optional[Any] = None
+ examples: Optional[Dict[str, Union[Example, Reference]]] = None
+ # Serialization rules for more complex scenarios
+ content: Optional[Dict[str, MediaType]] = None
+
+
+class Parameter(ParameterBase):
+ name: str
+ in_: ParameterInType = PSchema(..., alias="in")
+
+
+class Header(ParameterBase):
+ pass
+
+
+# Workaround OpenAPI recursive reference
+class EncodingWithHeaders(Encoding):
+ headers: Optional[Dict[str, Union[Header, Reference]]] = None
+
+
+class RequestBody(BaseModel):
+ description: Optional[str] = None
+ content: Dict[str, MediaType]
+ required: Optional[bool] = None
+
+
+class Link(BaseModel):
+ operationRef: Optional[str] = None
+ operationId: Optional[str] = None
+ parameters: Optional[Dict[str, Union[Any, str]]] = None
+ requestBody: Optional[Union[Any, str]] = None
+ description: Optional[str] = None
+ server: Optional[Server] = None
+
+
+class Response(BaseModel):
+ description: str
+ headers: Optional[Dict[str, Union[Header, Reference]]] = None
+ content: Optional[Dict[str, MediaType]] = None
+ links: Optional[Dict[str, Union[Link, Reference]]] = None
+
+
+class Responses(BaseModel):
+ default: Response
+
+
+class Operation(BaseModel):
+ tags: Optional[List[str]] = None
+ summary: Optional[str] = None
+ description: Optional[str] = None
+ externalDocs: Optional[ExternalDocumentation] = None
+ operationId: Optional[str] = None
+ parameters: Optional[List[Union[Parameter, Reference]]] = None
+ requestBody: Optional[Union[RequestBody, Reference]] = None
+ responses: Union[Responses, Dict[Union[str], Response]]
+ # Workaround OpenAPI recursive reference
+ callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None
+ deprecated: Optional[bool] = None
+ security: Optional[List[Dict[str, List[str]]]] = None
+ servers: Optional[List[Server]] = None
+
+
+class PathItem(BaseModel):
+ ref: Optional[str] = PSchema(None, alias="$ref")
+ summary: Optional[str] = None
+ description: Optional[str] = None
+ get: Optional[Operation] = None
+ put: Optional[Operation] = None
+ post: Optional[Operation] = None
+ delete: Optional[Operation] = None
+ options: Optional[Operation] = None
+ head: Optional[Operation] = None
+ patch: Optional[Operation] = None
+ trace: Optional[Operation] = None
+ servers: Optional[List[Server]] = None
+ parameters: Optional[List[Union[Parameter, Reference]]] = None
+
+
+# Workaround OpenAPI recursive reference
+class OperationWithCallbacks(BaseModel):
+ callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
+
+
+class SecuritySchemeType(Enum):
+ apiKey = "apiKey"
+ http = "http"
+ oauth2 = "oauth2"
+ openIdConnect = "openIdConnect"
+
+
+class SecurityBase(BaseModel):
+ type_: SecuritySchemeType = PSchema(..., alias="type")
+ description: Optional[str] = None
+
+
+class APIKeyIn(Enum):
+ query = "query"
+ header = "header"
+ cookie = "cookie"
+
+
+class APIKey(SecurityBase):
+ type_ = PSchema(SecuritySchemeType.apiKey, alias="type")
+ in_: APIKeyIn = PSchema(..., alias="in")
+ name: str
+
+
+class HTTPBase(SecurityBase):
+ type_ = PSchema(SecuritySchemeType.http, alias="type")
+ scheme: str
+
+
+class HTTPBearer(HTTPBase):
+ scheme = "bearer"
+ bearerFormat: Optional[str] = None
+
+
+class OAuthFlow(BaseModel):
+ refreshUrl: Optional[str] = None
+ scopes: Dict[str, str] = {}
+
+
+class OAuthFlowImplicit(OAuthFlow):
+ authorizationUrl: str
+
+
+class OAuthFlowPassword(OAuthFlow):
+ tokenUrl: str
+
+
+class OAuthFlowClientCredentials(OAuthFlow):
+ tokenUrl: str
+
+
+class OAuthFlowAuthorizationCode(OAuthFlow):
+ authorizationUrl: str
+ tokenUrl: str
+
+
+class OAuthFlows(BaseModel):
+ implicit: Optional[OAuthFlowImplicit] = None
+ password: Optional[OAuthFlowPassword] = None
+ clientCredentials: Optional[OAuthFlowClientCredentials] = None
+ authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
+
+
+class OAuth2(SecurityBase):
+ type_ = PSchema(SecuritySchemeType.oauth2, alias="type")
+ flows: OAuthFlows
+
+
+class OpenIdConnect(SecurityBase):
+ type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")
+ openIdConnectUrl: str
+
+
+SecurityScheme = Union[APIKey, HTTPBase, HTTPBearer, OAuth2, OpenIdConnect]
+
+
+class Components(BaseModel):
+ schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
+ responses: Optional[Dict[str, Union[Response, Reference]]] = None
+ parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
+ examples: Optional[Dict[str, Union[Example, Reference]]] = None
+ requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None
+ headers: Optional[Dict[str, Union[Header, Reference]]] = None
+ securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None
+ links: Optional[Dict[str, Union[Link, Reference]]] = None
+ callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
+
+
+class Tag(BaseModel):
+ name: str
+ description: Optional[str] = None
+ externalDocs: Optional[ExternalDocumentation] = None
+
+
+class OpenAPI(BaseModel):
+ openapi: str
+ info: Info
+ servers: Optional[List[Server]] = None
+ paths: Dict[str, PathItem]
+ components: Optional[Components] = None
+ security: Optional[List[Dict[str, List[str]]]] = None
+ tags: Optional[List[Tag]] = None
+ externalDocs: Optional[ExternalDocumentation] = None
diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py
new file mode 100644
index 000000000..3cf800740
--- /dev/null
+++ b/fastapi/openapi/utils.py
@@ -0,0 +1,280 @@
+from typing import Any, Dict, Sequence, Type
+
+from starlette.responses import HTMLResponse, JSONResponse
+from starlette.routing import BaseRoute
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+
+from fastapi import routing
+from fastapi.dependencies.models import Dependant
+from fastapi.dependencies.utils import get_flat_dependant
+from fastapi.encoders import jsonable_encoder
+from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
+from fastapi.openapi.models import OpenAPI
+from fastapi.params import Body
+from fastapi.utils import get_flat_models_from_routes, get_model_definitions
+from pydantic.fields import Field
+from pydantic.schema import field_schema, get_model_name_map
+from pydantic.utils import lenient_issubclass
+
+validation_error_definition = {
+ "title": "ValidationError",
+ "type": "object",
+ "properties": {
+ "loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
+ "msg": {"title": "Message", "type": "string"},
+ "type": {"title": "Error Type", "type": "string"},
+ },
+ "required": ["loc", "msg", "type"],
+}
+
+validation_error_response_definition = {
+ "title": "HTTPValidationError",
+ "type": "object",
+ "properties": {
+ "detail": {
+ "title": "Detail",
+ "type": "array",
+ "items": {"$ref": REF_PREFIX + "ValidationError"},
+ }
+ },
+}
+
+
+def get_openapi_params(dependant: Dependant):
+ flat_dependant = get_flat_dependant(dependant)
+ return (
+ flat_dependant.path_params
+ + flat_dependant.query_params
+ + flat_dependant.header_params
+ + flat_dependant.cookie_params
+ )
+
+def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
+ if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
+ return None
+ path = {}
+ security_schemes = {}
+ definitions = {}
+ for method in route.methods:
+ operation: Dict[str, Any] = {}
+ if route.tags:
+ operation["tags"] = route.tags
+ if route.summary:
+ operation["summary"] = route.summary
+ if route.description:
+ operation["description"] = route.description
+ if route.operation_id:
+ operation["operationId"] = route.operation_id
+ else:
+ operation["operationId"] = route.name
+ if route.deprecated:
+ operation["deprecated"] = route.deprecated
+ parameters = []
+ flat_dependant = get_flat_dependant(route.dependant)
+ security_definitions = {}
+ for security_requirement in flat_dependant.security_requirements:
+ security_definition = jsonable_encoder(
+ security_requirement.security_scheme,
+ exclude={"scheme_name"},
+ by_alias=True,
+ include_none=False,
+ )
+ security_name = (
+ getattr(
+ security_requirement.security_scheme, "scheme_name", None
+ )
+ or security_requirement.security_scheme.__class__.__name__
+ )
+ security_definitions[security_name] = security_definition
+ operation.setdefault("security", []).append(
+ {security_name: security_requirement.scopes}
+ )
+ if security_definitions:
+ security_schemes.update(
+ security_definitions
+ )
+ all_route_params = get_openapi_params(route.dependant)
+ for param in all_route_params:
+ if "ValidationError" not in definitions:
+ definitions["ValidationError"] = validation_error_definition
+ definitions[
+ "HTTPValidationError"
+ ] = validation_error_response_definition
+ parameter = {
+ "name": param.alias,
+ "in": param.schema.in_.value,
+ "required": param.required,
+ "schema": field_schema(param, model_name_map={})[0],
+ }
+ if param.schema.description:
+ parameter["description"] = param.schema.description
+ if param.schema.deprecated:
+ parameter["deprecated"] = param.schema.deprecated
+ parameters.append(parameter)
+ if parameters:
+ operation["parameters"] = parameters
+ if method in METHODS_WITH_BODY:
+ body_field = route.body_field
+ if body_field:
+ assert isinstance(body_field, Field)
+ body_schema, _ = field_schema(
+ body_field,
+ model_name_map=model_name_map,
+ ref_prefix=REF_PREFIX,
+ )
+ if isinstance(body_field.schema, Body):
+ request_media_type = body_field.schema.media_type
+ else:
+ # Includes not declared media types (Schema)
+ request_media_type = "application/json"
+ required = body_field.required
+ request_body_oai = {}
+ if required:
+ request_body_oai["required"] = required
+ request_body_oai["content"] = {
+ request_media_type: {"schema": body_schema}
+ }
+ operation["requestBody"] = request_body_oai
+ response_code = str(route.response_code)
+ response_schema = {"type": "string"}
+ if lenient_issubclass(route.response_wrapper, JSONResponse):
+ response_media_type = "application/json"
+ if route.response_field:
+ response_schema, _ = field_schema(
+ route.response_field,
+ model_name_map=model_name_map,
+ ref_prefix=REF_PREFIX,
+ )
+ else:
+ response_schema = {}
+ elif lenient_issubclass(route.response_wrapper, HTMLResponse):
+ response_media_type = "text/html"
+ else:
+ response_media_type = "text/plain"
+ content = {response_media_type: {"schema": response_schema}}
+ operation["responses"] = {
+ response_code: {
+ "description": route.response_description,
+ "content": content,
+ }
+ }
+ if all_route_params or route.body_field:
+ operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
+ }
+ },
+ }
+ path[method.lower()] = operation
+ return path, security_schemes, definitions
+
+
+def get_openapi(
+ *,
+ title: str,
+ version: str,
+ openapi_version: str = "3.0.2",
+ description: str = None,
+ routes: Sequence[BaseRoute]
+):
+ info = {"title": title, "version": version}
+ if description:
+ info["description"] = description
+ output = {"openapi": openapi_version, "info": info}
+ components: Dict[str, Dict] = {}
+ paths: Dict[str, Dict] = {}
+ flat_models = get_flat_models_from_routes(routes)
+ model_name_map = get_model_name_map(flat_models)
+ definitions = get_model_definitions(
+ flat_models=flat_models, model_name_map=model_name_map
+ )
+ for route in routes:
+ result = get_openapi_path(route=route, model_name_map=model_name_map)
+ if result:
+ path, security_schemes, path_definitions = result
+ if path:
+ paths.setdefault(route.path, {}).update(path)
+ if security_schemes:
+ components.setdefault("securitySchemes", {}).update(security_schemes)
+ if path_definitions:
+ definitions.update(path_definitions)
+ if definitions:
+ components.setdefault("schemas", {}).update(definitions)
+ if components:
+ output["components"] = components
+ output["paths"] = paths
+ return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
+
+
+def get_swagger_ui_html(*, openapi_url: str, title: str):
+ return HTMLResponse(
+ """
+
+
+
+
+
+ """ + title + """
+
+
+
+
+
+
+
+
+
+
+ """,
+ media_type="text/html",
+ )
+
+
+def get_redoc_html(*, openapi_url: str, title: str):
+ return HTMLResponse(
+ """
+
+
+
+
+ """ + title + """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ media_type="text/html",
+ )
diff --git a/fastapi/params.py b/fastapi/params.py
index 98b80943c..abbce8aeb 100644
--- a/fastapi/params.py
+++ b/fastapi/params.py
@@ -1,5 +1,6 @@
-from typing import Sequence
from enum import Enum
+from typing import Sequence, Any, Dict
+
from pydantic import Schema
@@ -12,6 +13,7 @@ class ParamTypes(Enum):
class Param(Schema):
in_: ParamTypes
+
def __init__(
self,
default,
@@ -27,7 +29,7 @@ class Param(Schema):
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: object,
+ **extra: Dict[str, Any],
):
self.deprecated = deprecated
super().__init__(
@@ -64,7 +66,7 @@ class Path(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: object,
+ **extra: Dict[str, Any],
):
self.description = description
self.deprecated = deprecated
@@ -103,7 +105,7 @@ class Query(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: object,
+ **extra: Dict[str, Any],
):
self.description = description
self.deprecated = deprecated
@@ -141,7 +143,7 @@ class Header(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: object,
+ **extra: Dict[str, Any],
):
self.description = description
self.deprecated = deprecated
@@ -179,7 +181,7 @@ class Cookie(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: object,
+ **extra: Dict[str, Any],
):
self.description = description
self.deprecated = deprecated
@@ -200,11 +202,49 @@ class Cookie(Param):
class Body(Schema):
+ def __init__(
+ self,
+ default,
+ *,
+ embed=False,
+ media_type: str = "application/json",
+ alias: str = None,
+ title: str = None,
+ description: str = None,
+ gt: float = None,
+ ge: float = None,
+ lt: float = None,
+ le: float = None,
+ min_length: int = None,
+ max_length: int = None,
+ regex: str = None,
+ **extra: Dict[str, Any],
+ ):
+ self.embed = embed
+ self.media_type = media_type
+ super().__init__(
+ default,
+ alias=alias,
+ title=title,
+ description=description,
+ gt=gt,
+ ge=ge,
+ lt=lt,
+ le=le,
+ min_length=min_length,
+ max_length=max_length,
+ regex=regex,
+ **extra,
+ )
+
+
+class Form(Body):
def __init__(
self,
default,
*,
sub_key=False,
+ media_type: str = "application/x-www-form-urlencoded",
alias: str = None,
title: str = None,
description: str = None,
@@ -215,11 +255,49 @@ class Body(Schema):
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: object,
+ **extra: Dict[str, Any],
):
- self.sub_key = sub_key
super().__init__(
default,
+ embed=sub_key,
+ media_type=media_type,
+ alias=alias,
+ title=title,
+ description=description,
+ gt=gt,
+ ge=ge,
+ lt=lt,
+ le=le,
+ min_length=min_length,
+ max_length=max_length,
+ regex=regex,
+ **extra,
+ )
+
+
+class File(Form):
+ def __init__(
+ self,
+ default,
+ *,
+ sub_key=False,
+ media_type: str = "multipart/form-data",
+ alias: str = None,
+ title: str = None,
+ description: str = None,
+ gt: float = None,
+ ge: float = None,
+ lt: float = None,
+ le: float = None,
+ min_length: int = None,
+ max_length: int = None,
+ regex: str = None,
+ **extra: Dict[str, Any],
+ ):
+ super().__init__(
+ default,
+ embed=sub_key,
+ media_type=media_type,
alias=alias,
title=title,
description=description,
@@ -235,12 +313,11 @@ class Body(Schema):
class Depends:
- def __init__(self, dependency = None):
+ def __init__(self, dependency=None):
self.dependency = dependency
-class Security:
- def __init__(self, security_scheme = None, scopes: Sequence[str] = None):
- self.security_scheme = security_scheme
- self.scopes = scopes
-
+class Security(Depends):
+ def __init__(self, dependency=None, scopes: Sequence[str] = None):
+ self.scopes = scopes or []
+ super().__init__(dependency=dependency)
diff --git a/fastapi/routing.py b/fastapi/routing.py
index 8c95b3327..6f7d592e5 100644
--- a/fastapi/routing.py
+++ b/fastapi/routing.py
@@ -1,341 +1,66 @@
import asyncio
import inspect
-import re
-import typing
-from copy import deepcopy
+from typing import Callable, List, Type
from starlette import routing
-from starlette.routing import get_name, request_response
-from starlette.requests import Request
-from starlette.responses import Response, JSONResponse
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
+from starlette.formparsers import UploadFile
+from starlette.requests import Request
+from starlette.responses import JSONResponse, Response
+from starlette.routing import get_name, request_response
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
-
-from pydantic.fields import Field, Required
-from pydantic.schema import get_annotation_from_schema
-from pydantic import BaseConfig, BaseModel, create_model, Schema
+from fastapi import params
+from fastapi.dependencies.models import Dependant
+from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
+from fastapi.encoders import jsonable_encoder
+from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
-from pydantic.errors import MissingError
+from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
-from .pydantic_utils import jsonable_encoder
-
-from fastapi import params
-from fastapi.security.base import SecurityBase
-
-param_supported_types = (str, int, float, bool)
-
-class Dependant:
- def __init__(
- self,
- *,
- path_params: typing.List[Field] = None,
- query_params: typing.List[Field] = None,
- header_params: typing.List[Field] = None,
- cookie_params: typing.List[Field] = None,
- body_params: typing.List[Field] = None,
- dependencies: typing.List["Dependant"] = None,
- security_schemes: typing.List[Field] = None,
- name: str = None,
- call: typing.Callable = None,
- request_param_name: str = None,
- ) -> None:
- self.path_params: typing.List[Field] = path_params or []
- self.query_params: typing.List[Field] = query_params or []
- self.header_params: typing.List[Field] = header_params or []
- self.cookie_params: typing.List[Field] = cookie_params or []
- self.body_params: typing.List[Field] = body_params or []
- self.dependencies: typing.List[Dependant] = dependencies or []
- self.security_schemes: typing.List[Field] = security_schemes or []
- self.request_param_name = request_param_name
- self.name = name
- self.call: typing.Callable = call
-
-
-def request_params_to_args(
- required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any]
-) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
- values = {}
- errors = []
- for field in required_params:
- value = received_params.get(field.alias)
- if value is None:
- if field.required:
- errors.append(
- ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
- )
- else:
- values[field.name] = deepcopy(field.default)
- continue
- v_, errors_ = field.validate(
- value, values, loc=(field.schema.in_.value, field.alias)
- )
+def serialize_response(*, field: Field = None, response):
+ if field:
+ errors = []
+ value, errors_ = field.validate(response, {}, loc=("response",))
if isinstance(errors_, ErrorWrapper):
- errors_: ErrorWrapper
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
- else:
- values[field.name] = v_
- return values, errors
-
-
-def request_body_to_args(
- required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any]
-) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
- values = {}
- errors = []
- if required_params:
- field = required_params[0]
- sub_key = getattr(field.schema, "sub_key", None)
- if len(required_params) == 1 and not sub_key:
- received_body = {field.alias: received_body}
- for field in required_params:
- value = received_body.get(field.alias)
- if value is None:
- if field.required:
- errors.append(
- ErrorWrapper(
- MissingError(), loc=("body", field.alias), config=BaseConfig
- )
- )
- else:
- values[field.name] = deepcopy(field.default)
- continue
-
- v_, errors_ = field.validate(value, values, loc=("body", field.alias))
- if isinstance(errors_, ErrorWrapper):
- errors_: ErrorWrapper
- errors.append(errors_)
- elif isinstance(errors_, list):
- errors.extend(errors_)
- else:
- values[field.name] = v_
- return values, errors
-
-
-def add_param_to_fields(
- *,
- param: inspect.Parameter,
- dependant: Dependant,
- default_schema=params.Param,
- force_type: params.ParamTypes = None,
-):
- default_value = Required
- if not param.default == param.empty:
- default_value = param.default
- if isinstance(default_value, params.Param):
- schema = default_value
- default_value = schema.default
- if schema.in_ is None:
- schema.in_ = default_schema.in_
- if force_type:
- schema.in_ = force_type
- else:
- schema = default_schema(default_value)
- required = default_value == Required
- annotation = typing.Any
- if not param.annotation == param.empty:
- annotation = param.annotation
- annotation = get_annotation_from_schema(annotation, schema)
- Config = BaseConfig
- field = Field(
- name=param.name,
- type_=annotation,
- default=None if required else default_value,
- alias=schema.alias or param.name,
- required=required,
- model_config=Config,
- class_validators=[],
- schema=schema,
- )
- if schema.in_ == params.ParamTypes.path:
- dependant.path_params.append(field)
- elif schema.in_ == params.ParamTypes.query:
- dependant.query_params.append(field)
- elif schema.in_ == params.ParamTypes.header:
- dependant.header_params.append(field)
- else:
- assert (
- schema.in_ == params.ParamTypes.cookie
- ), f"non-body parameters must be in path, query, header or cookie: {param.name}"
- dependant.cookie_params.append(field)
-
-
-def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
- default_value = Required
- if not param.default == param.empty:
- default_value = param.default
- if isinstance(default_value, Schema):
- schema = default_value
- default_value = schema.default
+ if errors:
+ raise ValidationError(errors)
+ return jsonable_encoder(value)
else:
- schema = Schema(default_value)
- required = default_value == Required
- annotation = get_annotation_from_schema(param.annotation, schema)
- field = Field(
- name=param.name,
- type_=annotation,
- default=None if required else default_value,
- alias=schema.alias or param.name,
- required=required,
- model_config=BaseConfig,
- class_validators=[],
- schema=schema,
- )
- dependant.body_params.append(field)
+ return jsonable_encoder(response)
-def get_sub_dependant(
- *, param: inspect.Parameter, path: str
+def get_app(
+ dependant: Dependant,
+ body_field: Field = None,
+ response_code: str = 200,
+ response_wrapper: Type[Response] = JSONResponse,
+ response_field: Type[Field] = None,
):
- depends: params.Depends = param.default
- if depends.dependency:
- dependency = depends.dependency
- else:
- dependency = param.annotation
- assert callable(dependency)
- sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
- if isinstance(dependency, SecurityBase):
- sub_dependant.security_schemes.append(dependency)
- return sub_dependant
-
-
-def get_flat_dependant(dependant: Dependant):
- flat_dependant = Dependant(
- path_params=dependant.path_params.copy(),
- query_params=dependant.query_params.copy(),
- header_params=dependant.header_params.copy(),
- cookie_params=dependant.cookie_params.copy(),
- body_params=dependant.body_params.copy(),
- security_schemes=dependant.security_schemes.copy(),
- )
- for sub_dependant in dependant.dependencies:
- if sub_dependant is dependant:
- raise ValueError("recursion", dependant.dependencies)
- flat_sub = get_flat_dependant(sub_dependant)
- flat_dependant.path_params.extend(flat_sub.path_params)
- flat_dependant.query_params.extend(flat_sub.query_params)
- flat_dependant.header_params.extend(flat_sub.header_params)
- flat_dependant.cookie_params.extend(flat_sub.cookie_params)
- flat_dependant.body_params.extend(flat_sub.body_params)
- flat_dependant.security_schemes.extend(flat_sub.security_schemes)
- return flat_dependant
-
-
-def get_path_param_names(path: str):
- return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
-
-
-def get_dependant(*, path: str, call: typing.Callable, name: str = None):
- path_param_names = get_path_param_names(path)
- endpoint_signature = inspect.signature(call)
- signature_params = endpoint_signature.parameters
- dependant = Dependant(call=call, name=name)
- for param_name in signature_params:
- param = signature_params[param_name]
- if isinstance(param.default, params.Depends):
- sub_dependant = get_sub_dependant(param=param, path=path)
- dependant.dependencies.append(sub_dependant)
- for param_name in signature_params:
- param = signature_params[param_name]
- if (
- (param.default == param.empty) or isinstance(param.default, params.Path)
- ) and (param_name in path_param_names):
- assert lenient_issubclass(
- param.annotation, param_supported_types
- ), f"Path params must be of type str, int, float or boot: {param}"
- param = signature_params[param_name]
- add_param_to_fields(
- param=param,
- dependant=dependant,
- default_schema=params.Path,
- force_type=params.ParamTypes.path,
- )
- elif (param.default == param.empty or param.default is None) and (
- param.annotation == param.empty
- or lenient_issubclass(param.annotation, param_supported_types)
- ):
- add_param_to_fields(
- param=param, dependant=dependant, default_schema=params.Query
- )
- elif isinstance(param.default, params.Param):
- if param.annotation != param.empty:
- assert lenient_issubclass(
- param.annotation, param_supported_types
- ), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
- add_param_to_fields(
- param=param, dependant=dependant, default_schema=params.Query
- )
- elif lenient_issubclass(param.annotation, Request):
- dependant.request_param_name = param_name
- elif not isinstance(param.default, params.Depends):
- add_param_to_body_fields(param=param, dependant=dependant)
- return dependant
-
-
-def is_coroutine_callable(call: typing.Callable):
- if inspect.isfunction(call):
- return asyncio.iscoroutinefunction(call)
- elif inspect.isclass(call):
- return False
- else:
- call = getattr(call, "__call__", None)
- if not call:
- return False
- else:
- return asyncio.iscoroutinefunction(call)
-
-
-async def solve_dependencies(*, request: Request, dependant: Dependant):
- values = {}
- errors = []
- for sub_dependant in dependant.dependencies:
- sub_values, sub_errors = await solve_dependencies(
- request=request, dependant=sub_dependant
- )
- if sub_errors:
- return {}, errors
- if is_coroutine_callable(sub_dependant.call):
- solved = await sub_dependant.call(**sub_values)
- else:
- solved = await run_in_threadpool(sub_dependant.call, **sub_values)
- values[sub_dependant.name] = solved
- path_values, path_errors = request_params_to_args(
- dependant.path_params, request.path_params
- )
- query_values, query_errors = request_params_to_args(
- dependant.query_params, request.query_params
- )
- header_values, header_errors = request_params_to_args(
- dependant.header_params, request.headers
- )
- cookie_values, cookie_errors = request_params_to_args(
- dependant.cookie_params, request.cookies
- )
- values.update(path_values)
- values.update(query_values)
- values.update(header_values)
- values.update(cookie_values)
- errors = path_errors + query_errors + header_errors + cookie_errors
- if dependant.body_params:
- body = await request.json()
- body_values, body_errors = request_body_to_args(dependant.body_params, body)
- values.update(body_values)
- errors.extend(body_errors)
- if dependant.request_param_name:
- values[dependant.request_param_name] = request
- return values, errors
-
-
-def get_app(dependant: Dependant):
- is_coroutine = asyncio.iscoroutinefunction(dependant.call)
+ is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
async def app(request: Request) -> Response:
- values, errors = await solve_dependencies(request=request, dependant=dependant)
+ body = None
+ if body_field:
+ if isinstance(body_field.schema, params.Form):
+ raw_body = await request.form()
+ body = {}
+ for field, value in raw_body.items():
+ if isinstance(value, UploadFile):
+ body[field] = await value.read()
+ else:
+ body[field] = value
+ else:
+ body = await request.json()
+ values, errors = await solve_dependencies(
+ request=request, dependant=dependant, body=body
+ )
if errors:
errors_out = ValidationError(errors)
raise HTTPException(
@@ -348,36 +73,56 @@ def get_app(dependant: Dependant):
raw_response = await run_in_threadpool(dependant.call, **values)
if isinstance(raw_response, Response):
return raw_response
- else:
- return JSONResponse(content=jsonable_encoder(raw_response))
- return app
-
+ if isinstance(raw_response, BaseModel):
+ return response_wrapper(
+ content=jsonable_encoder(raw_response), status_code=response_code
+ )
+ errors = []
+ try:
+ return response_wrapper(
+ content=serialize_response(
+ field=response_field, response=raw_response
+ ),
+ status_code=response_code,
+ )
+ except Exception as e:
+ errors.append(e)
+ try:
+ response = dict(raw_response)
+ return response_wrapper(
+ content=serialize_response(field=response_field, response=response),
+ status_code=response_code,
+ )
+ except Exception as e:
+ errors.append(e)
+ try:
+ response = vars(raw_response)
+ return response_wrapper(
+ content=serialize_response(field=response_field, response=response),
+ status_code=response_code,
+ )
+ except Exception as e:
+ errors.append(e)
+ raise ValueError(errors)
-def get_openapi_params(dependant: Dependant):
- flat_dependant = get_flat_dependant(dependant)
- return (
- flat_dependant.path_params
- + flat_dependant.query_params
- + flat_dependant.header_params
- + flat_dependant.cookie_params
- )
+ return app
class APIRoute(routing.Route):
def __init__(
self,
path: str,
- endpoint: typing.Callable,
+ endpoint: Callable,
*,
- methods: typing.List[str] = None,
+ methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -392,12 +137,12 @@ class APIRoute(routing.Route):
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema
- self.tags = tags
+ self.tags = tags or []
self.summary = summary
- self.description = description
+ self.description = description or self.endpoint.__doc__
self.operation_id = operation_id
self.deprecated = deprecated
- self.request_body: typing.Union[BaseModel, Field, None] = None
+ self.body_field: Field = None
self.response_description = response_description
self.response_code = response_code
self.response_wrapper = response_wrapper
@@ -430,53 +175,32 @@ class APIRoute(routing.Route):
), f"An endpoint must be a function or method"
self.dependant = get_dependant(path=path, call=self.endpoint)
- # flat_dependant = get_flat_dependant(self.dependant)
- # path_param_names = get_path_param_names(path)
- # for path_param in path_param_names:
- # assert path_param in {
- # f.alias for f in flat_dependant.path_params
- # }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}"
-
- if self.dependant.body_params:
- first_param = self.dependant.body_params[0]
- sub_key = getattr(first_param.schema, "sub_key", None)
- if len(self.dependant.body_params) == 1 and not sub_key:
- self.request_body = first_param
- else:
- model_name = "Body_" + self.name
- BodyModel = create_model(model_name)
- for f in self.dependant.body_params:
- BodyModel.__fields__[f.name] = f
- required = any(True for f in self.dependant.body_params if f.required)
- field = Field(
- name="body",
- type_=BodyModel,
- default=None,
- required=required,
- model_config=BaseConfig,
- class_validators=[],
- alias="body",
- schema=Schema(None),
- )
- self.request_body = field
-
- self.app = request_response(get_app(dependant=self.dependant))
+ self.body_field = get_body_field(dependant=self.dependant, name=self.name)
+ self.app = request_response(
+ get_app(
+ dependant=self.dependant,
+ body_field=self.body_field,
+ response_code=self.response_code,
+ response_wrapper=self.response_wrapper,
+ response_field=self.response_field,
+ )
+ )
class APIRouter(routing.Router):
def add_api_route(
self,
path: str,
- endpoint: typing.Callable,
- methods: typing.List[str] = None,
+ endpoint: Callable,
+ methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -487,7 +211,7 @@ class APIRouter(routing.Router):
methods=methods,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -502,27 +226,27 @@ class APIRouter(routing.Router):
def api_route(
self,
path: str,
- methods: typing.List[str] = None,
+ methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
- ) -> typing.Callable:
- def decorator(func: typing.Callable) -> typing.Callable:
+ ) -> Callable:
+ def decorator(func: Callable) -> Callable:
self.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -541,12 +265,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -556,7 +280,7 @@ class APIRouter(routing.Router):
methods=["GET"],
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -572,12 +296,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -587,7 +311,7 @@ class APIRouter(routing.Router):
methods=["PUT"],
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -603,12 +327,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -618,7 +342,7 @@ class APIRouter(routing.Router):
methods=["POST"],
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -634,12 +358,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -649,7 +373,7 @@ class APIRouter(routing.Router):
methods=["DELETE"],
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -665,12 +389,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -680,7 +404,7 @@ class APIRouter(routing.Router):
methods=["OPTIONS"],
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -696,12 +420,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -711,7 +435,7 @@ class APIRouter(routing.Router):
methods=["HEAD"],
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -727,12 +451,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -742,7 +466,7 @@ class APIRouter(routing.Router):
methods=["PATCH"],
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@@ -758,12 +482,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
- tags: typing.List[str] = [],
+ tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
- response_type: typing.Type = None,
+ response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@@ -773,7 +497,7 @@ class APIRouter(routing.Router):
methods=["TRACE"],
name=name,
include_in_schema=include_in_schema,
- tags=tags,
+ tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py
index 4b6766eb7..c0354fea7 100644
--- a/fastapi/security/api_key.py
+++ b/fastapi/security/api_key.py
@@ -1,11 +1,10 @@
-from starlette.requests import Request
+from enum import Enum
from pydantic import Schema
-from enum import Enum
-from .base import SecurityBase, Types
-__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"]
+from starlette.requests import Request
+from .base import SecurityBase, Types
class APIKeyIn(Enum):
query = "query"
@@ -21,7 +20,7 @@ class APIKeyBase(SecurityBase):
class APIKeyQuery(APIKeyBase):
in_ = Schema(APIKeyIn.query, alias="in")
-
+
async def __call__(self, requests: Request):
return requests.query_params.get(self.name)
diff --git a/fastapi/security/base.py b/fastapi/security/base.py
index 37433ff25..9ba430df9 100644
--- a/fastapi/security/base.py
+++ b/fastapi/security/base.py
@@ -1,7 +1,6 @@
from enum import Enum
-from pydantic import BaseModel, Schema
-__all__ = ["Types", "SecurityBase"]
+from pydantic import BaseModel, Schema
class Types(Enum):
diff --git a/fastapi/security/http.py b/fastapi/security/http.py
index aaaf86618..7a8bcfe48 100644
--- a/fastapi/security/http.py
+++ b/fastapi/security/http.py
@@ -1,9 +1,8 @@
+from pydantic import Schema
from starlette.requests import Request
-from pydantic import Schema
-from .base import SecurityBase, Types
-__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
+from .base import SecurityBase, Types
class HTTPBase(SecurityBase):
diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py
index a6607ef52..4febdafc2 100644
--- a/fastapi/security/oauth2.py
+++ b/fastapi/security/oauth2.py
@@ -3,10 +3,8 @@ from typing import Dict
from pydantic import BaseModel, Schema
from starlette.requests import Request
-from .base import SecurityBase, Types
-
-# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
+from .base import SecurityBase, Types
class OAuthFlow(BaseModel):
refreshUrl: str = None
diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py
index 2e7791a7a..c84c56de8 100644
--- a/fastapi/security/open_id_connect_url.py
+++ b/fastapi/security/open_id_connect_url.py
@@ -2,6 +2,7 @@ from starlette.requests import Request
from .base import SecurityBase, Types
+
class OpenIdConnect(SecurityBase):
type_ = Types.openIdConnect
openIdConnectUrl: str
diff --git a/fastapi/utils.py b/fastapi/utils.py
new file mode 100644
index 000000000..091f868fe
--- /dev/null
+++ b/fastapi/utils.py
@@ -0,0 +1,46 @@
+import re
+from typing import Dict, Sequence, Set, Type
+
+from starlette.routing import BaseRoute
+
+from fastapi import routing
+from fastapi.openapi.constants import REF_PREFIX
+from pydantic import BaseModel
+from pydantic.fields import Field
+from pydantic.schema import get_flat_models_from_fields, model_process_schema
+
+
+def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
+ body_fields_from_routes = []
+ responses_from_routes = []
+ for route in routes:
+ if route.include_in_schema and isinstance(route, routing.APIRoute):
+ if route.body_field:
+ assert isinstance(
+ route.body_field, Field
+ ), "A request body must be a Pydantic Field"
+ body_fields_from_routes.append(route.body_field)
+ if route.response_field:
+ responses_from_routes.append(route.response_field)
+ flat_models = get_flat_models_from_fields(
+ body_fields_from_routes + responses_from_routes
+ )
+ return flat_models
+
+
+def get_model_definitions(
+ *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
+):
+ definitions: Dict[str, Dict] = {}
+ for model in flat_models:
+ m_schema, m_definitions = model_process_schema(
+ model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
+ )
+ definitions.update(m_definitions)
+ model_name = model_name_map[model]
+ definitions[model_name] = m_schema
+ return definitions
+
+
+def get_path_param_names(path: str):
+ return {item.strip("{}") for item in re.findall("{[^}]*}", path)}