Browse Source

🔧 Add ty configs to check docs sources (#15770)

pull/15771/head
Sebastián Ramírez 2 weeks ago
committed by GitHub
parent
commit
b7de2b7feb
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 2
      .pre-commit-config.yaml
  2. 2
      scripts/contributors.py
  3. 2
      scripts/deploy_docs_status.py
  4. 4
      scripts/doc_parsing_utils.py
  5. 4
      scripts/docs.py
  6. 2
      scripts/label_approved.py
  7. 2
      scripts/lint.sh
  8. 3
      scripts/notify_translations.py
  9. 2
      scripts/people.py
  10. 2
      scripts/sponsors.py
  11. 2
      scripts/topic_repos.py
  12. 18
      tests/test_compat.py
  13. 4
      tests/test_custom_middleware_exception.py
  14. 7
      tests/test_datastructures.py
  15. 2
      tests/test_default_response_class.py
  16. 10
      tests/test_deprecated_responses.py
  17. 2
      tests/test_inherited_custom_class.py
  18. 10
      tests/test_jsonable_encoder.py
  19. 20
      tests/test_local_docs.py
  20. 2
      tests/test_openapi_schema_type.py
  21. 4
      tests/test_orjson_response_class.py
  22. 26
      tests/test_response_model_as_return_annotation.py
  23. 18
      tests/test_router_events.py
  24. 2
      tests/test_schema_compat_pydantic_v2.py
  25. 4
      tests/test_serialize_response_model.py
  26. 2
      tests/test_skip_defaults.py
  27. 4
      tests/test_sse.py
  28. 4
      tests/test_starlette_urlconvertors.py
  29. 9
      tests/test_stream_cancellation.py
  30. 6
      tests/test_swagger_ui_escape.py
  31. 2
      tests/test_tutorial/test_body/test_tutorial001.py
  32. 3
      tests/test_tutorial/test_body_nested_models/test_tutorial001_tutorial002_tutorial003.py
  33. 3
      tests/test_tutorial/test_custom_response/test_tutorial008.py
  34. 3
      tests/test_tutorial/test_custom_response/test_tutorial009.py
  35. 3
      tests/test_tutorial/test_custom_response/test_tutorial009b.py
  36. 10
      tests/test_tutorial/test_debugging/test_tutorial001.py
  37. 5
      tests/test_tutorial/test_openapi_webhooks/test_tutorial001.py
  38. 2
      tests/test_tutorial/test_python_types/test_tutorial003.py
  39. 8
      tests/test_tutorial/test_python_types/test_tutorial005.py
  40. 8
      tests/test_tutorial/test_security/test_tutorial005.py
  41. 14
      tests/test_tutorial/test_sql_databases/test_tutorial001.py
  42. 14
      tests/test_tutorial/test_sql_databases/test_tutorial002.py
  43. 5
      tests/test_webhooks_security.py
  44. 6
      tests/utils.py

2
.pre-commit-config.yaml

@ -45,7 +45,7 @@ repos:
- id: local-ty - id: local-ty
name: ty check name: ty check
entry: uv run ty check fastapi docs_src --force-exclude entry: uv run ty check
require_serial: true require_serial: true
language: unsupported language: unsupported
pass_filenames: false pass_filenames: false

2
scripts/contributors.py

@ -237,7 +237,7 @@ def update_content(*, content_path: Path, new_content: Any) -> bool:
def main() -> None: def main() -> None:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
settings = Settings() settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}") logging.info(f"Using config: {settings.model_dump_json()}")
g = Github(settings.github_token.get_secret_value()) g = Github(settings.github_token.get_secret_value())
repo = g.get_repo(settings.github_repository) repo = g.get_repo(settings.github_repository)

2
scripts/deploy_docs_status.py

@ -24,7 +24,7 @@ class LinkData(BaseModel):
def main() -> None: def main() -> None:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
settings = Settings() settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}") logging.info(f"Using config: {settings.model_dump_json()}")
g = Github(auth=Auth.Token(settings.github_token.get_secret_value())) g = Github(auth=Auth.Token(settings.github_token.get_secret_value()))

4
scripts/doc_parsing_utils.py

@ -625,14 +625,14 @@ def replace_multiline_code_block(
_line_b_code, line_b_comment = _split_hash_comment(line_b) _line_b_code, line_b_comment = _split_hash_comment(line_b)
res_line = line_b res_line = line_b
if line_b_comment: if line_b_comment:
res_line = res_line.replace(line_b_comment, line_a_comment, 1) res_line = res_line.replace(line_b_comment, line_a_comment or "", 1)
code_block.append(res_line) code_block.append(res_line)
elif block_language in {"console", "json", "slash-style-comments"}: elif block_language in {"console", "json", "slash-style-comments"}:
_line_a_code, line_a_comment = _split_slashes_comment(line_a) _line_a_code, line_a_comment = _split_slashes_comment(line_a)
_line_b_code, line_b_comment = _split_slashes_comment(line_b) _line_b_code, line_b_comment = _split_slashes_comment(line_b)
res_line = line_b res_line = line_b
if line_b_comment: if line_b_comment:
res_line = res_line.replace(line_b_comment, line_a_comment, 1) res_line = res_line.replace(line_b_comment, line_a_comment or "", 1)
code_block.append(res_line) code_block.append(res_line)
else: else:
code_block.append(line_b) code_block.append(line_b)

4
scripts/docs.py

@ -155,7 +155,7 @@ def build_lang(
""" """
build_zensical_lang_to_stage(lang) build_zensical_lang_to_stage(lang)
copy_zensical_stage_to_site(lang) copy_zensical_stage_to_site(lang)
typer.secho(f"Successfully built docs for: {lang}", color=typer.colors.GREEN) typer.secho(f"Successfully built docs for: {lang}", fg=typer.colors.GREEN)
def split_markdown_header(markdown: str) -> tuple[str, str]: def split_markdown_header(markdown: str) -> tuple[str, str]:
@ -408,7 +408,7 @@ def build_all() -> None:
for lang in langs: for lang in langs:
if lang != "en": if lang != "en":
copy_zensical_stage_to_site(lang) copy_zensical_stage_to_site(lang)
typer.secho("Successfully built all docs", color=typer.colors.GREEN) typer.secho("Successfully built all docs", fg=typer.colors.GREEN)
@app.command() @app.command()

2
scripts/label_approved.py

@ -22,7 +22,7 @@ class Settings(BaseSettings):
config: dict[str, LabelSettings] | Literal[""] = default_config config: dict[str, LabelSettings] | Literal[""] = default_config
settings = Settings() settings = Settings() # ty: ignore[missing-argument]
if settings.debug: if settings.debug:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
else: else:

2
scripts/lint.sh

@ -4,6 +4,6 @@ set -e
set -x set -x
mypy fastapi mypy fastapi
ty check fastapi docs_src --force-exclude ty check
ruff check fastapi tests docs_src scripts ruff check fastapi tests docs_src scripts
ruff format fastapi tests --check ruff format fastapi tests --check

3
scripts/notify_translations.py

@ -304,7 +304,7 @@ def update_comment(*, settings: Settings, comment_id: str, body: str) -> Comment
def main() -> None: def main() -> None:
settings = Settings() settings = Settings() # ty: ignore[missing-argument]
if settings.debug: if settings.debug:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
else: else:
@ -324,6 +324,7 @@ def main() -> None:
) or settings.number ) or settings.number
if number is None: if number is None:
raise RuntimeError("No PR number available") raise RuntimeError("No PR number available")
number = cast(int, number)
# Avoid race conditions with multiple labels # Avoid race conditions with multiple labels
sleep_time = random.random() * 10 # random number between 0 and 10 seconds sleep_time = random.random() * 10 # random number between 0 and 10 seconds

