Browse Source

Refactor, update code, several features

pull/1/head
Sebastián Ramírez 6 years ago
parent
commit
addfa89b0f
  1. 7
      fastapi/applications.py
  2. 78
      fastapi/openapi/docs.py
  3. 265
      fastapi/openapi/utils.py
  4. 52
      fastapi/routing.py
  5. 37
      fastapi/security/api_key.py
  6. 18
      fastapi/security/base.py
  7. 32
      fastapi/security/http.py
  8. 42
      fastapi/security/oauth2.py
  9. 10
      fastapi/security/open_id_connect_url.py

7
fastapi/applications.py

@ -8,7 +8,8 @@ from starlette.responses import JSONResponse
from fastapi import routing
from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
async def http_exception(request, exc: HTTPException):
@ -154,8 +155,10 @@ class FastAPI(Starlette):
response_wrapper=response_wrapper,
)
return func
return decorator
def include_router(self, router: "APIRouter", *, prefix=""):
self.router.include_router(router, prefix=prefix)
def get(
self,

78
fastapi/openapi/docs.py

@ -0,0 +1,78 @@
from starlette.responses import HTMLResponse
def get_swagger_ui_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<! doctype html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
<title>
"""
+ title
+ """
</title>
</head>
<body>
<div id="swagger-ui">
</div>
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({
url: '"""
+ openapi_url
+ """',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout"
})
</script>
</body>
</html>
""",
media_type="text/html",
)
def get_redoc_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<title>
"""
+ title
+ """
</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {
margin: 0;
padding: 0;
}
</style>
</head>
<body>
<redoc spec-url='"""
+ openapi_url
+ """'></redoc>
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
</body>
</html>
""",
media_type="text/html",
)

265
fastapi/openapi/utils.py

@ -1,4 +1,8 @@
from typing import Any, Dict, Sequence, Type
from typing import Any, Dict, Sequence, Type, List
from pydantic.fields import Field
from pydantic.schema import field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import BaseRoute
@ -12,9 +16,7 @@ from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
from pydantic.fields import Field
from pydantic.schema import field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass
validation_error_definition = {
"title": "ValidationError",
@ -49,91 +51,126 @@ def get_openapi_params(dependant: Dependant):
+ flat_dependant.cookie_params
)
def get_openapi_security_definitions(flat_dependant: Dependant):
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
security_definition = jsonable_encoder(
security_requirement.security_scheme.model,
by_alias=True,
include_none=False,
)
security_name = (
security_requirement.security_scheme.scheme_name
)
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
return security_definitions, operation_security
def get_openapi_operation_parameters(all_route_params: List[Field]):
definitions: Dict[str, Dict] = {}
parameters = []
for param in all_route_params:
if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition
definitions["HTTPValidationError"] = validation_error_response_definition
parameter = {
"name": param.alias,
"in": param.schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
if param.schema.description:
parameter["description"] = param.schema.description
if param.schema.deprecated:
parameter["deprecated"] = param.schema.deprecated
parameters.append(parameter)
return definitions, parameters
def get_openapi_operation_request_body(
*, body_field: Field, model_name_map: Dict[Type, str]
):
if not body_field:
return None
assert isinstance(body_field, Field)
body_schema, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
if isinstance(body_field.schema, Body):
request_media_type = body_field.schema.media_type
else:
# Includes not declared media types (Schema)
request_media_type = "application/json"
required = body_field.required
request_body_oai = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
return request_body_oai
def generate_operation_id(*, route: routing.APIRoute, method: str):
if route.operation_id:
return route.operation_id
path: str = route.path
operation_id = route.name + path
operation_id = operation_id.replace("{", "_").replace("}", "_").replace("/", "_")
operation_id = operation_id + "_" + method.lower()
return operation_id
def generate_operation_summary(*, route: routing.APIRoute, method: str):
if route.summary:
return route.summary
return method.title() + " " + route.name.replace("_", " ").title()
def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
operation["summary"] = generate_operation_summary(route=route, method=method)
if route.description:
operation["description"] = route.description
operation["operationId"] = generate_operation_id(route=route, method=method)
if route.deprecated:
operation["deprecated"] = route.deprecated
return operation
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
return None
path = {}
security_schemes = {}
definitions = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
for method in route.methods:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
if route.summary:
operation["summary"] = route.summary
if route.description:
operation["description"] = route.description
if route.operation_id:
operation["operationId"] = route.operation_id
else:
operation["operationId"] = route.name
if route.deprecated:
operation["deprecated"] = route.deprecated
parameters = []
operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict] = []
flat_dependant = get_flat_dependant(route.dependant)
security_definitions = {}
for security_requirement in flat_dependant.security_requirements:
security_definition = jsonable_encoder(
security_requirement.security_scheme,
exclude={"scheme_name"},
by_alias=True,
include_none=False,
)
security_name = (
getattr(
security_requirement.security_scheme, "scheme_name", None
)
or security_requirement.security_scheme.__class__.__name__
)
security_definitions[security_name] = security_definition
operation.setdefault("security", []).append(
{security_name: security_requirement.scopes}
)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
)
if operation_security:
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(
security_definitions
)
security_schemes.update(security_definitions)
all_route_params = get_openapi_params(route.dependant)
for param in all_route_params:
if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition
definitions[
"HTTPValidationError"
] = validation_error_response_definition
parameter = {
"name": param.alias,
"in": param.schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
if param.schema.description:
parameter["description"] = param.schema.description
if param.schema.deprecated:
parameter["deprecated"] = param.schema.deprecated
parameters.append(parameter)
validation_definitions, operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params
)
definitions.update(validation_definitions)
parameters.extend(operation_parameters)
if parameters:
operation["parameters"] = parameters
if method in METHODS_WITH_BODY:
body_field = route.body_field
if body_field:
assert isinstance(body_field, Field)
body_schema, _ = field_schema(
body_field,
model_name_map=model_name_map,
ref_prefix=REF_PREFIX,
)
if isinstance(body_field.schema, Body):
request_media_type = body_field.schema.media_type
else:
# Includes not declared media types (Schema)
request_media_type = "application/json"
required = body_field.required
request_body_oai = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {
request_media_type: {"schema": body_schema}
}
request_body_oai = get_openapi_operation_request_body(
body_field=route.body_field, model_name_map=model_name_map
)
if request_body_oai:
operation["requestBody"] = request_body_oai
response_code = str(route.response_code)
response_schema = {"type": "string"}
@ -206,75 +243,3 @@ def get_openapi(
output["components"] = components
output["paths"] = paths
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
def get_swagger_ui_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<! doctype html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
<title>
""" + title + """
</title>
</head>
<body>
<div id="swagger-ui">
</div>
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({
url: '"""
+ openapi_url
+ """',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout"
})
</script>
</body>
</html>
""",
media_type="text/html",
)
def get_redoc_html(*, openapi_url: str, title: str):
return HTMLResponse(
"""
<!DOCTYPE html>
<html>
<head>
<title>
""" + title + """
</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {
margin: 0;
padding: 0;
}
</style>
</head>
<body>
<redoc spec-url='""" + openapi_url + """'></redoc>
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
</body>
</html>
""",
media_type="text/html",
)

52
fastapi/routing.py

@ -2,6 +2,11 @@ import asyncio
import inspect
from typing import Callable, List, Type
from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
@ -15,10 +20,6 @@ from fastapi import params
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
def serialize_response(*, field: Field = None, response):
@ -44,11 +45,12 @@ def get_app(
response_field: Type[Field] = None,
):
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.schema, params.Form)
async def app(request: Request) -> Response:
body = None
if body_field:
if isinstance(body_field.schema, params.Form):
if is_body_form:
raw_body = await request.form()
body = {}
for field, value in raw_body.items():
@ -127,12 +129,7 @@ class APIRoute(routing.Route):
response_code=200,
response_wrapper=JSONResponse,
) -> None:
# TODO define how to read and provide security params, and how to have them globally too
# TODO implement dependencies and injection
# TODO refactor code structure
# TODO create testing
# TODO testing coverage
assert path.startswith("/"), "Routed paths must always start '/'"
assert path.startswith("/"), "Routed paths must always start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
@ -260,6 +257,39 @@ class APIRouter(routing.Router):
return decorator
def include_router(self, router: "APIRouter", *, prefix=""):
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
for route in router.routes:
if isinstance(route, APIRoute):
self.add_api_route(
prefix + route.path,
route.endpoint,
methods=route.methods,
name=route.name,
include_in_schema=route.include_in_schema,
tags=route.tags,
summary=route.summary,
description=route.description,
operation_id=route.operation_id,
deprecated=route.deprecated,
response_type=route.response_type,
response_description=route.response_description,
response_code=route.response_code,
response_wrapper=route.response_wrapper,
)
elif isinstance(route, routing.Route):
self.add_route(
prefix + route.path,
route.endpoint,
methods=route.methods,
name=route.name,
include_in_schema=route.include_in_schema,
)
def get(
self,
path: str,

37
fastapi/security/api_key.py

@ -1,39 +1,34 @@
from enum import Enum
from pydantic import Schema
from starlette.requests import Request
from .base import SecurityBase, Types
class APIKeyIn(Enum):
query = "query"
header = "header"
cookie = "cookie"
from .base import SecurityBase
from fastapi.openapi.models import APIKeyIn, APIKey
class APIKeyBase(SecurityBase):
type_ = Schema(Types.apiKey, alias="type")
in_: str = Schema(..., alias="in")
name: str
pass
class APIKeyQuery(APIKeyBase):
in_ = Schema(APIKeyIn.query, alias="in")
def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.query, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
return requests.query_params.get(self.name)
return requests.query_params.get(self.model.name)
class APIKeyHeader(APIKeyBase):
in_ = Schema(APIKeyIn.header, alias="in")
def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.header, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
return requests.headers.get(self.name)
return requests.headers.get(self.model.name)
class APIKeyCookie(APIKeyBase):
in_ = Schema(APIKeyIn.cookie, alias="in")
def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.cookie, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, requests: Request):
return requests.cookies.get(self.name)
return requests.cookies.get(self.model.name)

