Browse Source

♻️ Refactor, fix and update code

pull/1/head
Sebastián Ramírez 6 years ago
parent
commit
b9d912c638
  1. 2
      fastapi/__init__.py
  2. 345
      fastapi/applications.py
  3. 0
      fastapi/dependencies/__init__.py
  4. 46
      fastapi/dependencies/models.py
  5. 327
      fastapi/dependencies/utils.py
  6. 31
      fastapi/encoders.py
  7. 0
      fastapi/openapi/__init__.py
  8. 2
      fastapi/openapi/constants.py
  9. 347
      fastapi/openapi/models.py
  10. 280
      fastapi/openapi/utils.py
  11. 105
      fastapi/params.py
  12. 526
      fastapi/routing.py
  13. 9
      fastapi/security/api_key.py
  14. 3
      fastapi/security/base.py
  15. 5
      fastapi/security/http.py
  16. 4
      fastapi/security/oauth2.py
  17. 1
      fastapi/security/open_id_connect_url.py
  18. 46
      fastapi/utils.py

2
fastapi/__init__.py

@ -1,3 +1,3 @@
"""Fast API framework, fast high performance, fast to learn, fast to code"""
__version__ = '0.1'
__version__ = "0.1"

345
fastapi/applications.py

@ -1,61 +1,19 @@
import typing
import inspect
from typing import Any, Callable, Dict, List, Type
from starlette.applications import Starlette
from starlette.middleware.lifespan import LifespanMiddleware
from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.exceptions import ExceptionMiddleware
from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse
from starlette.requests import Request
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from starlette.middleware.lifespan import LifespanMiddleware
from starlette.responses import JSONResponse
from pydantic import BaseModel, BaseConfig, Schema
from pydantic.utils import lenient_issubclass
from pydantic.fields import Field
from pydantic.schema import (
field_schema,
get_flat_models_from_models,
get_flat_models_from_fields,
get_model_name_map,
schema,
model_process_schema,
)
from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant
from .pydantic_utils import jsonable_encoder
from fastapi import routing
from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
def docs(openapi_url):
return HTMLResponse(
"""
<! doctype html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
</head>
<body>
<div id="swagger-ui">
</div>
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({
url: '""" + openapi_url + """',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout"
})
</script>
</body>
</html>
""",
media_type="text/html",
)
async def http_exception(request, exc: HTTPException):
print(exc)
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
class FastAPI(Starlette):
@ -67,24 +25,26 @@ class FastAPI(Starlette):
description: str = "",
version: str = "0.1.0",
openapi_url: str = "/openapi.json",
docs_url: str = "/docs",
**extra: typing.Dict[str, typing.Any],
swagger_ui_url: str = "/docs",
redoc_url: str = "/redoc",
**extra: Dict[str, Any],
) -> None:
self._debug = debug
self.router = APIRouter()
self.router = routing.APIRouter()
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
)
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
self.schema_generator = None
self.template_env = self.load_template_env(template_directory)
self.title = title
self.description = description
self.version = version
self.openapi_url = openapi_url
self.docs_url = docs_url
self.swagger_ui_url = swagger_ui_url
self.redoc_url = redoc_url
self.extra = extra
self.openapi_version = "3.0.2"
@ -93,29 +53,52 @@ class FastAPI(Starlette):
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
if self.docs_url:
if self.swagger_ui_url or self.redoc_url:
assert self.openapi_url, "The openapi_url is required for the docs"
self.setup()
self.add_route(
self.openapi_url,
lambda req: JSONResponse(self.openapi()),
include_in_schema=False,
)
self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False)
def setup(self):
if self.openapi_url:
self.add_route(
self.openapi_url,
lambda req: JSONResponse(
get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
description=self.description,
routes=self.routes,
)
),
include_in_schema=False,
)
if self.swagger_ui_url:
self.add_route(
self.swagger_ui_url,
lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"),
include_in_schema=False,
)
if self.redoc_url:
self.add_route(
self.redoc_url,
lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"),
include_in_schema=False,
)
self.add_exception_handler(HTTPException, http_exception)
def add_api_route(
self,
path: str,
endpoint: typing.Callable,
methods: typing.List[str] = None,
endpoint: Callable,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -126,7 +109,7 @@ class FastAPI(Starlette):
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -140,27 +123,27 @@ class FastAPI(Starlette):
def api_route(
self,
path: str,
methods: typing.List[str] = None,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
) -> Callable:
def decorator(func: Callable) -> Callable:
self.router.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -179,12 +162,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -193,7 +176,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -209,12 +192,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -223,7 +206,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -239,12 +222,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -253,7 +236,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -269,12 +252,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -283,7 +266,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -299,12 +282,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -313,7 +296,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -329,12 +312,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -343,7 +326,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -359,12 +342,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -373,7 +356,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -389,12 +372,12 @@ class FastAPI(Starlette):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -403,7 +386,7 @@ class FastAPI(Starlette):
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -413,169 +396,3 @@ class FastAPI(Starlette):
response_code=response_code,
response_wrapper=response_wrapper,
)
def openapi(self):
info = {"title": self.title, "version": self.version}
if self.description:
info["description"] = self.description
output = {"openapi": self.openapi_version, "info": info}
components = {}
paths = {}
methods_with_body = set(("POST", "PUT"))
body_fields_from_routes = []
responses_from_routes = []
ref_prefix = "#/components/schemas/"
for route in self.routes:
route: APIRoute
if route.include_in_schema and isinstance(route, APIRoute):
if route.request_body:
assert isinstance(
route.request_body, Field
), "A request body must be a Pydantic BaseModel or Field"
body_fields_from_routes.append(route.request_body)
if route.response_field:
responses_from_routes.append(route.response_field)
flat_models = get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes
)
model_name_map = get_model_name_map(flat_models)
definitions = {}
for model in flat_models:
m_schema, m_definitions = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=ref_prefix
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
validation_error_definition = {
"title": "ValidationError",
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
"required": ["loc", "msg", "type"],
}
validation_error_response_definition = {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": ref_prefix + "ValidationError"},
}
},
}
for route in self.routes:
route: APIRoute
if route.include_in_schema and isinstance(route, APIRoute):
path = paths.get(route.path, {})
for method in route.methods:
operation = {}
if route.tags:
operation["tags"] = route.tags
if route.summary:
operation["summary"] = route.summary
if route.description:
operation["description"] = route.description
if route.operation_id:
operation["operationId"] = route.operation_id
else:
operation["operationId"] = route.name
if route.deprecated:
operation["deprecated"] = route.deprecated
parameters = []
flat_dependant = get_flat_dependant(route.dependant)
security_definitions = {}
for security_scheme in flat_dependant.security_schemes:
security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False)
security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__
security_definitions[security_name] = security_definition
if security_definitions:
components.setdefault("securitySchemes", {}).update(security_definitions)
operation["security"] = [{name: []} for name in security_definitions]
all_route_params = get_openapi_params(route.dependant)
for param in all_route_params:
if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition
definitions[
"HTTPValidationError"
] = validation_error_response_definition
parameter = {
"name": param.alias,
"in": param.schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
if param.schema.description:
parameter["description"] = param.schema.description
if param.schema.deprecated:
parameter["deprecated"] = param.schema.deprecated
parameters.append(parameter)
if parameters:
operation["parameters"] = parameters
if method in methods_with_body:
request_body = getattr(route, "request_body", None)
if request_body:
assert isinstance(request_body, Field)
body_schema, _ = field_schema(
request_body,
model_name_map=model_name_map,
ref_prefix=ref_prefix,
)
required = request_body.required
request_body_oai = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {
"application/json": {"schema": body_schema}
}
operation["requestBody"] = request_body_oai
response_code = str(route.response_code)
response_schema = {"type": "string"}
if lenient_issubclass(route.response_wrapper, JSONResponse):
response_media_type = "application/json"
if route.response_field:
response_schema, _ = field_schema(
route.response_field,
model_name_map=model_name_map,
ref_prefix=ref_prefix,
)
else:
response_schema = {}
elif lenient_issubclass(route.response_wrapper, HTMLResponse):
response_media_type = "text/html"
else:
response_media_type = "text/plain"
content = {response_media_type: {"schema": response_schema}}
operation["responses"] = {
response_code: {
"description": route.response_description,
"content": content,
}
}
if all_route_params or getattr(route, "request_body", None):
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": ref_prefix + "HTTPValidationError"
}
}
},
}
path[method.lower()] = operation
paths[route.path] = path
if definitions:
components.setdefault("schemas", {}).update(definitions)
if components:
output["components"] = components
output["paths"] = paths
return output

