Browse Source

Refactor, update code, several features

pull/1/head
Sebastián Ramírez 6 years ago
parent
commit
addfa89b0f
  1. 7
      fastapi/applications.py
  2. 78
      fastapi/openapi/docs.py
  3. 265
      fastapi/openapi/utils.py
  4. 52
      fastapi/routing.py
  5. 37
      fastapi/security/api_key.py
  6. 18
      fastapi/security/base.py
  7. 32
      fastapi/security/http.py
  8. 42
      fastapi/security/oauth2.py
  9. 10
      fastapi/security/open_id_connect_url.py

7
fastapi/applications.py

@ -8,7 +8,8 @@ from starlette.responses import JSONResponse
from fastapi import routing from fastapi import routing
from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
async def http_exception(request, exc: HTTPException): async def http_exception(request, exc: HTTPException):
@ -154,8 +155,10 @@ class FastAPI(Starlette):
response_wrapper=response_wrapper, response_wrapper=response_wrapper,
) )
return func return func
return decorator return decorator
def include_router(self, router: "APIRouter", *, prefix=""):
self.router.include_router(router, prefix=prefix)
def get( def get(
self, self,

78
fastapi/openapi/docs.py

@ -0,0 +1,78 @@
from starlette.responses import HTMLResponse
def get_swagger_ui_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<! doctype html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
<title>
"""
+ title
+ """
</title>
</head>
<body>
<div id="swagger-ui">
</div>
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({
url: '"""
+ openapi_url
+ """',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout"
})
</script>
</body>
</html>
""",
media_type="text/html",
)
def get_redoc_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<title>
"""
+ title
+ """
</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {
margin: 0;
padding: 0;
}
</style>
</head>
<body>
<redoc spec-url='"""
+ openapi_url
+ """'></redoc>
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
</body>
</html>
""",
media_type="text/html",
)

265
fastapi/openapi/utils.py

@ -1,4 +1,8 @@
from typing import Any, Dict, Sequence, Type from typing import Any, Dict, Sequence, Type, List
from pydantic.fields import Field
from pydantic.schema import field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass
from starlette.responses import HTMLResponse, JSONResponse from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import BaseRoute from starlette.routing import BaseRoute
@ -12,9 +16,7 @@ from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
from fastapi.openapi.models import OpenAPI from fastapi.openapi.models import OpenAPI
from fastapi.params import Body from fastapi.params import Body
from fastapi.utils import get_flat_models_from_routes, get_model_definitions from fastapi.utils import get_flat_models_from_routes, get_model_definitions
from pydantic.fields import Field
from pydantic.schema import field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass
validation_error_definition = { validation_error_definition = {
"title": "ValidationError", "title": "ValidationError",
@ -49,91 +51,126 @@ def get_openapi_params(dependant: Dependant):
+ flat_dependant.cookie_params + flat_dependant.cookie_params
) )
def get_openapi_security_definitions(flat_dependant: Dependant):
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
security_definition = jsonable_encoder(
security_requirement.security_scheme.model,
by_alias=True,
include_none=False,
)
security_name = (
security_requirement.security_scheme.scheme_name
)
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
return security_definitions, operation_security
def get_openapi_operation_parameters(all_route_params: List[Field]):
definitions: Dict[str, Dict] = {}
parameters = []
for param in all_route_params:
if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition
definitions["HTTPValidationError"] = validation_error_response_definition
parameter = {
"name": param.alias,
"in": param.schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
if param.schema.description:
parameter["description"] = param.schema.description
if param.schema.deprecated:
parameter["deprecated"] = param.schema.deprecated
parameters.append(parameter)
return definitions, parameters
def get_openapi_operation_request_body(
*, body_field: Field, model_name_map: Dict[Type, str]
):
if not body_field:
return None
assert isinstance(body_field, Field)
body_schema, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
if isinstance(body_field.schema, Body):
request_media_type = body_field.schema.media_type
else:
# Includes not declared media types (Schema)
request_media_type = "application/json"
required = body_field.required
request_body_oai = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
return request_body_oai
def generate_operation_id(*, route: routing.APIRoute, method: str):
if route.operation_id:
return route.operation_id
path: str = route.path
operation_id = route.name + path
operation_id = operation_id.replace("{", "_").replace("}", "_").replace("/", "_")
operation_id = operation_id + "_" + method.lower()
return operation_id
def generate_operation_summary(*, route: routing.APIRoute, method: str):
if route.summary:
return route.summary
return method.title() + " " + route.name.replace("_", " ").title()
def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
operation["summary"] = generate_operation_summary(route=route, method=method)
if route.description:
operation["description"] = route.description
operation["operationId"] = generate_operation_id(route=route, method=method)
if route.deprecated:
operation["deprecated"] = route.deprecated
return operation
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]): def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
if not (route.include_in_schema and isinstance(route, routing.APIRoute)): if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
return None return None
path = {} path = {}
security_schemes = {} security_schemes: Dict[str, Any] = {}
definitions = {} definitions: Dict[str, Any] = {}
for method in route.methods: for method in route.methods:
operation: Dict[str, Any] = {} operation = get_openapi_operation_metadata(route=route, method=method)
if route.tags: parameters: List[Dict] = []
operation["tags"] = route.tags
if route.summary:
operation["summary"] = route.summary
if route.description:
operation["description"] = route.description
if route.operation_id:
operation["operationId"] = route.operation_id
else:
operation["operationId"] = route.name
if route.deprecated:
operation["deprecated"] = route.deprecated
parameters = []
flat_dependant = get_flat_dependant(route.dependant) flat_dependant = get_flat_dependant(route.dependant)
security_definitions = {} security_definitions, operation_security = get_openapi_security_definitions(
for security_requirement in flat_dependant.security_requirements: flat_dependant=flat_dependant
security_definition = jsonable_encoder( )
security_requirement.security_scheme, if operation_security:
exclude={"scheme_name"}, operation.setdefault("security", []).extend(operation_security)
by_alias=True,
include_none=False,
)
security_name = (
getattr(
security_requirement.security_scheme, "scheme_name", None
)
or security_requirement.security_scheme.__class__.__name__
)
security_definitions[security_name] = security_definition
operation.setdefault("security", []).append(
{security_name: security_requirement.scopes}
)
if security_definitions: if security_definitions:
security_schemes.update( security_schemes.update(security_definitions)
security_definitions
)
all_route_params = get_openapi_params(route.dependant) all_route_params = get_openapi_params(route.dependant)
for param in all_route_params: validation_definitions, operation_parameters = get_openapi_operation_parameters(
if "ValidationError" not in definitions: all_route_params=all_route_params
definitions["ValidationError"] = validation_error_definition )
definitions[ definitions.update(validation_definitions)
"HTTPValidationError" parameters.extend(operation_parameters)
] = validation_error_response_definition
parameter = {
"name": param.alias,
"in": param.schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
if param.schema.description:
parameter["description"] = param.schema.description
if param.schema.deprecated:
parameter["deprecated"] = param.schema.deprecated
parameters.append(parameter)
if parameters: if parameters:
operation["parameters"] = parameters operation["parameters"] = parameters
if method in METHODS_WITH_BODY: if method in METHODS_WITH_BODY:
body_field = route.body_field request_body_oai = get_openapi_operation_request_body(
if body_field: body_field=route.body_field, model_name_map=model_name_map
assert isinstance(body_field, Field) )
body_schema, _ = field_schema( if request_body_oai:
body_field,
model_name_map=model_name_map,
ref_prefix=REF_PREFIX,
)
if isinstance(body_field.schema, Body):
request_media_type = body_field.schema.media_type
else:
# Includes not declared media types (Schema)
request_media_type = "application/json"
required = body_field.required
request_body_oai = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {
request_media_type: {"schema": body_schema}
}
operation["requestBody"] = request_body_oai operation["requestBody"] = request_body_oai
response_code = str(route.response_code) response_code = str(route.response_code)
response_schema = {"type": "string"} response_schema = {"type": "string"}
@ -206,75 +243,3 @@ def get_openapi(
output["components"] = components output["components"] = components
output["paths"] = paths output["paths"] = paths
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False) return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
def get_swagger_ui_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<! doctype html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
<title>
""" + title + """
</title>
</head>
<body>
<div id="swagger-ui">
</div>
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({
url: '"""
+ openapi_url
+ """',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout"
})
</script>
</body>
</html>
""",
media_type="text/html",
)
def get_redoc_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<title>
""" + title + """
</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {
margin: 0;
padding: 0;
}
</style>
</head>
<body>
<redoc spec-url='""" + openapi_url + """'></redoc>
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
</body>
</html>
""",
media_type="text/html",
)

52
fastapi/routing.py

@ -2,6 +2,11 @@ import asyncio
import inspect import inspect
from typing import Callable, List, Type from typing import Callable, List, Type
from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
from starlette import routing from starlette import routing
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@ -15,10 +20,6 @@ from fastapi import params
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
def serialize_response(*, field: Field = None, response): def serialize_response(*, field: Field = None, response):
@ -44,11 +45,12 @@ def get_app(
response_field: Type[Field] = None, response_field: Type[Field] = None,
): ):
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call) is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.schema, params.Form)
async def app(request: Request) -> Response: async def app(request: Request) -> Response:
body = None body = None
if body_field: if body_field:
if isinstance(body_field.schema, params.Form): if is_body_form:
raw_body = await request.form() raw_body = await request.form()
body = {} body = {}
for field, value in raw_body.items(): for field, value in raw_body.items():
@ -127,12 +129,7 @@ class APIRoute(routing.Route):
response_code=200, response_code=200,
response_wrapper=JSONResponse, response_wrapper=JSONResponse,
) -> None: ) -> None:
# TODO define how to read and provide security params, and how to have them globally too assert path.startswith("/"), "Routed paths must always start with '/'"
# TODO implement dependencies and injection
# TODO refactor code structure
# TODO create testing
# TODO testing coverage
assert path.startswith("/"), "Routed paths must always start '/'"
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
@ -260,6 +257,39 @@ class APIRouter(routing.Router):
return decorator return decorator
def include_router(self, router: "APIRouter", *, prefix=""):
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
for route in router.routes:
if isinstance(route, APIRoute):
self.add_api_route(
prefix + route.path,
route.endpoint,
methods=route.methods,
name=route.name,
include_in_schema=route.include_in_schema,
tags=route.tags,
summary=route.summary,
description=route.description,
operation_id=route.operation_id,
deprecated=route.deprecated,
response_type=route.response_type,
response_description=route.response_description,
response_code=route.response_code,
response_wrapper=route.response_wrapper,
)
elif isinstance(route, routing.Route):
self.add_route(
prefix + route.path,
route.endpoint,
methods=route.methods,
name=route.name,
include_in_schema=route.include_in_schema,
)
def get( def get(
self, self,
path: str, path: str,

37
fastapi/security/api_key.py

@ -1,39 +1,34 @@
from enum import Enum
from pydantic import Schema
from starlette.requests import Request from starlette.requests import Request
from .base import SecurityBase, Types from .base import SecurityBase
from fastapi.openapi.models import APIKeyIn, APIKey
class APIKeyIn(Enum):
query = "query"
header = "header"
cookie = "cookie"
class APIKeyBase(SecurityBase): class APIKeyBase(SecurityBase):
type_ = Schema(Types.apiKey, alias="type") pass
in_: str = Schema(..., alias="in")
name: str
class APIKeyQuery(APIKeyBase): class APIKeyQuery(APIKeyBase):
in_ = Schema(APIKeyIn.query, alias="in")
def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.query, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request): async def __call__(self, requests: Request):
return requests.query_params.get(self.name) return requests.query_params.get(self.model.name)
class APIKeyHeader(APIKeyBase): class APIKeyHeader(APIKeyBase):
in_ = Schema(APIKeyIn.header, alias="in") def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.header, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request): async def __call__(self, requests: Request):
return requests.headers.get(self.name) return requests.headers.get(self.model.name)
class APIKeyCookie(APIKeyBase): class APIKeyCookie(APIKeyBase):
in_ = Schema(APIKeyIn.cookie, alias="in") def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.cookie, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request): async def __call__(self, requests: Request):
return requests.cookies.get(self.name) return requests.cookies.get(self.model.name)

18
fastapi/security/base.py

@ -1,16 +1,6 @@
from enum import Enum from pydantic import BaseModel
from pydantic import BaseModel, Schema from fastapi.openapi.models import SecurityBase as SecurityBaseModel
class SecurityBase:
class Types(Enum): pass
apiKey = "apiKey"
http = "http"
oauth2 = "oauth2"
openIdConnect = "openIdConnect"
class SecurityBase(BaseModel):
scheme_name: str = None
type_: Types = Schema(..., alias="type")
description: str = None

32
fastapi/security/http.py

@ -1,26 +1,40 @@
from pydantic import Schema
from starlette.requests import Request from starlette.requests import Request
from .base import SecurityBase, Types from .base import SecurityBase
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
class HTTPBase(SecurityBase): class HTTPBase(SecurityBase):
type_ = Schema(Types.http, alias="type") def __init__(self, *, scheme: str, scheme_name: str = None):
scheme: str self.model = HTTPBaseModel(scheme=scheme)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request):
return request.headers.get("Authorization") return request.headers.get("Authorization")
class HTTPBasic(HTTPBase): class HTTPBasic(HTTPBase):
scheme = "basic" def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")
class HTTPBearer(HTTPBase): class HTTPBearer(HTTPBase):
scheme = "bearer" def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
bearerFormat: str = None self.model = HTTPBearerModel(bearerFormat=bearerFormat)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")
class HTTPDigest(HTTPBase): class HTTPDigest(HTTPBase):
scheme = "digest" def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="digest")
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")

42
fastapi/security/oauth2.py

@ -1,43 +1,13 @@
from typing import Dict
from pydantic import BaseModel, Schema
from starlette.requests import Request from starlette.requests import Request
from .base import SecurityBase, Types from .base import SecurityBase
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
class OAuthFlow(BaseModel):
refreshUrl: str = None
scopes: Dict[str, str] = {}
class OAuthFlowImplicit(OAuthFlow):
authorizationUrl: str
class OAuthFlowPassword(OAuthFlow):
tokenUrl: str
class OAuthFlowClientCredentials(OAuthFlow):
tokenUrl: str
class OAuthFlowAuthorizationCode(OAuthFlow):
authorizationUrl: str
tokenUrl: str
class OAuthFlows(BaseModel):
implicit: OAuthFlowImplicit = None
password: OAuthFlowPassword = None
clientCredentials: OAuthFlowClientCredentials = None
authorizationCode: OAuthFlowAuthorizationCode = None
class OAuth2(SecurityBase): class OAuth2(SecurityBase):
type_ = Schema(Types.oauth2, alias="type") def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
flows: OAuthFlows self.model = OAuth2Model(flows=flows)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request):
return request.headers.get("Authorization") return request.headers.get("Authorization")

10
fastapi/security/open_id_connect_url.py

@ -1,11 +1,13 @@
from starlette.requests import Request from starlette.requests import Request
from .base import SecurityBase, Types from .base import SecurityBase
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
class OpenIdConnect(SecurityBase): class OpenIdConnect(SecurityBase):
type_ = Types.openIdConnect def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
openIdConnectUrl: str self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request): async def __call__(self, request: Request):
return request.headers.get("Authorization") return request.headers.get("Authorization")

Loading…
Cancel
Save