2
scripts/people.py

@ -394,7 +394,7 @@ def update_content(*, content_path: Path, new_content: Any) -> bool:
def main() -> None: def main() -> None:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
settings = Settings() settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}") logging.info(f"Using config: {settings.model_dump_json()}")
rate_limiter.speed_multiplier = settings.speed_multiplier rate_limiter.speed_multiplier = settings.speed_multiplier
g = Github(settings.github_token.get_secret_value()) g = Github(settings.github_token.get_secret_value())

2
scripts/sponsors.py

@ -158,7 +158,7 @@ def update_content(*, content_path: Path, new_content: Any) -> bool:
def main() -> None: def main() -> None:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
settings = Settings() settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}") logging.info(f"Using config: {settings.model_dump_json()}")
g = Github(settings.pr_token.get_secret_value()) g = Github(settings.pr_token.get_secret_value())
repo = g.get_repo(settings.github_repository) repo = g.get_repo(settings.github_repository)

2
scripts/topic_repos.py

@ -24,7 +24,7 @@ class Repo(BaseModel):
def main() -> None: def main() -> None:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
settings = Settings() settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}") logging.info(f"Using config: {settings.model_dump_json()}")
g = Github(settings.github_token.get_secret_value(), per_page=100) g = Github(settings.github_token.get_secret_value(), per_page=100)

18
tests/test_compat.py

