From 6d4cc97266d9878c193d5d27c4a18b8c32702d57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 8 Dec 2018 11:56:07 +0400 Subject: [PATCH] :white_check_mark: Add first tests, for path and query --- tests/__init__.py | 0 tests/endpoints/__init__.py | 0 tests/endpoints/a.py | 13 ++ tests/endpoints/b.py | 13 ++ tests/main.py | 350 ++++++++++++++++++++++++++++++++++++ tests/test_path.py | 73 ++++++++ tests/test_query.py | 44 +++++ 7 files changed, 493 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/endpoints/__init__.py create mode 100644 tests/endpoints/a.py create mode 100644 tests/endpoints/b.py create mode 100644 tests/main.py create mode 100644 tests/test_path.py create mode 100644 tests/test_query.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/endpoints/__init__.py b/tests/endpoints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/endpoints/a.py b/tests/endpoints/a.py new file mode 100644 index 000000000..be5663f59 --- /dev/null +++ b/tests/endpoints/a.py @@ -0,0 +1,13 @@ +from fastapi.routing import APIRouter + +router = APIRouter() + + +@router.get("/dog") +def get_a_dog(): + return "Woof" + + +@router.get("/cat") +def get_a_cat(): + return "Meow" diff --git a/tests/endpoints/b.py b/tests/endpoints/b.py new file mode 100644 index 000000000..7747fb2e4 --- /dev/null +++ b/tests/endpoints/b.py @@ -0,0 +1,13 @@ +from fastapi.routing import APIRouter + +router = APIRouter() + + +@router.get("/dog") +def get_b_dog(): + return "B Woof" + + +@router.get("/cat") +def get_b_cat(): + return "B Meow" diff --git a/tests/main.py b/tests/main.py new file mode 100644 index 000000000..4d2b199b2 --- /dev/null +++ b/tests/main.py @@ -0,0 +1,350 @@ +from fastapi.applications import FastAPI +from fastapi.params import ( + Body, + Cookie, + Depends, + File, + Form, + Header, + Param, + Path, + Query, + Security, +) +from fastapi.security.http import HTTPBasic +from fastapi.security.oauth2 import ( + OAuth2, + OAuth2PasswordRequestData, + OAuth2PasswordRequestForm, +) +from pydantic import BaseModel +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse +from starlette.status import HTTP_202_ACCEPTED +from starlette.testclient import TestClient + +from .endpoints.a import router as router_a +from .endpoints.b import router as router_b + +app = FastAPI() + + +app.include_router(router_a) +app.include_router(router_b, prefix="/b") + + +@app.get("/text") +def get_text(): + return "Hello World" + + +@app.get("/path/{item_id}") +def get_id(item_id): + return item_id + + +@app.get("/path/str/{item_id}") +def get_str_id(item_id: str): + return item_id + + +@app.get("/path/int/{item_id}") +def get_int_id(item_id: int): + return item_id + + +@app.get("/path/float/{item_id}") +def get_float_id(item_id: float): + return item_id + + +@app.get("/path/bool/{item_id}") +def get_bool_id(item_id: bool): + return item_id + + +@app.get("/path/param/{item_id}") +def get_path_param_id(item_id: str = Path(None)): + return item_id + + +@app.get("/path/param-required/{item_id}") +def get_path_param_required_id(item_id: str = Path(...)): + return item_id + + +@app.get("/query") +def get_query(query): + if query is None: + return "foo bar" + return f"foo bar {query}" + + +@app.get("/query/optional") +def get_query_optional(query=None): + if query is None: + return "foo bar" + return f"foo bar {query}" + + +@app.get("/query/int") +def get_query_type(query: int): + if query is None: + return "foo bar" + return f"foo bar {query}" + + +@app.get("/query/int/optional") +def get_query_type_optional(query: int = None): + if query is None: + return "foo bar" + return f"foo bar {query}" + + +@app.get("/query/param") +def get_query_param(query=Query(None)): + if query is None: + return "foo bar" + return f"foo bar {query}" + + +@app.get("/query/param-required") +def get_query_param_required(query=Query(...)): + if query is None: + return "foo bar" + return f"foo bar {query}" + + +@app.get("/query/param-required/int") +def get_query_param_required_type(query: int = Query(...)): + if query is None: + return "foo bar" + return f"foo bar {query}" + + +@app.get("/cookie") +def get_cookie(coo=Cookie(None)): + return coo + + +@app.get("/header") +def get_header(head_name=Header(None)): + return head_name + + +@app.get("/header_under") +def get_header(head_name=Header(None, convert_underscores=False)): + return head_name + + +@app.get("/param") +def get_param(par=Param(None)): + return par + + +@app.get("/security") +def get_security(sec=Security(HTTPBasic())): + return sec + + +reusable_oauth2 = OAuth2( + flows={ + "password": { + "tokenUrl": "/token", + "scopes": {"read:user": "Read a User", "write:user": "Create a user"}, + } + } +) + + +@app.get("/security/oauth2") +def get_security_oauth2(sec=Security(reusable_oauth2, scopes=["read:user"])): + return sec + + +@app.post("/token") +def post_token(request_data: OAuth2PasswordRequestForm = Form(...)): + print(request_data) + data = request_data.parse() + print(data) + + print(request_data()) + access_token = request_data.username + ":" + request_data.password + return {"access_token": access_token} + + +class Item(BaseModel): + name: str + price: float + is_offer: bool + + +@app.put("/items/{item_id}") +def put_item(item_id: str, item: Item): + return item + + +@app.post("/items/") +def post_item(item: Item): + return item + + +@app.post("/items-all-params/{item_id}") +def post_items_all_params( + item_id: str = Path(...), + body: Item = Body(...), + query_a: int = Query(None), + query_b=Query(None), + coo: str = Cookie(None), + x_head: int = Header(None), + x_under: str = Header(None, convert_underscores=False), +): + return { + "item_id": item_id, + "body": body, + "query_a": query_a, + "query_b": query_b, + "coo": coo, + "x_head": x_head, + "x_under": x_under, + } + + +@app.post("/items-all-params-defaults/{item_id}") +def post_items_all_params_default( + item_id: str, + body_item_a: Item, + body_item_b: Item, + query_a: int, + query_b: int, + coo: str = Cookie(None), + x_head: int = Header(None), + x_under: str = Header(None, convert_underscores=False), +): + return { + "item_id": item_id, + "body_item_a": body_item_a, + "body_item_b": body_item_b, + "query_a": query_a, + "query_b": query_b, + "coo": coo, + "x_head": x_head, + "x_under": x_under, + } + + +@app.delete("/items/{item_id}") +def delete_item(item_id: str): + return item_id + + +@app.options("/options/") +def options(): + return JSONResponse(headers={"x-fastapi": "fast"}) + + +@app.head("/head/") +def head(): + return {"not sent": "nope"} + + +@app.patch("/patch/{user_id}") +def patch(user_id: str, increment: float): + return {"user_id": user_id, "total": 5 + increment} + + +@app.trace("/trace/") +def trace(): + return PlainTextResponse(media_type="message/http") + + +@app.get("/model", response_model=Item, status_code=HTTP_202_ACCEPTED) +def model(): + return {"name": "Foo", "price": "5.0", "password": "not sent"} + + +@app.get( + "/metadata", + tags=["tag1", "tag2"], + summary="The summary", + description="The description", + response_description="Response description", + deprecated=True, + operation_id="a_very_long_and_strange_operation_id", +) +def get_meta(): + return "Foo" + + +@app.get("/html", content_type=HTMLResponse) +def get_html(): + return """ + + +

+ Some text inside +

+ + + """ + + +class FakeDB: + def __init__(self): + self.data = { + "johndoe": { + "username": "johndoe", + "password": "shouldbehashed", + "fist_name": "John", + "last_name": "Doe", + } + } + + +class DBConnectionManager: + def __init__(self): + self.db = FakeDB() + + def __call__(self): + return self.db + + +connection_manager = DBConnectionManager() + + +class TokenUserData(BaseModel): + username: str + password: str + + +class UserInDB(BaseModel): + username: str + password: str + fist_name: str + last_name: str + + +def require_token( + token: str = Security(reusable_oauth2, scopes=["read:user", "write:user"]) +): + raw_token = token.replace("Bearer ", "") + # Never do this plaintext password usage in production + username, password = raw_token.split(":") + return TokenUserData(username=username, password=password) + + +def require_user( + db: FakeDB = Depends(connection_manager), + user_data: TokenUserData = Depends(require_token), +): + return db.data[user_data.username] + + +class UserOut(BaseModel): + username: str + fist_name: str + last_name: str + + +@app.get("/dependency", response_model=UserOut) +def get_dependency(user: UserInDB = Depends(require_user)): + return user diff --git a/tests/test_path.py b/tests/test_path.py new file mode 100644 index 000000000..f271331f4 --- /dev/null +++ b/tests/test_path.py @@ -0,0 +1,73 @@ +import pytest +from starlette.testclient import TestClient + +from .main import app + +client = TestClient(app) + + +def test_text_get(): + response = client.get("/text") + assert response.status_code == 200 + assert response.json() == "Hello World" + + +def test_nonexistent(): + response = client.get("/nonexistent") + assert response.status_code == 404 + assert response.json() == {"detail": "Not Found"} + + +response_not_valid_int = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] +} + +response_not_valid_float = { + "detail": [ + { + "loc": ["path", "item_id"], + "msg": "value is not a valid float", + "type": "type_error.float", + } + ] +} + + +@pytest.mark.parametrize( + "path,expected_status,expected_response", + [ + ("/path/foobar", 200, "foobar"), + ("/path/str/foobar", 200, "foobar"), + ("/path/str/42", 200, "42"), + ("/path/str/True", 200, "True"), + ("/path/int/foobar", 422, response_not_valid_int), + ("/path/int/True", 422, response_not_valid_int), + ("/path/int/42", 200, 42), + ("/path/int/42.5", 422, response_not_valid_int), + ("/path/float/foobar", 422, response_not_valid_float), + ("/path/float/True", 422, response_not_valid_float), + ("/path/float/42", 200, 42), + ("/path/float/42.5", 200, 42.5), + ("/path/bool/foobar", 200, False), + ("/path/bool/True", 200, True), + ("/path/bool/42", 200, False), + ("/path/bool/42.5", 200, False), + ("/path/bool/1", 200, True), + ("/path/bool/0", 200, False), + ("/path/bool/true", 200, True), + ("/path/bool/False", 200, False), + ("/path/bool/false", 200, False), + ("/path/param/foo", 200, "foo"), + ("/path/param-required/foo", 200, "foo"), + ], +) +def test_get_path(path, expected_status, expected_response): + response = client.get(path) + assert response.status_code == expected_status + assert response.json() == expected_response diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 000000000..fc792b843 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,44 @@ +import pytest +from starlette.testclient import TestClient + +from .main import app + +client = TestClient(app) + +response_missing = { + "detail": [ + {"loc": ["query"], "msg": "field required", "type": "value_error.missing"} + ] +} + +response_not_valid_int = { + "detail": [ + { + "loc": ["query", "query"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] +} + + +@pytest.mark.parametrize( + "path,expected_status,expected_response", + [ + ("/query", 422, response_missing), + ("/query?query=baz", 200, "foo bar baz"), + ("/query?not_declared=baz", 422, response_missing), + ("/query/optional", 200, "foo bar"), + ("/query/optional?query=baz", 200, "foo bar baz"), + ("/query/optional?not_declared=baz", 200, "foo bar"), + ("/query/int", 422, response_missing), + ("/query/int?query=42", 200, "foo bar 42"), + ("/query/int?query=42.5", 422, response_not_valid_int), + ("/query/int?query=baz", 422, response_not_valid_int), + ("/query/int?not_declared=baz", 422, response_missing), + ], +) +def test_get_path(path, expected_status, expected_response): + response = client.get(path) + assert response.status_code == expected_status + assert response.json() == expected_response