0
fastapi/dependencies/__init__.py

46
fastapi/dependencies/models.py

@ -0,0 +1,46 @@
from typing import Any, Callable, Dict, List, Sequence, Tuple
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from fastapi.security.base import SecurityBase
from pydantic import BaseConfig, Schema
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
param_supported_types = (str, int, float, bool)
class SecurityRequirement:
def __init__(self, security_scheme: SecurityBase, scopes: Sequence[str] = None):
self.security_scheme = security_scheme
self.scopes = scopes
class Dependant:
def __init__(
self,
*,
path_params: List[Field] = None,
query_params: List[Field] = None,
header_params: List[Field] = None,
cookie_params: List[Field] = None,
body_params: List[Field] = None,
dependencies: List["Dependant"] = None,
security_schemes: List[SecurityRequirement] = None,
name: str = None,
call: Callable = None,
request_param_name: str = None,
) -> None:
self.path_params = path_params or []
self.query_params = query_params or []
self.header_params = header_params or []
self.cookie_params = cookie_params or []
self.body_params = body_params or []
self.dependencies = dependencies or []
self.security_requirements = security_schemes or []
self.request_param_name = request_param_name
self.name = name
self.call = call

327
fastapi/dependencies/utils.py

@ -0,0 +1,327 @@
import asyncio
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Tuple
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from fastapi import params
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.security.base import SecurityBase
from fastapi.utils import get_path_param_names
from pydantic import BaseConfig, Schema, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
from pydantic.utils import lenient_issubclass
param_supported_types = (str, int, float, bool)
def get_sub_dependant(*, param: inspect.Parameter, path: str):
depends: params.Depends = param.default
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
assert callable(dependency)
sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
if isinstance(depends, params.Security) and isinstance(dependency, SecurityBase):
security_requirement = SecurityRequirement(
security_scheme=dependency, scopes=depends.scopes
)
sub_dependant.security_requirements.append(security_requirement)
return sub_dependant
def get_flat_dependant(dependant: Dependant):
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_schemes=dependant.security_requirements.copy(),
)
for sub_dependant in dependant.dependencies:
if sub_dependant is dependant:
raise ValueError("recursion", dependant.dependencies)
flat_sub = get_flat_dependant(sub_dependant)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
return flat_dependant
def get_dependant(*, path: str, call: Callable, name: str = None):
path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call)
signature_params = endpoint_signature.parameters
dependant = Dependant(call=call, name=name)
for param_name in signature_params:
param = signature_params[param_name]
if isinstance(param.default, params.Depends):
sub_dependant = get_sub_dependant(param=param, path=path)
dependant.dependencies.append(sub_dependant)
for param_name in signature_params:
param = signature_params[param_name]
if (
(param.default == param.empty) or isinstance(param.default, params.Path)
) and (param_name in path_param_names):
assert lenient_issubclass(
param.annotation, param_supported_types
) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}"
param = signature_params[param_name]
add_param_to_fields(
param=param,
dependant=dependant,
default_schema=params.Path,
force_type=params.ParamTypes.path,
)
elif (param.default == param.empty or param.default is None) and (
param.annotation == param.empty
or lenient_issubclass(param.annotation, param_supported_types)
):
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif isinstance(param.default, params.Param):
if param.annotation != param.empty:
assert lenient_issubclass(
param.annotation, param_supported_types
), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param_name
elif not isinstance(param.default, params.Depends):
add_param_to_body_fields(param=param, dependant=dependant)
return dependant
def add_param_to_fields(
*,
param: inspect.Parameter,
dependant: Dependant,
default_schema=params.Param,
force_type: params.ParamTypes = None,
):
default_value = Required
if not param.default == param.empty:
default_value = param.default
if isinstance(default_value, params.Param):
schema = default_value
default_value = schema.default
if schema.in_ is None:
schema.in_ = default_schema.in_
if force_type:
schema.in_ = force_type
else:
schema = default_schema(default_value)
required = default_value == Required
annotation = Any
if not param.annotation == param.empty:
annotation = param.annotation
annotation = get_annotation_from_schema(annotation, schema)
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=schema.alias or param.name,
required=required,
model_config=BaseConfig(),
class_validators=[],
schema=schema,
)
if schema.in_ == params.ParamTypes.path:
dependant.path_params.append(field)
elif schema.in_ == params.ParamTypes.query:
dependant.query_params.append(field)
elif schema.in_ == params.ParamTypes.header:
dependant.header_params.append(field)
else:
assert (
schema.in_ == params.ParamTypes.cookie
), f"non-body parameters must be in path, query, header or cookie: {param.name}"
dependant.cookie_params.append(field)
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
default_value = Required
if not param.default == param.empty:
default_value = param.default
if isinstance(default_value, Schema):
schema = default_value
default_value = schema.default
else:
schema = Schema(default_value)
required = default_value == Required
annotation = get_annotation_from_schema(param.annotation, schema)
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=schema.alias or param.name,
required=required,
model_config=BaseConfig,
class_validators=[],
schema=schema,
)
dependant.body_params.append(field)
def is_coroutine_callable(call: Callable = None):
if not call:
return False
if inspect.isfunction(call):
return asyncio.iscoroutinefunction(call)
if inspect.isclass(call):
return False
call = getattr(call, "__call__", None)
if not call:
return False
return asyncio.iscoroutinefunction(call)
async def solve_dependencies(
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None
):
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
for sub_dependant in dependant.dependencies:
sub_values, sub_errors = await solve_dependencies(
request=request, dependant=sub_dependant, body=body
)
if sub_errors:
return {}, errors
if sub_dependant.call and is_coroutine_callable(sub_dependant.call):
solved = await sub_dependant.call(**sub_values)
else:
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
values[
sub_dependant.name
] = solved # type: ignore # Sub-dependants always have a name
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values)
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors = path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above
dependant.body_params, body
)
values.update(body_values)
errors.extend(body_errors)
if dependant.request_param_name:
values[dependant.request_param_name] = request
return values, errors
def request_params_to_args(
required_params: List[Field], received_params: Dict[str, Any]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {}
errors = []
for field in required_params:
value = received_params.get(field.alias)
if value is None:
if field.required:
errors.append(
ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
)
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(
value, values, loc=(field.schema.in_.value, field.alias)
)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
async def request_body_to_args(
required_params: List[Field], received_body: Dict[str, Any]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {}
errors = []
if required_params:
field = required_params[0]
embed = getattr(field.schema, "embed", None)
if len(required_params) == 1 and not embed:
received_body = {field.alias: received_body}
for field in required_params:
value = received_body.get(field.alias)
if value is None:
if field.required:
errors.append(
ErrorWrapper(
MissingError(), loc=("body", field.alias), config=BaseConfig
)
)
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def get_body_field(*, dependant: Dependant, name: str):
flat_dependant = get_flat_dependant(dependant)
if not flat_dependant.body_params:
return None
first_param = flat_dependant.body_params[0]
embed = getattr(first_param.schema, "embed", None)
if len(flat_dependant.body_params) == 1 and not embed:
return first_param
model_name = "Body_" + name
BodyModel = create_model(model_name)
for f in flat_dependant.body_params:
BodyModel.__fields__[f.name] = f
required = any(True for f in flat_dependant.body_params if f.required)
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
BodySchema = params.File
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
BodySchema = params.Form
else:
BodySchema = params.Body
field = Field(
name="body",
type_=BodyModel,
default=None,
required=required,
model_config=BaseConfig,
class_validators=[],
alias="body",
schema=BodySchema(None),
)
return field

31
fastapi/pydantic_utils.py → fastapi/encoders.py

@ -1,33 +1,44 @@
from enum import Enum
from types import GeneratorType
from typing import Set
from pydantic import BaseModel
from enum import Enum
from pydantic.json import pydantic_encoder
def jsonable_encoder(
obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True,
obj,
include: Set[str] = None,
exclude: Set[str] = set(),
by_alias: bool = False,
include_none=True,
):
if isinstance(obj, BaseModel):
return jsonable_encoder(
obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none
obj.dict(include=include, exclude=exclude, by_alias=by_alias),
include_none=include_none,
)
elif isinstance(obj, Enum):
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dict):
return {
jsonable_encoder(
key, by_alias=by_alias, include_none=include_none,
): jsonable_encoder(
value, by_alias=by_alias, include_none=include_none,
)
for key, value in obj.items() if value is not None or include_none
key, by_alias=by_alias, include_none=include_none
): jsonable_encoder(value, by_alias=by_alias, include_none=include_none)
for key, value in obj.items()
if value is not None or include_none
}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [
jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none)
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
include_none=include_none,
)
for item in obj
]
return pydantic_encoder(obj)

