Browse Source

⬆ Upgrade Starlette to `0.21.0`, including the new [`TestClient` based on HTTPX](https://github.com/encode/starlette/releases/tag/0.21.0) (#5471)

Co-authored-by: Paweł Rubin <[email protected]>
Co-authored-by: Sebastián Ramírez <[email protected]>
pull/5628/head
Paweł Rubin 2 years ago
committed by GitHub
parent
commit
fdbd48be5f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      fastapi/security/api_key.py
  2. 8
      fastapi/security/http.py
  3. 6
      fastapi/security/oauth2.py
  4. 2
      fastapi/security/open_id_connect_url.py
  5. 6
      fastapi/security/utils.py
  6. 2
      pyproject.toml
  7. 2
      tests/test_enforce_once_required_parameter.py
  8. 2
      tests/test_extra_routes.py
  9. 2
      tests/test_get_request_body.py
  10. 9
      tests/test_param_include_in_schema.py
  11. 7
      tests/test_security_api_key_cookie.py
  12. 7
      tests/test_security_api_key_cookie_description.py
  13. 7
      tests/test_security_api_key_cookie_optional.py
  14. 8
      tests/test_tuples.py
  15. 2
      tests/test_tutorial/test_advanced_middleware/test_tutorial001.py
  16. 16
      tests/test_tutorial/test_body/test_tutorial001.py
  17. 16
      tests/test_tutorial/test_body/test_tutorial001_py310.py
  18. 5
      tests/test_tutorial/test_cookie_params/test_tutorial001.py
  19. 15
      tests/test_tutorial/test_cookie_params/test_tutorial001_py310.py
  20. 2
      tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py
  21. 2
      tests/test_tutorial/test_custom_response/test_tutorial006.py
  22. 2
      tests/test_tutorial/test_custom_response/test_tutorial006b.py
  23. 2
      tests/test_tutorial/test_custom_response/test_tutorial006c.py
  24. 2
      tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial006.py
  25. 6
      tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial007.py
  26. 12
      tests/test_tutorial/test_websockets/test_tutorial002.py

2
fastapi/security/api_key.py

@ -54,7 +54,7 @@ class APIKeyHeader(APIKeyBase):
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key: str = request.headers.get(self.model.name)
api_key = request.headers.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(

8
fastapi/security/http.py

@ -38,7 +38,7 @@ class HTTPBase(SecurityBase):
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
@ -67,7 +67,7 @@ class HTTPBasic(HTTPBase):
async def __call__( # type: ignore
self, request: Request
) -> Optional[HTTPBasicCredentials]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if self.realm:
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
@ -113,7 +113,7 @@ class HTTPBearer(HTTPBase):
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
@ -148,7 +148,7 @@ class HTTPDigest(HTTPBase):
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:

6
fastapi/security/oauth2.py

@ -126,7 +126,7 @@ class OAuth2(SecurityBase):
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
@ -157,7 +157,7 @@ class OAuth2PasswordBearer(OAuth2):
)
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
@ -200,7 +200,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
)
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:

2
fastapi/security/open_id_connect_url.py