@ -1,3 +1,5 @@
from typing import Any, cast
from fastapi import FastAPI, UploadFile from fastapi import FastAPI, UploadFile
from fastapi._compat import ( from fastapi._compat import (
Undefined, Undefined,
@ -56,9 +58,15 @@ def test_propagates_pydantic2_model_config():
@app.post("/") @app.post("/")
def foo(req: Model) -> dict[str, str | None]: def foo(req: Model) -> dict[str, str | None]:
value = req.value
if isinstance(value, Missing):
value = None
embedded_value = req.embedded_model.value
if isinstance(embedded_value, Missing):
embedded_value = None
return { return {
"value": req.value or None, "value": value,
"embedded_value": req.embedded_model.value or None, "embedded_value": embedded_value,
} }
client = TestClient(app) client = TestClient(app)
@ -100,7 +108,7 @@ def test_serialize_sequence_value_with_optional_list():
"""Test that serialize_sequence_value handles optional lists correctly.""" """Test that serialize_sequence_value handles optional lists correctly."""
from fastapi._compat import v2 from fastapi._compat import v2
field_info = FieldInfo(annotation=list[str] | None) field_info = FieldInfo(annotation=cast(Any, list[str] | None))
field = v2.ModelField(name="items", field_info=field_info) field = v2.ModelField(name="items", field_info=field_info)
result = v2.serialize_sequence_value(field=field, value=["a", "b", "c"]) result = v2.serialize_sequence_value(field=field, value=["a", "b", "c"])
assert result == ["a", "b", "c"] assert result == ["a", "b", "c"]
@ -111,7 +119,7 @@ def test_serialize_sequence_value_with_optional_list_pipe_union():
"""Test that serialize_sequence_value handles optional lists correctly (with new syntax).""" """Test that serialize_sequence_value handles optional lists correctly (with new syntax)."""
from fastapi._compat import v2 from fastapi._compat import v2
field_info = FieldInfo(annotation=list[str] | None) field_info = FieldInfo(annotation=cast(Any, list[str] | None))
field = v2.ModelField(name="items", field_info=field_info) field = v2.ModelField(name="items", field_info=field_info)
result = v2.serialize_sequence_value(field=field, value=["a", "b", "c"]) result = v2.serialize_sequence_value(field=field, value=["a", "b", "c"])
assert result == ["a", "b", "c"] assert result == ["a", "b", "c"]
@ -125,7 +133,7 @@ def test_serialize_sequence_value_with_none_first_in_union():
from fastapi._compat import v2 from fastapi._compat import v2
# Use Union[None, list[str]] to ensure None comes first in the union args # Use Union[None, list[str]] to ensure None comes first in the union args
field_info = FieldInfo(annotation=Union[None, list[str]]) # noqa: UP007 field_info = FieldInfo(annotation=cast(Any, Union[None, list[str]])) # noqa: UP007
field = v2.ModelField(name="items", field_info=field_info) field = v2.ModelField(name="items", field_info=field_info)
result = v2.serialize_sequence_value(field=field, value=["x", "y"]) result = v2.serialize_sequence_value(field=field, value=["x", "y"])
assert result == ["x", "y"] assert result == ["x", "y"]

4
tests/test_custom_middleware_exception.py

@ -3,6 +3,7 @@ from pathlib import Path
from fastapi import APIRouter, FastAPI, File, UploadFile from fastapi import APIRouter, FastAPI, File, UploadFile
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from starlette.types import ASGIApp
app = FastAPI() app = FastAPI()
@ -16,7 +17,7 @@ class ContentSizeLimitMiddleware:
max_content_size (optional): the maximum content size allowed in bytes, None for no limit max_content_size (optional): the maximum content size allowed in bytes, None for no limit
""" """
def __init__(self, app: APIRouter, max_content_size: int | None = None): def __init__(self, app: ASGIApp, max_content_size: int | None = None):
self.app = app self.app = app
self.max_content_size = max_content_size self.max_content_size = max_content_size
@ -31,6 +32,7 @@ class ContentSizeLimitMiddleware:
body_len = len(message.get("body", b"")) body_len = len(message.get("body", b""))
received += body_len received += body_len
assert self.max_content_size is not None
if received > self.max_content_size: if received > self.max_content_size:
raise HTTPException( raise HTTPException(
422, 422,

7
tests/test_datastructures.py

@ -1,9 +1,10 @@
import io import io
from pathlib import Path from pathlib import Path
from typing import cast
import pytest import pytest
from fastapi import FastAPI, UploadFile from fastapi import FastAPI, UploadFile
from fastapi.datastructures import Default from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -13,8 +14,8 @@ def test_upload_file_invalid_pydantic_v2():
def test_default_placeholder_equals(): def test_default_placeholder_equals():
placeholder_1 = Default("a") placeholder_1 = cast(DefaultPlaceholder, Default("a"))
placeholder_2 = Default("a") placeholder_2 = cast(DefaultPlaceholder, Default("a"))
assert placeholder_1 == placeholder_2 assert placeholder_1 == placeholder_2
assert placeholder_1.value == placeholder_2.value assert placeholder_1.value == placeholder_2.value

2
tests/test_default_response_class.py

@ -11,7 +11,7 @@ class ORJSONResponse(JSONResponse):
media_type = "application/x-orjson" media_type = "application/x-orjson"
def render(self, content: Any) -> bytes: def render(self, content: Any) -> bytes:
import orjson import orjson # ty: ignore[unresolved-import]
return orjson.dumps(content) return orjson.dumps(content)

10
tests/test_deprecated_responses.py

@ -3,7 +3,7 @@ import warnings
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.exceptions import FastAPIDeprecationWarning from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.responses import ORJSONResponse, UJSONResponse from fastapi.responses import ORJSONResponse, UJSONResponse # ty: ignore[deprecated]
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import BaseModel from pydantic import BaseModel
@ -21,7 +21,7 @@ class Item(BaseModel):
def _make_orjson_app() -> FastAPI: def _make_orjson_app() -> FastAPI:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning) warnings.simplefilter("ignore", FastAPIDeprecationWarning)
app = FastAPI(default_response_class=ORJSONResponse) app = FastAPI(default_response_class=ORJSONResponse) # ty: ignore[deprecated]
@app.get("/items") @app.get("/items")
def get_items() -> Item: def get_items() -> Item:
@ -44,7 +44,7 @@ def test_orjson_response_returns_correct_data():
@needs_orjson @needs_orjson
def test_orjson_response_emits_deprecation_warning(): def test_orjson_response_emits_deprecation_warning():
with pytest.warns(FastAPIDeprecationWarning, match="ORJSONResponse is deprecated"): with pytest.warns(FastAPIDeprecationWarning, match="ORJSONResponse is deprecated"):
ORJSONResponse(content={"hello": "world"}) ORJSONResponse(content={"hello": "world"}) # ty: ignore[deprecated]
# UJSON # UJSON
@ -53,7 +53,7 @@ def test_orjson_response_emits_deprecation_warning():
def _make_ujson_app() -> FastAPI: def _make_ujson_app() -> FastAPI:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning) warnings.simplefilter("ignore", FastAPIDeprecationWarning)
app = FastAPI(default_response_class=UJSONResponse) app = FastAPI(default_response_class=UJSONResponse) # ty: ignore[deprecated]
@app.get("/items") @app.get("/items")
def get_items() -> Item: def get_items() -> Item:
@ -76,4 +76,4 @@ def test_ujson_response_returns_correct_data():
@needs_ujson @needs_ujson
def test_ujson_response_emits_deprecation_warning(): def test_ujson_response_emits_deprecation_warning():
with pytest.warns(FastAPIDeprecationWarning, match="UJSONResponse is deprecated"): with pytest.warns(FastAPIDeprecationWarning, match="UJSONResponse is deprecated"):
UJSONResponse(content={"hello": "world"}) UJSONResponse(content={"hello": "world"}) # ty: ignore[deprecated]

2
tests/test_inherited_custom_class.py

@ -13,7 +13,7 @@ class MyUuid:
def __str__(self): def __str__(self):
return self.uuid return self.uuid
@property # type: ignore @property
def __class__(self): def __class__(self):
return uuid.UUID return uuid.UUID

10
tests/test_jsonable_encoder.py

@ -87,10 +87,10 @@ def test_encode_dict():
def test_encode_dict_include_exclude_list(): def test_encode_dict_include_exclude_list():
pet = {"name": "Firulais", "owner": {"name": "Foo"}} pet = {"name": "Firulais", "owner": {"name": "Foo"}}
assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}} assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}}
assert jsonable_encoder(pet, include=["name"]) == {"name": "Firulais"} assert jsonable_encoder(pet, include=["name"]) == {"name": "Firulais"} # ty: ignore[invalid-argument-type]
assert jsonable_encoder(pet, exclude=["owner"]) == {"name": "Firulais"} assert jsonable_encoder(pet, exclude=["owner"]) == {"name": "Firulais"} # ty: ignore[invalid-argument-type]
assert jsonable_encoder(pet, include=[]) == {} assert jsonable_encoder(pet, include=[]) == {} # ty: ignore[invalid-argument-type]
assert jsonable_encoder(pet, exclude=[]) == { assert jsonable_encoder(pet, exclude=[]) == { # ty: ignore[invalid-argument-type]
"name": "Firulais", "name": "Firulais",
"owner": {"name": "Foo"}, "owner": {"name": "Foo"},
} }
@ -176,7 +176,7 @@ def test_encode_model_with_config():
def test_encode_model_with_alias_raises(): def test_encode_model_with_alias_raises():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
ModelWithAlias(foo="Bar") ModelWithAlias(foo="Bar") # ty: ignore[missing-argument, unknown-argument]
def test_encode_model_with_alias(): def test_encode_model_with_alias():

20
tests/test_local_docs.py

