Browse Source

Add support for strings and __future__ type annotations (#451)

* Add support for strings and __future__ annotations

* Add comments indicating reason for string annotations

* Fix ignores (including removing some unused ignores)
pull/568/head
dmontagu 6 years ago
committed by Sebastián Ramírez
parent
commit
d8fe307d61
  1. 38
      fastapi/dependencies/utils.py
  2. 2
      fastapi/openapi/models.py
  3. 4
      fastapi/utils.py
  4. 9
      tests/test_security_oauth2.py

38
fastapi/dependencies/utils.py

@ -26,7 +26,7 @@ from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required, Shape
from pydantic.schema import get_annotation_from_schema
from pydantic.utils import lenient_issubclass
from pydantic.utils import ForwardRef, evaluate_forwardref, lenient_issubclass
from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
@ -171,6 +171,30 @@ def is_scalar_sequence_field(field: Field) -> bool:
return False
def get_typed_signature(call: Callable) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param, globalns),
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation
def get_dependant(
*,
path: str,
@ -180,7 +204,7 @@ def get_dependant(
use_cache: bool = True,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
for param_name, param in signature_params.items():
@ -329,8 +353,12 @@ async def solve_dependencies(
]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
response = response or Response( # type: ignore
content=None, status_code=None, headers=None, media_type=None, background=None
response = response or Response(
content=None,
status_code=None, # type: ignore
headers=None,
media_type=None,
background=None,
)
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
@ -405,7 +433,7 @@ async def solve_dependencies(
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above
body_values, body_errors = await request_body_to_args( # body_params checked above
required_params=dependant.body_params, received_body=body
)
values.update(body_values)

2
fastapi/openapi/models.py

@ -11,7 +11,7 @@ try:
import email_validator
assert email_validator # make autoflake ignore the unused import
from pydantic.types import EmailStr # type: ignore
from pydantic.types import EmailStr
except ImportError: # pragma: no cover
logger.warning(
"email-validator not installed, email fields will be treated as str.\n"

4
fastapi/utils.py

@ -58,10 +58,10 @@ def create_cloned_field(field: Field) -> Field:
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
use_type = create_model( # type: ignore
use_type = create_model(
original_type.__name__,
__config__=original_type.__config__,
__validators__=original_type.__validators__,
__validators__=original_type.__validators__, # type: ignore
)
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = f

9
tests/test_security_oauth2.py

@ -21,18 +21,21 @@ class User(BaseModel):
username: str
def get_current_user(oauth_header: str = Security(reusable_oauth2)):
# Here we use string annotations to test them
def get_current_user(oauth_header: "str" = Security(reusable_oauth2)):
user = User(username=oauth_header)
return user
@app.post("/login")
def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()):
# Here we use string annotations to test them
def read_current_user(form_data: "OAuth2PasswordRequestFormStrict" = Depends()):
return form_data
@app.get("/users/me")
def read_current_user(current_user: User = Depends(get_current_user)):
# Here we use string annotations to test them
def read_current_user(current_user: "User" = Depends(get_current_user)):
return current_user

Loading…
Cancel
Save