diff --git a/fastapi/applications.py b/fastapi/applications.py index 3f5a45b73..bb21076df 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -8,7 +8,8 @@ from starlette.responses import JSONResponse from fastapi import routing -from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html +from fastapi.openapi.utils import get_openapi +from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html async def http_exception(request, exc: HTTPException): @@ -154,8 +155,10 @@ class FastAPI(Starlette): response_wrapper=response_wrapper, ) return func - return decorator + + def include_router(self, router: "APIRouter", *, prefix=""): + self.router.include_router(router, prefix=prefix) def get( self, diff --git a/fastapi/openapi/docs.py b/fastapi/openapi/docs.py new file mode 100644 index 000000000..c8a1d6178 --- /dev/null +++ b/fastapi/openapi/docs.py @@ -0,0 +1,78 @@ +from starlette.responses import HTMLResponse + +def get_swagger_ui_html(*, openapi_url: str, title: str): + return HTMLResponse( + """ + + + + + + """ + + title + + """ + + + +
+
+ + + + + + """, + media_type="text/html", + ) + + +def get_redoc_html(*, openapi_url: str, title: str): + return HTMLResponse( + """ + + + + + """ + + title + + """ + + + + + + + + + + + + + + + """, + media_type="text/html", + ) diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 3cf800740..7dbeece73 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -1,4 +1,8 @@ -from typing import Any, Dict, Sequence, Type +from typing import Any, Dict, Sequence, Type, List + +from pydantic.fields import Field +from pydantic.schema import field_schema, get_model_name_map +from pydantic.utils import lenient_issubclass from starlette.responses import HTMLResponse, JSONResponse from starlette.routing import BaseRoute @@ -12,9 +16,7 @@ from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY from fastapi.openapi.models import OpenAPI from fastapi.params import Body from fastapi.utils import get_flat_models_from_routes, get_model_definitions -from pydantic.fields import Field -from pydantic.schema import field_schema, get_model_name_map -from pydantic.utils import lenient_issubclass + validation_error_definition = { "title": "ValidationError", @@ -49,91 +51,126 @@ def get_openapi_params(dependant: Dependant): + flat_dependant.cookie_params ) + +def get_openapi_security_definitions(flat_dependant: Dependant): + security_definitions = {} + operation_security = [] + for security_requirement in flat_dependant.security_requirements: + security_definition = jsonable_encoder( + security_requirement.security_scheme.model, + by_alias=True, + include_none=False, + ) + security_name = ( + security_requirement.security_scheme.scheme_name + + ) + security_definitions[security_name] = security_definition + operation_security.append({security_name: security_requirement.scopes}) + return security_definitions, operation_security + + +def get_openapi_operation_parameters(all_route_params: List[Field]): + definitions: Dict[str, Dict] = {} + parameters = [] + for param in all_route_params: + if "ValidationError" not in definitions: + definitions["ValidationError"] = validation_error_definition + definitions["HTTPValidationError"] = validation_error_response_definition + parameter = { + "name": param.alias, + "in": param.schema.in_.value, + "required": param.required, + "schema": field_schema(param, model_name_map={})[0], + } + if param.schema.description: + parameter["description"] = param.schema.description + if param.schema.deprecated: + parameter["deprecated"] = param.schema.deprecated + parameters.append(parameter) + return definitions, parameters + + +def get_openapi_operation_request_body( + *, body_field: Field, model_name_map: Dict[Type, str] +): + if not body_field: + return None + assert isinstance(body_field, Field) + body_schema, _ = field_schema( + body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX + ) + if isinstance(body_field.schema, Body): + request_media_type = body_field.schema.media_type + else: + # Includes not declared media types (Schema) + request_media_type = "application/json" + required = body_field.required + request_body_oai = {} + if required: + request_body_oai["required"] = required + request_body_oai["content"] = {request_media_type: {"schema": body_schema}} + return request_body_oai + + +def generate_operation_id(*, route: routing.APIRoute, method: str): + if route.operation_id: + return route.operation_id + path: str = route.path + operation_id = route.name + path + operation_id = operation_id.replace("{", "_").replace("}", "_").replace("/", "_") + operation_id = operation_id + "_" + method.lower() + return operation_id + + +def generate_operation_summary(*, route: routing.APIRoute, method: str): + if route.summary: + return route.summary + return method.title() + " " + route.name.replace("_", " ").title() + +def get_openapi_operation_metadata(*, route: BaseRoute, method: str): + operation: Dict[str, Any] = {} + if route.tags: + operation["tags"] = route.tags + operation["summary"] = generate_operation_summary(route=route, method=method) + if route.description: + operation["description"] = route.description + operation["operationId"] = generate_operation_id(route=route, method=method) + if route.deprecated: + operation["deprecated"] = route.deprecated + return operation + + def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): if not (route.include_in_schema and isinstance(route, routing.APIRoute)): return None path = {} - security_schemes = {} - definitions = {} + security_schemes: Dict[str, Any] = {} + definitions: Dict[str, Any] = {} for method in route.methods: - operation: Dict[str, Any] = {} - if route.tags: - operation["tags"] = route.tags - if route.summary: - operation["summary"] = route.summary - if route.description: - operation["description"] = route.description - if route.operation_id: - operation["operationId"] = route.operation_id - else: - operation["operationId"] = route.name - if route.deprecated: - operation["deprecated"] = route.deprecated - parameters = [] + operation = get_openapi_operation_metadata(route=route, method=method) + parameters: List[Dict] = [] flat_dependant = get_flat_dependant(route.dependant) - security_definitions = {} - for security_requirement in flat_dependant.security_requirements: - security_definition = jsonable_encoder( - security_requirement.security_scheme, - exclude={"scheme_name"}, - by_alias=True, - include_none=False, - ) - security_name = ( - getattr( - security_requirement.security_scheme, "scheme_name", None - ) - or security_requirement.security_scheme.__class__.__name__ - ) - security_definitions[security_name] = security_definition - operation.setdefault("security", []).append( - {security_name: security_requirement.scopes} - ) + security_definitions, operation_security = get_openapi_security_definitions( + flat_dependant=flat_dependant + ) + if operation_security: + operation.setdefault("security", []).extend(operation_security) if security_definitions: - security_schemes.update( - security_definitions - ) + security_schemes.update(security_definitions) all_route_params = get_openapi_params(route.dependant) - for param in all_route_params: - if "ValidationError" not in definitions: - definitions["ValidationError"] = validation_error_definition - definitions[ - "HTTPValidationError" - ] = validation_error_response_definition - parameter = { - "name": param.alias, - "in": param.schema.in_.value, - "required": param.required, - "schema": field_schema(param, model_name_map={})[0], - } - if param.schema.description: - parameter["description"] = param.schema.description - if param.schema.deprecated: - parameter["deprecated"] = param.schema.deprecated - parameters.append(parameter) + validation_definitions, operation_parameters = get_openapi_operation_parameters( + all_route_params=all_route_params + ) + definitions.update(validation_definitions) + parameters.extend(operation_parameters) if parameters: operation["parameters"] = parameters if method in METHODS_WITH_BODY: - body_field = route.body_field - if body_field: - assert isinstance(body_field, Field) - body_schema, _ = field_schema( - body_field, - model_name_map=model_name_map, - ref_prefix=REF_PREFIX, - ) - if isinstance(body_field.schema, Body): - request_media_type = body_field.schema.media_type - else: - # Includes not declared media types (Schema) - request_media_type = "application/json" - required = body_field.required - request_body_oai = {} - if required: - request_body_oai["required"] = required - request_body_oai["content"] = { - request_media_type: {"schema": body_schema} - } + request_body_oai = get_openapi_operation_request_body( + body_field=route.body_field, model_name_map=model_name_map + ) + if request_body_oai: operation["requestBody"] = request_body_oai response_code = str(route.response_code) response_schema = {"type": "string"} @@ -206,75 +243,3 @@ def get_openapi( output["components"] = components output["paths"] = paths return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False) - - -def get_swagger_ui_html(*, openapi_url: str, title: str): - return HTMLResponse( - """ - - - - - - """ + title + """ - - - -
-
- - - - - - """, - media_type="text/html", - ) - - -def get_redoc_html(*, openapi_url: str, title: str): - return HTMLResponse( - """ - - - - - """ + title + """ - - - - - - - - - - - - - - - """, - media_type="text/html", - ) diff --git a/fastapi/routing.py b/fastapi/routing.py index 6f7d592e5..22a62a53a 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -2,6 +2,11 @@ import asyncio import inspect from typing import Callable, List, Type +from pydantic import BaseConfig, BaseModel, Schema +from pydantic.error_wrappers import ErrorWrapper, ValidationError +from pydantic.fields import Field +from pydantic.utils import lenient_issubclass + from starlette import routing from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException @@ -15,10 +20,6 @@ from fastapi import params from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies from fastapi.encoders import jsonable_encoder -from pydantic import BaseConfig, BaseModel, Schema -from pydantic.error_wrappers import ErrorWrapper, ValidationError -from pydantic.fields import Field -from pydantic.utils import lenient_issubclass def serialize_response(*, field: Field = None, response): @@ -44,11 +45,12 @@ def get_app( response_field: Type[Field] = None, ): is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call) + is_body_form = body_field and isinstance(body_field.schema, params.Form) async def app(request: Request) -> Response: body = None if body_field: - if isinstance(body_field.schema, params.Form): + if is_body_form: raw_body = await request.form() body = {} for field, value in raw_body.items(): @@ -127,12 +129,7 @@ class APIRoute(routing.Route): response_code=200, response_wrapper=JSONResponse, ) -> None: - # TODO define how to read and provide security params, and how to have them globally too - # TODO implement dependencies and injection - # TODO refactor code structure - # TODO create testing - # TODO testing coverage - assert path.startswith("/"), "Routed paths must always start '/'" + assert path.startswith("/"), "Routed paths must always start with '/'" self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name @@ -260,6 +257,39 @@ class APIRouter(routing.Router): return decorator + def include_router(self, router: "APIRouter", *, prefix=""): + if prefix: + assert prefix.startswith("/"), "A path prefix must start with '/'" + assert not prefix.endswith( + "/" + ), "A path prefix must not end with '/', as the routes will start with '/'" + for route in router.routes: + if isinstance(route, APIRoute): + self.add_api_route( + prefix + route.path, + route.endpoint, + methods=route.methods, + name=route.name, + include_in_schema=route.include_in_schema, + tags=route.tags, + summary=route.summary, + description=route.description, + operation_id=route.operation_id, + deprecated=route.deprecated, + response_type=route.response_type, + response_description=route.response_description, + response_code=route.response_code, + response_wrapper=route.response_wrapper, + ) + elif isinstance(route, routing.Route): + self.add_route( + prefix + route.path, + route.endpoint, + methods=route.methods, + name=route.name, + include_in_schema=route.include_in_schema, + ) + def get( self, path: str, diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index c0354fea7..047898dfe 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -1,39 +1,34 @@ -from enum import Enum - -from pydantic import Schema - from starlette.requests import Request -from .base import SecurityBase, Types - -class APIKeyIn(Enum): - query = "query" - header = "header" - cookie = "cookie" - +from .base import SecurityBase +from fastapi.openapi.models import APIKeyIn, APIKey class APIKeyBase(SecurityBase): - type_ = Schema(Types.apiKey, alias="type") - in_: str = Schema(..., alias="in") - name: str - + pass class APIKeyQuery(APIKeyBase): - in_ = Schema(APIKeyIn.query, alias="in") + + def __init__(self, *, name: str, scheme_name: str = None): + self.model = APIKey(in_=APIKeyIn.query, name=name) + self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, requests: Request): - return requests.query_params.get(self.name) + return requests.query_params.get(self.model.name) class APIKeyHeader(APIKeyBase): - in_ = Schema(APIKeyIn.header, alias="in") + def __init__(self, *, name: str, scheme_name: str = None): + self.model = APIKey(in_=APIKeyIn.header, name=name) + self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, requests: Request): - return requests.headers.get(self.name) + return requests.headers.get(self.model.name) class APIKeyCookie(APIKeyBase): - in_ = Schema(APIKeyIn.cookie, alias="in") + def __init__(self, *, name: str, scheme_name: str = None): + self.model = APIKey(in_=APIKeyIn.cookie, name=name) + self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, requests: Request): - return requests.cookies.get(self.name) + return requests.cookies.get(self.model.name) diff --git a/fastapi/security/base.py b/fastapi/security/base.py index 9ba430df9..8589da0be 100644 --- a/fastapi/security/base.py +++ b/fastapi/security/base.py @@ -1,16 +1,6 @@ -from enum import Enum +from pydantic import BaseModel -from pydantic import BaseModel, Schema +from fastapi.openapi.models import SecurityBase as SecurityBaseModel - -class Types(Enum): - apiKey = "apiKey" - http = "http" - oauth2 = "oauth2" - openIdConnect = "openIdConnect" - - -class SecurityBase(BaseModel): - scheme_name: str = None - type_: Types = Schema(..., alias="type") - description: str = None +class SecurityBase: + pass diff --git a/fastapi/security/http.py b/fastapi/security/http.py index 7a8bcfe48..cee42b868 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -1,26 +1,40 @@ -from pydantic import Schema - from starlette.requests import Request -from .base import SecurityBase, Types +from .base import SecurityBase +from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel class HTTPBase(SecurityBase): - type_ = Schema(Types.http, alias="type") - scheme: str + def __init__(self, *, scheme: str, scheme_name: str = None): + self.model = HTTPBaseModel(scheme=scheme) + self.scheme_name = scheme_name or self.__class__.__name__ async def __call__(self, request: Request): return request.headers.get("Authorization") class HTTPBasic(HTTPBase): - scheme = "basic" + def __init__(self, *, scheme_name: str = None): + self.model = HTTPBaseModel(scheme="basic") + self.scheme_name = scheme_name or self.__class__.__name__ + + async def __call__(self, request: Request): + return request.headers.get("Authorization") class HTTPBearer(HTTPBase): - scheme = "bearer" - bearerFormat: str = None + def __init__(self, *, bearerFormat: str = None, scheme_name: str = None): + self.model = HTTPBearerModel(bearerFormat=bearerFormat) + self.scheme_name = scheme_name or self.__class__.__name__ + + async def __call__(self, request: Request): + return request.headers.get("Authorization") class HTTPDigest(HTTPBase): - scheme = "digest" + def __init__(self, *, scheme_name: str = None): + self.model = HTTPBaseModel(scheme="digest") + self.scheme_name = scheme_name or self.__class__.__name__ + + async def __call__(self, request: Request): + return request.headers.get("Authorization") diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index 4febdafc2..65517e962 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -1,43 +1,13 @@ -from typing import Dict - -from pydantic import BaseModel, Schema - from starlette.requests import Request -from .base import SecurityBase, Types - -class OAuthFlow(BaseModel): - refreshUrl: str = None - scopes: Dict[str, str] = {} - - -class OAuthFlowImplicit(OAuthFlow): - authorizationUrl: str - - -class OAuthFlowPassword(OAuthFlow): - tokenUrl: str - - -class OAuthFlowClientCredentials(OAuthFlow): - tokenUrl: str - - -class OAuthFlowAuthorizationCode(OAuthFlow): - authorizationUrl: str - tokenUrl: str - - -class OAuthFlows(BaseModel): - implicit: OAuthFlowImplicit = None - password: OAuthFlowPassword = None - clientCredentials: OAuthFlowClientCredentials = None - authorizationCode: OAuthFlowAuthorizationCode = None +from .base import SecurityBase +from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel class OAuth2(SecurityBase): - type_ = Schema(Types.oauth2, alias="type") - flows: OAuthFlows - + def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None): + self.model = OAuth2Model(flows=flows) + self.scheme_name = scheme_name or self.__class__.__name__ + async def __call__(self, request: Request): return request.headers.get("Authorization") diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index c84c56de8..49c5aae2d 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -1,11 +1,13 @@ from starlette.requests import Request -from .base import SecurityBase, Types +from .base import SecurityBase +from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel class OpenIdConnect(SecurityBase): - type_ = Types.openIdConnect - openIdConnectUrl: str - + def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None): + self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl) + self.scheme_name = scheme_name or self.__class__.__name__ + async def __call__(self, request: Request): return request.headers.get("Authorization")