18
fastapi/security/base.py

@ -1,16 +1,6 @@
from enum import Enum
from pydantic import BaseModel
from pydantic import BaseModel, Schema
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
class Types(Enum):
apiKey = "apiKey"
http = "http"
oauth2 = "oauth2"
openIdConnect = "openIdConnect"
class SecurityBase(BaseModel):
scheme_name: str = None
type_: Types = Schema(..., alias="type")
description: str = None
class SecurityBase:
pass

32
fastapi/security/http.py

@ -1,26 +1,40 @@
from pydantic import Schema
from starlette.requests import Request
from .base import SecurityBase, Types
from .base import SecurityBase
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
class HTTPBase(SecurityBase):
type_ = Schema(Types.http, alias="type")
scheme: str
def __init__(self, *, scheme: str, scheme_name: str = None):
self.model = HTTPBaseModel(scheme=scheme)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")
class HTTPBasic(HTTPBase):
scheme = "basic"
def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")
class HTTPBearer(HTTPBase):
scheme = "bearer"
bearerFormat: str = None
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")
class HTTPDigest(HTTPBase):
scheme = "digest"
def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="digest")
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")

42
fastapi/security/oauth2.py

@ -1,43 +1,13 @@
from typing import Dict
from pydantic import BaseModel, Schema
from starlette.requests import Request
from .base import SecurityBase, Types
class OAuthFlow(BaseModel):
refreshUrl: str = None
scopes: Dict[str, str] = {}
class OAuthFlowImplicit(OAuthFlow):
authorizationUrl: str
class OAuthFlowPassword(OAuthFlow):
tokenUrl: str
class OAuthFlowClientCredentials(OAuthFlow):
tokenUrl: str
class OAuthFlowAuthorizationCode(OAuthFlow):
authorizationUrl: str
tokenUrl: str
class OAuthFlows(BaseModel):
implicit: OAuthFlowImplicit = None
password: OAuthFlowPassword = None
clientCredentials: OAuthFlowClientCredentials = None
authorizationCode: OAuthFlowAuthorizationCode = None
from .base import SecurityBase
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
class OAuth2(SecurityBase):
type_ = Schema(Types.oauth2, alias="type")
flows: OAuthFlows
def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
self.model = OAuth2Model(flows=flows)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")

10
fastapi/security/open_id_connect_url.py

@ -1,11 +1,13 @@
from starlette.requests import Request
from .base import SecurityBase, Types
from .base import SecurityBase
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
class OpenIdConnect(SecurityBase):
type_ = Types.openIdConnect
openIdConnectUrl: str
def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
self.scheme_name = scheme_name or self.__class__.__name__
async def __call__(self, request: Request):
return request.headers.get("Authorization")

Loading…
Cancel
Save