diff --git a/fastapi/applications.py b/fastapi/applications.py
index 3f5a45b73..bb21076df 100644
--- a/fastapi/applications.py
+++ b/fastapi/applications.py
@@ -8,7 +8,8 @@ from starlette.responses import JSONResponse
from fastapi import routing
-from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
+from fastapi.openapi.utils import get_openapi
+from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
async def http_exception(request, exc: HTTPException):
@@ -154,8 +155,10 @@ class FastAPI(Starlette):
response_wrapper=response_wrapper,
)
return func
-
return decorator
+
+ def include_router(self, router: "APIRouter", *, prefix=""):
+ self.router.include_router(router, prefix=prefix)
def get(
self,
diff --git a/fastapi/openapi/docs.py b/fastapi/openapi/docs.py
new file mode 100644
index 000000000..c8a1d6178
--- /dev/null
+++ b/fastapi/openapi/docs.py
@@ -0,0 +1,78 @@
+from starlette.responses import HTMLResponse
+
+def get_swagger_ui_html(*, openapi_url: str, title: str):
+ return HTMLResponse(
+ """
+
+
+
+
+
+ """
+ + title
+ + """
+
+
+
+
+
+
+
+
+
+
+ """,
+ media_type="text/html",
+ )
+
+
+def get_redoc_html(*, openapi_url: str, title: str):
+ return HTMLResponse(
+ """
+
+
+
+
+ """
+ + title
+ + """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ media_type="text/html",
+ )
diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py
index 3cf800740..7dbeece73 100644
--- a/fastapi/openapi/utils.py
+++ b/fastapi/openapi/utils.py
@@ -1,4 +1,8 @@
-from typing import Any, Dict, Sequence, Type
+from typing import Any, Dict, Sequence, Type, List
+
+from pydantic.fields import Field
+from pydantic.schema import field_schema, get_model_name_map
+from pydantic.utils import lenient_issubclass
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import BaseRoute
@@ -12,9 +16,7 @@ from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
-from pydantic.fields import Field
-from pydantic.schema import field_schema, get_model_name_map
-from pydantic.utils import lenient_issubclass
+
validation_error_definition = {
"title": "ValidationError",
@@ -49,91 +51,126 @@ def get_openapi_params(dependant: Dependant):
+ flat_dependant.cookie_params
)
+
+def get_openapi_security_definitions(flat_dependant: Dependant):
+ security_definitions = {}
+ operation_security = []
+ for security_requirement in flat_dependant.security_requirements:
+ security_definition = jsonable_encoder(
+ security_requirement.security_scheme.model,
+ by_alias=True,
+ include_none=False,
+ )
+ security_name = (
+ security_requirement.security_scheme.scheme_name
+
+ )
+ security_definitions[security_name] = security_definition
+ operation_security.append({security_name: security_requirement.scopes})
+ return security_definitions, operation_security
+
+
+def get_openapi_operation_parameters(all_route_params: List[Field]):
+ definitions: Dict[str, Dict] = {}
+ parameters = []
+ for param in all_route_params:
+ if "ValidationError" not in definitions:
+ definitions["ValidationError"] = validation_error_definition
+ definitions["HTTPValidationError"] = validation_error_response_definition
+ parameter = {
+ "name": param.alias,
+ "in": param.schema.in_.value,
+ "required": param.required,
+ "schema": field_schema(param, model_name_map={})[0],
+ }
+ if param.schema.description:
+ parameter["description"] = param.schema.description
+ if param.schema.deprecated:
+ parameter["deprecated"] = param.schema.deprecated
+ parameters.append(parameter)
+ return definitions, parameters
+
+
+def get_openapi_operation_request_body(
+ *, body_field: Field, model_name_map: Dict[Type, str]
+):
+ if not body_field:
+ return None
+ assert isinstance(body_field, Field)
+ body_schema, _ = field_schema(
+ body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
+ )
+ if isinstance(body_field.schema, Body):
+ request_media_type = body_field.schema.media_type
+ else:
+ # Includes not declared media types (Schema)
+ request_media_type = "application/json"
+ required = body_field.required
+ request_body_oai = {}
+ if required:
+ request_body_oai["required"] = required
+ request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
+ return request_body_oai
+
+
+def generate_operation_id(*, route: routing.APIRoute, method: str):
+ if route.operation_id:
+ return route.operation_id
+ path: str = route.path
+ operation_id = route.name + path
+ operation_id = operation_id.replace("{", "_").replace("}", "_").replace("/", "_")
+ operation_id = operation_id + "_" + method.lower()
+ return operation_id
+
+
+def generate_operation_summary(*, route: routing.APIRoute, method: str):
+ if route.summary:
+ return route.summary
+ return method.title() + " " + route.name.replace("_", " ").title()
+
+def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
+ operation: Dict[str, Any] = {}
+ if route.tags:
+ operation["tags"] = route.tags
+ operation["summary"] = generate_operation_summary(route=route, method=method)
+ if route.description:
+ operation["description"] = route.description
+ operation["operationId"] = generate_operation_id(route=route, method=method)
+ if route.deprecated:
+ operation["deprecated"] = route.deprecated
+ return operation
+
+
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
return None
path = {}
- security_schemes = {}
- definitions = {}
+ security_schemes: Dict[str, Any] = {}
+ definitions: Dict[str, Any] = {}
for method in route.methods:
- operation: Dict[str, Any] = {}
- if route.tags:
- operation["tags"] = route.tags
- if route.summary:
- operation["summary"] = route.summary
- if route.description:
- operation["description"] = route.description
- if route.operation_id:
- operation["operationId"] = route.operation_id
- else:
- operation["operationId"] = route.name
- if route.deprecated:
- operation["deprecated"] = route.deprecated
- parameters = []
+ operation = get_openapi_operation_metadata(route=route, method=method)
+ parameters: List[Dict] = []
flat_dependant = get_flat_dependant(route.dependant)
- security_definitions = {}
- for security_requirement in flat_dependant.security_requirements:
- security_definition = jsonable_encoder(
- security_requirement.security_scheme,
- exclude={"scheme_name"},
- by_alias=True,
- include_none=False,
- )
- security_name = (
- getattr(
- security_requirement.security_scheme, "scheme_name", None
- )
- or security_requirement.security_scheme.__class__.__name__
- )
- security_definitions[security_name] = security_definition
- operation.setdefault("security", []).append(
- {security_name: security_requirement.scopes}
- )
+ security_definitions, operation_security = get_openapi_security_definitions(
+ flat_dependant=flat_dependant
+ )
+ if operation_security:
+ operation.setdefault("security", []).extend(operation_security)
if security_definitions:
- security_schemes.update(
- security_definitions
- )
+ security_schemes.update(security_definitions)
all_route_params = get_openapi_params(route.dependant)
- for param in all_route_params:
- if "ValidationError" not in definitions:
- definitions["ValidationError"] = validation_error_definition
- definitions[
- "HTTPValidationError"
- ] = validation_error_response_definition
- parameter = {
- "name": param.alias,
- "in": param.schema.in_.value,
- "required": param.required,
- "schema": field_schema(param, model_name_map={})[0],
- }
- if param.schema.description:
- parameter["description"] = param.schema.description
- if param.schema.deprecated:
- parameter["deprecated"] = param.schema.deprecated
- parameters.append(parameter)
+ validation_definitions, operation_parameters = get_openapi_operation_parameters(
+ all_route_params=all_route_params
+ )
+ definitions.update(validation_definitions)
+ parameters.extend(operation_parameters)
if parameters:
operation["parameters"] = parameters
if method in METHODS_WITH_BODY:
- body_field = route.body_field
- if body_field:
- assert isinstance(body_field, Field)
- body_schema, _ = field_schema(
- body_field,
- model_name_map=model_name_map,
- ref_prefix=REF_PREFIX,
- )
- if isinstance(body_field.schema, Body):
- request_media_type = body_field.schema.media_type
- else:
- # Includes not declared media types (Schema)
- request_media_type = "application/json"
- required = body_field.required
- request_body_oai = {}
- if required:
- request_body_oai["required"] = required
- request_body_oai["content"] = {
- request_media_type: {"schema": body_schema}
- }
+ request_body_oai = get_openapi_operation_request_body(
+ body_field=route.body_field, model_name_map=model_name_map
+ )
+ if request_body_oai:
operation["requestBody"] = request_body_oai
response_code = str(route.response_code)
response_schema = {"type": "string"}
@@ -206,75 +243,3 @@ def get_openapi(
output["components"] = components
output["paths"] = paths
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
-
-
-def get_swagger_ui_html(*, openapi_url: str, title: str):
- return HTMLResponse(
- """
-
-
-
-
-
- """ + title + """
-
-
-
-
-
-
-
-
-
-
- """,
- media_type="text/html",
- )
-
-
-def get_redoc_html(*, openapi_url: str, title: str):
- return HTMLResponse(
- """
-
-
-
-
- """ + title + """
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- """,
- media_type="text/html",
- )
diff --git a/fastapi/routing.py b/fastapi/routing.py
index 6f7d592e5..22a62a53a 100644
--- a/fastapi/routing.py
+++ b/fastapi/routing.py
@@ -2,6 +2,11 @@ import asyncio
import inspect
from typing import Callable, List, Type
+from pydantic import BaseConfig, BaseModel, Schema
+from pydantic.error_wrappers import ErrorWrapper, ValidationError
+from pydantic.fields import Field
+from pydantic.utils import lenient_issubclass
+
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
@@ -15,10 +20,6 @@ from fastapi import params
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
from fastapi.encoders import jsonable_encoder
-from pydantic import BaseConfig, BaseModel, Schema
-from pydantic.error_wrappers import ErrorWrapper, ValidationError
-from pydantic.fields import Field
-from pydantic.utils import lenient_issubclass
def serialize_response(*, field: Field = None, response):
@@ -44,11 +45,12 @@ def get_app(
response_field: Type[Field] = None,
):
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
+ is_body_form = body_field and isinstance(body_field.schema, params.Form)
async def app(request: Request) -> Response:
body = None
if body_field:
- if isinstance(body_field.schema, params.Form):
+ if is_body_form:
raw_body = await request.form()
body = {}
for field, value in raw_body.items():
@@ -127,12 +129,7 @@ class APIRoute(routing.Route):
response_code=200,
response_wrapper=JSONResponse,
) -> None:
- # TODO define how to read and provide security params, and how to have them globally too
- # TODO implement dependencies and injection
- # TODO refactor code structure
- # TODO create testing
- # TODO testing coverage
- assert path.startswith("/"), "Routed paths must always start '/'"
+ assert path.startswith("/"), "Routed paths must always start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
@@ -260,6 +257,39 @@ class APIRouter(routing.Router):
return decorator
+ def include_router(self, router: "APIRouter", *, prefix=""):
+ if prefix:
+ assert prefix.startswith("/"), "A path prefix must start with '/'"
+ assert not prefix.endswith(
+ "/"
+ ), "A path prefix must not end with '/', as the routes will start with '/'"
+ for route in router.routes:
+ if isinstance(route, APIRoute):
+ self.add_api_route(
+ prefix + route.path,
+ route.endpoint,
+ methods=route.methods,
+ name=route.name,
+ include_in_schema=route.include_in_schema,
+ tags=route.tags,
+ summary=route.summary,
+ description=route.description,
+ operation_id=route.operation_id,
+ deprecated=route.deprecated,
+ response_type=route.response_type,
+ response_description=route.response_description,
+ response_code=route.response_code,
+ response_wrapper=route.response_wrapper,
+ )
+ elif isinstance(route, routing.Route):
+ self.add_route(
+ prefix + route.path,
+ route.endpoint,
+ methods=route.methods,
+ name=route.name,
+ include_in_schema=route.include_in_schema,
+ )
+
def get(
self,
path: str,
diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py
index c0354fea7..047898dfe 100644
--- a/fastapi/security/api_key.py
+++ b/fastapi/security/api_key.py
@@ -1,39 +1,34 @@
-from enum import Enum
-
-from pydantic import Schema
-
from starlette.requests import Request
-from .base import SecurityBase, Types
-
-class APIKeyIn(Enum):
- query = "query"
- header = "header"
- cookie = "cookie"
-
+from .base import SecurityBase
+from fastapi.openapi.models import APIKeyIn, APIKey
class APIKeyBase(SecurityBase):
- type_ = Schema(Types.apiKey, alias="type")
- in_: str = Schema(..., alias="in")
- name: str
-
+ pass
class APIKeyQuery(APIKeyBase):
- in_ = Schema(APIKeyIn.query, alias="in")
+
+ def __init__(self, *, name: str, scheme_name: str = None):
+ self.model = APIKey(in_=APIKeyIn.query, name=name)
+ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
- return requests.query_params.get(self.name)
+ return requests.query_params.get(self.model.name)
class APIKeyHeader(APIKeyBase):
- in_ = Schema(APIKeyIn.header, alias="in")
+ def __init__(self, *, name: str, scheme_name: str = None):
+ self.model = APIKey(in_=APIKeyIn.header, name=name)
+ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
- return requests.headers.get(self.name)
+ return requests.headers.get(self.model.name)
class APIKeyCookie(APIKeyBase):
- in_ = Schema(APIKeyIn.cookie, alias="in")
+ def __init__(self, *, name: str, scheme_name: str = None):
+ self.model = APIKey(in_=APIKeyIn.cookie, name=name)
+ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
- return requests.cookies.get(self.name)
+ return requests.cookies.get(self.model.name)
diff --git a/fastapi/security/base.py b/fastapi/security/base.py
index 9ba430df9..8589da0be 100644
--- a/fastapi/security/base.py
+++ b/fastapi/security/base.py
@@ -1,16 +1,6 @@
-from enum import Enum
+from pydantic import BaseModel
-from pydantic import BaseModel, Schema
+from fastapi.openapi.models import SecurityBase as SecurityBaseModel
-
-class Types(Enum):
- apiKey = "apiKey"
- http = "http"
- oauth2 = "oauth2"
- openIdConnect = "openIdConnect"
-
-
-class SecurityBase(BaseModel):
- scheme_name: str = None
- type_: Types = Schema(..., alias="type")
- description: str = None
+class SecurityBase:
+ pass
diff --git a/fastapi/security/http.py b/fastapi/security/http.py
index 7a8bcfe48..cee42b868 100644
--- a/fastapi/security/http.py
+++ b/fastapi/security/http.py
@@ -1,26 +1,40 @@
-from pydantic import Schema
-
from starlette.requests import Request
-from .base import SecurityBase, Types
+from .base import SecurityBase
+from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
class HTTPBase(SecurityBase):
- type_ = Schema(Types.http, alias="type")
- scheme: str
+ def __init__(self, *, scheme: str, scheme_name: str = None):
+ self.model = HTTPBaseModel(scheme=scheme)
+ self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")
class HTTPBasic(HTTPBase):
- scheme = "basic"
+ def __init__(self, *, scheme_name: str = None):
+ self.model = HTTPBaseModel(scheme="basic")
+ self.scheme_name = scheme_name or self.__class__.__name__
+
+ async def __call__(self, request: Request):
+ return request.headers.get("Authorization")
class HTTPBearer(HTTPBase):
- scheme = "bearer"
- bearerFormat: str = None
+ def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
+ self.model = HTTPBearerModel(bearerFormat=bearerFormat)
+ self.scheme_name = scheme_name or self.__class__.__name__
+
+ async def __call__(self, request: Request):
+ return request.headers.get("Authorization")
class HTTPDigest(HTTPBase):
- scheme = "digest"
+ def __init__(self, *, scheme_name: str = None):
+ self.model = HTTPBaseModel(scheme="digest")
+ self.scheme_name = scheme_name or self.__class__.__name__
+
+ async def __call__(self, request: Request):
+ return request.headers.get("Authorization")
diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py
index 4febdafc2..65517e962 100644
--- a/fastapi/security/oauth2.py
+++ b/fastapi/security/oauth2.py
@@ -1,43 +1,13 @@
-from typing import Dict
-
-from pydantic import BaseModel, Schema
-
from starlette.requests import Request
-from .base import SecurityBase, Types
-
-class OAuthFlow(BaseModel):
- refreshUrl: str = None
- scopes: Dict[str, str] = {}
-
-
-class OAuthFlowImplicit(OAuthFlow):
- authorizationUrl: str
-
-
-class OAuthFlowPassword(OAuthFlow):
- tokenUrl: str
-
-
-class OAuthFlowClientCredentials(OAuthFlow):
- tokenUrl: str
-
-
-class OAuthFlowAuthorizationCode(OAuthFlow):
- authorizationUrl: str
- tokenUrl: str
-
-
-class OAuthFlows(BaseModel):
- implicit: OAuthFlowImplicit = None
- password: OAuthFlowPassword = None
- clientCredentials: OAuthFlowClientCredentials = None
- authorizationCode: OAuthFlowAuthorizationCode = None
+from .base import SecurityBase
+from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
class OAuth2(SecurityBase):
- type_ = Schema(Types.oauth2, alias="type")
- flows: OAuthFlows
-
+ def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
+ self.model = OAuth2Model(flows=flows)
+ self.scheme_name = scheme_name or self.__class__.__name__
+
async def __call__(self, request: Request):
return request.headers.get("Authorization")
diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py
index c84c56de8..49c5aae2d 100644
--- a/fastapi/security/open_id_connect_url.py
+++ b/fastapi/security/open_id_connect_url.py
@@ -1,11 +1,13 @@
from starlette.requests import Request
-from .base import SecurityBase, Types
+from .base import SecurityBase
+from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
class OpenIdConnect(SecurityBase):
- type_ = Types.openIdConnect
- openIdConnectUrl: str
-
+ def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
+ self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
+ self.scheme_name = scheme_name or self.__class__.__name__
+
async def __call__(self, request: Request):
return request.headers.get("Authorization")