diff --git a/fastapi/__init__.py b/fastapi/__init__.py index a52bbccf6..2bb1b27c2 100644 --- a/fastapi/__init__.py +++ b/fastapi/__init__.py @@ -1,3 +1,3 @@ """Fast API framework, fast high performance, fast to learn, fast to code""" -__version__ = '0.1' +__version__ = "0.1" diff --git a/fastapi/applications.py b/fastapi/applications.py index 2e1875aa1..3f5a45b73 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1,61 +1,19 @@ -import typing -import inspect +from typing import Any, Callable, Dict, List, Type from starlette.applications import Starlette -from starlette.middleware.lifespan import LifespanMiddleware +from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.middleware.errors import ServerErrorMiddleware -from starlette.exceptions import ExceptionMiddleware -from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse -from starlette.requests import Request -from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY +from starlette.middleware.lifespan import LifespanMiddleware +from starlette.responses import JSONResponse -from pydantic import BaseModel, BaseConfig, Schema -from pydantic.utils import lenient_issubclass -from pydantic.fields import Field -from pydantic.schema import ( - field_schema, - get_flat_models_from_models, - get_flat_models_from_fields, - get_model_name_map, - schema, - model_process_schema, -) -from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant -from .pydantic_utils import jsonable_encoder +from fastapi import routing +from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html -def docs(openapi_url): - return HTMLResponse( - """ - - - - - - -
-
- - - - - - """, - media_type="text/html", - ) +async def http_exception(request, exc: HTTPException): + print(exc) + return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) class FastAPI(Starlette): @@ -67,24 +25,26 @@ class FastAPI(Starlette): description: str = "", version: str = "0.1.0", openapi_url: str = "/openapi.json", - docs_url: str = "/docs", - **extra: typing.Dict[str, typing.Any], + swagger_ui_url: str = "/docs", + redoc_url: str = "/redoc", + **extra: Dict[str, Any], ) -> None: self._debug = debug - self.router = APIRouter() + self.router = routing.APIRouter() self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.error_middleware = ServerErrorMiddleware( self.exception_middleware, debug=debug ) self.lifespan_middleware = LifespanMiddleware(self.error_middleware) - self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] + self.schema_generator = None self.template_env = self.load_template_env(template_directory) self.title = title self.description = description self.version = version self.openapi_url = openapi_url - self.docs_url = docs_url + self.swagger_ui_url = swagger_ui_url + self.redoc_url = redoc_url self.extra = extra self.openapi_version = "3.0.2" @@ -93,29 +53,52 @@ class FastAPI(Starlette): assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'" assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'" - if self.docs_url: + if self.swagger_ui_url or self.redoc_url: assert self.openapi_url, "The openapi_url is required for the docs" + self.setup() - self.add_route( - self.openapi_url, - lambda req: JSONResponse(self.openapi()), - include_in_schema=False, - ) - self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False) + def setup(self): + if self.openapi_url: + self.add_route( + self.openapi_url, + lambda req: JSONResponse( + get_openapi( + title=self.title, + version=self.version, + openapi_version=self.openapi_version, + description=self.description, + routes=self.routes, + ) + ), + include_in_schema=False, + ) + if self.swagger_ui_url: + self.add_route( + self.swagger_ui_url, + lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"), + include_in_schema=False, + ) + if self.redoc_url: + self.add_route( + self.redoc_url, + lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"), + include_in_schema=False, + ) + self.add_exception_handler(HTTPException, http_exception) def add_api_route( self, path: str, - endpoint: typing.Callable, - methods: typing.List[str] = None, + endpoint: Callable, + methods: List[str] = None, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -126,7 +109,7 @@ class FastAPI(Starlette): methods=methods, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -140,27 +123,27 @@ class FastAPI(Starlette): def api_route( self, path: str, - methods: typing.List[str] = None, + methods: List[str] = None, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, - ) -> typing.Callable: - def decorator(func: typing.Callable) -> typing.Callable: + ) -> Callable: + def decorator(func: Callable) -> Callable: self.router.add_api_route( path, func, methods=methods, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -179,12 +162,12 @@ class FastAPI(Starlette): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -193,7 +176,7 @@ class FastAPI(Starlette): path=path, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -209,12 +192,12 @@ class FastAPI(Starlette): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -223,7 +206,7 @@ class FastAPI(Starlette): path=path, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -239,12 +222,12 @@ class FastAPI(Starlette): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -253,7 +236,7 @@ class FastAPI(Starlette): path=path, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -269,12 +252,12 @@ class FastAPI(Starlette): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -283,7 +266,7 @@ class FastAPI(Starlette): path=path, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -299,12 +282,12 @@ class FastAPI(Starlette): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -313,7 +296,7 @@ class FastAPI(Starlette): path=path, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -329,12 +312,12 @@ class FastAPI(Starlette): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -343,7 +326,7 @@ class FastAPI(Starlette): path=path, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -359,12 +342,12 @@ class FastAPI(Starlette): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -373,7 +356,7 @@ class FastAPI(Starlette): path=path, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -389,12 +372,12 @@ class FastAPI(Starlette): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -403,7 +386,7 @@ class FastAPI(Starlette): path=path, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -413,169 +396,3 @@ class FastAPI(Starlette): response_code=response_code, response_wrapper=response_wrapper, ) - - def openapi(self): - info = {"title": self.title, "version": self.version} - if self.description: - info["description"] = self.description - output = {"openapi": self.openapi_version, "info": info} - components = {} - paths = {} - methods_with_body = set(("POST", "PUT")) - body_fields_from_routes = [] - responses_from_routes = [] - ref_prefix = "#/components/schemas/" - for route in self.routes: - route: APIRoute - if route.include_in_schema and isinstance(route, APIRoute): - if route.request_body: - assert isinstance( - route.request_body, Field - ), "A request body must be a Pydantic BaseModel or Field" - body_fields_from_routes.append(route.request_body) - if route.response_field: - responses_from_routes.append(route.response_field) - flat_models = get_flat_models_from_fields( - body_fields_from_routes + responses_from_routes - ) - model_name_map = get_model_name_map(flat_models) - definitions = {} - for model in flat_models: - m_schema, m_definitions = model_process_schema( - model, model_name_map=model_name_map, ref_prefix=ref_prefix - ) - definitions.update(m_definitions) - model_name = model_name_map[model] - definitions[model_name] = m_schema - validation_error_definition = { - "title": "ValidationError", - "type": "object", - "properties": { - "loc": { - "title": "Location", - "type": "array", - "items": {"type": "string"}, - }, - "msg": {"title": "Message", "type": "string"}, - "type": {"title": "Error Type", "type": "string"}, - }, - "required": ["loc", "msg", "type"], - } - validation_error_response_definition = { - "title": "HTTPValidationError", - "type": "object", - "properties": { - "detail": { - "title": "Detail", - "type": "array", - "items": {"$ref": ref_prefix + "ValidationError"}, - } - }, - } - for route in self.routes: - route: APIRoute - if route.include_in_schema and isinstance(route, APIRoute): - path = paths.get(route.path, {}) - for method in route.methods: - operation = {} - if route.tags: - operation["tags"] = route.tags - if route.summary: - operation["summary"] = route.summary - if route.description: - operation["description"] = route.description - if route.operation_id: - operation["operationId"] = route.operation_id - else: - operation["operationId"] = route.name - if route.deprecated: - operation["deprecated"] = route.deprecated - parameters = [] - flat_dependant = get_flat_dependant(route.dependant) - security_definitions = {} - for security_scheme in flat_dependant.security_schemes: - security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False) - security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__ - security_definitions[security_name] = security_definition - if security_definitions: - components.setdefault("securitySchemes", {}).update(security_definitions) - operation["security"] = [{name: []} for name in security_definitions] - all_route_params = get_openapi_params(route.dependant) - for param in all_route_params: - if "ValidationError" not in definitions: - definitions["ValidationError"] = validation_error_definition - definitions[ - "HTTPValidationError" - ] = validation_error_response_definition - parameter = { - "name": param.alias, - "in": param.schema.in_.value, - "required": param.required, - "schema": field_schema(param, model_name_map={})[0], - } - if param.schema.description: - parameter["description"] = param.schema.description - if param.schema.deprecated: - parameter["deprecated"] = param.schema.deprecated - parameters.append(parameter) - if parameters: - operation["parameters"] = parameters - if method in methods_with_body: - request_body = getattr(route, "request_body", None) - if request_body: - assert isinstance(request_body, Field) - body_schema, _ = field_schema( - request_body, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ) - required = request_body.required - request_body_oai = {} - if required: - request_body_oai["required"] = required - request_body_oai["content"] = { - "application/json": {"schema": body_schema} - } - operation["requestBody"] = request_body_oai - response_code = str(route.response_code) - response_schema = {"type": "string"} - if lenient_issubclass(route.response_wrapper, JSONResponse): - response_media_type = "application/json" - if route.response_field: - response_schema, _ = field_schema( - route.response_field, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ) - else: - response_schema = {} - elif lenient_issubclass(route.response_wrapper, HTMLResponse): - response_media_type = "text/html" - else: - response_media_type = "text/plain" - content = {response_media_type: {"schema": response_schema}} - operation["responses"] = { - response_code: { - "description": route.response_description, - "content": content, - } - } - if all_route_params or getattr(route, "request_body", None): - operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": ref_prefix + "HTTPValidationError" - } - } - }, - } - path[method.lower()] = operation - paths[route.path] = path - if definitions: - components.setdefault("schemas", {}).update(definitions) - if components: - output["components"] = components - output["paths"] = paths - return output diff --git a/fastapi/dependencies/__init__.py b/fastapi/dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py new file mode 100644 index 000000000..ad9419db5 --- /dev/null +++ b/fastapi/dependencies/models.py @@ -0,0 +1,46 @@ +from typing import Any, Callable, Dict, List, Sequence, Tuple + +from starlette.concurrency import run_in_threadpool +from starlette.requests import Request + +from fastapi.security.base import SecurityBase +from pydantic import BaseConfig, Schema +from pydantic.error_wrappers import ErrorWrapper +from pydantic.errors import MissingError +from pydantic.fields import Field, Required +from pydantic.schema import get_annotation_from_schema + +param_supported_types = (str, int, float, bool) + + +class SecurityRequirement: + def __init__(self, security_scheme: SecurityBase, scopes: Sequence[str] = None): + self.security_scheme = security_scheme + self.scopes = scopes + + +class Dependant: + def __init__( + self, + *, + path_params: List[Field] = None, + query_params: List[Field] = None, + header_params: List[Field] = None, + cookie_params: List[Field] = None, + body_params: List[Field] = None, + dependencies: List["Dependant"] = None, + security_schemes: List[SecurityRequirement] = None, + name: str = None, + call: Callable = None, + request_param_name: str = None, + ) -> None: + self.path_params = path_params or [] + self.query_params = query_params or [] + self.header_params = header_params or [] + self.cookie_params = cookie_params or [] + self.body_params = body_params or [] + self.dependencies = dependencies or [] + self.security_requirements = security_schemes or [] + self.request_param_name = request_param_name + self.name = name + self.call = call diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py new file mode 100644 index 000000000..6e86de5a5 --- /dev/null +++ b/fastapi/dependencies/utils.py @@ -0,0 +1,327 @@ +import asyncio +import inspect +from copy import deepcopy +from typing import Any, Callable, Dict, List, Tuple + +from starlette.concurrency import run_in_threadpool +from starlette.requests import Request + +from fastapi import params +from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.security.base import SecurityBase +from fastapi.utils import get_path_param_names +from pydantic import BaseConfig, Schema, create_model +from pydantic.error_wrappers import ErrorWrapper +from pydantic.errors import MissingError +from pydantic.fields import Field, Required +from pydantic.schema import get_annotation_from_schema +from pydantic.utils import lenient_issubclass + +param_supported_types = (str, int, float, bool) + + +def get_sub_dependant(*, param: inspect.Parameter, path: str): + depends: params.Depends = param.default + if depends.dependency: + dependency = depends.dependency + else: + dependency = param.annotation + assert callable(dependency) + sub_dependant = get_dependant(path=path, call=dependency, name=param.name) + if isinstance(depends, params.Security) and isinstance(dependency, SecurityBase): + security_requirement = SecurityRequirement( + security_scheme=dependency, scopes=depends.scopes + ) + sub_dependant.security_requirements.append(security_requirement) + return sub_dependant + + +def get_flat_dependant(dependant: Dependant): + flat_dependant = Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + security_schemes=dependant.security_requirements.copy(), + ) + for sub_dependant in dependant.dependencies: + if sub_dependant is dependant: + raise ValueError("recursion", dependant.dependencies) + flat_sub = get_flat_dependant(sub_dependant) + flat_dependant.path_params.extend(flat_sub.path_params) + flat_dependant.query_params.extend(flat_sub.query_params) + flat_dependant.header_params.extend(flat_sub.header_params) + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + flat_dependant.body_params.extend(flat_sub.body_params) + flat_dependant.security_requirements.extend(flat_sub.security_requirements) + return flat_dependant + + +def get_dependant(*, path: str, call: Callable, name: str = None): + path_param_names = get_path_param_names(path) + endpoint_signature = inspect.signature(call) + signature_params = endpoint_signature.parameters + dependant = Dependant(call=call, name=name) + for param_name in signature_params: + param = signature_params[param_name] + if isinstance(param.default, params.Depends): + sub_dependant = get_sub_dependant(param=param, path=path) + dependant.dependencies.append(sub_dependant) + for param_name in signature_params: + param = signature_params[param_name] + if ( + (param.default == param.empty) or isinstance(param.default, params.Path) + ) and (param_name in path_param_names): + assert lenient_issubclass( + param.annotation, param_supported_types + ) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}" + param = signature_params[param_name] + add_param_to_fields( + param=param, + dependant=dependant, + default_schema=params.Path, + force_type=params.ParamTypes.path, + ) + elif (param.default == param.empty or param.default is None) and ( + param.annotation == param.empty + or lenient_issubclass(param.annotation, param_supported_types) + ): + add_param_to_fields( + param=param, dependant=dependant, default_schema=params.Query + ) + elif isinstance(param.default, params.Param): + if param.annotation != param.empty: + assert lenient_issubclass( + param.annotation, param_supported_types + ), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}" + add_param_to_fields( + param=param, dependant=dependant, default_schema=params.Query + ) + elif lenient_issubclass(param.annotation, Request): + dependant.request_param_name = param_name + elif not isinstance(param.default, params.Depends): + add_param_to_body_fields(param=param, dependant=dependant) + return dependant + + +def add_param_to_fields( + *, + param: inspect.Parameter, + dependant: Dependant, + default_schema=params.Param, + force_type: params.ParamTypes = None, +): + default_value = Required + if not param.default == param.empty: + default_value = param.default + if isinstance(default_value, params.Param): + schema = default_value + default_value = schema.default + if schema.in_ is None: + schema.in_ = default_schema.in_ + if force_type: + schema.in_ = force_type + else: + schema = default_schema(default_value) + required = default_value == Required + annotation = Any + if not param.annotation == param.empty: + annotation = param.annotation + annotation = get_annotation_from_schema(annotation, schema) + field = Field( + name=param.name, + type_=annotation, + default=None if required else default_value, + alias=schema.alias or param.name, + required=required, + model_config=BaseConfig(), + class_validators=[], + schema=schema, + ) + if schema.in_ == params.ParamTypes.path: + dependant.path_params.append(field) + elif schema.in_ == params.ParamTypes.query: + dependant.query_params.append(field) + elif schema.in_ == params.ParamTypes.header: + dependant.header_params.append(field) + else: + assert ( + schema.in_ == params.ParamTypes.cookie + ), f"non-body parameters must be in path, query, header or cookie: {param.name}" + dependant.cookie_params.append(field) + + +def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): + default_value = Required + if not param.default == param.empty: + default_value = param.default + if isinstance(default_value, Schema): + schema = default_value + default_value = schema.default + else: + schema = Schema(default_value) + required = default_value == Required + annotation = get_annotation_from_schema(param.annotation, schema) + field = Field( + name=param.name, + type_=annotation, + default=None if required else default_value, + alias=schema.alias or param.name, + required=required, + model_config=BaseConfig, + class_validators=[], + schema=schema, + ) + dependant.body_params.append(field) + + +def is_coroutine_callable(call: Callable = None): + if not call: + return False + if inspect.isfunction(call): + return asyncio.iscoroutinefunction(call) + if inspect.isclass(call): + return False + call = getattr(call, "__call__", None) + if not call: + return False + return asyncio.iscoroutinefunction(call) + + +async def solve_dependencies( + *, request: Request, dependant: Dependant, body: Dict[str, Any] = None +): + values: Dict[str, Any] = {} + errors: List[ErrorWrapper] = [] + for sub_dependant in dependant.dependencies: + sub_values, sub_errors = await solve_dependencies( + request=request, dependant=sub_dependant, body=body + ) + if sub_errors: + return {}, errors + if sub_dependant.call and is_coroutine_callable(sub_dependant.call): + solved = await sub_dependant.call(**sub_values) + else: + solved = await run_in_threadpool(sub_dependant.call, **sub_values) + values[ + sub_dependant.name + ] = solved # type: ignore # Sub-dependants always have a name + path_values, path_errors = request_params_to_args( + dependant.path_params, request.path_params + ) + query_values, query_errors = request_params_to_args( + dependant.query_params, request.query_params + ) + header_values, header_errors = request_params_to_args( + dependant.header_params, request.headers + ) + cookie_values, cookie_errors = request_params_to_args( + dependant.cookie_params, request.cookies + ) + values.update(path_values) + values.update(query_values) + values.update(header_values) + values.update(cookie_values) + errors = path_errors + query_errors + header_errors + cookie_errors + if dependant.body_params: + body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above + dependant.body_params, body + ) + values.update(body_values) + errors.extend(body_errors) + if dependant.request_param_name: + values[dependant.request_param_name] = request + return values, errors + + +def request_params_to_args( + required_params: List[Field], received_params: Dict[str, Any] +) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: + values = {} + errors = [] + for field in required_params: + value = received_params.get(field.alias) + if value is None: + if field.required: + errors.append( + ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig) + ) + else: + values[field.name] = deepcopy(field.default) + continue + v_, errors_ = field.validate( + value, values, loc=(field.schema.in_.value, field.alias) + ) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[field.name] = v_ + return values, errors + + +async def request_body_to_args( + required_params: List[Field], received_body: Dict[str, Any] +) -> Tuple[Dict[str, Any], List[ErrorWrapper]]: + values = {} + errors = [] + if required_params: + field = required_params[0] + embed = getattr(field.schema, "embed", None) + if len(required_params) == 1 and not embed: + received_body = {field.alias: received_body} + for field in required_params: + value = received_body.get(field.alias) + if value is None: + if field.required: + errors.append( + ErrorWrapper( + MissingError(), loc=("body", field.alias), config=BaseConfig + ) + ) + else: + values[field.name] = deepcopy(field.default) + continue + v_, errors_ = field.validate(value, values, loc=("body", field.alias)) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[field.name] = v_ + return values, errors + + +def get_body_field(*, dependant: Dependant, name: str): + flat_dependant = get_flat_dependant(dependant) + if not flat_dependant.body_params: + return None + first_param = flat_dependant.body_params[0] + embed = getattr(first_param.schema, "embed", None) + if len(flat_dependant.body_params) == 1 and not embed: + return first_param + model_name = "Body_" + name + BodyModel = create_model(model_name) + for f in flat_dependant.body_params: + BodyModel.__fields__[f.name] = f + required = any(True for f in flat_dependant.body_params if f.required) + if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params): + BodySchema = params.File + elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params): + BodySchema = params.Form + else: + BodySchema = params.Body + + field = Field( + name="body", + type_=BodyModel, + default=None, + required=required, + model_config=BaseConfig, + class_validators=[], + alias="body", + schema=BodySchema(None), + ) + return field diff --git a/fastapi/pydantic_utils.py b/fastapi/encoders.py similarity index 56% rename from fastapi/pydantic_utils.py rename to fastapi/encoders.py index 8fc6589a4..95ce4479e 100644 --- a/fastapi/pydantic_utils.py +++ b/fastapi/encoders.py @@ -1,33 +1,44 @@ +from enum import Enum from types import GeneratorType from typing import Set + from pydantic import BaseModel -from enum import Enum from pydantic.json import pydantic_encoder def jsonable_encoder( - obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True, + obj, + include: Set[str] = None, + exclude: Set[str] = set(), + by_alias: bool = False, + include_none=True, ): if isinstance(obj, BaseModel): return jsonable_encoder( - obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none + obj.dict(include=include, exclude=exclude, by_alias=by_alias), + include_none=include_none, ) - elif isinstance(obj, Enum): + if isinstance(obj, Enum): return obj.value if isinstance(obj, (str, int, float, type(None))): return obj if isinstance(obj, dict): return { jsonable_encoder( - key, by_alias=by_alias, include_none=include_none, - ): jsonable_encoder( - value, by_alias=by_alias, include_none=include_none, - ) - for key, value in obj.items() if value is not None or include_none + key, by_alias=by_alias, include_none=include_none + ): jsonable_encoder(value, by_alias=by_alias, include_none=include_none) + for key, value in obj.items() + if value is not None or include_none } if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): return [ - jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none) + jsonable_encoder( + item, + include=include, + exclude=exclude, + by_alias=by_alias, + include_none=include_none, + ) for item in obj ] return pydantic_encoder(obj) diff --git a/fastapi/openapi/__init__.py b/fastapi/openapi/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastapi/openapi/constants.py b/fastapi/openapi/constants.py new file mode 100644 index 000000000..1d94a3377 --- /dev/null +++ b/fastapi/openapi/constants.py @@ -0,0 +1,2 @@ +METHODS_WITH_BODY = set(("POST", "PUT")) +REF_PREFIX = "#/components/schemas/" diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py new file mode 100644 index 000000000..e3d96bd7f --- /dev/null +++ b/fastapi/openapi/models.py @@ -0,0 +1,347 @@ +import logging +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Schema as PSchema +from pydantic.types import UrlStr + +try: + import pydantic.types.EmailStr + from pydantic.types import EmailStr +except ImportError: + logging.warning( + "email-validator not installed, email fields will be treated as str" + ) + + class EmailStr(str): + pass + + +class Contact(BaseModel): + name: Optional[str] = None + url: Optional[UrlStr] = None + email: Optional[EmailStr] = None + + +class License(BaseModel): + name: str + url: Optional[UrlStr] = None + + +class Info(BaseModel): + title: str + description: Optional[str] = None + termsOfService: Optional[str] = None + contact: Optional[Contact] = None + license: Optional[License] = None + version: str + + +class ServerVariable(BaseModel): + enum: Optional[List[str]] = None + default: str + description: Optional[str] = None + + +class Server(BaseModel): + url: UrlStr + description: Optional[str] = None + variables: Optional[Dict[str, ServerVariable]] = None + + +class Reference(BaseModel): + ref: str = PSchema(..., alias="$ref") + + +class Discriminator(BaseModel): + propertyName: str + mapping: Optional[Dict[str, str]] = None + + +class XML(BaseModel): + name: Optional[str] = None + namespace: Optional[str] = None + prefix: Optional[str] = None + attribute: Optional[bool] = None + wrapped: Optional[bool] = None + + +class ExternalDocumentation(BaseModel): + description: Optional[str] = None + url: UrlStr + + +class SchemaBase(BaseModel): + ref: Optional[str] = PSchema(None, alias="$ref") + title: Optional[str] = None + multipleOf: Optional[float] = None + maximum: Optional[float] = None + exclusiveMaximum: Optional[float] = None + minimum: Optional[float] = None + exclusiveMinimum: Optional[float] = None + maxLength: Optional[int] = PSchema(None, gte=0) + minLength: Optional[int] = PSchema(None, gte=0) + pattern: Optional[str] = None + maxItems: Optional[int] = PSchema(None, gte=0) + minItems: Optional[int] = PSchema(None, gte=0) + uniqueItems: Optional[bool] = None + maxProperties: Optional[int] = PSchema(None, gte=0) + minProperties: Optional[int] = PSchema(None, gte=0) + required: Optional[List[str]] = None + enum: Optional[List[str]] = None + type: Optional[str] = None + allOf: Optional[List[Any]] = None + oneOf: Optional[List[Any]] = None + anyOf: Optional[List[Any]] = None + not_: Optional[List[Any]] = PSchema(None, alias="not") + items: Optional[Any] = None + properties: Optional[Dict[str, Any]] = None + additionalProperties: Optional[Union[bool, Any]] = None + description: Optional[str] = None + format: Optional[str] = None + default: Optional[Any] = None + nullable: Optional[bool] = None + discriminator: Optional[Discriminator] = None + readOnly: Optional[bool] = None + writeOnly: Optional[bool] = None + xml: Optional[XML] = None + externalDocs: Optional[ExternalDocumentation] = None + example: Optional[Any] = None + deprecated: Optional[bool] = None + + +class Schema(SchemaBase): + allOf: Optional[List[SchemaBase]] = None + oneOf: Optional[List[SchemaBase]] = None + anyOf: Optional[List[SchemaBase]] = None + not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") + items: Optional[SchemaBase] = None + properties: Optional[Dict[str, SchemaBase]] = None + additionalProperties: Optional[Union[bool, SchemaBase]] = None + + +class Example(BaseModel): + summary: Optional[str] = None + description: Optional[str] = None + value: Optional[Any] = None + externalValue: Optional[UrlStr] = None + + +class ParameterInType(Enum): + query = "query" + header = "header" + path = "path" + cookie = "cookie" + + +class Encoding(BaseModel): + contentType: Optional[str] = None + # Workaround OpenAPI recursive reference, using Any + headers: Optional[Dict[str, Union[Any, Reference]]] = None + style: Optional[str] = None + explode: Optional[bool] = None + allowReserved: Optional[bool] = None + + +class MediaType(BaseModel): + schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") + example: Optional[Any] = None + examples: Optional[Dict[str, Union[Example, Reference]]] = None + encoding: Optional[Dict[str, Encoding]] = None + + +class ParameterBase(BaseModel): + description: Optional[str] = None + required: Optional[bool] = None + deprecated: Optional[bool] = None + # Serialization rules for simple scenarios + style: Optional[str] = None + explode: Optional[bool] = None + allowReserved: Optional[bool] = None + schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema") + example: Optional[Any] = None + examples: Optional[Dict[str, Union[Example, Reference]]] = None + # Serialization rules for more complex scenarios + content: Optional[Dict[str, MediaType]] = None + + +class Parameter(ParameterBase): + name: str + in_: ParameterInType = PSchema(..., alias="in") + + +class Header(ParameterBase): + pass + + +# Workaround OpenAPI recursive reference +class EncodingWithHeaders(Encoding): + headers: Optional[Dict[str, Union[Header, Reference]]] = None + + +class RequestBody(BaseModel): + description: Optional[str] = None + content: Dict[str, MediaType] + required: Optional[bool] = None + + +class Link(BaseModel): + operationRef: Optional[str] = None + operationId: Optional[str] = None + parameters: Optional[Dict[str, Union[Any, str]]] = None + requestBody: Optional[Union[Any, str]] = None + description: Optional[str] = None + server: Optional[Server] = None + + +class Response(BaseModel): + description: str + headers: Optional[Dict[str, Union[Header, Reference]]] = None + content: Optional[Dict[str, MediaType]] = None + links: Optional[Dict[str, Union[Link, Reference]]] = None + + +class Responses(BaseModel): + default: Response + + +class Operation(BaseModel): + tags: Optional[List[str]] = None + summary: Optional[str] = None + description: Optional[str] = None + externalDocs: Optional[ExternalDocumentation] = None + operationId: Optional[str] = None + parameters: Optional[List[Union[Parameter, Reference]]] = None + requestBody: Optional[Union[RequestBody, Reference]] = None + responses: Union[Responses, Dict[Union[str], Response]] + # Workaround OpenAPI recursive reference + callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None + deprecated: Optional[bool] = None + security: Optional[List[Dict[str, List[str]]]] = None + servers: Optional[List[Server]] = None + + +class PathItem(BaseModel): + ref: Optional[str] = PSchema(None, alias="$ref") + summary: Optional[str] = None + description: Optional[str] = None + get: Optional[Operation] = None + put: Optional[Operation] = None + post: Optional[Operation] = None + delete: Optional[Operation] = None + options: Optional[Operation] = None + head: Optional[Operation] = None + patch: Optional[Operation] = None + trace: Optional[Operation] = None + servers: Optional[List[Server]] = None + parameters: Optional[List[Union[Parameter, Reference]]] = None + + +# Workaround OpenAPI recursive reference +class OperationWithCallbacks(BaseModel): + callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None + + +class SecuritySchemeType(Enum): + apiKey = "apiKey" + http = "http" + oauth2 = "oauth2" + openIdConnect = "openIdConnect" + + +class SecurityBase(BaseModel): + type_: SecuritySchemeType = PSchema(..., alias="type") + description: Optional[str] = None + + +class APIKeyIn(Enum): + query = "query" + header = "header" + cookie = "cookie" + + +class APIKey(SecurityBase): + type_ = PSchema(SecuritySchemeType.apiKey, alias="type") + in_: APIKeyIn = PSchema(..., alias="in") + name: str + + +class HTTPBase(SecurityBase): + type_ = PSchema(SecuritySchemeType.http, alias="type") + scheme: str + + +class HTTPBearer(HTTPBase): + scheme = "bearer" + bearerFormat: Optional[str] = None + + +class OAuthFlow(BaseModel): + refreshUrl: Optional[str] = None + scopes: Dict[str, str] = {} + + +class OAuthFlowImplicit(OAuthFlow): + authorizationUrl: str + + +class OAuthFlowPassword(OAuthFlow): + tokenUrl: str + + +class OAuthFlowClientCredentials(OAuthFlow): + tokenUrl: str + + +class OAuthFlowAuthorizationCode(OAuthFlow): + authorizationUrl: str + tokenUrl: str + + +class OAuthFlows(BaseModel): + implicit: Optional[OAuthFlowImplicit] = None + password: Optional[OAuthFlowPassword] = None + clientCredentials: Optional[OAuthFlowClientCredentials] = None + authorizationCode: Optional[OAuthFlowAuthorizationCode] = None + + +class OAuth2(SecurityBase): + type_ = PSchema(SecuritySchemeType.oauth2, alias="type") + flows: OAuthFlows + + +class OpenIdConnect(SecurityBase): + type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") + openIdConnectUrl: str + + +SecurityScheme = Union[APIKey, HTTPBase, HTTPBearer, OAuth2, OpenIdConnect] + + +class Components(BaseModel): + schemas: Optional[Dict[str, Union[Schema, Reference]]] = None + responses: Optional[Dict[str, Union[Response, Reference]]] = None + parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None + examples: Optional[Dict[str, Union[Example, Reference]]] = None + requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None + headers: Optional[Dict[str, Union[Header, Reference]]] = None + securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None + links: Optional[Dict[str, Union[Link, Reference]]] = None + callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None + + +class Tag(BaseModel): + name: str + description: Optional[str] = None + externalDocs: Optional[ExternalDocumentation] = None + + +class OpenAPI(BaseModel): + openapi: str + info: Info + servers: Optional[List[Server]] = None + paths: Dict[str, PathItem] + components: Optional[Components] = None + security: Optional[List[Dict[str, List[str]]]] = None + tags: Optional[List[Tag]] = None + externalDocs: Optional[ExternalDocumentation] = None diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py new file mode 100644 index 000000000..3cf800740 --- /dev/null +++ b/fastapi/openapi/utils.py @@ -0,0 +1,280 @@ +from typing import Any, Dict, Sequence, Type + +from starlette.responses import HTMLResponse, JSONResponse +from starlette.routing import BaseRoute +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY + +from fastapi import routing +from fastapi.dependencies.models import Dependant +from fastapi.dependencies.utils import get_flat_dependant +from fastapi.encoders import jsonable_encoder +from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY +from fastapi.openapi.models import OpenAPI +from fastapi.params import Body +from fastapi.utils import get_flat_models_from_routes, get_model_definitions +from pydantic.fields import Field +from pydantic.schema import field_schema, get_model_name_map +from pydantic.utils import lenient_issubclass + +validation_error_definition = { + "title": "ValidationError", + "type": "object", + "properties": { + "loc": {"title": "Location", "type": "array", "items": {"type": "string"}}, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], +} + +validation_error_response_definition = { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": REF_PREFIX + "ValidationError"}, + } + }, +} + + +def get_openapi_params(dependant: Dependant): + flat_dependant = get_flat_dependant(dependant) + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) + +def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): + if not (route.include_in_schema and isinstance(route, routing.APIRoute)): + return None + path = {} + security_schemes = {} + definitions = {} + for method in route.methods: + operation: Dict[str, Any] = {} + if route.tags: + operation["tags"] = route.tags + if route.summary: + operation["summary"] = route.summary + if route.description: + operation["description"] = route.description + if route.operation_id: + operation["operationId"] = route.operation_id + else: + operation["operationId"] = route.name + if route.deprecated: + operation["deprecated"] = route.deprecated + parameters = [] + flat_dependant = get_flat_dependant(route.dependant) + security_definitions = {} + for security_requirement in flat_dependant.security_requirements: + security_definition = jsonable_encoder( + security_requirement.security_scheme, + exclude={"scheme_name"}, + by_alias=True, + include_none=False, + ) + security_name = ( + getattr( + security_requirement.security_scheme, "scheme_name", None + ) + or security_requirement.security_scheme.__class__.__name__ + ) + security_definitions[security_name] = security_definition + operation.setdefault("security", []).append( + {security_name: security_requirement.scopes} + ) + if security_definitions: + security_schemes.update( + security_definitions + ) + all_route_params = get_openapi_params(route.dependant) + for param in all_route_params: + if "ValidationError" not in definitions: + definitions["ValidationError"] = validation_error_definition + definitions[ + "HTTPValidationError" + ] = validation_error_response_definition + parameter = { + "name": param.alias, + "in": param.schema.in_.value, + "required": param.required, + "schema": field_schema(param, model_name_map={})[0], + } + if param.schema.description: + parameter["description"] = param.schema.description + if param.schema.deprecated: + parameter["deprecated"] = param.schema.deprecated + parameters.append(parameter) + if parameters: + operation["parameters"] = parameters + if method in METHODS_WITH_BODY: + body_field = route.body_field + if body_field: + assert isinstance(body_field, Field) + body_schema, _ = field_schema( + body_field, + model_name_map=model_name_map, + ref_prefix=REF_PREFIX, + ) + if isinstance(body_field.schema, Body): + request_media_type = body_field.schema.media_type + else: + # Includes not declared media types (Schema) + request_media_type = "application/json" + required = body_field.required + request_body_oai = {} + if required: + request_body_oai["required"] = required + request_body_oai["content"] = { + request_media_type: {"schema": body_schema} + } + operation["requestBody"] = request_body_oai + response_code = str(route.response_code) + response_schema = {"type": "string"} + if lenient_issubclass(route.response_wrapper, JSONResponse): + response_media_type = "application/json" + if route.response_field: + response_schema, _ = field_schema( + route.response_field, + model_name_map=model_name_map, + ref_prefix=REF_PREFIX, + ) + else: + response_schema = {} + elif lenient_issubclass(route.response_wrapper, HTMLResponse): + response_media_type = "text/html" + else: + response_media_type = "text/plain" + content = {response_media_type: {"schema": response_schema}} + operation["responses"] = { + response_code: { + "description": route.response_description, + "content": content, + } + } + if all_route_params or route.body_field: + operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { + "description": "Validation Error", + "content": { + "application/json": { + "schema": {"$ref": REF_PREFIX + "HTTPValidationError"} + } + }, + } + path[method.lower()] = operation + return path, security_schemes, definitions + + +def get_openapi( + *, + title: str, + version: str, + openapi_version: str = "3.0.2", + description: str = None, + routes: Sequence[BaseRoute] +): + info = {"title": title, "version": version} + if description: + info["description"] = description + output = {"openapi": openapi_version, "info": info} + components: Dict[str, Dict] = {} + paths: Dict[str, Dict] = {} + flat_models = get_flat_models_from_routes(routes) + model_name_map = get_model_name_map(flat_models) + definitions = get_model_definitions( + flat_models=flat_models, model_name_map=model_name_map + ) + for route in routes: + result = get_openapi_path(route=route, model_name_map=model_name_map) + if result: + path, security_schemes, path_definitions = result + if path: + paths.setdefault(route.path, {}).update(path) + if security_schemes: + components.setdefault("securitySchemes", {}).update(security_schemes) + if path_definitions: + definitions.update(path_definitions) + if definitions: + components.setdefault("schemas", {}).update(definitions) + if components: + output["components"] = components + output["paths"] = paths + return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False) + + +def get_swagger_ui_html(*, openapi_url: str, title: str): + return HTMLResponse( + """ + + + + + + """ + title + """ + + + +
+
+ + + + + + """, + media_type="text/html", + ) + + +def get_redoc_html(*, openapi_url: str, title: str): + return HTMLResponse( + """ + + + + + """ + title + """ + + + + + + + + + + + + + + + """, + media_type="text/html", + ) diff --git a/fastapi/params.py b/fastapi/params.py index 98b80943c..abbce8aeb 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -1,5 +1,6 @@ -from typing import Sequence from enum import Enum +from typing import Sequence, Any, Dict + from pydantic import Schema @@ -12,6 +13,7 @@ class ParamTypes(Enum): class Param(Schema): in_: ParamTypes + def __init__( self, default, @@ -27,7 +29,7 @@ class Param(Schema): min_length: int = None, max_length: int = None, regex: str = None, - **extra: object, + **extra: Dict[str, Any], ): self.deprecated = deprecated super().__init__( @@ -64,7 +66,7 @@ class Path(Param): min_length: int = None, max_length: int = None, regex: str = None, - **extra: object, + **extra: Dict[str, Any], ): self.description = description self.deprecated = deprecated @@ -103,7 +105,7 @@ class Query(Param): min_length: int = None, max_length: int = None, regex: str = None, - **extra: object, + **extra: Dict[str, Any], ): self.description = description self.deprecated = deprecated @@ -141,7 +143,7 @@ class Header(Param): min_length: int = None, max_length: int = None, regex: str = None, - **extra: object, + **extra: Dict[str, Any], ): self.description = description self.deprecated = deprecated @@ -179,7 +181,7 @@ class Cookie(Param): min_length: int = None, max_length: int = None, regex: str = None, - **extra: object, + **extra: Dict[str, Any], ): self.description = description self.deprecated = deprecated @@ -200,11 +202,49 @@ class Cookie(Param): class Body(Schema): + def __init__( + self, + default, + *, + embed=False, + media_type: str = "application/json", + alias: str = None, + title: str = None, + description: str = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra: Dict[str, Any], + ): + self.embed = embed + self.media_type = media_type + super().__init__( + default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + **extra, + ) + + +class Form(Body): def __init__( self, default, *, sub_key=False, + media_type: str = "application/x-www-form-urlencoded", alias: str = None, title: str = None, description: str = None, @@ -215,11 +255,49 @@ class Body(Schema): min_length: int = None, max_length: int = None, regex: str = None, - **extra: object, + **extra: Dict[str, Any], ): - self.sub_key = sub_key super().__init__( default, + embed=sub_key, + media_type=media_type, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + **extra, + ) + + +class File(Form): + def __init__( + self, + default, + *, + sub_key=False, + media_type: str = "multipart/form-data", + alias: str = None, + title: str = None, + description: str = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra: Dict[str, Any], + ): + super().__init__( + default, + embed=sub_key, + media_type=media_type, alias=alias, title=title, description=description, @@ -235,12 +313,11 @@ class Body(Schema): class Depends: - def __init__(self, dependency = None): + def __init__(self, dependency=None): self.dependency = dependency -class Security: - def __init__(self, security_scheme = None, scopes: Sequence[str] = None): - self.security_scheme = security_scheme - self.scopes = scopes - +class Security(Depends): + def __init__(self, dependency=None, scopes: Sequence[str] = None): + self.scopes = scopes or [] + super().__init__(dependency=dependency) diff --git a/fastapi/routing.py b/fastapi/routing.py index 8c95b3327..6f7d592e5 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,341 +1,66 @@ import asyncio import inspect -import re -import typing -from copy import deepcopy +from typing import Callable, List, Type from starlette import routing -from starlette.routing import get_name, request_response -from starlette.requests import Request -from starlette.responses import Response, JSONResponse from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException +from starlette.formparsers import UploadFile +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import get_name, request_response from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY - -from pydantic.fields import Field, Required -from pydantic.schema import get_annotation_from_schema -from pydantic import BaseConfig, BaseModel, create_model, Schema +from fastapi import params +from fastapi.dependencies.models import Dependant +from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies +from fastapi.encoders import jsonable_encoder +from pydantic import BaseConfig, BaseModel, Schema from pydantic.error_wrappers import ErrorWrapper, ValidationError -from pydantic.errors import MissingError +from pydantic.fields import Field from pydantic.utils import lenient_issubclass -from .pydantic_utils import jsonable_encoder - -from fastapi import params -from fastapi.security.base import SecurityBase - -param_supported_types = (str, int, float, bool) - -class Dependant: - def __init__( - self, - *, - path_params: typing.List[Field] = None, - query_params: typing.List[Field] = None, - header_params: typing.List[Field] = None, - cookie_params: typing.List[Field] = None, - body_params: typing.List[Field] = None, - dependencies: typing.List["Dependant"] = None, - security_schemes: typing.List[Field] = None, - name: str = None, - call: typing.Callable = None, - request_param_name: str = None, - ) -> None: - self.path_params: typing.List[Field] = path_params or [] - self.query_params: typing.List[Field] = query_params or [] - self.header_params: typing.List[Field] = header_params or [] - self.cookie_params: typing.List[Field] = cookie_params or [] - self.body_params: typing.List[Field] = body_params or [] - self.dependencies: typing.List[Dependant] = dependencies or [] - self.security_schemes: typing.List[Field] = security_schemes or [] - self.request_param_name = request_param_name - self.name = name - self.call: typing.Callable = call - - -def request_params_to_args( - required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any] -) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]: - values = {} - errors = [] - for field in required_params: - value = received_params.get(field.alias) - if value is None: - if field.required: - errors.append( - ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig) - ) - else: - values[field.name] = deepcopy(field.default) - continue - v_, errors_ = field.validate( - value, values, loc=(field.schema.in_.value, field.alias) - ) +def serialize_response(*, field: Field = None, response): + if field: + errors = [] + value, errors_ = field.validate(response, {}, loc=("response",)) if isinstance(errors_, ErrorWrapper): - errors_: ErrorWrapper errors.append(errors_) elif isinstance(errors_, list): errors.extend(errors_) - else: - values[field.name] = v_ - return values, errors - - -def request_body_to_args( - required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any] -) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]: - values = {} - errors = [] - if required_params: - field = required_params[0] - sub_key = getattr(field.schema, "sub_key", None) - if len(required_params) == 1 and not sub_key: - received_body = {field.alias: received_body} - for field in required_params: - value = received_body.get(field.alias) - if value is None: - if field.required: - errors.append( - ErrorWrapper( - MissingError(), loc=("body", field.alias), config=BaseConfig - ) - ) - else: - values[field.name] = deepcopy(field.default) - continue - - v_, errors_ = field.validate(value, values, loc=("body", field.alias)) - if isinstance(errors_, ErrorWrapper): - errors_: ErrorWrapper - errors.append(errors_) - elif isinstance(errors_, list): - errors.extend(errors_) - else: - values[field.name] = v_ - return values, errors - - -def add_param_to_fields( - *, - param: inspect.Parameter, - dependant: Dependant, - default_schema=params.Param, - force_type: params.ParamTypes = None, -): - default_value = Required - if not param.default == param.empty: - default_value = param.default - if isinstance(default_value, params.Param): - schema = default_value - default_value = schema.default - if schema.in_ is None: - schema.in_ = default_schema.in_ - if force_type: - schema.in_ = force_type - else: - schema = default_schema(default_value) - required = default_value == Required - annotation = typing.Any - if not param.annotation == param.empty: - annotation = param.annotation - annotation = get_annotation_from_schema(annotation, schema) - Config = BaseConfig - field = Field( - name=param.name, - type_=annotation, - default=None if required else default_value, - alias=schema.alias or param.name, - required=required, - model_config=Config, - class_validators=[], - schema=schema, - ) - if schema.in_ == params.ParamTypes.path: - dependant.path_params.append(field) - elif schema.in_ == params.ParamTypes.query: - dependant.query_params.append(field) - elif schema.in_ == params.ParamTypes.header: - dependant.header_params.append(field) - else: - assert ( - schema.in_ == params.ParamTypes.cookie - ), f"non-body parameters must be in path, query, header or cookie: {param.name}" - dependant.cookie_params.append(field) - - -def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): - default_value = Required - if not param.default == param.empty: - default_value = param.default - if isinstance(default_value, Schema): - schema = default_value - default_value = schema.default + if errors: + raise ValidationError(errors) + return jsonable_encoder(value) else: - schema = Schema(default_value) - required = default_value == Required - annotation = get_annotation_from_schema(param.annotation, schema) - field = Field( - name=param.name, - type_=annotation, - default=None if required else default_value, - alias=schema.alias or param.name, - required=required, - model_config=BaseConfig, - class_validators=[], - schema=schema, - ) - dependant.body_params.append(field) + return jsonable_encoder(response) -def get_sub_dependant( - *, param: inspect.Parameter, path: str +def get_app( + dependant: Dependant, + body_field: Field = None, + response_code: str = 200, + response_wrapper: Type[Response] = JSONResponse, + response_field: Type[Field] = None, ): - depends: params.Depends = param.default - if depends.dependency: - dependency = depends.dependency - else: - dependency = param.annotation - assert callable(dependency) - sub_dependant = get_dependant(path=path, call=dependency, name=param.name) - if isinstance(dependency, SecurityBase): - sub_dependant.security_schemes.append(dependency) - return sub_dependant - - -def get_flat_dependant(dependant: Dependant): - flat_dependant = Dependant( - path_params=dependant.path_params.copy(), - query_params=dependant.query_params.copy(), - header_params=dependant.header_params.copy(), - cookie_params=dependant.cookie_params.copy(), - body_params=dependant.body_params.copy(), - security_schemes=dependant.security_schemes.copy(), - ) - for sub_dependant in dependant.dependencies: - if sub_dependant is dependant: - raise ValueError("recursion", dependant.dependencies) - flat_sub = get_flat_dependant(sub_dependant) - flat_dependant.path_params.extend(flat_sub.path_params) - flat_dependant.query_params.extend(flat_sub.query_params) - flat_dependant.header_params.extend(flat_sub.header_params) - flat_dependant.cookie_params.extend(flat_sub.cookie_params) - flat_dependant.body_params.extend(flat_sub.body_params) - flat_dependant.security_schemes.extend(flat_sub.security_schemes) - return flat_dependant - - -def get_path_param_names(path: str): - return {item.strip("{}") for item in re.findall("{[^}]*}", path)} - - -def get_dependant(*, path: str, call: typing.Callable, name: str = None): - path_param_names = get_path_param_names(path) - endpoint_signature = inspect.signature(call) - signature_params = endpoint_signature.parameters - dependant = Dependant(call=call, name=name) - for param_name in signature_params: - param = signature_params[param_name] - if isinstance(param.default, params.Depends): - sub_dependant = get_sub_dependant(param=param, path=path) - dependant.dependencies.append(sub_dependant) - for param_name in signature_params: - param = signature_params[param_name] - if ( - (param.default == param.empty) or isinstance(param.default, params.Path) - ) and (param_name in path_param_names): - assert lenient_issubclass( - param.annotation, param_supported_types - ), f"Path params must be of type str, int, float or boot: {param}" - param = signature_params[param_name] - add_param_to_fields( - param=param, - dependant=dependant, - default_schema=params.Path, - force_type=params.ParamTypes.path, - ) - elif (param.default == param.empty or param.default is None) and ( - param.annotation == param.empty - or lenient_issubclass(param.annotation, param_supported_types) - ): - add_param_to_fields( - param=param, dependant=dependant, default_schema=params.Query - ) - elif isinstance(param.default, params.Param): - if param.annotation != param.empty: - assert lenient_issubclass( - param.annotation, param_supported_types - ), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}" - add_param_to_fields( - param=param, dependant=dependant, default_schema=params.Query - ) - elif lenient_issubclass(param.annotation, Request): - dependant.request_param_name = param_name - elif not isinstance(param.default, params.Depends): - add_param_to_body_fields(param=param, dependant=dependant) - return dependant - - -def is_coroutine_callable(call: typing.Callable): - if inspect.isfunction(call): - return asyncio.iscoroutinefunction(call) - elif inspect.isclass(call): - return False - else: - call = getattr(call, "__call__", None) - if not call: - return False - else: - return asyncio.iscoroutinefunction(call) - - -async def solve_dependencies(*, request: Request, dependant: Dependant): - values = {} - errors = [] - for sub_dependant in dependant.dependencies: - sub_values, sub_errors = await solve_dependencies( - request=request, dependant=sub_dependant - ) - if sub_errors: - return {}, errors - if is_coroutine_callable(sub_dependant.call): - solved = await sub_dependant.call(**sub_values) - else: - solved = await run_in_threadpool(sub_dependant.call, **sub_values) - values[sub_dependant.name] = solved - path_values, path_errors = request_params_to_args( - dependant.path_params, request.path_params - ) - query_values, query_errors = request_params_to_args( - dependant.query_params, request.query_params - ) - header_values, header_errors = request_params_to_args( - dependant.header_params, request.headers - ) - cookie_values, cookie_errors = request_params_to_args( - dependant.cookie_params, request.cookies - ) - values.update(path_values) - values.update(query_values) - values.update(header_values) - values.update(cookie_values) - errors = path_errors + query_errors + header_errors + cookie_errors - if dependant.body_params: - body = await request.json() - body_values, body_errors = request_body_to_args(dependant.body_params, body) - values.update(body_values) - errors.extend(body_errors) - if dependant.request_param_name: - values[dependant.request_param_name] = request - return values, errors - - -def get_app(dependant: Dependant): - is_coroutine = asyncio.iscoroutinefunction(dependant.call) + is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call) async def app(request: Request) -> Response: - values, errors = await solve_dependencies(request=request, dependant=dependant) + body = None + if body_field: + if isinstance(body_field.schema, params.Form): + raw_body = await request.form() + body = {} + for field, value in raw_body.items(): + if isinstance(value, UploadFile): + body[field] = await value.read() + else: + body[field] = value + else: + body = await request.json() + values, errors = await solve_dependencies( + request=request, dependant=dependant, body=body + ) if errors: errors_out = ValidationError(errors) raise HTTPException( @@ -348,36 +73,56 @@ def get_app(dependant: Dependant): raw_response = await run_in_threadpool(dependant.call, **values) if isinstance(raw_response, Response): return raw_response - else: - return JSONResponse(content=jsonable_encoder(raw_response)) - return app - + if isinstance(raw_response, BaseModel): + return response_wrapper( + content=jsonable_encoder(raw_response), status_code=response_code + ) + errors = [] + try: + return response_wrapper( + content=serialize_response( + field=response_field, response=raw_response + ), + status_code=response_code, + ) + except Exception as e: + errors.append(e) + try: + response = dict(raw_response) + return response_wrapper( + content=serialize_response(field=response_field, response=response), + status_code=response_code, + ) + except Exception as e: + errors.append(e) + try: + response = vars(raw_response) + return response_wrapper( + content=serialize_response(field=response_field, response=response), + status_code=response_code, + ) + except Exception as e: + errors.append(e) + raise ValueError(errors) -def get_openapi_params(dependant: Dependant): - flat_dependant = get_flat_dependant(dependant) - return ( - flat_dependant.path_params - + flat_dependant.query_params - + flat_dependant.header_params - + flat_dependant.cookie_params - ) + return app class APIRoute(routing.Route): def __init__( self, path: str, - endpoint: typing.Callable, + endpoint: Callable, *, - methods: typing.List[str] = None, + methods: List[str] = None, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -392,12 +137,12 @@ class APIRoute(routing.Route): self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name self.include_in_schema = include_in_schema - self.tags = tags + self.tags = tags or [] self.summary = summary - self.description = description + self.description = description or self.endpoint.__doc__ self.operation_id = operation_id self.deprecated = deprecated - self.request_body: typing.Union[BaseModel, Field, None] = None + self.body_field: Field = None self.response_description = response_description self.response_code = response_code self.response_wrapper = response_wrapper @@ -430,53 +175,32 @@ class APIRoute(routing.Route): ), f"An endpoint must be a function or method" self.dependant = get_dependant(path=path, call=self.endpoint) - # flat_dependant = get_flat_dependant(self.dependant) - # path_param_names = get_path_param_names(path) - # for path_param in path_param_names: - # assert path_param in { - # f.alias for f in flat_dependant.path_params - # }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}" - - if self.dependant.body_params: - first_param = self.dependant.body_params[0] - sub_key = getattr(first_param.schema, "sub_key", None) - if len(self.dependant.body_params) == 1 and not sub_key: - self.request_body = first_param - else: - model_name = "Body_" + self.name - BodyModel = create_model(model_name) - for f in self.dependant.body_params: - BodyModel.__fields__[f.name] = f - required = any(True for f in self.dependant.body_params if f.required) - field = Field( - name="body", - type_=BodyModel, - default=None, - required=required, - model_config=BaseConfig, - class_validators=[], - alias="body", - schema=Schema(None), - ) - self.request_body = field - - self.app = request_response(get_app(dependant=self.dependant)) + self.body_field = get_body_field(dependant=self.dependant, name=self.name) + self.app = request_response( + get_app( + dependant=self.dependant, + body_field=self.body_field, + response_code=self.response_code, + response_wrapper=self.response_wrapper, + response_field=self.response_field, + ) + ) class APIRouter(routing.Router): def add_api_route( self, path: str, - endpoint: typing.Callable, - methods: typing.List[str] = None, + endpoint: Callable, + methods: List[str] = None, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -487,7 +211,7 @@ class APIRouter(routing.Router): methods=methods, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -502,27 +226,27 @@ class APIRouter(routing.Router): def api_route( self, path: str, - methods: typing.List[str] = None, + methods: List[str] = None, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, - ) -> typing.Callable: - def decorator(func: typing.Callable) -> typing.Callable: + ) -> Callable: + def decorator(func: Callable) -> Callable: self.add_api_route( path, func, methods=methods, name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -541,12 +265,12 @@ class APIRouter(routing.Router): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -556,7 +280,7 @@ class APIRouter(routing.Router): methods=["GET"], name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -572,12 +296,12 @@ class APIRouter(routing.Router): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -587,7 +311,7 @@ class APIRouter(routing.Router): methods=["PUT"], name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -603,12 +327,12 @@ class APIRouter(routing.Router): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -618,7 +342,7 @@ class APIRouter(routing.Router): methods=["POST"], name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -634,12 +358,12 @@ class APIRouter(routing.Router): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -649,7 +373,7 @@ class APIRouter(routing.Router): methods=["DELETE"], name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -665,12 +389,12 @@ class APIRouter(routing.Router): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -680,7 +404,7 @@ class APIRouter(routing.Router): methods=["OPTIONS"], name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -696,12 +420,12 @@ class APIRouter(routing.Router): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -711,7 +435,7 @@ class APIRouter(routing.Router): methods=["HEAD"], name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -727,12 +451,12 @@ class APIRouter(routing.Router): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -742,7 +466,7 @@ class APIRouter(routing.Router): methods=["PATCH"], name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, @@ -758,12 +482,12 @@ class APIRouter(routing.Router): path: str, name: str = None, include_in_schema: bool = True, - tags: typing.List[str] = [], + tags: List[str] = None, summary: str = None, description: str = None, operation_id: str = None, deprecated: bool = None, - response_type: typing.Type = None, + response_type: Type = None, response_description: str = "Successful Response", response_code=200, response_wrapper=JSONResponse, @@ -773,7 +497,7 @@ class APIRouter(routing.Router): methods=["TRACE"], name=name, include_in_schema=include_in_schema, - tags=tags, + tags=tags or [], summary=summary, description=description, operation_id=operation_id, diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 4b6766eb7..c0354fea7 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -1,11 +1,10 @@ -from starlette.requests import Request +from enum import Enum from pydantic import Schema -from enum import Enum -from .base import SecurityBase, Types -__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"] +from starlette.requests import Request +from .base import SecurityBase, Types class APIKeyIn(Enum): query = "query" @@ -21,7 +20,7 @@ class APIKeyBase(SecurityBase): class APIKeyQuery(APIKeyBase): in_ = Schema(APIKeyIn.query, alias="in") - + async def __call__(self, requests: Request): return requests.query_params.get(self.name) diff --git a/fastapi/security/base.py b/fastapi/security/base.py index 37433ff25..9ba430df9 100644 --- a/fastapi/security/base.py +++ b/fastapi/security/base.py @@ -1,7 +1,6 @@ from enum import Enum -from pydantic import BaseModel, Schema -__all__ = ["Types", "SecurityBase"] +from pydantic import BaseModel, Schema class Types(Enum): diff --git a/fastapi/security/http.py b/fastapi/security/http.py index aaaf86618..7a8bcfe48 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -1,9 +1,8 @@ +from pydantic import Schema from starlette.requests import Request -from pydantic import Schema -from .base import SecurityBase, Types -__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"] +from .base import SecurityBase, Types class HTTPBase(SecurityBase): diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index a6607ef52..4febdafc2 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -3,10 +3,8 @@ from typing import Dict from pydantic import BaseModel, Schema from starlette.requests import Request -from .base import SecurityBase, Types - -# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"] +from .base import SecurityBase, Types class OAuthFlow(BaseModel): refreshUrl: str = None diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index 2e7791a7a..c84c56de8 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -2,6 +2,7 @@ from starlette.requests import Request from .base import SecurityBase, Types + class OpenIdConnect(SecurityBase): type_ = Types.openIdConnect openIdConnectUrl: str diff --git a/fastapi/utils.py b/fastapi/utils.py new file mode 100644 index 000000000..091f868fe --- /dev/null +++ b/fastapi/utils.py @@ -0,0 +1,46 @@ +import re +from typing import Dict, Sequence, Set, Type + +from starlette.routing import BaseRoute + +from fastapi import routing +from fastapi.openapi.constants import REF_PREFIX +from pydantic import BaseModel +from pydantic.fields import Field +from pydantic.schema import get_flat_models_from_fields, model_process_schema + + +def get_flat_models_from_routes(routes: Sequence[BaseRoute]): + body_fields_from_routes = [] + responses_from_routes = [] + for route in routes: + if route.include_in_schema and isinstance(route, routing.APIRoute): + if route.body_field: + assert isinstance( + route.body_field, Field + ), "A request body must be a Pydantic Field" + body_fields_from_routes.append(route.body_field) + if route.response_field: + responses_from_routes.append(route.response_field) + flat_models = get_flat_models_from_fields( + body_fields_from_routes + responses_from_routes + ) + return flat_models + + +def get_model_definitions( + *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str] +): + definitions: Dict[str, Dict] = {} + for model in flat_models: + m_schema, m_definitions = model_process_schema( + model, model_name_map=model_name_map, ref_prefix=REF_PREFIX + ) + definitions.update(m_definitions) + model_name = model_name_map[model] + definitions[model_name] = m_schema + return definitions + + +def get_path_param_names(path: str): + return {item.strip("{}") for item in re.findall("{[^}]*}", path)}