From d8fe307d61a55148a4d95c550f0ef33148ba8681 Mon Sep 17 00:00:00 2001 From: dmontagu <35119617+dmontagu@users.noreply.github.com> Date: Sun, 29 Sep 2019 14:19:09 -0700 Subject: [PATCH] :sparkles: Add support for strings and __future__ type annotations (#451) * Add support for strings and __future__ annotations * Add comments indicating reason for string annotations * Fix ignores (including removing some unused ignores) --- fastapi/dependencies/utils.py | 38 ++++++++++++++++++++++++++++++----- fastapi/openapi/models.py | 2 +- fastapi/utils.py | 4 ++-- tests/test_security_oauth2.py | 9 ++++++--- 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 7f0f59092..852f1e025 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -26,7 +26,7 @@ from pydantic.error_wrappers import ErrorWrapper from pydantic.errors import MissingError from pydantic.fields import Field, Required, Shape from pydantic.schema import get_annotation_from_schema -from pydantic.utils import lenient_issubclass +from pydantic.utils import ForwardRef, evaluate_forwardref, lenient_issubclass from starlette.background import BackgroundTasks from starlette.concurrency import run_in_threadpool from starlette.datastructures import FormData, Headers, QueryParams, UploadFile @@ -171,6 +171,30 @@ def is_scalar_sequence_field(field: Field) -> bool: return False +def get_typed_signature(call: Callable) -> inspect.Signature: + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param, globalns), + ) + for param in signature.parameters.values() + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature + + +def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any: + annotation = param.annotation + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + return annotation + + def get_dependant( *, path: str, @@ -180,7 +204,7 @@ def get_dependant( use_cache: bool = True, ) -> Dependant: path_param_names = get_path_param_names(path) - endpoint_signature = inspect.signature(call) + endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache) for param_name, param in signature_params.items(): @@ -329,8 +353,12 @@ async def solve_dependencies( ]: values: Dict[str, Any] = {} errors: List[ErrorWrapper] = [] - response = response or Response( # type: ignore - content=None, status_code=None, headers=None, media_type=None, background=None + response = response or Response( + content=None, + status_code=None, # type: ignore + headers=None, + media_type=None, + background=None, ) dependency_cache = dependency_cache or {} sub_dependant: Dependant @@ -405,7 +433,7 @@ async def solve_dependencies( values.update(cookie_values) errors += path_errors + query_errors + header_errors + cookie_errors if dependant.body_params: - body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above + body_values, body_errors = await request_body_to_args( # body_params checked above required_params=dependant.body_params, received_body=body ) values.update(body_values) diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py index 3dd9f04dc..e5c50070e 100644 --- a/fastapi/openapi/models.py +++ b/fastapi/openapi/models.py @@ -11,7 +11,7 @@ try: import email_validator assert email_validator # make autoflake ignore the unused import - from pydantic.types import EmailStr # type: ignore + from pydantic.types import EmailStr except ImportError: # pragma: no cover logger.warning( "email-validator not installed, email fields will be treated as str.\n" diff --git a/fastapi/utils.py b/fastapi/utils.py index 17a16b522..8cb0ec123 100644 --- a/fastapi/utils.py +++ b/fastapi/utils.py @@ -58,10 +58,10 @@ def create_cloned_field(field: Field) -> Field: use_type = original_type if lenient_issubclass(original_type, BaseModel): original_type = cast(Type[BaseModel], original_type) - use_type = create_model( # type: ignore + use_type = create_model( original_type.__name__, __config__=original_type.__config__, - __validators__=original_type.__validators__, + __validators__=original_type.__validators__, # type: ignore ) for f in original_type.__fields__.values(): use_type.__fields__[f.name] = f diff --git a/tests/test_security_oauth2.py b/tests/test_security_oauth2.py index 890613b29..5cf2592f3 100644 --- a/tests/test_security_oauth2.py +++ b/tests/test_security_oauth2.py @@ -21,18 +21,21 @@ class User(BaseModel): username: str -def get_current_user(oauth_header: str = Security(reusable_oauth2)): +# Here we use string annotations to test them +def get_current_user(oauth_header: "str" = Security(reusable_oauth2)): user = User(username=oauth_header) return user @app.post("/login") -def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()): +# Here we use string annotations to test them +def read_current_user(form_data: "OAuth2PasswordRequestFormStrict" = Depends()): return form_data @app.get("/users/me") -def read_current_user(current_user: User = Depends(get_current_user)): +# Here we use string annotations to test them +def read_current_user(current_user: "User" = Depends(get_current_user)): return current_user