0
fastapi/openapi/__init__.py

2
fastapi/openapi/constants.py

@ -0,0 +1,2 @@
METHODS_WITH_BODY = set(("POST", "PUT"))
REF_PREFIX = "#/components/schemas/"

347
fastapi/openapi/models.py

@ -0,0 +1,347 @@
import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Schema as PSchema
from pydantic.types import UrlStr
try:
import pydantic.types.EmailStr
from pydantic.types import EmailStr
except ImportError:
logging.warning(
"email-validator not installed, email fields will be treated as str"
)
class EmailStr(str):
pass
class Contact(BaseModel):
name: Optional[str] = None
url: Optional[UrlStr] = None
email: Optional[EmailStr] = None
class License(BaseModel):
name: str
url: Optional[UrlStr] = None
class Info(BaseModel):
title: str
description: Optional[str] = None
termsOfService: Optional[str] = None
contact: Optional[Contact] = None
license: Optional[License] = None
version: str
class ServerVariable(BaseModel):
enum: Optional[List[str]] = None
default: str
description: Optional[str] = None
class Server(BaseModel):
url: UrlStr
description: Optional[str] = None
variables: Optional[Dict[str, ServerVariable]] = None
class Reference(BaseModel):
ref: str = PSchema(..., alias="$ref")
class Discriminator(BaseModel):
propertyName: str
mapping: Optional[Dict[str, str]] = None
class XML(BaseModel):
name: Optional[str] = None
namespace: Optional[str] = None
prefix: Optional[str] = None
attribute: Optional[bool] = None
wrapped: Optional[bool] = None
class ExternalDocumentation(BaseModel):
description: Optional[str] = None
url: UrlStr
class SchemaBase(BaseModel):
ref: Optional[str] = PSchema(None, alias="$ref")
title: Optional[str] = None
multipleOf: Optional[float] = None
maximum: Optional[float] = None
exclusiveMaximum: Optional[float] = None
minimum: Optional[float] = None
exclusiveMinimum: Optional[float] = None
maxLength: Optional[int] = PSchema(None, gte=0)
minLength: Optional[int] = PSchema(None, gte=0)
pattern: Optional[str] = None
maxItems: Optional[int] = PSchema(None, gte=0)
minItems: Optional[int] = PSchema(None, gte=0)
uniqueItems: Optional[bool] = None
maxProperties: Optional[int] = PSchema(None, gte=0)
minProperties: Optional[int] = PSchema(None, gte=0)
required: Optional[List[str]] = None
enum: Optional[List[str]] = None
type: Optional[str] = None
allOf: Optional[List[Any]] = None
oneOf: Optional[List[Any]] = None
anyOf: Optional[List[Any]] = None
not_: Optional[List[Any]] = PSchema(None, alias="not")
items: Optional[Any] = None
properties: Optional[Dict[str, Any]] = None
additionalProperties: Optional[Union[bool, Any]] = None
description: Optional[str] = None
format: Optional[str] = None
default: Optional[Any] = None
nullable: Optional[bool] = None
discriminator: Optional[Discriminator] = None
readOnly: Optional[bool] = None
writeOnly: Optional[bool] = None
xml: Optional[XML] = None
externalDocs: Optional[ExternalDocumentation] = None
example: Optional[Any] = None
deprecated: Optional[bool] = None
class Schema(SchemaBase):
allOf: Optional[List[SchemaBase]] = None
oneOf: Optional[List[SchemaBase]] = None
anyOf: Optional[List[SchemaBase]] = None
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")
items: Optional[SchemaBase] = None
properties: Optional[Dict[str, SchemaBase]] = None
additionalProperties: Optional[Union[bool, SchemaBase]] = None
class Example(BaseModel):
summary: Optional[str] = None
description: Optional[str] = None
value: Optional[Any] = None
externalValue: Optional[UrlStr] = None
class ParameterInType(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
class Encoding(BaseModel):
contentType: Optional[str] = None
# Workaround OpenAPI recursive reference, using Any
headers: Optional[Dict[str, Union[Any, Reference]]] = None
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
class MediaType(BaseModel):
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
encoding: Optional[Dict[str, Encoding]] = None
class ParameterBase(BaseModel):
description: Optional[str] = None
required: Optional[bool] = None
deprecated: Optional[bool] = None
# Serialization rules for simple scenarios
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
# Serialization rules for more complex scenarios
content: Optional[Dict[str, MediaType]] = None
class Parameter(ParameterBase):
name: str
in_: ParameterInType = PSchema(..., alias="in")
class Header(ParameterBase):
pass
# Workaround OpenAPI recursive reference
class EncodingWithHeaders(Encoding):
headers: Optional[Dict[str, Union[Header, Reference]]] = None
class RequestBody(BaseModel):
description: Optional[str] = None
content: Dict[str, MediaType]
required: Optional[bool] = None
class Link(BaseModel):
operationRef: Optional[str] = None
operationId: Optional[str] = None
parameters: Optional[Dict[str, Union[Any, str]]] = None
requestBody: Optional[Union[Any, str]] = None
description: Optional[str] = None
server: Optional[Server] = None
class Response(BaseModel):
description: str
headers: Optional[Dict[str, Union[Header, Reference]]] = None
content: Optional[Dict[str, MediaType]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
class Responses(BaseModel):
default: Response
class Operation(BaseModel):
tags: Optional[List[str]] = None
summary: Optional[str] = None
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
operationId: Optional[str] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
requestBody: Optional[Union[RequestBody, Reference]] = None
responses: Union[Responses, Dict[Union[str], Response]]
# Workaround OpenAPI recursive reference
callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None
deprecated: Optional[bool] = None
security: Optional[List[Dict[str, List[str]]]] = None
servers: Optional[List[Server]] = None
class PathItem(BaseModel):
ref: Optional[str] = PSchema(None, alias="$ref")
summary: Optional[str] = None
description: Optional[str] = None
get: Optional[Operation] = None
put: Optional[Operation] = None
post: Optional[Operation] = None
delete: Optional[Operation] = None
options: Optional[Operation] = None
head: Optional[Operation] = None
patch: Optional[Operation] = None
trace: Optional[Operation] = None
servers: Optional[List[Server]] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
# Workaround OpenAPI recursive reference
class OperationWithCallbacks(BaseModel):
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
class SecuritySchemeType(Enum):
apiKey = "apiKey"
http = "http"
oauth2 = "oauth2"
openIdConnect = "openIdConnect"
class SecurityBase(BaseModel):
type_: SecuritySchemeType = PSchema(..., alias="type")
description: Optional[str] = None
class APIKeyIn(Enum):
query = "query"
header = "header"
cookie = "cookie"
class APIKey(SecurityBase):
type_ = PSchema(SecuritySchemeType.apiKey, alias="type")
in_: APIKeyIn = PSchema(..., alias="in")
name: str
class HTTPBase(SecurityBase):
type_ = PSchema(SecuritySchemeType.http, alias="type")
scheme: str
class HTTPBearer(HTTPBase):
scheme = "bearer"
bearerFormat: Optional[str] = None
class OAuthFlow(BaseModel):
refreshUrl: Optional[str] = None
scopes: Dict[str, str] = {}
class OAuthFlowImplicit(OAuthFlow):
authorizationUrl: str
class OAuthFlowPassword(OAuthFlow):
tokenUrl: str
class OAuthFlowClientCredentials(OAuthFlow):
tokenUrl: str
class OAuthFlowAuthorizationCode(OAuthFlow):
authorizationUrl: str
tokenUrl: str
class OAuthFlows(BaseModel):
implicit: Optional[OAuthFlowImplicit] = None
password: Optional[OAuthFlowPassword] = None
clientCredentials: Optional[OAuthFlowClientCredentials] = None
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
class OAuth2(SecurityBase):
type_ = PSchema(SecuritySchemeType.oauth2, alias="type")
flows: OAuthFlows
class OpenIdConnect(SecurityBase):
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")
openIdConnectUrl: str
SecurityScheme = Union[APIKey, HTTPBase, HTTPBearer, OAuth2, OpenIdConnect]
class Components(BaseModel):
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
responses: Optional[Dict[str, Union[Response, Reference]]] = None
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None
headers: Optional[Dict[str, Union[Header, Reference]]] = None
securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
class Tag(BaseModel):
name: str
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
class OpenAPI(BaseModel):
openapi: str
info: Info
servers: Optional[List[Server]] = None
paths: Dict[str, PathItem]
components: Optional[Components] = None
security: Optional[List[Dict[str, List[str]]]] = None
tags: Optional[List[Tag]] = None
externalDocs: Optional[ExternalDocumentation] = None

280
fastapi/openapi/utils.py

@ -0,0 +1,280 @@
from typing import Any, Dict, Sequence, Type
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from fastapi import routing
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
from pydantic.fields import Field
from pydantic.schema import field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass
validation_error_definition = {
"title": "ValidationError",
"type": "object",
"properties": {
"loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
"required": ["loc", "msg", "type"],
}
validation_error_response_definition = {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": REF_PREFIX + "ValidationError"},
}
},
}
def get_openapi_params(dependant: Dependant):
flat_dependant = get_flat_dependant(dependant)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
return None
path = {}
security_schemes = {}
definitions = {}
for method in route.methods:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
if route.summary:
operation["summary"] = route.summary
if route.description:
operation["description"] = route.description
if route.operation_id:
operation["operationId"] = route.operation_id
else:
operation["operationId"] = route.name
if route.deprecated:
operation["deprecated"] = route.deprecated
parameters = []
flat_dependant = get_flat_dependant(route.dependant)
security_definitions = {}
for security_requirement in flat_dependant.security_requirements:
security_definition = jsonable_encoder(
security_requirement.security_scheme,
exclude={"scheme_name"},
by_alias=True,
include_none=False,
)
security_name = (
getattr(
security_requirement.security_scheme, "scheme_name", None
)
or security_requirement.security_scheme.__class__.__name__
)
security_definitions[security_name] = security_definition
operation.setdefault("security", []).append(
{security_name: security_requirement.scopes}
)
if security_definitions:
security_schemes.update(
security_definitions
)
all_route_params = get_openapi_params(route.dependant)
for param in all_route_params:
if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition
definitions[
"HTTPValidationError"
] = validation_error_response_definition
parameter = {
"name": param.alias,
"in": param.schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
if param.schema.description:
parameter["description"] = param.schema.description
if param.schema.deprecated:
parameter["deprecated"] = param.schema.deprecated
parameters.append(parameter)
if parameters:
operation["parameters"] = parameters
if method in METHODS_WITH_BODY:
body_field = route.body_field
if body_field:
assert isinstance(body_field, Field)
body_schema, _ = field_schema(
body_field,
model_name_map=model_name_map,
ref_prefix=REF_PREFIX,
)
if isinstance(body_field.schema, Body):
request_media_type = body_field.schema.media_type
else:
# Includes not declared media types (Schema)
request_media_type = "application/json"
required = body_field.required
request_body_oai = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {
request_media_type: {"schema": body_schema}
}
operation["requestBody"] = request_body_oai
response_code = str(route.response_code)
response_schema = {"type": "string"}
if lenient_issubclass(route.response_wrapper, JSONResponse):
response_media_type = "application/json"
if route.response_field:
response_schema, _ = field_schema(
route.response_field,
model_name_map=model_name_map,
ref_prefix=REF_PREFIX,
)
else:
response_schema = {}
elif lenient_issubclass(route.response_wrapper, HTMLResponse):
response_media_type = "text/html"
else:
response_media_type = "text/plain"
content = {response_media_type: {"schema": response_schema}}
operation["responses"] = {
response_code: {
"description": route.response_description,
"content": content,
}
}
if all_route_params or route.body_field:
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
}
},
}
path[method.lower()] = operation
return path, security_schemes, definitions
def get_openapi(
*,
title: str,
version: str,
openapi_version: str = "3.0.2",
description: str = None,
routes: Sequence[BaseRoute]
):
info = {"title": title, "version": version}
if description:
info["description"] = description
output = {"openapi": openapi_version, "info": info}
components: Dict[str, Dict] = {}
paths: Dict[str, Dict] = {}
flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models)
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
result = get_openapi_path(route=route, model_name_map=model_name_map)
if result:
path, security_schemes, path_definitions = result
if path:
paths.setdefault(route.path, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(security_schemes)
if path_definitions:
definitions.update(path_definitions)
if definitions:
components.setdefault("schemas", {}).update(definitions)
if components:
output["components"] = components
output["paths"] = paths
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
def get_swagger_ui_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<! doctype html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
<title>
""" + title + """
</title>
</head>
<body>
<div id="swagger-ui">
</div>
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({
url: '"""
+ openapi_url
+ """',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout"
})
</script>
</body>
</html>
""",
media_type="text/html",
)
def get_redoc_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<title>
""" + title + """
</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {
margin: 0;
padding: 0;
}
</style>
</head>
<body>
<redoc spec-url='""" + openapi_url + """'></redoc>
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
</body>
</html>
""",
media_type="text/html",
)

105
fastapi/params.py

@ -1,5 +1,6 @@
from typing import Sequence
from enum import Enum
from typing import Sequence, Any, Dict
from pydantic import Schema
@ -12,6 +13,7 @@ class ParamTypes(Enum):
class Param(Schema):
in_: ParamTypes
def __init__(
self,
default,
@ -27,7 +29,7 @@ class Param(Schema):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
**extra: Dict[str, Any],
):
self.deprecated = deprecated
super().__init__(
@ -64,7 +66,7 @@ class Path(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
**extra: Dict[str, Any],
):
self.description = description
self.deprecated = deprecated
@ -103,7 +105,7 @@ class Query(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
**extra: Dict[str, Any],
):
self.description = description
self.deprecated = deprecated
@ -141,7 +143,7 @@ class Header(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
**extra: Dict[str, Any],
):
self.description = description
self.deprecated = deprecated
@ -179,7 +181,7 @@ class Cookie(Param):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
**extra: Dict[str, Any],
):
self.description = description
self.deprecated = deprecated
@ -200,11 +202,49 @@ class Cookie(Param):
class Body(Schema):
def __init__(
self,
default,
*,
embed=False,
media_type: str = "application/json",
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
):
self.embed = embed
self.media_type = media_type
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Form(Body):
def __init__(
self,
default,
*,
sub_key=False,
media_type: str = "application/x-www-form-urlencoded",
alias: str = None,
title: str = None,
description: str = None,
@ -215,11 +255,49 @@ class Body(Schema):
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
**extra: Dict[str, Any],
):
self.sub_key = sub_key
super().__init__(
default,
embed=sub_key,
media_type=media_type,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class File(Form):
def __init__(
self,
default,
*,
sub_key=False,
media_type: str = "multipart/form-data",
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Dict[str, Any],
):
super().__init__(
default,
embed=sub_key,
media_type=media_type,
alias=alias,
title=title,
description=description,
@ -235,12 +313,11 @@ class Body(Schema):
class Depends:
def __init__(self, dependency = None):
def __init__(self, dependency=None):
self.dependency = dependency
class Security:
def __init__(self, security_scheme = None, scopes: Sequence[str] = None):
self.security_scheme = security_scheme
self.scopes = scopes
class Security(Depends):
def __init__(self, dependency=None, scopes: Sequence[str] = None):
self.scopes = scopes or []
super().__init__(dependency=dependency)

526
fastapi/routing.py

@ -1,341 +1,66 @@
import asyncio
import inspect
import re
import typing
from copy import deepcopy
from typing import Callable, List, Type
from starlette import routing
from starlette.routing import get_name, request_response
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.formparsers import UploadFile
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import get_name, request_response
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
from pydantic import BaseConfig, BaseModel, create_model, Schema
from fastapi import params
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.errors import MissingError
from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
from .pydantic_utils import jsonable_encoder
from fastapi import params
from fastapi.security.base import SecurityBase
param_supported_types = (str, int, float, bool)
class Dependant:
def __init__(
self,
*,
path_params: typing.List[Field] = None,
query_params: typing.List[Field] = None,
header_params: typing.List[Field] = None,
cookie_params: typing.List[Field] = None,
body_params: typing.List[Field] = None,
dependencies: typing.List["Dependant"] = None,
security_schemes: typing.List[Field] = None,
name: str = None,
call: typing.Callable = None,
request_param_name: str = None,
) -> None:
self.path_params: typing.List[Field] = path_params or []
self.query_params: typing.List[Field] = query_params or []
self.header_params: typing.List[Field] = header_params or []
self.cookie_params: typing.List[Field] = cookie_params or []
self.body_params: typing.List[Field] = body_params or []
self.dependencies: typing.List[Dependant] = dependencies or []
self.security_schemes: typing.List[Field] = security_schemes or []
self.request_param_name = request_param_name
self.name = name
self.call: typing.Callable = call
def request_params_to_args(
required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
values = {}
errors = []
for field in required_params:
value = received_params.get(field.alias)
if value is None:
if field.required:
errors.append(
ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
)
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(
value, values, loc=(field.schema.in_.value, field.alias)
)
def serialize_response(*, field: Field = None, response):
if field:
errors = []
value, errors_ = field.validate(response, {}, loc=("response",))
if isinstance(errors_, ErrorWrapper):
errors_: ErrorWrapper
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def request_body_to_args(
required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
values = {}
errors = []
if required_params:
field = required_params[0]
sub_key = getattr(field.schema, "sub_key", None)
if len(required_params) == 1 and not sub_key:
received_body = {field.alias: received_body}
for field in required_params:
value = received_body.get(field.alias)
if value is None:
if field.required:
errors.append(
ErrorWrapper(
MissingError(), loc=("body", field.alias), config=BaseConfig
)
)
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
if isinstance(errors_, ErrorWrapper):
errors_: ErrorWrapper
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def add_param_to_fields(
*,
param: inspect.Parameter,
dependant: Dependant,
default_schema=params.Param,
force_type: params.ParamTypes = None,
):
default_value = Required
if not param.default == param.empty:
default_value = param.default
if isinstance(default_value, params.Param):
schema = default_value
default_value = schema.default
if schema.in_ is None:
schema.in_ = default_schema.in_
if force_type:
schema.in_ = force_type
else:
schema = default_schema(default_value)
required = default_value == Required
annotation = typing.Any
if not param.annotation == param.empty:
annotation = param.annotation
annotation = get_annotation_from_schema(annotation, schema)
Config = BaseConfig
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=schema.alias or param.name,
required=required,
model_config=Config,
class_validators=[],
schema=schema,
)
if schema.in_ == params.ParamTypes.path:
dependant.path_params.append(field)
elif schema.in_ == params.ParamTypes.query:
dependant.query_params.append(field)
elif schema.in_ == params.ParamTypes.header:
dependant.header_params.append(field)
else:
assert (
schema.in_ == params.ParamTypes.cookie
), f"non-body parameters must be in path, query, header or cookie: {param.name}"
dependant.cookie_params.append(field)
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
default_value = Required
if not param.default == param.empty:
default_value = param.default
if isinstance(default_value, Schema):
schema = default_value
default_value = schema.default
if errors:
raise ValidationError(errors)
return jsonable_encoder(value)
else:
schema = Schema(default_value)
required = default_value == Required
annotation = get_annotation_from_schema(param.annotation, schema)
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=schema.alias or param.name,
required=required,
model_config=BaseConfig,
class_validators=[],
schema=schema,
)
dependant.body_params.append(field)
return jsonable_encoder(response)
def get_sub_dependant(
*, param: inspect.Parameter, path: str
def get_app(
dependant: Dependant,
body_field: Field = None,
response_code: str = 200,
response_wrapper: Type[Response] = JSONResponse,
response_field: Type[Field] = None,
):
depends: params.Depends = param.default
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
assert callable(dependency)
sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
if isinstance(dependency, SecurityBase):
sub_dependant.security_schemes.append(dependency)
return sub_dependant
def get_flat_dependant(dependant: Dependant):
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_schemes=dependant.security_schemes.copy(),
)
for sub_dependant in dependant.dependencies:
if sub_dependant is dependant:
raise ValueError("recursion", dependant.dependencies)
flat_sub = get_flat_dependant(sub_dependant)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_schemes.extend(flat_sub.security_schemes)
return flat_dependant
def get_path_param_names(path: str):
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
def get_dependant(*, path: str, call: typing.Callable, name: str = None):
path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call)
signature_params = endpoint_signature.parameters
dependant = Dependant(call=call, name=name)
for param_name in signature_params:
param = signature_params[param_name]
if isinstance(param.default, params.Depends):
sub_dependant = get_sub_dependant(param=param, path=path)
dependant.dependencies.append(sub_dependant)
for param_name in signature_params:
param = signature_params[param_name]
if (
(param.default == param.empty) or isinstance(param.default, params.Path)
) and (param_name in path_param_names):
assert lenient_issubclass(
param.annotation, param_supported_types
), f"Path params must be of type str, int, float or boot: {param}"
param = signature_params[param_name]
add_param_to_fields(
param=param,
dependant=dependant,
default_schema=params.Path,
force_type=params.ParamTypes.path,
)
elif (param.default == param.empty or param.default is None) and (
param.annotation == param.empty
or lenient_issubclass(param.annotation, param_supported_types)
):
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif isinstance(param.default, params.Param):
if param.annotation != param.empty:
assert lenient_issubclass(
param.annotation, param_supported_types
), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param_name
elif not isinstance(param.default, params.Depends):
add_param_to_body_fields(param=param, dependant=dependant)
return dependant
def is_coroutine_callable(call: typing.Callable):
if inspect.isfunction(call):
return asyncio.iscoroutinefunction(call)
elif inspect.isclass(call):
return False
else:
call = getattr(call, "__call__", None)
if not call:
return False
else:
return asyncio.iscoroutinefunction(call)
async def solve_dependencies(*, request: Request, dependant: Dependant):
values = {}
errors = []
for sub_dependant in dependant.dependencies:
sub_values, sub_errors = await solve_dependencies(
request=request, dependant=sub_dependant
)
if sub_errors:
return {}, errors
if is_coroutine_callable(sub_dependant.call):
solved = await sub_dependant.call(**sub_values)
else:
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
values[sub_dependant.name] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values)
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors = path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
body = await request.json()
body_values, body_errors = request_body_to_args(dependant.body_params, body)
values.update(body_values)
errors.extend(body_errors)
if dependant.request_param_name:
values[dependant.request_param_name] = request
return values, errors
def get_app(dependant: Dependant):
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
async def app(request: Request) -> Response:
values, errors = await solve_dependencies(request=request, dependant=dependant)
body = None
if body_field:
if isinstance(body_field.schema, params.Form):
raw_body = await request.form()
body = {}
for field, value in raw_body.items():
if isinstance(value, UploadFile):
body[field] = await value.read()
else:
body[field] = value
else:
body = await request.json()
values, errors = await solve_dependencies(
request=request, dependant=dependant, body=body
)
if errors:
errors_out = ValidationError(errors)
raise HTTPException(
@ -348,36 +73,56 @@ def get_app(dependant: Dependant):
raw_response = await run_in_threadpool(dependant.call, **values)
if isinstance(raw_response, Response):
return raw_response
else:
return JSONResponse(content=jsonable_encoder(raw_response))
return app
if isinstance(raw_response, BaseModel):
return response_wrapper(
content=jsonable_encoder(raw_response), status_code=response_code
)
errors = []
try:
return response_wrapper(
content=serialize_response(
field=response_field, response=raw_response
),
status_code=response_code,
)
except Exception as e:
errors.append(e)
try:
response = dict(raw_response)
return response_wrapper(
content=serialize_response(field=response_field, response=response),
status_code=response_code,
)
except Exception as e:
errors.append(e)
try:
response = vars(raw_response)
return response_wrapper(
content=serialize_response(field=response_field, response=response),
status_code=response_code,
)
except Exception as e:
errors.append(e)
raise ValueError(errors)
def get_openapi_params(dependant: Dependant):
flat_dependant = get_flat_dependant(dependant)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
return app
class APIRoute(routing.Route):
def __init__(
self,
path: str,
endpoint: typing.Callable,
endpoint: Callable,
*,
methods: typing.List[str] = None,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -392,12 +137,12 @@ class APIRoute(routing.Route):
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema
self.tags = tags
self.tags = tags or []
self.summary = summary
self.description = description
self.description = description or self.endpoint.__doc__
self.operation_id = operation_id
self.deprecated = deprecated
self.request_body: typing.Union[BaseModel, Field, None] = None
self.body_field: Field = None
self.response_description = response_description
self.response_code = response_code
self.response_wrapper = response_wrapper
@ -430,53 +175,32 @@ class APIRoute(routing.Route):
), f"An endpoint must be a function or method"
self.dependant = get_dependant(path=path, call=self.endpoint)
# flat_dependant = get_flat_dependant(self.dependant)
# path_param_names = get_path_param_names(path)
# for path_param in path_param_names:
# assert path_param in {
# f.alias for f in flat_dependant.path_params
# }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}"
if self.dependant.body_params:
first_param = self.dependant.body_params[0]
sub_key = getattr(first_param.schema, "sub_key", None)
if len(self.dependant.body_params) == 1 and not sub_key:
self.request_body = first_param
else:
model_name = "Body_" + self.name
BodyModel = create_model(model_name)
for f in self.dependant.body_params:
BodyModel.__fields__[f.name] = f
required = any(True for f in self.dependant.body_params if f.required)
field = Field(
name="body",
type_=BodyModel,
default=None,
required=required,
model_config=BaseConfig,
class_validators=[],
alias="body",
schema=Schema(None),
)
self.request_body = field
self.app = request_response(get_app(dependant=self.dependant))
self.body_field = get_body_field(dependant=self.dependant, name=self.name)
self.app = request_response(
get_app(
dependant=self.dependant,
body_field=self.body_field,
response_code=self.response_code,
response_wrapper=self.response_wrapper,
response_field=self.response_field,
)
)
class APIRouter(routing.Router):
def add_api_route(
self,
path: str,
endpoint: typing.Callable,
methods: typing.List[str] = None,
endpoint: Callable,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -487,7 +211,7 @@ class APIRouter(routing.Router):
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -502,27 +226,27 @@ class APIRouter(routing.Router):
def api_route(
self,
path: str,
methods: typing.List[str] = None,
methods: List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
) -> Callable:
def decorator(func: Callable) -> Callable:
self.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -541,12 +265,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -556,7 +280,7 @@ class APIRouter(routing.Router):
methods=["GET"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -572,12 +296,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -587,7 +311,7 @@ class APIRouter(routing.Router):
methods=["PUT"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -603,12 +327,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -618,7 +342,7 @@ class APIRouter(routing.Router):
methods=["POST"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -634,12 +358,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -649,7 +373,7 @@ class APIRouter(routing.Router):
methods=["DELETE"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -665,12 +389,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -680,7 +404,7 @@ class APIRouter(routing.Router):
methods=["OPTIONS"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -696,12 +420,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -711,7 +435,7 @@ class APIRouter(routing.Router):
methods=["HEAD"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -727,12 +451,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -742,7 +466,7 @@ class APIRouter(routing.Router):
methods=["PATCH"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,
@ -758,12 +482,12 @@ class APIRouter(routing.Router):
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
tags: List[str] = None,
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_type: Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
@ -773,7 +497,7 @@ class APIRouter(routing.Router):
methods=["TRACE"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
tags=tags or [],
summary=summary,
description=description,
operation_id=operation_id,

9
fastapi/security/api_key.py

@ -1,11 +1,10 @@
from starlette.requests import Request
from enum import Enum
from pydantic import Schema
from enum import Enum
from .base import SecurityBase, Types
__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"]
from starlette.requests import Request
from .base import SecurityBase, Types
class APIKeyIn(Enum):
query = "query"
@ -21,7 +20,7 @@ class APIKeyBase(SecurityBase):
class APIKeyQuery(APIKeyBase):
in_ = Schema(APIKeyIn.query, alias="in")
async def __call__(self, requests: Request):
return requests.query_params.get(self.name)

3
fastapi/security/base.py

@ -1,7 +1,6 @@
from enum import Enum
from pydantic import BaseModel, Schema
__all__ = ["Types", "SecurityBase"]
from pydantic import BaseModel, Schema
class Types(Enum):

5
fastapi/security/http.py

@ -1,9 +1,8 @@
from pydantic import Schema
from starlette.requests import Request
from pydantic import Schema
from .base import SecurityBase, Types
__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
from .base import SecurityBase, Types
class HTTPBase(SecurityBase):

4
fastapi/security/oauth2.py

@ -3,10 +3,8 @@ from typing import Dict
from pydantic import BaseModel, Schema
from starlette.requests import Request
from .base import SecurityBase, Types
# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
from .base import SecurityBase, Types
class OAuthFlow(BaseModel):
refreshUrl: str = None

1
fastapi/security/open_id_connect_url.py

@ -2,6 +2,7 @@ from starlette.requests import Request
from .base import SecurityBase, Types
class OpenIdConnect(SecurityBase):
type_ = Types.openIdConnect
openIdConnectUrl: str

46
fastapi/utils.py

@ -0,0 +1,46 @@
import re
from typing import Dict, Sequence, Set, Type
from starlette.routing import BaseRoute
from fastapi import routing
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseModel
from pydantic.fields import Field
from pydantic.schema import get_flat_models_from_fields, model_process_schema
def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
body_fields_from_routes = []
responses_from_routes = []
for route in routes:
if route.include_in_schema and isinstance(route, routing.APIRoute):
if route.body_field:
assert isinstance(
route.body_field, Field
), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
flat_models = get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes
)
return flat_models
def get_model_definitions(
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
):
definitions: Dict[str, Dict] = {}
for model in flat_models:
m_schema, m_definitions = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
return definitions
def get_path_param_names(path: str):
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
Loading…
Cancel
Save