@ -9,7 +9,7 @@ def test_strings_in_generated_swagger():
swagger_css_url = sig.parameters.get("swagger_css_url").default # type: ignore swagger_css_url = sig.parameters.get("swagger_css_url").default # type: ignore
swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default # type: ignore swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default # type: ignore
html = get_swagger_ui_html(openapi_url="/docs", title="title") html = get_swagger_ui_html(openapi_url="/docs", title="title")
body_content = html.body.decode() body_content = bytes(html.body).decode()
assert swagger_js_url in body_content assert swagger_js_url in body_content
assert swagger_css_url in body_content assert swagger_css_url in body_content
assert swagger_favicon_url in body_content assert swagger_favicon_url in body_content
@ -26,7 +26,7 @@ def test_strings_in_custom_swagger():
swagger_css_url=swagger_css_url, swagger_css_url=swagger_css_url,
swagger_favicon_url=swagger_favicon_url, swagger_favicon_url=swagger_favicon_url,
) )
body_content = html.body.decode() body_content = bytes(html.body).decode()
assert swagger_js_url in body_content assert swagger_js_url in body_content
assert swagger_css_url in body_content assert swagger_css_url in body_content
assert swagger_favicon_url in body_content assert swagger_favicon_url in body_content
@ -37,7 +37,7 @@ def test_strings_in_generated_redoc():
redoc_js_url = sig.parameters.get("redoc_js_url").default # type: ignore redoc_js_url = sig.parameters.get("redoc_js_url").default # type: ignore
redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default # type: ignore redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default # type: ignore
html = get_redoc_html(openapi_url="/docs", title="title") html = get_redoc_html(openapi_url="/docs", title="title")
body_content = html.body.decode() body_content = bytes(html.body).decode()
assert redoc_js_url in body_content assert redoc_js_url in body_content
assert redoc_favicon_url in body_content assert redoc_favicon_url in body_content
@ -51,17 +51,17 @@ def test_strings_in_custom_redoc():
redoc_js_url=redoc_js_url, redoc_js_url=redoc_js_url,
redoc_favicon_url=redoc_favicon_url, redoc_favicon_url=redoc_favicon_url,
) )
body_content = html.body.decode() body_content = bytes(html.body).decode()
assert redoc_js_url in body_content assert redoc_js_url in body_content
assert redoc_favicon_url in body_content assert redoc_favicon_url in body_content
def test_google_fonts_in_generated_redoc(): def test_google_fonts_in_generated_redoc():
body_with_google_fonts = get_redoc_html( body_with_google_fonts = bytes(
openapi_url="/docs", title="title" get_redoc_html(openapi_url="/docs", title="title").body
).body.decode() ).decode()
assert "fonts.googleapis.com" in body_with_google_fonts assert "fonts.googleapis.com" in body_with_google_fonts
body_without_google_fonts = get_redoc_html( body_without_google_fonts = bytes(
openapi_url="/docs", title="title", with_google_fonts=False get_redoc_html(openapi_url="/docs", title="title", with_google_fonts=False).body
).body.decode() ).decode()
assert "fonts.googleapis.com" not in body_without_google_fonts assert "fonts.googleapis.com" not in body_without_google_fonts

2
tests/test_openapi_schema_type.py

@ -21,4 +21,4 @@ def test_allowed_schema_type(
def test_invalid_type_value() -> None: def test_invalid_type_value() -> None:
"""Test that Schema raises ValueError for invalid type values.""" """Test that Schema raises ValueError for invalid type values."""
with pytest.raises(ValueError, match="2 validation errors for Schema"): with pytest.raises(ValueError, match="2 validation errors for Schema"):
Schema(type=True) # type: ignore[arg-type] Schema(type=True) # type: ignore[arg-type] # ty: ignore[invalid-argument-type]

4
tests/test_orjson_response_class.py

@ -6,13 +6,13 @@ pytest.importorskip("orjson")
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.exceptions import FastAPIDeprecationWarning from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse # ty: ignore[deprecated]
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.elements import quoted_name
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning) warnings.simplefilter("ignore", FastAPIDeprecationWarning)
app = FastAPI(default_response_class=ORJSONResponse) app = FastAPI(default_response_class=ORJSONResponse) # ty: ignore[deprecated]
@app.get("/orjson_non_str_keys") @app.get("/orjson_non_str_keys")

26
tests/test_response_model_as_return_annotation.py