@ -23,7 +23,7 @@ class OpenIdConnect(SecurityBase):
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(

6
fastapi/security/utils.py

@ -1,7 +1,9 @@
from typing import Tuple
from typing import Optional, Tuple
def get_authorization_scheme_param(authorization_header_value: str) -> Tuple[str, str]:
def get_authorization_scheme_param(
authorization_header_value: Optional[str],
) -> Tuple[str, str]:
if not authorization_header_value:
return "", ""
scheme, _, param = authorization_header_value.partition(" ")

2
pyproject.toml

@ -39,7 +39,7 @@ classifiers = [
"Topic :: Internet :: WWW/HTTP",
]
dependencies = [
"starlette==0.20.4",
"starlette==0.21.0",
"pydantic >=1.6.2,!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0",
]
dynamic = ["version"]

2
tests/test_enforce_once_required_parameter.py

@ -101,7 +101,7 @@ def test_schema():
def test_get_invalid():
response = client.get("/foo", params={"client_id": None})
response = client.get("/foo")
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

2
tests/test_extra_routes.py

@ -333,7 +333,7 @@ def test_get_api_route_not_decorated():
def test_delete():
response = client.delete("/items/foo", json={"name": "Foo"})
response = client.request("DELETE", "/items/foo", json={"name": "Foo"})
assert response.status_code == 200, response.text
assert response.json() == {"item_id": "foo", "item": {"name": "Foo", "price": None}}

2
tests/test_get_request_body.py

@ -104,5 +104,5 @@ def test_openapi_schema():
def test_get_with_body():
body = {"name": "Foo", "description": "Some description", "price": 5.5}
response = client.get("/product", json=body)
response = client.request("GET", "/product", json=body)
assert response.json() == body

9
tests/test_param_include_in_schema.py

@ -33,8 +33,6 @@ async def hidden_query(
return {"hidden_query": hidden_query}
client = TestClient(app)
openapi_shema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
@ -161,6 +159,7 @@ openapi_shema = {
def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_shema
@ -184,7 +183,8 @@ def test_openapi_schema():
],
)
def test_hidden_cookie(path, cookies, expected_status, expected_response):
response = client.get(path, cookies=cookies)
client = TestClient(app, cookies=cookies)
response = client.get(path)
assert response.status_code == expected_status
assert response.json() == expected_response
@ -207,12 +207,14 @@ def test_hidden_cookie(path, cookies, expected_status, expected_response):
],
)
def test_hidden_header(path, headers, expected_status, expected_response):
client = TestClient(app)
response = client.get(path, headers=headers)
assert response.status_code == expected_status
assert response.json() == expected_response
def test_hidden_path():
client = TestClient(app)
response = client.get("/hidden_path/hidden_path")
assert response.status_code == 200
assert response.json() == {"hidden_path": "hidden_path"}
@ -234,6 +236,7 @@ def test_hidden_path():
],
)
def test_hidden_query(path, expected_status, expected_response):
client = TestClient(app)
response = client.get(path)
assert response.status_code == expected_status
assert response.json() == expected_response

7
tests/test_security_api_key_cookie.py

