Browse Source

🎉 Start tracking messy initial stage

...before refactoring and breaking something
pull/1/head
Sebastián Ramírez 6 years ago
commit
406c092a3b
  1. 3
      fastapi/__init__.py
  2. 581
      fastapi/applications.py
  3. 246
      fastapi/params.py
  4. 33
      fastapi/pydantic_utils.py
  5. 785
      fastapi/routing.py
  6. 0
      fastapi/security/__init__.py
  7. 40
      fastapi/security/api_key.py
  8. 17
      fastapi/security/base.py
  9. 27
      fastapi/security/http.py
  10. 45
      fastapi/security/oauth2.py
  11. 10
      fastapi/security/open_id_connect_url.py

3
fastapi/__init__.py

@ -0,0 +1,3 @@
"""Fast API framework, fast high performance, fast to learn, fast to code"""
__version__ = '0.1'

581
fastapi/applications.py

@ -0,0 +1,581 @@
import typing
import inspect
from starlette.applications import Starlette
from starlette.middleware.lifespan import LifespanMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.exceptions import ExceptionMiddleware
from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse
from starlette.requests import Request
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from pydantic import BaseModel, BaseConfig, Schema
from pydantic.utils import lenient_issubclass
from pydantic.fields import Field
from pydantic.schema import (
field_schema,
get_flat_models_from_models,
get_flat_models_from_fields,
get_model_name_map,
schema,
model_process_schema,
)
from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant
from .pydantic_utils import jsonable_encoder
def docs(openapi_url):
return HTMLResponse(
"""
<! doctype html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
</head>
<body>
<div id="swagger-ui">
</div>
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({
url: '""" + openapi_url + """',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout"
})
</script>
</body>
</html>
""",
media_type="text/html",
)
class FastAPI(Starlette):
def __init__(
self,
debug: bool = False,
template_directory: str = None,
title: str = "Fast API",
description: str = "",
version: str = "0.1.0",
openapi_url: str = "/openapi.json",
docs_url: str = "/docs",
**extra: typing.Dict[str, typing.Any],
) -> None:
self._debug = debug
self.router = APIRouter()
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
)
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
self.template_env = self.load_template_env(template_directory)
self.title = title
self.description = description
self.version = version
self.openapi_url = openapi_url
self.docs_url = docs_url
self.extra = extra
self.openapi_version = "3.0.2"
if self.openapi_url:
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
if self.docs_url:
assert self.openapi_url, "The openapi_url is required for the docs"
self.add_route(
self.openapi_url,
lambda req: JSONResponse(self.openapi()),
include_in_schema=False,
)
self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False)
def add_api_route(
self,
path: str,
endpoint: typing.Callable,
methods: typing.List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> None:
self.router.add_api_route(
path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def api_route(
self,
path: str,
methods: typing.List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
self.router.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
return func
return decorator
def get(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.router.get(
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def put(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.router.put(
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def post(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.router.post(
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def delete(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.router.delete(
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def options(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.router.options(
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def head(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.router.head(
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def patch(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.router.patch(
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def trace(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.router.trace(
path=path,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def openapi(self):
info = {"title": self.title, "version": self.version}
if self.description:
info["description"] = self.description
output = {"openapi": self.openapi_version, "info": info}
components = {}
paths = {}
methods_with_body = set(("POST", "PUT"))
body_fields_from_routes = []
responses_from_routes = []
ref_prefix = "#/components/schemas/"
for route in self.routes:
route: APIRoute
if route.include_in_schema and isinstance(route, APIRoute):
if route.request_body:
assert isinstance(
route.request_body, Field
), "A request body must be a Pydantic BaseModel or Field"
body_fields_from_routes.append(route.request_body)
if route.response_field:
responses_from_routes.append(route.response_field)
flat_models = get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes
)
model_name_map = get_model_name_map(flat_models)
definitions = {}
for model in flat_models:
m_schema, m_definitions = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=ref_prefix
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
validation_error_definition = {
"title": "ValidationError",
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
"required": ["loc", "msg", "type"],
}
validation_error_response_definition = {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": ref_prefix + "ValidationError"},
}
},
}
for route in self.routes:
route: APIRoute
if route.include_in_schema and isinstance(route, APIRoute):
path = paths.get(route.path, {})
for method in route.methods:
operation = {}
if route.tags:
operation["tags"] = route.tags
if route.summary:
operation["summary"] = route.summary
if route.description:
operation["description"] = route.description
if route.operation_id:
operation["operationId"] = route.operation_id
else:
operation["operationId"] = route.name
if route.deprecated:
operation["deprecated"] = route.deprecated
parameters = []
flat_dependant = get_flat_dependant(route.dependant)
security_definitions = {}
for security_scheme in flat_dependant.security_schemes:
security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False)
security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__
security_definitions[security_name] = security_definition
if security_definitions:
components.setdefault("securitySchemes", {}).update(security_definitions)
operation["security"] = [{name: []} for name in security_definitions]
all_route_params = get_openapi_params(route.dependant)
for param in all_route_params:
if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition
definitions[
"HTTPValidationError"
] = validation_error_response_definition
parameter = {
"name": param.alias,
"in": param.schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
if param.schema.description:
parameter["description"] = param.schema.description
if param.schema.deprecated:
parameter["deprecated"] = param.schema.deprecated
parameters.append(parameter)
if parameters:
operation["parameters"] = parameters
if method in methods_with_body:
request_body = getattr(route, "request_body", None)
if request_body:
assert isinstance(request_body, Field)
body_schema, _ = field_schema(
request_body,
model_name_map=model_name_map,
ref_prefix=ref_prefix,
)
required = request_body.required
request_body_oai = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {
"application/json": {"schema": body_schema}
}
operation["requestBody"] = request_body_oai
response_code = str(route.response_code)
response_schema = {"type": "string"}
if lenient_issubclass(route.response_wrapper, JSONResponse):
response_media_type = "application/json"
if route.response_field:
response_schema, _ = field_schema(
route.response_field,
model_name_map=model_name_map,
ref_prefix=ref_prefix,
)
else:
response_schema = {}
elif lenient_issubclass(route.response_wrapper, HTMLResponse):
response_media_type = "text/html"
else:
response_media_type = "text/plain"
content = {response_media_type: {"schema": response_schema}}
operation["responses"] = {
response_code: {
"description": route.response_description,
"content": content,
}
}
if all_route_params or getattr(route, "request_body", None):
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": ref_prefix + "HTTPValidationError"
}
}
},
}
path[method.lower()] = operation
paths[route.path] = path
if definitions:
components.setdefault("schemas", {}).update(definitions)
if components:
output["components"] = components
output["paths"] = paths
return output

246
fastapi/params.py

@ -0,0 +1,246 @@
from typing import Sequence
from enum import Enum
from pydantic import Schema
class ParamTypes(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
class Param(Schema):
in_: ParamTypes
def __init__(
self,
default,
*,
deprecated: bool = None,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
):
self.deprecated = deprecated
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Path(Param):
in_ = ParamTypes.path
def __init__(
self,
default,
*,
deprecated: bool = None,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
):
self.description = description
self.deprecated = deprecated
self.in_ = self.in_
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Query(Param):
in_ = ParamTypes.query
def __init__(
self,
default,
*,
deprecated: bool = None,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
):
self.description = description
self.deprecated = deprecated
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Header(Param):
in_ = ParamTypes.header
def __init__(
self,
default,
*,
deprecated: bool = None,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
):
self.description = description
self.deprecated = deprecated
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Cookie(Param):
in_ = ParamTypes.cookie
def __init__(
self,
default,
*,
deprecated: bool = None,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
):
self.description = description
self.deprecated = deprecated
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Body(Schema):
def __init__(
self,
default,
*,
sub_key=False,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: object,
):
self.sub_key = sub_key
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Depends:
def __init__(self, dependency = None):
self.dependency = dependency
class Security:
def __init__(self, security_scheme = None, scopes: Sequence[str] = None):
self.security_scheme = security_scheme
self.scopes = scopes

33
fastapi/pydantic_utils.py

@ -0,0 +1,33 @@
from types import GeneratorType
from typing import Set
from pydantic import BaseModel
from enum import Enum
from pydantic.json import pydantic_encoder
def jsonable_encoder(
obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True,
):
if isinstance(obj, BaseModel):
return jsonable_encoder(
obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none
)
elif isinstance(obj, Enum):
return obj.value
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dict):
return {
jsonable_encoder(
key, by_alias=by_alias, include_none=include_none,
): jsonable_encoder(
value, by_alias=by_alias, include_none=include_none,
)
for key, value in obj.items() if value is not None or include_none
}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [
jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none)
for item in obj
]
return pydantic_encoder(obj)

785
fastapi/routing.py

@ -0,0 +1,785 @@
import asyncio
import inspect
import re
import typing
from copy import deepcopy
from starlette import routing
from starlette.routing import get_name, request_response
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
from pydantic import BaseConfig, BaseModel, create_model, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.errors import MissingError
from pydantic.utils import lenient_issubclass
from .pydantic_utils import jsonable_encoder
from fastapi import params
from fastapi.security.base import SecurityBase
param_supported_types = (str, int, float, bool)
class Dependant:
def __init__(
self,
*,
path_params: typing.List[Field] = None,
query_params: typing.List[Field] = None,
header_params: typing.List[Field] = None,
cookie_params: typing.List[Field] = None,
body_params: typing.List[Field] = None,
dependencies: typing.List["Dependant"] = None,
security_schemes: typing.List[Field] = None,
name: str = None,
call: typing.Callable = None,
request_param_name: str = None,
) -> None:
self.path_params: typing.List[Field] = path_params or []
self.query_params: typing.List[Field] = query_params or []
self.header_params: typing.List[Field] = header_params or []
self.cookie_params: typing.List[Field] = cookie_params or []
self.body_params: typing.List[Field] = body_params or []
self.dependencies: typing.List[Dependant] = dependencies or []
self.security_schemes: typing.List[Field] = security_schemes or []
self.request_param_name = request_param_name
self.name = name
self.call: typing.Callable = call
def request_params_to_args(
required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
values = {}
errors = []
for field in required_params:
value = received_params.get(field.alias)
if value is None:
if field.required:
errors.append(
ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
)
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(
value, values, loc=(field.schema.in_.value, field.alias)
)
if isinstance(errors_, ErrorWrapper):
errors_: ErrorWrapper
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def request_body_to_args(
required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
values = {}
errors = []
if required_params:
field = required_params[0]
sub_key = getattr(field.schema, "sub_key", None)
if len(required_params) == 1 and not sub_key:
received_body = {field.alias: received_body}
for field in required_params:
value = received_body.get(field.alias)
if value is None:
if field.required:
errors.append(
ErrorWrapper(
MissingError(), loc=("body", field.alias), config=BaseConfig
)
)
else:
values[field.name] = deepcopy(field.default)
continue
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
if isinstance(errors_, ErrorWrapper):
errors_: ErrorWrapper
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def add_param_to_fields(
*,
param: inspect.Parameter,
dependant: Dependant,
default_schema=params.Param,
force_type: params.ParamTypes = None,
):
default_value = Required
if not param.default == param.empty:
default_value = param.default
if isinstance(default_value, params.Param):
schema = default_value
default_value = schema.default
if schema.in_ is None:
schema.in_ = default_schema.in_
if force_type:
schema.in_ = force_type
else:
schema = default_schema(default_value)
required = default_value == Required
annotation = typing.Any
if not param.annotation == param.empty:
annotation = param.annotation
annotation = get_annotation_from_schema(annotation, schema)
Config = BaseConfig
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=schema.alias or param.name,
required=required,
model_config=Config,
class_validators=[],
schema=schema,
)
if schema.in_ == params.ParamTypes.path:
dependant.path_params.append(field)
elif schema.in_ == params.ParamTypes.query:
dependant.query_params.append(field)
elif schema.in_ == params.ParamTypes.header:
dependant.header_params.append(field)
else:
assert (
schema.in_ == params.ParamTypes.cookie
), f"non-body parameters must be in path, query, header or cookie: {param.name}"
dependant.cookie_params.append(field)
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
default_value = Required
if not param.default == param.empty:
default_value = param.default
if isinstance(default_value, Schema):
schema = default_value
default_value = schema.default
else:
schema = Schema(default_value)
required = default_value == Required
annotation = get_annotation_from_schema(param.annotation, schema)
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=schema.alias or param.name,
required=required,
model_config=BaseConfig,
class_validators=[],
schema=schema,
)
dependant.body_params.append(field)
def get_sub_dependant(
*, param: inspect.Parameter, path: str
):
depends: params.Depends = param.default
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
assert callable(dependency)
sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
if isinstance(dependency, SecurityBase):
sub_dependant.security_schemes.append(dependency)
return sub_dependant
def get_flat_dependant(dependant: Dependant):
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_schemes=dependant.security_schemes.copy(),
)
for sub_dependant in dependant.dependencies:
if sub_dependant is dependant:
raise ValueError("recursion", dependant.dependencies)
flat_sub = get_flat_dependant(sub_dependant)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_schemes.extend(flat_sub.security_schemes)
return flat_dependant
def get_path_param_names(path: str):
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
def get_dependant(*, path: str, call: typing.Callable, name: str = None):
path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call)
signature_params = endpoint_signature.parameters
dependant = Dependant(call=call, name=name)
for param_name in signature_params:
param = signature_params[param_name]
if isinstance(param.default, params.Depends):
sub_dependant = get_sub_dependant(param=param, path=path)
dependant.dependencies.append(sub_dependant)
for param_name in signature_params:
param = signature_params[param_name]
if (
(param.default == param.empty) or isinstance(param.default, params.Path)
) and (param_name in path_param_names):
assert lenient_issubclass(
param.annotation, param_supported_types
), f"Path params must be of type str, int, float or boot: {param}"
param = signature_params[param_name]
add_param_to_fields(
param=param,
dependant=dependant,
default_schema=params.Path,
force_type=params.ParamTypes.path,
)
elif (param.default == param.empty or param.default is None) and (
param.annotation == param.empty
or lenient_issubclass(param.annotation, param_supported_types)
):
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif isinstance(param.default, params.Param):
if param.annotation != param.empty:
assert lenient_issubclass(
param.annotation, param_supported_types
), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param_name
elif not isinstance(param.default, params.Depends):
add_param_to_body_fields(param=param, dependant=dependant)
return dependant
def is_coroutine_callable(call: typing.Callable):
if inspect.isfunction(call):
return asyncio.iscoroutinefunction(call)
elif inspect.isclass(call):
return False
else:
call = getattr(call, "__call__", None)
if not call:
return False
else:
return asyncio.iscoroutinefunction(call)
async def solve_dependencies(*, request: Request, dependant: Dependant):
values = {}
errors = []
for sub_dependant in dependant.dependencies:
sub_values, sub_errors = await solve_dependencies(
request=request, dependant=sub_dependant
)
if sub_errors:
return {}, errors
if is_coroutine_callable(sub_dependant.call):
solved = await sub_dependant.call(**sub_values)
else:
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
values[sub_dependant.name] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values)
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors = path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
body = await request.json()
body_values, body_errors = request_body_to_args(dependant.body_params, body)
values.update(body_values)
errors.extend(body_errors)
if dependant.request_param_name:
values[dependant.request_param_name] = request
return values, errors
def get_app(dependant: Dependant):
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
async def app(request: Request) -> Response:
values, errors = await solve_dependencies(request=request, dependant=dependant)
if errors:
errors_out = ValidationError(errors)
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
)
else:
if is_coroutine:
raw_response = await dependant.call(**values)
else:
raw_response = await run_in_threadpool(dependant.call, **values)
if isinstance(raw_response, Response):
return raw_response
else:
return JSONResponse(content=jsonable_encoder(raw_response))
return app
def get_openapi_params(dependant: Dependant):
flat_dependant = get_flat_dependant(dependant)
return (
flat_dependant.path_params
+ flat_dependant.query_params
+ flat_dependant.header_params
+ flat_dependant.cookie_params
)
class APIRoute(routing.Route):
def __init__(
self,
path: str,
endpoint: typing.Callable,
*,
methods: typing.List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> None:
# TODO define how to read and provide security params, and how to have them globally too
# TODO implement dependencies and injection
# TODO refactor code structure
# TODO create testing
# TODO testing coverage
assert path.startswith("/"), "Routed paths must always start '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema
self.tags = tags
self.summary = summary
self.description = description
self.operation_id = operation_id
self.deprecated = deprecated
self.request_body: typing.Union[BaseModel, Field, None] = None
self.response_description = response_description
self.response_code = response_code
self.response_wrapper = response_wrapper
self.response_field = None
if response_type:
assert lenient_issubclass(
response_wrapper, JSONResponse
), "To declare a type the response must be a JSON response"
self.response_type = response_type
response_name = "Response_" + self.name
self.response_field = Field(
name=response_name,
type_=self.response_type,
class_validators=[],
default=None,
required=False,
model_config=BaseConfig(),
schema=Schema(None),
)
else:
self.response_type = None
if methods is None:
methods = ["GET"]
self.methods = methods
self.path_regex, self.path_format, self.param_convertors = self.compile_path(
path
)
assert inspect.isfunction(endpoint) or inspect.ismethod(
endpoint
), f"An endpoint must be a function or method"
self.dependant = get_dependant(path=path, call=self.endpoint)
# flat_dependant = get_flat_dependant(self.dependant)
# path_param_names = get_path_param_names(path)
# for path_param in path_param_names:
# assert path_param in {
# f.alias for f in flat_dependant.path_params
# }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}"
if self.dependant.body_params:
first_param = self.dependant.body_params[0]
sub_key = getattr(first_param.schema, "sub_key", None)
if len(self.dependant.body_params) == 1 and not sub_key:
self.request_body = first_param
else:
model_name = "Body_" + self.name
BodyModel = create_model(model_name)
for f in self.dependant.body_params:
BodyModel.__fields__[f.name] = f
required = any(True for f in self.dependant.body_params if f.required)
field = Field(
name="body",
type_=BodyModel,
default=None,
required=required,
model_config=BaseConfig,
class_validators=[],
alias="body",
schema=Schema(None),
)
self.request_body = field
self.app = request_response(get_app(dependant=self.dependant))
class APIRouter(routing.Router):
def add_api_route(
self,
path: str,
endpoint: typing.Callable,
methods: typing.List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> None:
route = APIRoute(
path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
self.routes.append(route)
def api_route(
self,
path: str,
methods: typing.List[str] = None,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
self.add_api_route(
path,
func,
methods=methods,
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
return func
return decorator
def get(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["GET"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def put(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["PUT"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def post(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["POST"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def delete(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["DELETE"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def options(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["OPTIONS"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def head(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["HEAD"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def patch(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["PATCH"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)
def trace(
self,
path: str,
name: str = None,
include_in_schema: bool = True,
tags: typing.List[str] = [],
summary: str = None,
description: str = None,
operation_id: str = None,
deprecated: bool = None,
response_type: typing.Type = None,
response_description: str = "Successful Response",
response_code=200,
response_wrapper=JSONResponse,
):
return self.api_route(
path=path,
methods=["TRACE"],
name=name,
include_in_schema=include_in_schema,
tags=tags,
summary=summary,
description=description,
operation_id=operation_id,
deprecated=deprecated,
response_type=response_type,
response_description=response_description,
response_code=response_code,
response_wrapper=response_wrapper,
)

0
fastapi/security/__init__.py

40
fastapi/security/api_key.py

@ -0,0 +1,40 @@
from starlette.requests import Request
from pydantic import Schema
from enum import Enum
from .base import SecurityBase, Types
__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"]
class APIKeyIn(Enum):
query = "query"
header = "header"
cookie = "cookie"
class APIKeyBase(SecurityBase):
type_ = Schema(Types.apiKey, alias="type")
in_: str = Schema(..., alias="in")
name: str
class APIKeyQuery(APIKeyBase):
in_ = Schema(APIKeyIn.query, alias="in")
async def __call__(self, requests: Request):
return requests.query_params.get(self.name)
class APIKeyHeader(APIKeyBase):
in_ = Schema(APIKeyIn.header, alias="in")
async def __call__(self, requests: Request):
return requests.headers.get(self.name)
class APIKeyCookie(APIKeyBase):
in_ = Schema(APIKeyIn.cookie, alias="in")
async def __call__(self, requests: Request):
return requests.cookies.get(self.name)

17
fastapi/security/base.py

@ -0,0 +1,17 @@
from enum import Enum
from pydantic import BaseModel, Schema
__all__ = ["Types", "SecurityBase"]
class Types(Enum):
apiKey = "apiKey"
http = "http"
oauth2 = "oauth2"
openIdConnect = "openIdConnect"
class SecurityBase(BaseModel):
scheme_name: str = None
type_: Types = Schema(..., alias="type")
description: str = None

27
fastapi/security/http.py

@ -0,0 +1,27 @@
from starlette.requests import Request
from pydantic import Schema
from .base import SecurityBase, Types
__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
class HTTPBase(SecurityBase):
type_ = Schema(Types.http, alias="type")
scheme: str
async def __call__(self, request: Request):
return request.headers.get("Authorization")
class HTTPBasic(HTTPBase):
scheme = "basic"
class HTTPBearer(HTTPBase):
scheme = "bearer"
bearerFormat: str = None
class HTTPDigest(HTTPBase):
scheme = "digest"

45
fastapi/security/oauth2.py

@ -0,0 +1,45 @@
from typing import Dict
from pydantic import BaseModel, Schema
from starlette.requests import Request
from .base import SecurityBase, Types
# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
class OAuthFlow(BaseModel):
refreshUrl: str = None
scopes: Dict[str, str] = {}
class OAuthFlowImplicit(OAuthFlow):
authorizationUrl: str
class OAuthFlowPassword(OAuthFlow):
tokenUrl: str
class OAuthFlowClientCredentials(OAuthFlow):
tokenUrl: str
class OAuthFlowAuthorizationCode(OAuthFlow):
authorizationUrl: str
tokenUrl: str
class OAuthFlows(BaseModel):
implicit: OAuthFlowImplicit = None
password: OAuthFlowPassword = None
clientCredentials: OAuthFlowClientCredentials = None
authorizationCode: OAuthFlowAuthorizationCode = None
class OAuth2(SecurityBase):
type_ = Schema(Types.oauth2, alias="type")
flows: OAuthFlows
async def __call__(self, request: Request):
return request.headers.get("Authorization")

10
fastapi/security/open_id_connect_url.py

@ -0,0 +1,10 @@
from starlette.requests import Request
from .base import SecurityBase, Types
class OpenIdConnect(SecurityBase):
type_ = Types.openIdConnect
openIdConnectUrl: str
async def __call__(self, request: Request):
return request.headers.get("Authorization")
Loading…
Cancel
Save