@ -78,22 +78,22 @@ def no_response_model_annotation_return_same_model() -> User:
@app.get("/no_response_model-annotation-return_exact_dict") @app.get("/no_response_model-annotation-return_exact_dict")
def no_response_model_annotation_return_exact_dict() -> User: def no_response_model_annotation_return_exact_dict() -> User:
return {"name": "John", "surname": "Doe"} return {"name": "John", "surname": "Doe"} # ty: ignore[invalid-return-type]
@app.get("/no_response_model-annotation-return_invalid_dict") @app.get("/no_response_model-annotation-return_invalid_dict")
def no_response_model_annotation_return_invalid_dict() -> User: def no_response_model_annotation_return_invalid_dict() -> User:
return {"name": "John"} return {"name": "John"} # ty: ignore[invalid-return-type]
@app.get("/no_response_model-annotation-return_invalid_model") @app.get("/no_response_model-annotation-return_invalid_model")
def no_response_model_annotation_return_invalid_model() -> User: def no_response_model_annotation_return_invalid_model() -> User:
return Item(name="Foo", price=42.0) return Item(name="Foo", price=42.0) # ty: ignore[invalid-return-type]
@app.get("/no_response_model-annotation-return_dict_with_extra_data") @app.get("/no_response_model-annotation-return_dict_with_extra_data")
def no_response_model_annotation_return_dict_with_extra_data() -> User: def no_response_model_annotation_return_dict_with_extra_data() -> User:
return {"name": "John", "surname": "Doe", "password_hash": "secret"} return {"name": "John", "surname": "Doe", "password_hash": "secret"} # ty: ignore[invalid-return-type]
@app.get("/no_response_model-annotation-return_submodel_with_extra_data") @app.get("/no_response_model-annotation-return_submodel_with_extra_data")
@ -108,24 +108,24 @@ def response_model_none_annotation_return_same_model() -> User:
@app.get("/response_model_none-annotation-return_exact_dict", response_model=None) @app.get("/response_model_none-annotation-return_exact_dict", response_model=None)
def response_model_none_annotation_return_exact_dict() -> User: def response_model_none_annotation_return_exact_dict() -> User:
return {"name": "John", "surname": "Doe"} return {"name": "John", "surname": "Doe"} # ty: ignore[invalid-return-type]
@app.get("/response_model_none-annotation-return_invalid_dict", response_model=None) @app.get("/response_model_none-annotation-return_invalid_dict", response_model=None)
def response_model_none_annotation_return_invalid_dict() -> User: def response_model_none_annotation_return_invalid_dict() -> User:
return {"name": "John"} return {"name": "John"} # ty: ignore[invalid-return-type]
@app.get("/response_model_none-annotation-return_invalid_model", response_model=None) @app.get("/response_model_none-annotation-return_invalid_model", response_model=None)
def response_model_none_annotation_return_invalid_model() -> User: def response_model_none_annotation_return_invalid_model() -> User:
return Item(name="Foo", price=42.0) return Item(name="Foo", price=42.0) # ty: ignore[invalid-return-type]
@app.get( @app.get(
"/response_model_none-annotation-return_dict_with_extra_data", response_model=None "/response_model_none-annotation-return_dict_with_extra_data", response_model=None
) )
def response_model_none_annotation_return_dict_with_extra_data() -> User: def response_model_none_annotation_return_dict_with_extra_data() -> User:
return {"name": "John", "surname": "Doe", "password_hash": "secret"} return {"name": "John", "surname": "Doe", "password_hash": "secret"} # ty: ignore[invalid-return-type]
@app.get( @app.get(
@ -140,21 +140,21 @@ def response_model_none_annotation_return_submodel_with_extra_data() -> User:
"/response_model_model1-annotation_model2-return_same_model", response_model=User "/response_model_model1-annotation_model2-return_same_model", response_model=User
) )
def response_model_model1_annotation_model2_return_same_model() -> Item: def response_model_model1_annotation_model2_return_same_model() -> Item:
return User(name="John", surname="Doe") return User(name="John", surname="Doe") # ty: ignore[invalid-return-type]
@app.get( @app.get(
"/response_model_model1-annotation_model2-return_exact_dict", response_model=User "/response_model_model1-annotation_model2-return_exact_dict", response_model=User
) )
def response_model_model1_annotation_model2_return_exact_dict() -> Item: def response_model_model1_annotation_model2_return_exact_dict() -> Item:
return {"name": "John", "surname": "Doe"} return {"name": "John", "surname": "Doe"} # ty: ignore[invalid-return-type]
@app.get( @app.get(
"/response_model_model1-annotation_model2-return_invalid_dict", response_model=User "/response_model_model1-annotation_model2-return_invalid_dict", response_model=User
) )
def response_model_model1_annotation_model2_return_invalid_dict() -> Item: def response_model_model1_annotation_model2_return_invalid_dict() -> Item:
return {"name": "John"} return {"name": "John"} # ty: ignore[invalid-return-type]
@app.get( @app.get(
@ -169,7 +169,7 @@ def response_model_model1_annotation_model2_return_invalid_model() -> Item:
response_model=User, response_model=User,
) )
def response_model_model1_annotation_model2_return_dict_with_extra_data() -> Item: def response_model_model1_annotation_model2_return_dict_with_extra_data() -> Item:
return {"name": "John", "surname": "Doe", "password_hash": "secret"} return {"name": "John", "surname": "Doe", "password_hash": "secret"} # ty: ignore[invalid-return-type]
@app.get( @app.get(
@ -177,7 +177,7 @@ def response_model_model1_annotation_model2_return_dict_with_extra_data() -> Ite
response_model=User, response_model=User,
) )
def response_model_model1_annotation_model2_return_submodel_with_extra_data() -> Item: def response_model_model1_annotation_model2_return_submodel_with_extra_data() -> Item:
return DBUser(name="John", surname="Doe", password_hash="secret") return DBUser(name="John", surname="Doe", password_hash="secret") # ty: ignore[invalid-return-type]
@app.get( @app.get(

18
tests/test_router_events.py

@ -31,31 +31,31 @@ def test_router_events(state: State) -> None:
def main() -> dict[str, str]: def main() -> dict[str, str]:
return {"message": "Hello World"} return {"message": "Hello World"}
@app.on_event("startup") @app.on_event("startup") # ty: ignore[deprecated]
def app_startup() -> None: def app_startup() -> None:
state.app_startup = True state.app_startup = True
@app.on_event("shutdown") @app.on_event("shutdown") # ty: ignore[deprecated]
def app_shutdown() -> None: def app_shutdown() -> None:
state.app_shutdown = True state.app_shutdown = True
router = APIRouter() router = APIRouter()
@router.on_event("startup") @router.on_event("startup") # ty: ignore[deprecated]
def router_startup() -> None: def router_startup() -> None:
state.router_startup = True state.router_startup = True
@router.on_event("shutdown") @router.on_event("shutdown") # ty: ignore[deprecated]
def router_shutdown() -> None: def router_shutdown() -> None:
state.router_shutdown = True state.router_shutdown = True
sub_router = APIRouter() sub_router = APIRouter()
@sub_router.on_event("startup") @sub_router.on_event("startup") # ty: ignore[deprecated]
def sub_router_startup() -> None: def sub_router_startup() -> None:
state.sub_router_startup = True state.sub_router_startup = True
@sub_router.on_event("shutdown") @sub_router.on_event("shutdown") # ty: ignore[deprecated]
def sub_router_shutdown() -> None: def sub_router_shutdown() -> None:
state.sub_router_shutdown = True state.sub_router_shutdown = True
@ -253,7 +253,7 @@ def test_router_async_shutdown_handler(state: State) -> None:
def main() -> dict[str, str]: def main() -> dict[str, str]:
return {"message": "Hello World"} return {"message": "Hello World"}
@app.on_event("shutdown") @app.on_event("shutdown") # ty: ignore[deprecated]
async def app_shutdown() -> None: async def app_shutdown() -> None:
state.app_shutdown = True state.app_shutdown = True
@ -274,7 +274,7 @@ def test_router_sync_generator_lifespan(state: State) -> None:
yield yield
state.app_shutdown = True state.app_shutdown = True
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type] app = FastAPI(lifespan=lifespan) # type: ignore[invalid-argument-type] # ty: ignore[invalid-argument-type]
@app.get("/") @app.get("/")
def main() -> dict[str, str]: def main() -> dict[str, str]:
@ -300,7 +300,7 @@ def test_router_async_generator_lifespan(state: State) -> None:
yield yield
state.app_shutdown = True state.app_shutdown = True
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type] app = FastAPI(lifespan=lifespan) # type: ignore[invalid-argument-type] # ty: ignore[invalid-argument-type]
@app.get("/") @app.get("/")
def main() -> dict[str, str]: def main() -> dict[str, str]:

2
tests/test_schema_compat_pydantic_v2.py

@ -26,7 +26,7 @@ def get_client():
@app.get("/users") @app.get("/users")
async def get_user() -> User: async def get_user() -> User:
return {"username": "alice", "role": "admin"} return {"username": "alice", "role": "admin"} # ty: ignore[invalid-return-type]
client = TestClient(app) client = TestClient(app)
return client return client

4
tests/test_serialize_response_model.py

@ -18,7 +18,7 @@ def get_valid():
@app.get("/items/coerce", response_model=Item) @app.get("/items/coerce", response_model=Item)
def get_coerce(): def get_coerce():
return Item(aliased_name="coerce", price="1.0") return Item(aliased_name="coerce", price="1.0") # ty: ignore[invalid-argument-type]
@app.get("/items/validlist", response_model=list[Item]) @app.get("/items/validlist", response_model=list[Item])
@ -52,7 +52,7 @@ def get_valid_exclude_unset():
response_model_exclude_unset=True, response_model_exclude_unset=True,
) )
def get_coerce_exclude_unset(): def get_coerce_exclude_unset():
return Item(aliased_name="coerce", price="1.0") return Item(aliased_name="coerce", price="1.0") # ty: ignore[invalid-argument-type]
@app.get( @app.get(

2
tests/test_skip_defaults.py

@ -29,7 +29,7 @@ class ModelDefaults(BaseModel):
@app.get("/", response_model=Model, response_model_exclude_unset=True) @app.get("/", response_model=Model, response_model_exclude_unset=True)
def get_root() -> ModelSubclass: def get_root() -> ModelSubclass:
return ModelSubclass(sub={}, y=1, z=0) return ModelSubclass(sub={}, y=1, z=0) # ty: ignore[invalid-argument-type]
@app.get( @app.get(

4
tests/test_sse.py

@ -227,7 +227,7 @@ def test_server_sent_event_single_line_fields_reject_newlines(
field_name: str, value: str field_name: str, value: str
): ):
with pytest.raises(ValueError, match=f"SSE '{field_name}' must be a single line"): with pytest.raises(ValueError, match=f"SSE '{field_name}' must be a single line"):
ServerSentEvent(data="test", **{field_name: value}) ServerSentEvent(data="test", **{field_name: value}) # ty: ignore[invalid-argument-type]
def test_server_sent_event_negative_retry_rejected(): def test_server_sent_event_negative_retry_rejected():
@ -237,7 +237,7 @@ def test_server_sent_event_negative_retry_rejected():
def test_server_sent_event_float_retry_rejected(): def test_server_sent_event_float_retry_rejected():
with pytest.raises(ValueError): with pytest.raises(ValueError):
ServerSentEvent(data="test", retry=1.5) # type: ignore[arg-type] ServerSentEvent(data="test", retry=1.5) # type: ignore[arg-type] # ty: ignore[invalid-argument-type]
def test_raw_data_sent_without_json_encoding(client: TestClient): def test_raw_data_sent_without_json_encoding(client: TestClient):

4
tests/test_starlette_urlconvertors.py

@ -32,7 +32,7 @@ def test_route_converters_int():
response = client.get("/int/5") response = client.get("/int/5")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"int": 5} assert response.json() == {"int": 5}
assert app.url_path_for("int_convertor", param=5) == "/int/5" # type: ignore assert app.url_path_for("int_convertor", param=5) == "/int/5"
def test_route_converters_float(): def test_route_converters_float():
@ -40,7 +40,7 @@ def test_route_converters_float():
response = client.get("/float/25.5") response = client.get("/float/25.5")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"float": 25.5} assert response.json() == {"float": 25.5}
assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5" # type: ignore assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5"
def test_route_converters_path(): def test_route_converters_path():

9
tests/test_stream_cancellation.py

@ -10,6 +10,7 @@ import anyio
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from starlette.types import Message, Scope
pytestmark = [ pytestmark = [
pytest.mark.anyio, pytest.mark.anyio,
@ -45,16 +46,16 @@ async def _run_asgi_and_cancel(app: FastAPI, path: str, timeout: float) -> bool:
""" """
chunks: list[bytes] = [] chunks: list[bytes] = []
async def receive(): # type: ignore[no-untyped-def] async def receive() -> Message:
# Simulate a client that never disconnects, rely on cancellation # Simulate a client that never disconnects, rely on cancellation
await anyio.sleep(float("inf")) await anyio.sleep(float("inf"))
return {"type": "http.disconnect"} # pragma: no cover return {"type": "http.disconnect"} # pragma: no cover
async def send(message: dict) -> None: # type: ignore[type-arg] async def send(message: Message) -> None:
if message["type"] == "http.response.body": if message["type"] == "http.response.body":
chunks.append(message.get("body", b"")) chunks.append(message.get("body", b""))
scope = { scope: Scope = {
"type": "http", "type": "http",
"asgi": {"version": "3.0", "spec_version": "2.0"}, "asgi": {"version": "3.0", "spec_version": "2.0"},
"http_version": "1.1", "http_version": "1.1",
@ -67,7 +68,7 @@ async def _run_asgi_and_cancel(app: FastAPI, path: str, timeout: float) -> bool:
} }
with anyio.move_on_after(timeout) as cancel_scope: with anyio.move_on_after(timeout) as cancel_scope:
await app(scope, receive, send) # type: ignore[arg-type] await app(scope, receive, send)
# If we got here within the timeout the generator was cancellable. # If we got here within the timeout the generator was cancellable.
# cancel_scope.cancelled_caught is True when move_on_after fired. # cancel_scope.cancelled_caught is True when move_on_after fired.

6
tests/test_swagger_ui_escape.py

@ -8,7 +8,7 @@ def test_init_oauth_html_chars_are_escaped():
title="Test", title="Test",
init_oauth={"appName": xss_payload}, init_oauth={"appName": xss_payload},
) )
body = html.body.decode() body = bytes(html.body).decode()
assert "</script><script>" not in body assert "</script><script>" not in body
assert "\\u003c/script\\u003e\\u003cscript\\u003e" in body assert "\\u003c/script\\u003e\\u003cscript\\u003e" in body
@ -20,7 +20,7 @@ def test_swagger_ui_parameters_html_chars_are_escaped():
title="Test", title="Test",
swagger_ui_parameters={"customKey": "<img src=x onerror=alert(1)>"}, swagger_ui_parameters={"customKey": "<img src=x onerror=alert(1)>"},
) )
body = html.body.decode() body = bytes(html.body).decode()
assert "<img src=x onerror=alert(1)>" not in body assert "<img src=x onerror=alert(1)>" not in body
assert "\\u003cimg" in body assert "\\u003cimg" in body
@ -31,7 +31,7 @@ def test_normal_init_oauth_still_works():
title="Test", title="Test",
init_oauth={"clientId": "my-client", "appName": "My App"}, init_oauth={"clientId": "my-client", "appName": "My App"},
) )
body = html.body.decode() body = bytes(html.body).decode()
assert '"clientId": "my-client"' in body assert '"clientId": "my-client"' in body
assert '"appName": "My App"' in body assert '"appName": "My App"' in body
assert "ui.initOAuth" in body assert "ui.initOAuth" in body

