commit 406c092a3bf65bbd4405ce87611a7e0b9c0ae706 Author: Sebastián Ramírez Date: Wed Dec 5 10:56:50 2018 +0400 :tada: Start tracking messy initial stage ...before refactoring and breaking something diff --git a/fastapi/__init__.py b/fastapi/__init__.py new file mode 100644 index 000000000..a52bbccf6 --- /dev/null +++ b/fastapi/__init__.py @@ -0,0 +1,3 @@ +"""Fast API framework, fast high performance, fast to learn, fast to code""" + +__version__ = '0.1' diff --git a/fastapi/applications.py b/fastapi/applications.py new file mode 100644 index 000000000..2e1875aa1 --- /dev/null +++ b/fastapi/applications.py @@ -0,0 +1,581 @@ +import typing +import inspect + +from starlette.applications import Starlette +from starlette.middleware.lifespan import LifespanMiddleware +from starlette.middleware.errors import ServerErrorMiddleware +from starlette.exceptions import ExceptionMiddleware +from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse +from starlette.requests import Request +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY + +from pydantic import BaseModel, BaseConfig, Schema +from pydantic.utils import lenient_issubclass +from pydantic.fields import Field +from pydantic.schema import ( + field_schema, + get_flat_models_from_models, + get_flat_models_from_fields, + get_model_name_map, + schema, + model_process_schema, +) + +from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant +from .pydantic_utils import jsonable_encoder + + +def docs(openapi_url): + return HTMLResponse( + """ + + + + + + +
+
+ + + + + + """, + media_type="text/html", + ) + + +class FastAPI(Starlette): + def __init__( + self, + debug: bool = False, + template_directory: str = None, + title: str = "Fast API", + description: str = "", + version: str = "0.1.0", + openapi_url: str = "/openapi.json", + docs_url: str = "/docs", + **extra: typing.Dict[str, typing.Any], + ) -> None: + self._debug = debug + self.router = APIRouter() + self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) + self.error_middleware = ServerErrorMiddleware( + self.exception_middleware, debug=debug + ) + self.lifespan_middleware = LifespanMiddleware(self.error_middleware) + self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] + self.template_env = self.load_template_env(template_directory) + + self.title = title + self.description = description + self.version = version + self.openapi_url = openapi_url + self.docs_url = docs_url + self.extra = extra + + self.openapi_version = "3.0.2" + + if self.openapi_url: + assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'" + assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'" + + if self.docs_url: + assert self.openapi_url, "The openapi_url is required for the docs" + + self.add_route( + self.openapi_url, + lambda req: JSONResponse(self.openapi()), + include_in_schema=False, + ) + self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False) + + def add_api_route( + self, + path: str, + endpoint: typing.Callable, + methods: typing.List[str] = None, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ) -> None: + self.router.add_api_route( + path, + endpoint=endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def api_route( + self, + path: str, + methods: typing.List[str] = None, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: + self.router.add_api_route( + path, + func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + return func + + return decorator + + def get( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.router.get( + path=path, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def put( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.router.put( + path=path, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def post( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.router.post( + path=path, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def delete( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.router.delete( + path=path, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def options( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.router.options( + path=path, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def head( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.router.head( + path=path, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def patch( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.router.patch( + path=path, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def trace( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.router.trace( + path=path, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def openapi(self): + info = {"title": self.title, "version": self.version} + if self.description: + info["description"] = self.description + output = {"openapi": self.openapi_version, "info": info} + components = {} + paths = {} + methods_with_body = set(("POST", "PUT")) + body_fields_from_routes = [] + responses_from_routes = [] + ref_prefix = "#/components/schemas/" + for route in self.routes: + route: APIRoute + if route.include_in_schema and isinstance(route, APIRoute): + if route.request_body: + assert isinstance( + route.request_body, Field + ), "A request body must be a Pydantic BaseModel or Field" + body_fields_from_routes.append(route.request_body) + if route.response_field: + responses_from_routes.append(route.response_field) + flat_models = get_flat_models_from_fields( + body_fields_from_routes + responses_from_routes + ) + model_name_map = get_model_name_map(flat_models) + definitions = {} + for model in flat_models: + m_schema, m_definitions = model_process_schema( + model, model_name_map=model_name_map, ref_prefix=ref_prefix + ) + definitions.update(m_definitions) + model_name = model_name_map[model] + definitions[model_name] = m_schema + validation_error_definition = { + "title": "ValidationError", + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], + } + validation_error_response_definition = { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": ref_prefix + "ValidationError"}, + } + }, + } + for route in self.routes: + route: APIRoute + if route.include_in_schema and isinstance(route, APIRoute): + path = paths.get(route.path, {}) + for method in route.methods: + operation = {} + if route.tags: + operation["tags"] = route.tags + if route.summary: + operation["summary"] = route.summary + if route.description: + operation["description"] = route.description + if route.operation_id: + operation["operationId"] = route.operation_id + else: + operation["operationId"] = route.name + if route.deprecated: + operation["deprecated"] = route.deprecated + parameters = [] + flat_dependant = get_flat_dependant(route.dependant) + security_definitions = {} + for security_scheme in flat_dependant.security_schemes: + security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False) + security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__ + security_definitions[security_name] = security_definition + if security_definitions: + components.setdefault("securitySchemes", {}).update(security_definitions) + operation["security"] = [{name: []} for name in security_definitions] + all_route_params = get_openapi_params(route.dependant) + for param in all_route_params: + if "ValidationError" not in definitions: + definitions["ValidationError"] = validation_error_definition + definitions[ + "HTTPValidationError" + ] = validation_error_response_definition + parameter = { + "name": param.alias, + "in": param.schema.in_.value, + "required": param.required, + "schema": field_schema(param, model_name_map={})[0], + } + if param.schema.description: + parameter["description"] = param.schema.description + if param.schema.deprecated: + parameter["deprecated"] = param.schema.deprecated + parameters.append(parameter) + if parameters: + operation["parameters"] = parameters + if method in methods_with_body: + request_body = getattr(route, "request_body", None) + if request_body: + assert isinstance(request_body, Field) + body_schema, _ = field_schema( + request_body, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ) + required = request_body.required + request_body_oai = {} + if required: + request_body_oai["required"] = required + request_body_oai["content"] = { + "application/json": {"schema": body_schema} + } + operation["requestBody"] = request_body_oai + response_code = str(route.response_code) + response_schema = {"type": "string"} + if lenient_issubclass(route.response_wrapper, JSONResponse): + response_media_type = "application/json" + if route.response_field: + response_schema, _ = field_schema( + route.response_field, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ) + else: + response_schema = {} + elif lenient_issubclass(route.response_wrapper, HTMLResponse): + response_media_type = "text/html" + else: + response_media_type = "text/plain" + content = {response_media_type: {"schema": response_schema}} + operation["responses"] = { + response_code: { + "description": route.response_description, + "content": content, + } + } + if all_route_params or getattr(route, "request_body", None): + operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": ref_prefix + "HTTPValidationError" + } + } + }, + } + path[method.lower()] = operation + paths[route.path] = path + if definitions: + components.setdefault("schemas", {}).update(definitions) + if components: + output["components"] = components + output["paths"] = paths + return output diff --git a/fastapi/params.py b/fastapi/params.py new file mode 100644 index 000000000..98b80943c --- /dev/null +++ b/fastapi/params.py @@ -0,0 +1,246 @@ +from typing import Sequence +from enum import Enum +from pydantic import Schema + + +class ParamTypes(Enum): + query = "query" + header = "header" + path = "path" + cookie = "cookie" + + +class Param(Schema): + in_: ParamTypes + def __init__( + self, + default, + *, + deprecated: bool = None, + alias: str = None, + title: str = None, + description: str = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra: object, + ): + self.deprecated = deprecated + super().__init__( + default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + **extra, + ) + + +class Path(Param): + in_ = ParamTypes.path + + def __init__( + self, + default, + *, + deprecated: bool = None, + alias: str = None, + title: str = None, + description: str = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra: object, + ): + self.description = description + self.deprecated = deprecated + self.in_ = self.in_ + super().__init__( + default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + **extra, + ) + + +class Query(Param): + in_ = ParamTypes.query + + def __init__( + self, + default, + *, + deprecated: bool = None, + alias: str = None, + title: str = None, + description: str = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra: object, + ): + self.description = description + self.deprecated = deprecated + super().__init__( + default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + **extra, + ) + + +class Header(Param): + in_ = ParamTypes.header + + def __init__( + self, + default, + *, + deprecated: bool = None, + alias: str = None, + title: str = None, + description: str = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra: object, + ): + self.description = description + self.deprecated = deprecated + super().__init__( + default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + **extra, + ) + + +class Cookie(Param): + in_ = ParamTypes.cookie + + def __init__( + self, + default, + *, + deprecated: bool = None, + alias: str = None, + title: str = None, + description: str = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra: object, + ): + self.description = description + self.deprecated = deprecated + super().__init__( + default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + **extra, + ) + + +class Body(Schema): + def __init__( + self, + default, + *, + sub_key=False, + alias: str = None, + title: str = None, + description: str = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + min_length: int = None, + max_length: int = None, + regex: str = None, + **extra: object, + ): + self.sub_key = sub_key + super().__init__( + default, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + **extra, + ) + + +class Depends: + def __init__(self, dependency = None): + self.dependency = dependency + + +class Security: + def __init__(self, security_scheme = None, scopes: Sequence[str] = None): + self.security_scheme = security_scheme + self.scopes = scopes + diff --git a/fastapi/pydantic_utils.py b/fastapi/pydantic_utils.py new file mode 100644 index 000000000..8fc6589a4 --- /dev/null +++ b/fastapi/pydantic_utils.py @@ -0,0 +1,33 @@ +from types import GeneratorType +from typing import Set +from pydantic import BaseModel +from enum import Enum +from pydantic.json import pydantic_encoder + + +def jsonable_encoder( + obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True, +): + if isinstance(obj, BaseModel): + return jsonable_encoder( + obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none + ) + elif isinstance(obj, Enum): + return obj.value + if isinstance(obj, (str, int, float, type(None))): + return obj + if isinstance(obj, dict): + return { + jsonable_encoder( + key, by_alias=by_alias, include_none=include_none, + ): jsonable_encoder( + value, by_alias=by_alias, include_none=include_none, + ) + for key, value in obj.items() if value is not None or include_none + } + if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): + return [ + jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none) + for item in obj + ] + return pydantic_encoder(obj) diff --git a/fastapi/routing.py b/fastapi/routing.py new file mode 100644 index 000000000..8c95b3327 --- /dev/null +++ b/fastapi/routing.py @@ -0,0 +1,785 @@ +import asyncio +import inspect +import re +import typing +from copy import deepcopy + +from starlette import routing +from starlette.routing import get_name, request_response +from starlette.requests import Request +from starlette.responses import Response, JSONResponse +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY + + +from pydantic.fields import Field, Required +from pydantic.schema import get_annotation_from_schema +from pydantic import BaseConfig, BaseModel, create_model, Schema +from pydantic.error_wrappers import ErrorWrapper, ValidationError +from pydantic.errors import MissingError +from pydantic.utils import lenient_issubclass +from .pydantic_utils import jsonable_encoder + +from fastapi import params +from fastapi.security.base import SecurityBase + + +param_supported_types = (str, int, float, bool) + + +class Dependant: + def __init__( + self, + *, + path_params: typing.List[Field] = None, + query_params: typing.List[Field] = None, + header_params: typing.List[Field] = None, + cookie_params: typing.List[Field] = None, + body_params: typing.List[Field] = None, + dependencies: typing.List["Dependant"] = None, + security_schemes: typing.List[Field] = None, + name: str = None, + call: typing.Callable = None, + request_param_name: str = None, + ) -> None: + self.path_params: typing.List[Field] = path_params or [] + self.query_params: typing.List[Field] = query_params or [] + self.header_params: typing.List[Field] = header_params or [] + self.cookie_params: typing.List[Field] = cookie_params or [] + self.body_params: typing.List[Field] = body_params or [] + self.dependencies: typing.List[Dependant] = dependencies or [] + self.security_schemes: typing.List[Field] = security_schemes or [] + self.request_param_name = request_param_name + self.name = name + self.call: typing.Callable = call + + +def request_params_to_args( + required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any] +) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]: + values = {} + errors = [] + for field in required_params: + value = received_params.get(field.alias) + if value is None: + if field.required: + errors.append( + ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig) + ) + else: + values[field.name] = deepcopy(field.default) + continue + v_, errors_ = field.validate( + value, values, loc=(field.schema.in_.value, field.alias) + ) + if isinstance(errors_, ErrorWrapper): + errors_: ErrorWrapper + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[field.name] = v_ + return values, errors + + +def request_body_to_args( + required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any] +) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]: + values = {} + errors = [] + if required_params: + field = required_params[0] + sub_key = getattr(field.schema, "sub_key", None) + if len(required_params) == 1 and not sub_key: + received_body = {field.alias: received_body} + for field in required_params: + value = received_body.get(field.alias) + if value is None: + if field.required: + errors.append( + ErrorWrapper( + MissingError(), loc=("body", field.alias), config=BaseConfig + ) + ) + else: + values[field.name] = deepcopy(field.default) + continue + + v_, errors_ = field.validate(value, values, loc=("body", field.alias)) + if isinstance(errors_, ErrorWrapper): + errors_: ErrorWrapper + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[field.name] = v_ + return values, errors + + +def add_param_to_fields( + *, + param: inspect.Parameter, + dependant: Dependant, + default_schema=params.Param, + force_type: params.ParamTypes = None, +): + default_value = Required + if not param.default == param.empty: + default_value = param.default + if isinstance(default_value, params.Param): + schema = default_value + default_value = schema.default + if schema.in_ is None: + schema.in_ = default_schema.in_ + if force_type: + schema.in_ = force_type + else: + schema = default_schema(default_value) + required = default_value == Required + annotation = typing.Any + if not param.annotation == param.empty: + annotation = param.annotation + annotation = get_annotation_from_schema(annotation, schema) + Config = BaseConfig + field = Field( + name=param.name, + type_=annotation, + default=None if required else default_value, + alias=schema.alias or param.name, + required=required, + model_config=Config, + class_validators=[], + schema=schema, + ) + if schema.in_ == params.ParamTypes.path: + dependant.path_params.append(field) + elif schema.in_ == params.ParamTypes.query: + dependant.query_params.append(field) + elif schema.in_ == params.ParamTypes.header: + dependant.header_params.append(field) + else: + assert ( + schema.in_ == params.ParamTypes.cookie + ), f"non-body parameters must be in path, query, header or cookie: {param.name}" + dependant.cookie_params.append(field) + + +def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant): + default_value = Required + if not param.default == param.empty: + default_value = param.default + if isinstance(default_value, Schema): + schema = default_value + default_value = schema.default + else: + schema = Schema(default_value) + required = default_value == Required + annotation = get_annotation_from_schema(param.annotation, schema) + field = Field( + name=param.name, + type_=annotation, + default=None if required else default_value, + alias=schema.alias or param.name, + required=required, + model_config=BaseConfig, + class_validators=[], + schema=schema, + ) + dependant.body_params.append(field) + + +def get_sub_dependant( + *, param: inspect.Parameter, path: str +): + depends: params.Depends = param.default + if depends.dependency: + dependency = depends.dependency + else: + dependency = param.annotation + assert callable(dependency) + sub_dependant = get_dependant(path=path, call=dependency, name=param.name) + if isinstance(dependency, SecurityBase): + sub_dependant.security_schemes.append(dependency) + return sub_dependant + + +def get_flat_dependant(dependant: Dependant): + flat_dependant = Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + security_schemes=dependant.security_schemes.copy(), + ) + for sub_dependant in dependant.dependencies: + if sub_dependant is dependant: + raise ValueError("recursion", dependant.dependencies) + flat_sub = get_flat_dependant(sub_dependant) + flat_dependant.path_params.extend(flat_sub.path_params) + flat_dependant.query_params.extend(flat_sub.query_params) + flat_dependant.header_params.extend(flat_sub.header_params) + flat_dependant.cookie_params.extend(flat_sub.cookie_params) + flat_dependant.body_params.extend(flat_sub.body_params) + flat_dependant.security_schemes.extend(flat_sub.security_schemes) + return flat_dependant + + +def get_path_param_names(path: str): + return {item.strip("{}") for item in re.findall("{[^}]*}", path)} + + +def get_dependant(*, path: str, call: typing.Callable, name: str = None): + path_param_names = get_path_param_names(path) + endpoint_signature = inspect.signature(call) + signature_params = endpoint_signature.parameters + dependant = Dependant(call=call, name=name) + for param_name in signature_params: + param = signature_params[param_name] + if isinstance(param.default, params.Depends): + sub_dependant = get_sub_dependant(param=param, path=path) + dependant.dependencies.append(sub_dependant) + for param_name in signature_params: + param = signature_params[param_name] + if ( + (param.default == param.empty) or isinstance(param.default, params.Path) + ) and (param_name in path_param_names): + assert lenient_issubclass( + param.annotation, param_supported_types + ), f"Path params must be of type str, int, float or boot: {param}" + param = signature_params[param_name] + add_param_to_fields( + param=param, + dependant=dependant, + default_schema=params.Path, + force_type=params.ParamTypes.path, + ) + elif (param.default == param.empty or param.default is None) and ( + param.annotation == param.empty + or lenient_issubclass(param.annotation, param_supported_types) + ): + add_param_to_fields( + param=param, dependant=dependant, default_schema=params.Query + ) + elif isinstance(param.default, params.Param): + if param.annotation != param.empty: + assert lenient_issubclass( + param.annotation, param_supported_types + ), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}" + add_param_to_fields( + param=param, dependant=dependant, default_schema=params.Query + ) + elif lenient_issubclass(param.annotation, Request): + dependant.request_param_name = param_name + elif not isinstance(param.default, params.Depends): + add_param_to_body_fields(param=param, dependant=dependant) + return dependant + + +def is_coroutine_callable(call: typing.Callable): + if inspect.isfunction(call): + return asyncio.iscoroutinefunction(call) + elif inspect.isclass(call): + return False + else: + call = getattr(call, "__call__", None) + if not call: + return False + else: + return asyncio.iscoroutinefunction(call) + + +async def solve_dependencies(*, request: Request, dependant: Dependant): + values = {} + errors = [] + for sub_dependant in dependant.dependencies: + sub_values, sub_errors = await solve_dependencies( + request=request, dependant=sub_dependant + ) + if sub_errors: + return {}, errors + if is_coroutine_callable(sub_dependant.call): + solved = await sub_dependant.call(**sub_values) + else: + solved = await run_in_threadpool(sub_dependant.call, **sub_values) + values[sub_dependant.name] = solved + path_values, path_errors = request_params_to_args( + dependant.path_params, request.path_params + ) + query_values, query_errors = request_params_to_args( + dependant.query_params, request.query_params + ) + header_values, header_errors = request_params_to_args( + dependant.header_params, request.headers + ) + cookie_values, cookie_errors = request_params_to_args( + dependant.cookie_params, request.cookies + ) + values.update(path_values) + values.update(query_values) + values.update(header_values) + values.update(cookie_values) + errors = path_errors + query_errors + header_errors + cookie_errors + if dependant.body_params: + body = await request.json() + body_values, body_errors = request_body_to_args(dependant.body_params, body) + values.update(body_values) + errors.extend(body_errors) + if dependant.request_param_name: + values[dependant.request_param_name] = request + return values, errors + + +def get_app(dependant: Dependant): + is_coroutine = asyncio.iscoroutinefunction(dependant.call) + + async def app(request: Request) -> Response: + values, errors = await solve_dependencies(request=request, dependant=dependant) + if errors: + errors_out = ValidationError(errors) + raise HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors() + ) + else: + if is_coroutine: + raw_response = await dependant.call(**values) + else: + raw_response = await run_in_threadpool(dependant.call, **values) + if isinstance(raw_response, Response): + return raw_response + else: + return JSONResponse(content=jsonable_encoder(raw_response)) + return app + + +def get_openapi_params(dependant: Dependant): + flat_dependant = get_flat_dependant(dependant) + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) + + +class APIRoute(routing.Route): + def __init__( + self, + path: str, + endpoint: typing.Callable, + *, + methods: typing.List[str] = None, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ) -> None: + # TODO define how to read and provide security params, and how to have them globally too + # TODO implement dependencies and injection + # TODO refactor code structure + # TODO create testing + # TODO testing coverage + assert path.startswith("/"), "Routed paths must always start '/'" + self.path = path + self.endpoint = endpoint + self.name = get_name(endpoint) if name is None else name + self.include_in_schema = include_in_schema + self.tags = tags + self.summary = summary + self.description = description + self.operation_id = operation_id + self.deprecated = deprecated + self.request_body: typing.Union[BaseModel, Field, None] = None + self.response_description = response_description + self.response_code = response_code + self.response_wrapper = response_wrapper + self.response_field = None + if response_type: + assert lenient_issubclass( + response_wrapper, JSONResponse + ), "To declare a type the response must be a JSON response" + self.response_type = response_type + response_name = "Response_" + self.name + self.response_field = Field( + name=response_name, + type_=self.response_type, + class_validators=[], + default=None, + required=False, + model_config=BaseConfig(), + schema=Schema(None), + ) + else: + self.response_type = None + if methods is None: + methods = ["GET"] + self.methods = methods + self.path_regex, self.path_format, self.param_convertors = self.compile_path( + path + ) + assert inspect.isfunction(endpoint) or inspect.ismethod( + endpoint + ), f"An endpoint must be a function or method" + + self.dependant = get_dependant(path=path, call=self.endpoint) + # flat_dependant = get_flat_dependant(self.dependant) + # path_param_names = get_path_param_names(path) + # for path_param in path_param_names: + # assert path_param in { + # f.alias for f in flat_dependant.path_params + # }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}" + + if self.dependant.body_params: + first_param = self.dependant.body_params[0] + sub_key = getattr(first_param.schema, "sub_key", None) + if len(self.dependant.body_params) == 1 and not sub_key: + self.request_body = first_param + else: + model_name = "Body_" + self.name + BodyModel = create_model(model_name) + for f in self.dependant.body_params: + BodyModel.__fields__[f.name] = f + required = any(True for f in self.dependant.body_params if f.required) + field = Field( + name="body", + type_=BodyModel, + default=None, + required=required, + model_config=BaseConfig, + class_validators=[], + alias="body", + schema=Schema(None), + ) + self.request_body = field + + self.app = request_response(get_app(dependant=self.dependant)) + + +class APIRouter(routing.Router): + def add_api_route( + self, + path: str, + endpoint: typing.Callable, + methods: typing.List[str] = None, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ) -> None: + route = APIRoute( + path, + endpoint=endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + self.routes.append(route) + + def api_route( + self, + path: str, + methods: typing.List[str] = None, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: + self.add_api_route( + path, + func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + return func + + return decorator + + def get( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.api_route( + path=path, + methods=["GET"], + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def put( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.api_route( + path=path, + methods=["PUT"], + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def post( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.api_route( + path=path, + methods=["POST"], + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def delete( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.api_route( + path=path, + methods=["DELETE"], + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def options( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.api_route( + path=path, + methods=["OPTIONS"], + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def head( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.api_route( + path=path, + methods=["HEAD"], + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def patch( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.api_route( + path=path, + methods=["PATCH"], + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) + + def trace( + self, + path: str, + name: str = None, + include_in_schema: bool = True, + tags: typing.List[str] = [], + summary: str = None, + description: str = None, + operation_id: str = None, + deprecated: bool = None, + response_type: typing.Type = None, + response_description: str = "Successful Response", + response_code=200, + response_wrapper=JSONResponse, + ): + return self.api_route( + path=path, + methods=["TRACE"], + name=name, + include_in_schema=include_in_schema, + tags=tags, + summary=summary, + description=description, + operation_id=operation_id, + deprecated=deprecated, + response_type=response_type, + response_description=response_description, + response_code=response_code, + response_wrapper=response_wrapper, + ) diff --git a/fastapi/security/__init__.py b/fastapi/security/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py new file mode 100644 index 000000000..4b6766eb7 --- /dev/null +++ b/fastapi/security/api_key.py @@ -0,0 +1,40 @@ +from starlette.requests import Request + +from pydantic import Schema +from enum import Enum +from .base import SecurityBase, Types + +__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"] + + +class APIKeyIn(Enum): + query = "query" + header = "header" + cookie = "cookie" + + +class APIKeyBase(SecurityBase): + type_ = Schema(Types.apiKey, alias="type") + in_: str = Schema(..., alias="in") + name: str + + +class APIKeyQuery(APIKeyBase): + in_ = Schema(APIKeyIn.query, alias="in") + + async def __call__(self, requests: Request): + return requests.query_params.get(self.name) + + +class APIKeyHeader(APIKeyBase): + in_ = Schema(APIKeyIn.header, alias="in") + + async def __call__(self, requests: Request): + return requests.headers.get(self.name) + + +class APIKeyCookie(APIKeyBase): + in_ = Schema(APIKeyIn.cookie, alias="in") + + async def __call__(self, requests: Request): + return requests.cookies.get(self.name) diff --git a/fastapi/security/base.py b/fastapi/security/base.py new file mode 100644 index 000000000..37433ff25 --- /dev/null +++ b/fastapi/security/base.py @@ -0,0 +1,17 @@ +from enum import Enum +from pydantic import BaseModel, Schema + +__all__ = ["Types", "SecurityBase"] + + +class Types(Enum): + apiKey = "apiKey" + http = "http" + oauth2 = "oauth2" + openIdConnect = "openIdConnect" + + +class SecurityBase(BaseModel): + scheme_name: str = None + type_: Types = Schema(..., alias="type") + description: str = None diff --git a/fastapi/security/http.py b/fastapi/security/http.py new file mode 100644 index 000000000..aaaf86618 --- /dev/null +++ b/fastapi/security/http.py @@ -0,0 +1,27 @@ + +from starlette.requests import Request +from pydantic import Schema +from .base import SecurityBase, Types + +__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"] + + +class HTTPBase(SecurityBase): + type_ = Schema(Types.http, alias="type") + scheme: str + + async def __call__(self, request: Request): + return request.headers.get("Authorization") + + +class HTTPBasic(HTTPBase): + scheme = "basic" + + +class HTTPBearer(HTTPBase): + scheme = "bearer" + bearerFormat: str = None + + +class HTTPDigest(HTTPBase): + scheme = "digest" diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py new file mode 100644 index 000000000..a6607ef52 --- /dev/null +++ b/fastapi/security/oauth2.py @@ -0,0 +1,45 @@ +from typing import Dict + +from pydantic import BaseModel, Schema + +from starlette.requests import Request +from .base import SecurityBase, Types + +# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"] + + +class OAuthFlow(BaseModel): + refreshUrl: str = None + scopes: Dict[str, str] = {} + + +class OAuthFlowImplicit(OAuthFlow): + authorizationUrl: str + + +class OAuthFlowPassword(OAuthFlow): + tokenUrl: str + + +class OAuthFlowClientCredentials(OAuthFlow): + tokenUrl: str + + +class OAuthFlowAuthorizationCode(OAuthFlow): + authorizationUrl: str + tokenUrl: str + + +class OAuthFlows(BaseModel): + implicit: OAuthFlowImplicit = None + password: OAuthFlowPassword = None + clientCredentials: OAuthFlowClientCredentials = None + authorizationCode: OAuthFlowAuthorizationCode = None + + +class OAuth2(SecurityBase): + type_ = Schema(Types.oauth2, alias="type") + flows: OAuthFlows + + async def __call__(self, request: Request): + return request.headers.get("Authorization") diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py new file mode 100644 index 000000000..2e7791a7a --- /dev/null +++ b/fastapi/security/open_id_connect_url.py @@ -0,0 +1,10 @@ +from starlette.requests import Request + +from .base import SecurityBase, Types + +class OpenIdConnect(SecurityBase): + type_ = Types.openIdConnect + openIdConnectUrl: str + + async def __call__(self, request: Request): + return request.headers.get("Authorization")