diff --git a/docs/img/tutorial/path-params/image03.png b/docs/img/tutorial/path-params/image03.png new file mode 100644 index 000000000..d08645d1c Binary files /dev/null and b/docs/img/tutorial/path-params/image03.png differ diff --git a/docs/src/path_params/tutorial005.py b/docs/src/path_params/tutorial005.py new file mode 100644 index 000000000..d4f24bce7 --- /dev/null +++ b/docs/src/path_params/tutorial005.py @@ -0,0 +1,21 @@ +from enum import Enum + +from fastapi import FastAPI + + +class ModelName(Enum): + alexnet = "alexnet" + resnet = "resnet" + lenet = "lenet" + + +app = FastAPI() + + +@app.get("/model/{model_name}") +async def get_model(model_name: ModelName): + if model_name == ModelName.alexnet: + return {"model_name": model_name, "message": "Deep Learning FTW!"} + if model_name.value == "lenet": + return {"model_name": model_name, "message": "LeCNN all the images"} + return {"model_name": model_name, "message": "Have some residuals"} diff --git a/docs/src/query_params/tutorial007.py b/docs/src/query_params/tutorial007.py new file mode 100644 index 000000000..8ef5b3004 --- /dev/null +++ b/docs/src/query_params/tutorial007.py @@ -0,0 +1,11 @@ +from typing import Optional + +from fastapi import FastAPI + +app = FastAPI() + + +@app.get("/items/{item_id}") +async def read_user_item(item_id: str, limit: Optional[int] = None): + item = {"item_id": item_id, "limit": limit} + return item diff --git a/docs/src/query_params_str_validations/tutorial012.py b/docs/src/query_params_str_validations/tutorial012.py new file mode 100644 index 000000000..7ea9f017d --- /dev/null +++ b/docs/src/query_params_str_validations/tutorial012.py @@ -0,0 +1,11 @@ +from typing import List + +from fastapi import FastAPI, Query + +app = FastAPI() + + +@app.get("/items/") +async def read_items(q: List[str] = Query(["foo", "bar"])): + query_items = {"q": q} + return query_items diff --git a/docs/tutorial/path-params.md b/docs/tutorial/path-params.md index b7cf9d4df..96e29366e 100644 --- a/docs/tutorial/path-params.md +++ b/docs/tutorial/path-params.md @@ -35,7 +35,7 @@ If you run this example and open your browser at "parsing". ## Data validation @@ -61,12 +61,11 @@ because the path parameter `item_id` had a value of `"foo"`, which is not an `in The same error would appear if you provided a `float` instead of an int, as in: http://127.0.0.1:8000/items/4.2 - !!! check So, with the same Python type declaration, **FastAPI** gives you data validation. - Notice that the error also clearly states exactly the point where the validation didn't pass. - + Notice that the error also clearly states exactly the point where the validation didn't pass. + This is incredibly helpful while developing and debugging code that interacts with your API. ## Documentation @@ -96,8 +95,7 @@ All the data validation is performed under the hood by `Enum`. + +### Create an `Enum` class + +Import `Enum` and create a sub-class that inherits from it. + +And create class attributes with fixed values, those fixed values will be the available valid values: + +```Python hl_lines="1 6 7 8 9" +{!./src/path_params/tutorial005.py!} +``` + +!!! info + Enumerations (or enums) are available in Python since version 3.4. + +!!! tip + If you are wondering, "AlexNet", "ResNet", and "LeNet" are just names of Machine Learning models. + +### Declare a *path parameter* + +Then create a *path parameter* with a type annotation using the enum class you created (`ModelName`): + +```Python hl_lines="16" +{!./src/path_params/tutorial005.py!} +``` + +### Check the docs + +Because the available values for the *path parameter* are specified, the interactive docs can show them nicely: + + + +### Working with Python *enumerations* + +The value of the *path parameter* will be an *enumeration member*. + +#### Compare *enumeration members* + +You can compare it with the *enumeration member* in your created enum `ModelName`: + +```Python hl_lines="17" +{!./src/path_params/tutorial005.py!} +``` + +#### Get the *enumeration value* + +You can get the actual value (a `str` in this case) using `model_name.value`, or in general, `your_enum_member.value`: + +```Python hl_lines="19" +{!./src/path_params/tutorial005.py!} +``` + +!!! tip + You could also access the value `"lenet"` with `ModelName.lenet.value`. + +#### Return *enumeration members* + +You can return *enum members* from your *path operation*, even nested in a JSON body (e.g. a `dict`). + +They will be converted to their corresponding values before returning them to the client: + +```Python hl_lines="18 20 21" +{!./src/path_params/tutorial005.py!} +``` + ## Path parameters containing paths Let's say you have a *path operation* with a path `/files/{file_path}`. diff --git a/docs/tutorial/query-params-str-validations.md b/docs/tutorial/query-params-str-validations.md index a82018437..4258a71fd 100644 --- a/docs/tutorial/query-params-str-validations.md +++ b/docs/tutorial/query-params-str-validations.md @@ -12,7 +12,6 @@ The query parameter `q` is of type `str`, and by default is `None`, so it is opt We are going to enforce that even though `q` is optional, whenever it is provided, it **doesn't exceed a length of 50 characters**. - ### Import `Query` To achieve that, first import `Query` from `fastapi`: @@ -29,7 +28,7 @@ And now use it as the default value of your parameter, setting the parameter `ma {!./src/query_params_str_validations/tutorial002.py!} ``` -As we have to replace the default value `None` with `Query(None)`, the first parameter to `Query` serves the same purpose of defining that default value. +As we have to replace the default value `None` with `Query(None)`, the first parameter to `Query` serves the same purpose of defining that default value. So: @@ -41,7 +40,7 @@ q: str = Query(None) ```Python q: str = None -``` +``` But it declares it explicitly as being a query parameter. @@ -53,7 +52,6 @@ q: str = Query(None, max_length=50) This will validate the data, show a clear error when the data is not valid, and document the parameter in the OpenAPI schema path operation. - ## Add more validations You can also add a parameter `min_length`: @@ -119,7 +117,7 @@ So, when you need to declare a value as required while using `Query`, you can us {!./src/query_params_str_validations/tutorial006.py!} ``` -!!! info +!!! info If you hadn't seen that `...` before: it is a a special single value, it is part of Python and is called "Ellipsis". This will let **FastAPI** know that this parameter is required. @@ -156,11 +154,35 @@ So, the response to that URL would be: !!! tip To declare a query parameter with a type of `list`, like in the example above, you need to explicitly use `Query`, otherwise it would be interpreted as a request body. - The interactive API docs will update accordingly, to allow multiple values: +### Query parameter list / multiple values with defaults + +And you can also define a default `list` of values if none are provided: + +```Python hl_lines="9" +{!./src/query_params_str_validations/tutorial012.py!} +``` + +If you go to: + +``` +http://localhost:8000/items/ +``` + +the default of `q` will be: `["foo", "bar"]` and your response will be: + +```JSON +{ + "q": [ + "foo", + "bar" + ] +} +``` + ## Declare more metadata You can add more information about the parameter. diff --git a/docs/tutorial/query-params.md b/docs/tutorial/query-params.md index 54a71f36d..85a69205d 100644 --- a/docs/tutorial/query-params.md +++ b/docs/tutorial/query-params.md @@ -186,3 +186,39 @@ In this case, there are 3 query parameters: * `needy`, a required `str`. * `skip`, an `int` with a default value of `0`. * `limit`, an optional `int`. + +!!! tip + You could also use `Enum`s the same way as with *path parameters*. + +## Optional type declarations + +!!! warning + This might be an advanced use case. + + You might want to skip it. + +If you are using `mypy` it could complain with type declarations like: + +```Python +limit: int = None +``` + +With an error like: + +``` +Incompatible types in assignment (expression has type "None", variable has type "int") +``` + +In those cases you can use `Optional` to tell `mypy` that the value could be `None`, like: + +```Python +from typing import Optional + +limit: Optional[int] = None +``` + +In a *path operation* that could look like: + +```Python hl_lines="9" +{!./src/query_params/tutorial007.py!} +``` diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 194187f28..2596d5754 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,8 +1,6 @@ import asyncio import inspect from copy import deepcopy -from datetime import date, datetime, time, timedelta -from decimal import Decimal from typing import ( Any, Callable, @@ -14,8 +12,8 @@ from typing import ( Tuple, Type, Union, + cast, ) -from uuid import UUID from fastapi import params from fastapi.dependencies.models import Dependant, SecurityRequirement @@ -23,7 +21,7 @@ from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.utils import get_path_param_names -from pydantic import BaseConfig, Schema, create_model +from pydantic import BaseConfig, BaseModel, Schema, create_model from pydantic.error_wrappers import ErrorWrapper from pydantic.errors import MissingError from pydantic.fields import Field, Required, Shape @@ -35,22 +33,21 @@ from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.requests import Request from starlette.websockets import WebSocket -param_supported_types = ( - str, - int, - float, - bool, - UUID, - date, - datetime, - time, - timedelta, - Decimal, -) - -sequence_shapes = {Shape.LIST, Shape.SET, Shape.TUPLE} +sequence_shapes = { + Shape.LIST, + Shape.SET, + Shape.TUPLE, + Shape.SEQUENCE, + Shape.TUPLE_ELLIPS, +} sequence_types = (list, set, tuple) -sequence_shape_to_type = {Shape.LIST: list, Shape.SET: set, Shape.TUPLE: tuple} +sequence_shape_to_type = { + Shape.LIST: list, + Shape.SET: set, + Shape.TUPLE: tuple, + Shape.SEQUENCE: list, + Shape.TUPLE_ELLIPS: list, +} def get_param_sub_dependant( @@ -126,6 +123,26 @@ def get_flat_dependant(dependant: Dependant) -> Dependant: return flat_dependant +def is_scalar_field(field: Field) -> bool: + return ( + field.shape == Shape.SINGLETON + and not lenient_issubclass(field.type_, BaseModel) + and not isinstance(field.schema, params.Body) + ) + + +def is_scalar_sequence_field(field: Field) -> bool: + if field.shape in sequence_shapes and not lenient_issubclass( + field.type_, BaseModel + ): + if field.sub_fields is not None: + for sub_field in field.sub_fields: + if not is_scalar_field(sub_field): + return False + return True + return False + + def get_dependant( *, path: str, call: Callable, name: str = None, security_scopes: List[str] = None ) -> Dependant: @@ -133,83 +150,78 @@ def get_dependant( endpoint_signature = inspect.signature(call) signature_params = endpoint_signature.parameters dependant = Dependant(call=call, name=name) - for param_name in signature_params: - param = signature_params[param_name] + for param_name, param in signature_params.items(): if isinstance(param.default, params.Depends): sub_dependant = get_param_sub_dependant( param=param, path=path, security_scopes=security_scopes ) dependant.dependencies.append(sub_dependant) - for param_name in signature_params: - param = signature_params[param_name] - if ( - (param.default == param.empty) or isinstance(param.default, params.Path) - ) and (param_name in path_param_names): - assert ( - lenient_issubclass(param.annotation, param_supported_types) - or param.annotation == param.empty + for param_name, param in signature_params.items(): + if isinstance(param.default, params.Depends): + continue + if add_non_field_param_to_dependency(param=param, dependant=dependant): + continue + param_field = get_param_field(param=param, default_schema=params.Query) + if param_name in path_param_names: + assert param.default == param.empty or isinstance( + param.default, params.Path + ), "Path params must have no defaults or use Path(...)" + assert is_scalar_field( + field=param_field ), f"Path params must be of one of the supported types" - add_param_to_fields( + param_field = get_param_field( param=param, - dependant=dependant, default_schema=params.Path, force_type=params.ParamTypes.path, ) - elif ( - param.default == param.empty - or param.default is None - or isinstance(param.default, param_supported_types) - ) and ( - param.annotation == param.empty - or lenient_issubclass(param.annotation, param_supported_types) - ): - add_param_to_fields( - param=param, dependant=dependant, default_schema=params.Query - ) - elif isinstance(param.default, params.Param): - if param.annotation != param.empty: - origin = getattr(param.annotation, "__origin__", None) - param_all_types = param_supported_types + (list, tuple, set) - if isinstance(param.default, (params.Query, params.Header)): - assert lenient_issubclass( - param.annotation, param_all_types - ) or lenient_issubclass( - origin, param_all_types - ), f"Parameters for Query and Header must be of type str, int, float, bool, list, tuple or set: {param}" - else: - assert lenient_issubclass( - param.annotation, param_supported_types - ), f"Parameters for Path and Cookies must be of type str, int, float, bool: {param}" - add_param_to_fields( - param=param, dependant=dependant, default_schema=params.Query - ) - elif lenient_issubclass(param.annotation, Request): - dependant.request_param_name = param_name - elif lenient_issubclass(param.annotation, WebSocket): - dependant.websocket_param_name = param_name - elif lenient_issubclass(param.annotation, BackgroundTasks): - dependant.background_tasks_param_name = param_name - elif lenient_issubclass(param.annotation, SecurityScopes): - dependant.security_scopes_param_name = param_name - elif not isinstance(param.default, params.Depends): - add_param_to_body_fields(param=param, dependant=dependant) + add_param_to_fields(field=param_field, dependant=dependant) + elif is_scalar_field(field=param_field): + add_param_to_fields(field=param_field, dependant=dependant) + elif isinstance( + param.default, (params.Query, params.Header) + ) and is_scalar_sequence_field(param_field): + add_param_to_fields(field=param_field, dependant=dependant) + else: + assert isinstance( + param_field.schema, params.Body + ), f"Param: {param_field.name} can only be a request body, using Body(...)" + dependant.body_params.append(param_field) return dependant -def add_param_to_fields( +def add_non_field_param_to_dependency( + *, param: inspect.Parameter, dependant: Dependant +) -> Optional[bool]: + if lenient_issubclass(param.annotation, Request): + dependant.request_param_name = param.name + return True + elif lenient_issubclass(param.annotation, WebSocket): + dependant.websocket_param_name = param.name + return True + elif lenient_issubclass(param.annotation, BackgroundTasks): + dependant.background_tasks_param_name = param.name + return True + elif lenient_issubclass(param.annotation, SecurityScopes): + dependant.security_scopes_param_name = param.name + return True + return None + + +def get_param_field( *, param: inspect.Parameter, - dependant: Dependant, - default_schema: Type[Schema] = params.Param, + default_schema: Type[params.Param] = params.Param, force_type: params.ParamTypes = None, -) -> None: +) -> Field: default_value = Required + had_schema = False if not param.default == param.empty: default_value = param.default - if isinstance(default_value, params.Param): + if isinstance(default_value, Schema): + had_schema = True schema = default_value default_value = schema.default - if getattr(schema, "in_", None) is None: + if isinstance(schema, params.Param) and getattr(schema, "in_", None) is None: schema.in_ = default_schema.in_ if force_type: schema.in_ = force_type @@ -234,43 +246,26 @@ def add_param_to_fields( class_validators={}, schema=schema, ) - if schema.in_ == params.ParamTypes.path: + if not had_schema and not is_scalar_field(field=field): + field.schema = params.Body(schema.default) + return field + + +def add_param_to_fields(*, field: Field, dependant: Dependant) -> None: + field.schema = cast(params.Param, field.schema) + if field.schema.in_ == params.ParamTypes.path: dependant.path_params.append(field) - elif schema.in_ == params.ParamTypes.query: + elif field.schema.in_ == params.ParamTypes.query: dependant.query_params.append(field) - elif schema.in_ == params.ParamTypes.header: + elif field.schema.in_ == params.ParamTypes.header: dependant.header_params.append(field) else: assert ( - schema.in_ == params.ParamTypes.cookie - ), f"non-body parameters must be in path, query, header or cookie: {param.name}" + field.schema.in_ == params.ParamTypes.cookie + ), f"non-body parameters must be in path, query, header or cookie: {field.name}" dependant.cookie_params.append(field) -def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None: - default_value = Required - if not param.default == param.empty: - default_value = param.default - if isinstance(default_value, Schema): - schema = default_value - default_value = schema.default - else: - schema = Schema(default_value) - required = default_value == Required - annotation = get_annotation_from_schema(param.annotation, schema) - field = Field( - name=param.name, - type_=annotation, - default=None if required else default_value, - alias=schema.alias or param.name, - required=required, - model_config=BaseConfig, - class_validators={}, - schema=schema, - ) - dependant.body_params.append(field) - - def is_coroutine_callable(call: Callable) -> bool: if inspect.isfunction(call): return asyncio.iscoroutinefunction(call) @@ -354,7 +349,7 @@ def request_params_to_args( if field.shape in sequence_shapes and isinstance( received_params, (QueryParams, Headers) ): - value = received_params.getlist(field.alias) + value = received_params.getlist(field.alias) or field.default else: value = received_params.get(field.alias) schema: params.Param = field.schema diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 87e223cb6..26d491bea 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast from fastapi import routing from fastapi.dependencies.models import Dependant @@ -9,7 +9,7 @@ from fastapi.openapi.models import OpenAPI from fastapi.params import Body, Param from fastapi.utils import get_flat_models_from_routes, get_model_definitions from pydantic.fields import Field -from pydantic.schema import Schema, field_schema, get_model_name_map +from pydantic.schema import field_schema, get_model_name_map from pydantic.utils import lenient_issubclass from starlette.responses import JSONResponse from starlette.routing import BaseRoute @@ -97,12 +97,8 @@ def get_openapi_operation_request_body( body_schema, _ = field_schema( body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) - schema: Schema = body_field.schema - if isinstance(schema, Body): - request_media_type = schema.media_type - else: - # Includes not declared media types (Schema) - request_media_type = "application/json" + body_field.schema = cast(Body, body_field.schema) + request_media_type = body_field.schema.media_type required = body_field.required request_body_oai: Dict[str, Any] = {} if required: diff --git a/pyproject.toml b/pyproject.toml index e8be41838..8700b0ef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ ] requires = [ "starlette >=0.11.1,<=0.12.0", - "pydantic >=0.17,<=0.26.0" + "pydantic >=0.26,<=0.26.0" ] description-file = "README.md" requires-python = ">=3.6" diff --git a/tests/test_invalid_sequence_param.py b/tests/test_invalid_sequence_param.py new file mode 100644 index 000000000..bdc4b1bcb --- /dev/null +++ b/tests/test_invalid_sequence_param.py @@ -0,0 +1,29 @@ +from typing import List, Tuple + +import pytest +from fastapi import FastAPI, Query +from pydantic import BaseModel + + +def test_invalid_sequence(): + with pytest.raises(AssertionError): + app = FastAPI() + + class Item(BaseModel): + title: str + + @app.get("/items/") + def read_items(q: List[Item] = Query(None)): + pass # pragma: no cover + + +def test_invalid_tuple(): + with pytest.raises(AssertionError): + app = FastAPI() + + class Item(BaseModel): + title: str + + @app.get("/items/") + def read_items(q: Tuple[Item, Item] = Query(None)): + pass # pragma: no cover diff --git a/tests/test_tutorial/test_path_params/test_tutorial005.py b/tests/test_tutorial/test_path_params/test_tutorial005.py new file mode 100644 index 000000000..3245cdceb --- /dev/null +++ b/tests/test_tutorial/test_path_params/test_tutorial005.py @@ -0,0 +1,120 @@ +import pytest +from starlette.testclient import TestClient + +from path_params.tutorial005 import app + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/model/{model_name}": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Get Model", + "operationId": "get_model_model__model_name__get", + "parameters": [ + { + "required": True, + "schema": { + "title": "Model_Name", + "enum": ["alexnet", "resnet", "lenet"], + }, + "name": "model_name", + "in": "path", + } + ], + } + } + }, + "components": { + "schemas": { + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +def test_openapi(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +@pytest.mark.parametrize( + "url,status_code,expected", + [ + ( + "/model/alexnet", + 200, + {"model_name": "alexnet", "message": "Deep Learning FTW!"}, + ), + ( + "/model/lenet", + 200, + {"model_name": "lenet", "message": "LeCNN all the images"}, + ), + ( + "/model/resnet", + 200, + {"model_name": "resnet", "message": "Have some residuals"}, + ), + ( + "/model/foo", + 422, + { + "detail": [ + { + "loc": ["path", "model_name"], + "msg": "value is not a valid enumeration member", + "type": "type_error.enum", + } + ] + }, + ), + ], +) +def test_get_enums(url, status_code, expected): + response = client.get(url) + assert response.status_code == status_code + assert response.json() == expected diff --git a/tests/test_tutorial/test_query_params/test_tutorial007.py b/tests/test_tutorial/test_query_params/test_tutorial007.py new file mode 100644 index 000000000..a0fb23850 --- /dev/null +++ b/tests/test_tutorial/test_query_params/test_tutorial007.py @@ -0,0 +1,95 @@ +from starlette.testclient import TestClient + +from query_params.tutorial007 import app + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/items/{item_id}": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read User Item", + "operationId": "read_user_item_items__item_id__get", + "parameters": [ + { + "required": True, + "schema": {"title": "Item_Id", "type": "string"}, + "name": "item_id", + "in": "path", + }, + { + "required": False, + "schema": {"title": "Limit", "type": "integer"}, + "name": "limit", + "in": "query", + }, + ], + } + } + }, + "components": { + "schemas": { + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +def test_openapi(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_read_item(): + response = client.get("/items/foo") + assert response.status_code == 200 + assert response.json() == {"item_id": "foo", "limit": None} + + +def test_read_item_query(): + response = client.get("/items/foo?limit=5") + assert response.status_code == 200 + assert response.json() == {"item_id": "foo", "limit": 5} diff --git a/tests/test_tutorial/test_query_params_str_validations/test_tutorial012.py b/tests/test_tutorial/test_query_params_str_validations/test_tutorial012.py new file mode 100644 index 000000000..1e00c5017 --- /dev/null +++ b/tests/test_tutorial/test_query_params_str_validations/test_tutorial012.py @@ -0,0 +1,96 @@ +from starlette.testclient import TestClient + +from query_params_str_validations.tutorial012 import app + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/items/": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Read Items", + "operationId": "read_items_items__get", + "parameters": [ + { + "required": False, + "schema": { + "title": "Q", + "type": "array", + "items": {"type": "string"}, + "default": ["foo", "bar"], + }, + "name": "q", + "in": "query", + } + ], + } + } + }, + "components": { + "schemas": { + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +def test_default_query_values(): + url = "/items/" + response = client.get(url) + assert response.status_code == 200 + assert response.json() == {"q": ["foo", "bar"]} + + +def test_multi_query_values(): + url = "/items/?q=baz&q=foobar" + response = client.get(url) + assert response.status_code == 200 + assert response.json() == {"q": ["baz", "foobar"]}