@ -22,8 +22,6 @@ def read_current_user(current_user: User = Depends(get_current_user)):
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
@ -51,18 +49,21 @@ openapi_schema = {
def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me", cookies={"key": "secret"})
client = TestClient(app, cookies={"key": "secret"})
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 403, response.text
assert response.json() == {"detail": "Not authenticated"}

7
tests/test_security_api_key_cookie_description.py

@ -22,8 +22,6 @@ def read_current_user(current_user: User = Depends(get_current_user)):
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
@ -56,18 +54,21 @@ openapi_schema = {
def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me", cookies={"key": "secret"})
client = TestClient(app, cookies={"key": "secret"})
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 403, response.text
assert response.json() == {"detail": "Not authenticated"}

7
tests/test_security_api_key_cookie_optional.py

@ -29,8 +29,6 @@ def read_current_user(current_user: User = Depends(get_current_user)):
return current_user
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
@ -58,18 +56,21 @@ openapi_schema = {
def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema
def test_security_api_key():
response = client.get("/users/me", cookies={"key": "secret"})
client = TestClient(app, cookies={"key": "secret"})
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == {"username": "secret"}
def test_security_api_key_no_key():
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == {"msg": "Create an account first"}

8
tests/test_tuples.py

@ -252,16 +252,14 @@ def test_tuple_with_model_invalid():
def test_tuple_form_valid():
response = client.post("/tuple-form/", data=[("values", "1"), ("values", "2")])
response = client.post("/tuple-form/", data={"values": ("1", "2")})
assert response.status_code == 200, response.text
assert response.json() == [1, 2]
def test_tuple_form_invalid():
response = client.post(
"/tuple-form/", data=[("values", "1"), ("values", "2"), ("values", "3")]
)
response = client.post("/tuple-form/", data={"values": ("1", "2", "3")})
assert response.status_code == 422, response.text
response = client.post("/tuple-form/", data=[("values", "1")])
response = client.post("/tuple-form/", data={"values": ("1")})
assert response.status_code == 422, response.text

2
tests/test_tutorial/test_advanced_middleware/test_tutorial001.py

@ -9,6 +9,6 @@ def test_middleware():
assert response.status_code == 200, response.text
client = TestClient(app)
response = client.get("/", allow_redirects=False)
response = client.get("/", follow_redirects=False)
assert response.status_code == 307, response.text
assert response.headers["location"] == "https://testserver/"

16
tests/test_tutorial/test_body/test_tutorial001.py

@ -176,7 +176,7 @@ def test_post_broken_body():
response = client.post(
"/items/",
headers={"content-type": "application/json"},
data="{some broken json}",
content="{some broken json}",
)
assert response.status_code == 422, response.text
assert response.json() == {
@ -214,7 +214,7 @@ def test_post_form_for_json():
def test_explicit_content_type():
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
headers={"Content-Type": "application/json"},
)
assert response.status_code == 200, response.text
@ -223,7 +223,7 @@ def test_explicit_content_type():
def test_geo_json():
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
headers={"Content-Type": "application/geo+json"},
)
assert response.status_code == 200, response.text
@ -232,7 +232,7 @@ def test_geo_json():
def test_no_content_type_is_json():
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
)
assert response.status_code == 200, response.text
assert response.json() == {
@ -255,17 +255,19 @@ def test_wrong_headers():
]
}
response = client.post("/items/", data=data, headers={"Content-Type": "text/plain"})
response = client.post(
"/items/", content=data, headers={"Content-Type": "text/plain"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict
response = client.post(
"/items/", data=data, headers={"Content-Type": "application/geo+json-seq"}
"/items/", content=data, headers={"Content-Type": "application/geo+json-seq"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict
response = client.post(
"/items/", data=data, headers={"Content-Type": "application/not-really-json"}
"/items/", content=data, headers={"Content-Type": "application/not-really-json"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict

16
tests/test_tutorial/test_body/test_tutorial001_py310.py

@ -185,7 +185,7 @@ def test_post_broken_body(client: TestClient):
response = client.post(
"/items/",
headers={"content-type": "application/json"},
data="{some broken json}",
content="{some broken json}",
)
assert response.status_code == 422, response.text
assert response.json() == {
@ -225,7 +225,7 @@ def test_post_form_for_json(client: TestClient):
def test_explicit_content_type(client: TestClient):
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
headers={"Content-Type": "application/json"},
)
assert response.status_code == 200, response.text
@ -235,7 +235,7 @@ def test_explicit_content_type(client: TestClient):
def test_geo_json(client: TestClient):
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
headers={"Content-Type": "application/geo+json"},
)
assert response.status_code == 200, response.text
@ -245,7 +245,7 @@ def test_geo_json(client: TestClient):
def test_no_content_type_is_json(client: TestClient):
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
)
assert response.status_code == 200, response.text
assert response.json() == {
@ -269,17 +269,19 @@ def test_wrong_headers(client: TestClient):
]
}
response = client.post("/items/", data=data, headers={"Content-Type": "text/plain"})
response = client.post(
"/items/", content=data, headers={"Content-Type": "text/plain"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict
response = client.post(
"/items/", data=data, headers={"Content-Type": "application/geo+json-seq"}
"/items/", content=data, headers={"Content-Type": "application/geo+json-seq"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict
response = client.post(
"/items/", data=data, headers={"Content-Type": "application/not-really-json"}
"/items/", content=data, headers={"Content-Type": "application/not-really-json"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict

5
tests/test_tutorial/test_cookie_params/test_tutorial001.py

@ -3,8 +3,6 @@ from fastapi.testclient import TestClient
from docs_src.cookie_params.tutorial001 import app
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
@ -88,6 +86,7 @@ openapi_schema = {
],
)
def test(path, cookies, expected_status, expected_response):
response = client.get(path, cookies=cookies)
client = TestClient(app, cookies=cookies)
response = client.get(path)
assert response.status_code == expected_status
assert response.json() == expected_response

15
tests/test_tutorial/test_cookie_params/test_tutorial001_py310.py

@ -70,14 +70,6 @@ openapi_schema = {
}
@pytest.fixture(name="client")
def get_client():
from docs_src.cookie_params.tutorial001_py310 import app
client = TestClient(app)
return client
@needs_py310
@pytest.mark.parametrize(
"path,cookies,expected_status,expected_response",
@ -94,7 +86,10 @@ def get_client():
("/items", {"session": "cookiesession"}, 200, {"ads_id": None}),
],
)
def test(path, cookies, expected_status, expected_response, client: TestClient):
response = client.get(path, cookies=cookies)
def test(path, cookies, expected_status, expected_response):
from docs_src.cookie_params.tutorial001_py310 import app
client = TestClient(app, cookies=cookies)
response = client.get(path)
assert response.status_code == expected_status
assert response.json() == expected_response

2
tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py

@ -26,7 +26,7 @@ def test_gzip_request(compress):
data = gzip.compress(data)
headers["Content-Encoding"] = "gzip"
headers["Content-Type"] = "application/json"
response = client.post("/sum", data=data, headers=headers)
response = client.post("/sum", content=data, headers=headers)
assert response.json() == {"sum": n}

2
tests/test_tutorial/test_custom_response/test_tutorial006.py

@ -32,6 +32,6 @@ def test_openapi_schema():
def test_get():
response = client.get("/typer", allow_redirects=False)
response = client.get("/typer", follow_redirects=False)
assert response.status_code == 307, response.text
assert response.headers["location"] == "https://typer.tiangolo.com"

2
tests/test_tutorial/test_custom_response/test_tutorial006b.py

@ -27,6 +27,6 @@ def test_openapi_schema():
def test_redirect_response_class():
response = client.get("/fastapi", allow_redirects=False)
response = client.get("/fastapi", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://fastapi.tiangolo.com"

2
tests/test_tutorial/test_custom_response/test_tutorial006c.py

@ -27,6 +27,6 @@ def test_openapi_schema():
def test_redirect_status_code():
response = client.get("/pydantic", allow_redirects=False)
response = client.get("/pydantic", follow_redirects=False)
assert response.status_code == 302
assert response.headers["location"] == "https://pydantic-docs.helpmanual.io/"

2
tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial006.py

@ -47,7 +47,7 @@ def test_openapi_schema():
def test_post():
response = client.post("/items/", data=b"this is actually not validated")
response = client.post("/items/", content=b"this is actually not validated")
assert response.status_code == 200, response.text
assert response.json() == {
"size": 30,

6
tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial007.py

@ -58,7 +58,7 @@ def test_post():
- x-men
- x-avengers
"""
response = client.post("/items/", data=yaml_data)
response = client.post("/items/", content=yaml_data)
assert response.status_code == 200, response.text
assert response.json() == {
"name": "Deadpoolio",
@ -74,7 +74,7 @@ def test_post_broken_yaml():
x - x-men
x - x-avengers
"""
response = client.post("/items/", data=yaml_data)
response = client.post("/items/", content=yaml_data)
assert response.status_code == 422, response.text
assert response.json() == {"detail": "Invalid YAML"}
@ -88,7 +88,7 @@ def test_post_invalid():
- x-avengers
- sneaky: object
"""
response = client.post("/items/", data=yaml_data)
response = client.post("/items/", content=yaml_data)
assert response.status_code == 422, response.text
assert response.json() == {
"detail": [

12
tests/test_tutorial/test_websockets/test_tutorial002.py

@ -4,20 +4,18 @@ from fastapi.websockets import WebSocketDisconnect
from docs_src.websockets.tutorial002 import app
client = TestClient(app)
def test_main():
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200, response.text
assert b"<!DOCTYPE html>" in response.content
def test_websocket_with_cookie():
client = TestClient(app, cookies={"session": "fakesession"})
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(
"/items/foo/ws", cookies={"session": "fakesession"}
) as websocket:
with client.websocket_connect("/items/foo/ws") as websocket:
message = "Message one"
websocket.send_text(message)
data = websocket.receive_text()
@ -33,6 +31,7 @@ def test_websocket_with_cookie():
def test_websocket_with_header():
client = TestClient(app)
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/items/bar/ws?token=some-token") as websocket:
message = "Message one"
@ -50,6 +49,7 @@ def test_websocket_with_header():
def test_websocket_with_header_and_query():
client = TestClient(app)
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/items/2/ws?q=3&token=some-token") as websocket:
message = "Message one"
@ -71,6 +71,7 @@ def test_websocket_with_header_and_query():
def test_websocket_no_credentials():
client = TestClient(app)
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/items/foo/ws"):
pytest.fail(
@ -79,6 +80,7 @@ def test_websocket_no_credentials():
def test_websocket_invalid_data():
client = TestClient(app)
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/items/foo/ws?q=bar&token=some-token"):
pytest.fail(

Loading…
Cancel
Save