2
tests/test_tutorial/test_body/test_tutorial001.py

@ -157,7 +157,7 @@ def test_post_broken_body(client: TestClient):
def test_post_form_for_json(client: TestClient): def test_post_form_for_json(client: TestClient):
response = client.post("/items/", data={"name": "Foo", "price": 50.5}) response = client.post("/items/", data={"name": "Foo", "price": "50.5"})
assert response.status_code == 422, response.text assert response.status_code == 422, response.text
assert response.json() == { assert response.json() == {
"detail": [ "detail": [

3
tests/test_tutorial/test_body_nested_models/test_tutorial001_tutorial002_tutorial003.py

@ -1,4 +1,5 @@
import importlib import importlib
from typing import Any
import pytest import pytest
from dirty_equals import IsList from dirty_equals import IsList
@ -130,7 +131,7 @@ def test_put_missing_required(client: TestClient):
def test_openapi_schema(client: TestClient, mod_name: str): def test_openapi_schema(client: TestClient, mod_name: str):
tags_schema = {"default": [], "title": "Tags"} tags_schema: dict[str, Any] = {"default": [], "title": "Tags"}
if mod_name.startswith("tutorial001"): if mod_name.startswith("tutorial001"):
tags_schema.update(UNTYPED_LIST_SCHEMA) tags_schema.update(UNTYPED_LIST_SCHEMA)
elif mod_name.startswith("tutorial002"): elif mod_name.startswith("tutorial002"):

3
tests/test_tutorial/test_custom_response/test_tutorial008.py

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Any, cast
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -10,7 +11,7 @@ client = TestClient(app)
def test_get(tmp_path: Path): def test_get(tmp_path: Path):
file_path: Path = tmp_path / "large-video-file.mp4" file_path: Path = tmp_path / "large-video-file.mp4"
tutorial008_py310.some_file_path = str(file_path) cast(Any, tutorial008_py310).some_file_path = str(file_path)
test_content = b"Fake video bytes" test_content = b"Fake video bytes"
file_path.write_bytes(test_content) file_path.write_bytes(test_content)
response = client.get("/") response = client.get("/")

3
tests/test_tutorial/test_custom_response/test_tutorial009.py

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Any, cast
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -10,7 +11,7 @@ client = TestClient(app)
def test_get(tmp_path: Path): def test_get(tmp_path: Path):
file_path: Path = tmp_path / "large-video-file.mp4" file_path: Path = tmp_path / "large-video-file.mp4"
tutorial009_py310.some_file_path = str(file_path) cast(Any, tutorial009_py310).some_file_path = str(file_path)
test_content = b"Fake video bytes" test_content = b"Fake video bytes"
file_path.write_bytes(test_content) file_path.write_bytes(test_content)
response = client.get("/") response = client.get("/")

3
tests/test_tutorial/test_custom_response/test_tutorial009b.py

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Any, cast
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -10,7 +11,7 @@ client = TestClient(app)
def test_get(tmp_path: Path): def test_get(tmp_path: Path):
file_path: Path = tmp_path / "large-video-file.mp4" file_path: Path = tmp_path / "large-video-file.mp4"
tutorial009b_py310.some_file_path = str(file_path) cast(Any, tutorial009b_py310).some_file_path = str(file_path)
test_content = b"Fake video bytes" test_content = b"Fake video bytes"
file_path.write_bytes(test_content) file_path.write_bytes(test_content)
response = client.get("/") response = client.get("/")

10
tests/test_tutorial/test_debugging/test_tutorial001.py

@ -1,7 +1,7 @@
import importlib import importlib
import runpy import runpy
import sys import sys
import unittest from unittest import mock
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -20,7 +20,7 @@ def get_client():
def test_uvicorn_run_is_not_called_on_import(): def test_uvicorn_run_is_not_called_on_import():
if sys.modules.get(MOD_NAME): if sys.modules.get(MOD_NAME):
del sys.modules[MOD_NAME] # pragma: no cover del sys.modules[MOD_NAME] # pragma: no cover
with unittest.mock.patch("uvicorn.run") as uvicorn_run_mock: with mock.patch("uvicorn.run") as uvicorn_run_mock:
importlib.import_module(MOD_NAME) importlib.import_module(MOD_NAME)
uvicorn_run_mock.assert_not_called() uvicorn_run_mock.assert_not_called()
@ -34,12 +34,10 @@ def test_get_root(client: TestClient):
def test_uvicorn_run_called_when_run_as_main(): # Just for coverage def test_uvicorn_run_called_when_run_as_main(): # Just for coverage
if sys.modules.get(MOD_NAME): if sys.modules.get(MOD_NAME):
del sys.modules[MOD_NAME] del sys.modules[MOD_NAME]
with unittest.mock.patch("uvicorn.run") as uvicorn_run_mock: with mock.patch("uvicorn.run") as uvicorn_run_mock:
runpy.run_module(MOD_NAME, run_name="__main__") runpy.run_module(MOD_NAME, run_name="__main__")
uvicorn_run_mock.assert_called_once_with( uvicorn_run_mock.assert_called_once_with(mock.ANY, host="0.0.0.0", port=8000)
unittest.mock.ANY, host="0.0.0.0", port=8000
)
def test_openapi_schema(client: TestClient): def test_openapi_schema(client: TestClient):

5
tests/test_tutorial/test_openapi_webhooks/test_tutorial001.py

@ -1,3 +1,4 @@
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from inline_snapshot import snapshot from inline_snapshot import snapshot
@ -14,7 +15,9 @@ def test_get():
def test_dummy_webhook(): def test_dummy_webhook():
# Just for coverage # Just for coverage
app.webhooks.routes[0].endpoint({}) route = app.webhooks.routes[0]
assert isinstance(route, APIRoute)
route.endpoint({})
def test_openapi_schema(): def test_openapi_schema():

2
tests/test_tutorial/test_python_types/test_tutorial003.py

@ -9,4 +9,4 @@ def test_get_name_with_age_pass_int():
def test_get_name_with_age_pass_str(): def test_get_name_with_age_pass_str():
assert get_name_with_age("John", "30") == "John is this old: 30" assert get_name_with_age("John", "30") == "John is this old: 30" # ty: ignore[invalid-argument-type]

8
tests/test_tutorial/test_python_types/test_tutorial005.py

@ -4,9 +4,9 @@ from docs_src.python_types.tutorial005_py310 import get_items
def test_get_items(): def test_get_items():
res = get_items( res = get_items(
"item_a", "item_a",
"item_b", "item_b", # ty: ignore[invalid-argument-type]
"item_c", "item_c", # ty: ignore[invalid-argument-type]
"item_d", "item_d", # ty: ignore[invalid-argument-type]
"item_e", "item_e", # ty: ignore[invalid-argument-type]
) )
assert res == ("item_a", "item_b", "item_c", "item_d", "item_e") assert res == ("item_a", "item_b", "item_c", "item_d", "item_e")

8
tests/test_tutorial/test_security/test_tutorial005.py

@ -1,6 +1,7 @@
import importlib import importlib
from functools import lru_cache from functools import lru_cache
from types import ModuleType from types import ModuleType
from typing import Any, cast
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -29,12 +30,13 @@ def cache_verify_password(mod: ModuleType):
f"Module {mod.__name__} does not have attribute 'verify_password'" f"Module {mod.__name__} does not have attribute 'verify_password'"
) )
original_func = mod.verify_password mod_any = cast(Any, mod)
original_func = mod_any.verify_password
cached_func = lru_cache()(original_func) cached_func = lru_cache()(original_func)
mod.verify_password = cached_func mod_any.verify_password = cached_func
yield yield
mod.verify_password = original_func mod_any.verify_password = original_func
def get_access_token( def get_access_token(

14
tests/test_tutorial/test_sql_databases/test_tutorial001.py

@ -1,5 +1,6 @@
import importlib import importlib
import warnings import warnings
from typing import Any, cast
import pytest import pytest
from dirty_equals import IsInt from dirty_equals import IsInt
@ -35,15 +36,18 @@ def get_client(request: pytest.FixtureRequest):
mod = importlib.import_module(f"docs_src.sql_databases.{request.param}") mod = importlib.import_module(f"docs_src.sql_databases.{request.param}")
clear_sqlmodel() clear_sqlmodel()
importlib.reload(mod) importlib.reload(mod)
mod.sqlite_url = "sqlite://" mod_any = cast(Any, mod)
mod.engine = create_engine( mod_any.sqlite_url = "sqlite://"
mod.sqlite_url, connect_args={"check_same_thread": False}, poolclass=StaticPool mod_any.engine = create_engine(
mod_any.sqlite_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
) )
with TestClient(mod.app) as c: with TestClient(mod_any.app) as c:
yield c yield c
# Clean up connection explicitly to avoid resource warning # Clean up connection explicitly to avoid resource warning
mod.engine.dispose() mod_any.engine.dispose()
def test_crud_app(client: TestClient): def test_crud_app(client: TestClient):

14
tests/test_tutorial/test_sql_databases/test_tutorial002.py

@ -1,5 +1,6 @@
import importlib import importlib
import warnings import warnings
from typing import Any, cast
import pytest import pytest
from dirty_equals import IsInt from dirty_equals import IsInt
@ -35,15 +36,18 @@ def get_client(request: pytest.FixtureRequest):
mod = importlib.import_module(f"docs_src.sql_databases.{request.param}") mod = importlib.import_module(f"docs_src.sql_databases.{request.param}")
clear_sqlmodel() clear_sqlmodel()
importlib.reload(mod) importlib.reload(mod)
mod.sqlite_url = "sqlite://" mod_any = cast(Any, mod)
mod.engine = create_engine( mod_any.sqlite_url = "sqlite://"
mod.sqlite_url, connect_args={"check_same_thread": False}, poolclass=StaticPool mod_any.engine = create_engine(
mod_any.sqlite_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
) )
with TestClient(mod.app) as c: with TestClient(mod_any.app) as c:
yield c yield c
# Clean up connection explicitly to avoid resource warning # Clean up connection explicitly to avoid resource warning
mod.engine.dispose() mod_any.engine.dispose()
def test_crud_app(client: TestClient): def test_crud_app(client: TestClient):

5
tests/test_webhooks_security.py

@ -33,7 +33,10 @@ client = TestClient(app)
def test_dummy_webhook(): def test_dummy_webhook():
# Just for coverage # Just for coverage
new_subscription(body={}, token="Bearer 123") new_subscription(
body=Subscription(username="rick", monthly_fee=9.99, start_date=datetime.now()),
token="Bearer 123",
)
def test_openapi_schema(): def test_openapi_schema():

6
tests/utils.py

@ -1,5 +1,5 @@
import importlib
import sys import sys
from importlib.util import find_spec
import pytest import pytest
@ -11,12 +11,12 @@ needs_py314 = pytest.mark.skipif(
) )
needs_orjson = pytest.mark.skipif( needs_orjson = pytest.mark.skipif(
importlib.util.find_spec("orjson") is None, find_spec("orjson") is None,
reason="requires orjson", reason="requires orjson",
) )
needs_ujson = pytest.mark.skipif( needs_ujson = pytest.mark.skipif(
importlib.util.find_spec("ujson") is None, find_spec("ujson") is None,
reason="requires ujson", reason="requires ujson",
) )

Loading…
Cancel
Save