```console
-$ pip install -e ."[dev,doc,test]"
+$ pip install -r requirements.txt
---> 100%
```
diff --git a/docs/ru/docs/contributing.md b/docs/ru/docs/contributing.md
index f61ef1cb6..f9b8912e5 100644
--- a/docs/ru/docs/contributing.md
+++ b/docs/ru/docs/contributing.md
@@ -108,7 +108,7 @@ $ python -m pip install --upgrade pip
```console
-$ pip install -e ."[dev,doc,test]"
+$ pip install -r requirements.txt
---> 100%
```
diff --git a/docs/ru/docs/tutorial/index.md b/docs/ru/docs/tutorial/index.md
new file mode 100644
index 000000000..4277a6c4f
--- /dev/null
+++ b/docs/ru/docs/tutorial/index.md
@@ -0,0 +1,80 @@
+# Учебник - Руководство пользователя - Введение
+
+В этом руководстве шаг за шагом показано, как использовать **FastApi** с большинством его функций.
+
+Каждый раздел постепенно основывается на предыдущих, но он структурирован по отдельным темам, так что вы можете перейти непосредственно к конкретной теме для решения ваших конкретных потребностей в API.
+
+Он также создан для использования в качестве будущего справочника.
+
+Так что вы можете вернуться и посмотреть именно то, что вам нужно.
+
+## Запустите код
+
+Все блоки кода можно копировать и использовать напрямую (на самом деле это проверенные файлы Python).
+
+Чтобы запустить любой из примеров, скопируйте код в файл `main.py` и запустите `uvicorn` с параметрами:
+
+
+
+```console
+$ uvicorn main:app --reload
+
+INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
+INFO: Started reloader process [28720]
+INFO: Started server process [28722]
+INFO: Waiting for application startup.
+INFO: Application startup complete.
+```
+
+
+
+**НАСТОЯТЕЛЬНО рекомендуется**, чтобы вы написали или скопировали код, отредактировали его и запустили локально.
+
+Использование кода в вашем редакторе — это то, что действительно показывает вам преимущества FastAPI, видя, как мало кода вам нужно написать, все проверки типов, автодополнение и т.д.
+
+---
+
+## Установка FastAPI
+
+Первый шаг — установить FastAPI.
+
+Для руководства вы, возможно, захотите установить его со всеми дополнительными зависимостями и функциями:
+
+
+
+```console
+$ pip install "fastapi[all]"
+
+---> 100%
+```
+
+
+
+...это также включает `uvicorn`, который вы можете использовать в качестве сервера, который запускает ваш код.
+
+!!! note "Технические детали"
+ Вы также можете установить его по частям.
+
+ Это то, что вы, вероятно, сделаете, когда захотите развернуть свое приложение в рабочей среде:
+
+ ```
+ pip install fastapi
+ ```
+
+ Также установите `uvicorn` для работы в качестве сервера:
+
+ ```
+ pip install "uvicorn[standard]"
+ ```
+
+ И то же самое для каждой из необязательных зависимостей, которые вы хотите использовать.
+
+## Продвинутое руководство пользователя
+
+Существует также **Продвинутое руководство пользователя**, которое вы сможете прочитать после руководства **Учебник - Руководство пользователя**.
+
+**Продвинутое руководство пользователя** основано на этом, использует те же концепции и учит вас некоторым дополнительным функциям.
+
+Но вы должны сначала прочитать **Учебник - Руководство пользователя** (то, что вы читаете прямо сейчас).
+
+Он разработан таким образом, что вы можете создать полноценное приложение, используя только **Учебник - Руководство пользователя**, а затем расширить его различными способами, в зависимости от ваших потребностей, используя некоторые дополнительные идеи из **Продвинутого руководства пользователя**.
diff --git a/docs/ru/docs/tutorial/schema-extra-example.md b/docs/ru/docs/tutorial/schema-extra-example.md
new file mode 100644
index 000000000..a0363b9ba
--- /dev/null
+++ b/docs/ru/docs/tutorial/schema-extra-example.md
@@ -0,0 +1,189 @@
+# Объявление примера запроса данных
+
+Вы можете объявлять примеры данных, которые ваше приложение может получать.
+
+Вот несколько способов, как это можно сделать.
+
+## Pydantic `schema_extra`
+
+Вы можете объявить ключ `example` для модели Pydantic, используя класс `Config` и переменную `schema_extra`, как описано в
Pydantic документации: Настройка схемы:
+
+=== "Python 3.10+"
+
+ ```Python hl_lines="13-21"
+ {!> ../../../docs_src/schema_extra_example/tutorial001_py310.py!}
+ ```
+
+=== "Python 3.6+"
+
+ ```Python hl_lines="15-23"
+ {!> ../../../docs_src/schema_extra_example/tutorial001.py!}
+ ```
+
+Эта дополнительная информация будет включена в **JSON Schema** выходных данных для этой модели, и она будет использоваться в документации к API.
+
+!!! tip Подсказка
+ Вы можете использовать тот же метод для расширения JSON-схемы и добавления своей собственной дополнительной информации.
+
+ Например, вы можете использовать это для добавления дополнительной информации для пользовательского интерфейса в вашем веб-приложении и т.д.
+
+## Дополнительные аргументы поля `Field`
+
+При использовании `Field()` с моделями Pydantic, вы также можете объявлять дополнительную информацию для **JSON Schema**, передавая любые другие произвольные аргументы в функцию.
+
+Вы можете использовать это, чтобы добавить аргумент `example` для каждого поля:
+
+=== "Python 3.10+"
+
+ ```Python hl_lines="2 8-11"
+ {!> ../../../docs_src/schema_extra_example/tutorial002_py310.py!}
+ ```
+
+=== "Python 3.6+"
+
+ ```Python hl_lines="4 10-13"
+ {!> ../../../docs_src/schema_extra_example/tutorial002.py!}
+ ```
+
+!!! warning Внимание
+ Имейте в виду, что эти дополнительные переданные аргументы не добавляют никакой валидации, только дополнительную информацию для документации.
+
+## Использование `example` и `examples` в OpenAPI
+
+При использовании любой из этих функций:
+
+* `Path()`
+* `Query()`
+* `Header()`
+* `Cookie()`
+* `Body()`
+* `Form()`
+* `File()`
+
+вы также можете добавить аргумент, содержащий `example` или группу `examples` с дополнительной информацией, которая будет добавлена в **OpenAPI**.
+
+### Параметр `Body` с аргументом `example`
+
+Здесь мы передаём аргумент `example`, как пример данных ожидаемых в параметре `Body()`:
+
+=== "Python 3.10+"
+
+ ```Python hl_lines="22-27"
+ {!> ../../../docs_src/schema_extra_example/tutorial003_an_py310.py!}
+ ```
+
+=== "Python 3.9+"
+
+ ```Python hl_lines="22-27"
+ {!> ../../../docs_src/schema_extra_example/tutorial003_an_py39.py!}
+ ```
+
+=== "Python 3.6+"
+
+ ```Python hl_lines="23-28"
+ {!> ../../../docs_src/schema_extra_example/tutorial003_an.py!}
+ ```
+
+=== "Python 3.10+ non-Annotated"
+
+ !!! tip Заметка
+ Рекомендуется использовать версию с `Annotated`, если это возможно.
+
+ ```Python hl_lines="18-23"
+ {!> ../../../docs_src/schema_extra_example/tutorial003_py310.py!}
+ ```
+
+=== "Python 3.6+ non-Annotated"
+
+ !!! tip Заметка
+ Рекомендуется использовать версию с `Annotated`, если это возможно.
+
+ ```Python hl_lines="20-25"
+ {!> ../../../docs_src/schema_extra_example/tutorial003.py!}
+ ```
+
+### Аргумент "example" в UI документации
+
+С любым из вышеуказанных методов это будет выглядеть так в `/docs`:
+
+

+
+### `Body` с аргументом `examples`
+
+В качестве альтернативы одному аргументу `example`, вы можете передавать `examples` используя тип данных `dict` с **несколькими примерами**, каждый из которых содержит дополнительную информацию, которая также будет добавлена в **OpenAPI**.
+
+Ключи `dict` указывают на каждый пример, а значения для каждого из них - на еще один тип `dict` с дополнительной информацией.
+
+Каждый конкретный пример типа `dict` в аргументе `examples` может содержать:
+
+* `summary`: Краткое описание для примера.
+* `description`: Полное описание, которое может содержать текст в формате Markdown.
+* `value`: Это конкретный пример, который отображается, например, в виде типа `dict`.
+* `externalValue`: альтернатива параметру `value`, URL-адрес, указывающий на пример. Хотя это может не поддерживаться таким же количеством инструментов разработки и тестирования API, как параметр `value`.
+
+=== "Python 3.10+"
+
+ ```Python hl_lines="23-49"
+ {!> ../../../docs_src/schema_extra_example/tutorial004_an_py310.py!}
+ ```
+
+=== "Python 3.9+"
+
+ ```Python hl_lines="23-49"
+ {!> ../../../docs_src/schema_extra_example/tutorial004_an_py39.py!}
+ ```
+
+=== "Python 3.6+"
+
+ ```Python hl_lines="24-50"
+ {!> ../../../docs_src/schema_extra_example/tutorial004_an.py!}
+ ```
+
+=== "Python 3.10+ non-Annotated"
+
+ !!! tip Заметка
+ Рекомендуется использовать версию с `Annotated`, если это возможно.
+
+ ```Python hl_lines="19-45"
+ {!> ../../../docs_src/schema_extra_example/tutorial004_py310.py!}
+ ```
+
+=== "Python 3.6+ non-Annotated"
+
+ !!! tip Заметка
+ Рекомендуется использовать версию с `Annotated`, если это возможно.
+
+ ```Python hl_lines="21-47"
+ {!> ../../../docs_src/schema_extra_example/tutorial004.py!}
+ ```
+
+### Аргумент "examples" в UI документации
+
+С аргументом `examples`, добавленным в `Body()`, страница документации `/docs` будет выглядеть так:
+
+

+
+## Технические Детали
+
+!!! warning Внимание
+ Эти технические детали относятся к стандартам **JSON Schema** и **OpenAPI**.
+
+ Если предложенные выше идеи уже работают для вас, возможно этого будет достаточно и эти детали вам не потребуются, можете спокойно их пропустить.
+
+Когда вы добавляете пример внутрь модели Pydantic, используя `schema_extra` или `Field(example="something")`, этот пример добавляется в **JSON Schema** для данной модели Pydantic.
+
+И эта **JSON Schema** модели Pydantic включается в **OpenAPI** вашего API, а затем используется в UI документации.
+
+Поля `example` как такового не существует в стандартах **JSON Schema**. В последних версиях JSON-схемы определено поле
`examples`, но OpenAPI 3.0.3 основан на более старой версии JSON-схемы, которая не имела поля `examples`.
+
+Таким образом, OpenAPI 3.0.3 определяет своё собственное поле
`example` для модифицированной версии **JSON Schema**, которую он использует чтобы достичь той же цели (однако это именно поле `example`, а не `examples`), и именно это используется API в UI документации (с интеграцией Swagger UI).
+
+Итак, хотя поле `example` не является частью JSON-схемы, оно является частью настраиваемой версии JSON-схемы в OpenAPI, и именно это поле будет использоваться в UI документации.
+
+Однако, когда вы используете поле `example` или `examples` с любой другой функцией (`Query()`, `Body()`, и т.д.), эти примеры не добавляются в JSON-схему, которая описывает эти данные (даже в собственную версию JSON-схемы OpenAPI), они добавляются непосредственно в объявление *операции пути* в OpenAPI (вне частей OpenAPI, которые используют JSON-схему).
+
+Для функций `Path()`, `Query()`, `Header()`, и `Cookie()`, аргументы `example` или `examples` добавляются в
определение OpenAPI, к объекту `Parameter Object` (в спецификации).
+
+И для функций `Body()`, `File()` и `Form()` аргументы `example` или `examples` аналогично добавляются в
определение OpenAPI, к объекту `Request Body Object`, в поле `content` в объекте `Media Type Object` (в спецификации).
+
+С другой стороны, существует более новая версия OpenAPI: **3.1.0**, недавно выпущенная. Она основана на последней версии JSON-схемы и большинство модификаций из OpenAPI JSON-схемы удалены в обмен на новые возможности из последней версии JSON-схемы, так что все эти мелкие отличия устранены. Тем не менее, Swagger UI в настоящее время не поддерживает OpenAPI 3.1.0, поэтому пока лучше продолжать использовать вышеупомянутые методы.
diff --git a/docs/ru/mkdocs.yml b/docs/ru/mkdocs.yml
index e41333894..9fb56ce1b 100644
--- a/docs/ru/mkdocs.yml
+++ b/docs/ru/mkdocs.yml
@@ -67,6 +67,7 @@ nav:
- fastapi-people.md
- python-types.md
- Учебник - руководство пользователя:
+ - tutorial/index.md
- tutorial/first-steps.md
- tutorial/path-params.md
- tutorial/query-params-str-validations.md
@@ -81,6 +82,7 @@ nav:
- tutorial/body-multiple-params.md
- tutorial/static-files.md
- tutorial/debugging.md
+ - tutorial/schema-extra-example.md
- async.md
- Развёртывание:
- deployment/index.md
diff --git a/docs/zh/docs/advanced/response-change-status-code.md b/docs/zh/docs/advanced/response-change-status-code.md
new file mode 100644
index 000000000..a289cf201
--- /dev/null
+++ b/docs/zh/docs/advanced/response-change-status-code.md
@@ -0,0 +1,31 @@
+# 响应 - 更改状态码
+
+你可能之前已经了解到,你可以设置默认的[响应状态码](../tutorial/response-status-code.md){.internal-link target=_blank}。
+
+但在某些情况下,你需要返回一个不同于默认值的状态码。
+
+## 使用场景
+
+例如,假设你想默认返回一个HTTP状态码为“OK”`200`。
+
+但如果数据不存在,你想创建它,并返回一个HTTP状态码为“CREATED”`201`。
+
+但你仍然希望能够使用`response_model`过滤和转换你返回的数据。
+
+对于这些情况,你可以使用一个`Response`参数。
+
+## 使用 `Response` 参数
+
+你可以在你的*路径操作函数*中声明一个`Response`类型的参数(就像你可以为cookies和头部做的那样)。
+
+然后你可以在这个*临时*响应对象中设置`status_code`。
+
+```Python hl_lines="1 9 12"
+{!../../../docs_src/response_change_status_code/tutorial001.py!}
+```
+
+然后你可以像平常一样返回任何你需要的对象(例如一个`dict`或者一个数据库模型)。如果你声明了一个`response_model`,它仍然会被用来过滤和转换你返回的对象。
+
+**FastAPI**将使用这个临时响应来提取状态码(也包括cookies和头部),并将它们放入包含你返回的值的最终响应中,该响应由任何`response_model`过滤。
+
+你也可以在依赖项中声明`Response`参数,并在其中设置状态码。但请注意,最后设置的状态码将会生效。
diff --git a/docs/zh/docs/advanced/response-headers.md b/docs/zh/docs/advanced/response-headers.md
new file mode 100644
index 000000000..85dab15ac
--- /dev/null
+++ b/docs/zh/docs/advanced/response-headers.md
@@ -0,0 +1,39 @@
+# 响应头
+
+## 使用 `Response` 参数
+
+你可以在你的*路径操作函数*中声明一个`Response`类型的参数(就像你可以为cookies做的那样)。
+
+然后你可以在这个*临时*响应对象中设置头部。
+```Python hl_lines="1 7-8"
+{!../../../docs_src/response_headers/tutorial002.py!}
+```
+
+然后你可以像平常一样返回任何你需要的对象(例如一个`dict`或者一个数据库模型)。如果你声明了一个`response_model`,它仍然会被用来过滤和转换你返回的对象。
+
+**FastAPI**将使用这个临时响应来提取头部(也包括cookies和状态码),并将它们放入包含你返回的值的最终响应中,该响应由任何`response_model`过滤。
+
+你也可以在依赖项中声明`Response`参数,并在其中设置头部(和cookies)。
+
+## 直接返回 `Response`
+
+你也可以在直接返回`Response`时添加头部。
+
+按照[直接返回响应](response-directly.md){.internal-link target=_blank}中所述创建响应,并将头部作为附加参数传递:
+```Python hl_lines="10-12"
+{!../../../docs_src/response_headers/tutorial001.py!}
+```
+
+
+!!! 注意 "技术细节"
+ 你也可以使用`from starlette.responses import Response`或`from starlette.responses import JSONResponse`。
+
+ **FastAPI**提供了与`fastapi.responses`相同的`starlette.responses`,只是为了方便开发者。但是,大多数可用的响应都直接来自Starlette。
+
+ 由于`Response`经常用于设置头部和cookies,因此**FastAPI**还在`fastapi.Response`中提供了它。
+
+## 自定义头部
+
+请注意,可以使用'X-'前缀添加自定义专有头部。
+
+但是,如果你有自定义头部,你希望浏览器中的客户端能够看到它们,你需要将它们添加到你的CORS配置中(在[CORS(跨源资源共享)](../tutorial/cors.md){.internal-link target=_blank}中阅读更多),使用在
Starlette的CORS文档中记录的`expose_headers`参数。
diff --git a/docs/zh/docs/contributing.md b/docs/zh/docs/contributing.md
index 36c3631c4..4ebd67315 100644
--- a/docs/zh/docs/contributing.md
+++ b/docs/zh/docs/contributing.md
@@ -97,7 +97,7 @@ $ python -m venv env
```console
-$ pip install -e ."[dev,doc,test]"
+$ pip install -r requirements.txt
---> 100%
```
diff --git a/docs/zh/mkdocs.yml b/docs/zh/mkdocs.yml
index 75bd2ccab..522c83766 100644
--- a/docs/zh/mkdocs.yml
+++ b/docs/zh/mkdocs.yml
@@ -117,6 +117,8 @@ nav:
- advanced/response-directly.md
- advanced/custom-response.md
- advanced/response-cookies.md
+ - advanced/response-change-status-code.md
+ - advanced/response-headers.md
- advanced/wsgi.md
- contributing.md
- help-fastapi.md
diff --git a/fastapi/__init__.py b/fastapi/__init__.py
index d564d5fa3..46a056363 100644
--- a/fastapi/__init__.py
+++ b/fastapi/__init__.py
@@ -1,6 +1,6 @@
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
-__version__ = "0.96.0"
+__version__ = "0.97.0"
from starlette import status as status
diff --git a/fastapi/applications.py b/fastapi/applications.py
index 8b3a74d3c..298aca921 100644
--- a/fastapi/applications.py
+++ b/fastapi/applications.py
@@ -19,8 +19,9 @@ from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exception_handlers import (
http_exception_handler,
request_validation_exception_handler,
+ websocket_request_validation_exception_handler,
)
-from fastapi.exceptions import RequestValidationError
+from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.logger import logger
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.docs import (
@@ -145,6 +146,11 @@ class FastAPI(Starlette):
self.exception_handlers.setdefault(
RequestValidationError, request_validation_exception_handler
)
+ self.exception_handlers.setdefault(
+ WebSocketRequestValidationError,
+ # Starlette still has incorrect type specification for the handlers
+ websocket_request_validation_exception_handler, # type: ignore
+ )
self.user_middleware: List[Middleware] = (
[] if middleware is None else list(middleware)
@@ -395,15 +401,34 @@ class FastAPI(Starlette):
return decorator
def add_api_websocket_route(
- self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
+ self,
+ path: str,
+ endpoint: Callable[..., Any],
+ name: Optional[str] = None,
+ *,
+ dependencies: Optional[Sequence[Depends]] = None,
) -> None:
- self.router.add_api_websocket_route(path, endpoint, name=name)
+ self.router.add_api_websocket_route(
+ path,
+ endpoint,
+ name=name,
+ dependencies=dependencies,
+ )
def websocket(
- self, path: str, name: Optional[str] = None
+ self,
+ path: str,
+ name: Optional[str] = None,
+ *,
+ dependencies: Optional[Sequence[Depends]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
- self.add_api_websocket_route(path, func, name=name)
+ self.add_api_websocket_route(
+ path,
+ func,
+ name=name,
+ dependencies=dependencies,
+ )
return func
return decorator
diff --git a/fastapi/exception_handlers.py b/fastapi/exception_handlers.py
index 4d7ea5ec2..6c2ba7fed 100644
--- a/fastapi/exception_handlers.py
+++ b/fastapi/exception_handlers.py
@@ -1,10 +1,11 @@
from fastapi.encoders import jsonable_encoder
-from fastapi.exceptions import RequestValidationError
+from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.utils import is_body_allowed_for_status_code
+from fastapi.websockets import WebSocket
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
-from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
@@ -23,3 +24,11 @@ async def request_validation_exception_handler(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": jsonable_encoder(exc.errors())},
)
+
+
+async def websocket_request_validation_exception_handler(
+ websocket: WebSocket, exc: WebSocketRequestValidationError
+) -> None:
+ await websocket.close(
+ code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors())
+ )
diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py
index ca097b1ce..cac5330a2 100644
--- a/fastapi/exceptions.py
+++ b/fastapi/exceptions.py
@@ -11,7 +11,7 @@ class HTTPException(StarletteHTTPException):
self,
status_code: int,
detail: Any = None,
- headers: Optional[Dict[str, Any]] = None,
+ headers: Optional[Dict[str, str]] = None,
) -> None:
super().__init__(status_code=status_code, detail=detail, headers=headers)
diff --git a/fastapi/middleware/asyncexitstack.py b/fastapi/middleware/asyncexitstack.py
index 503a68ac7..30a0ae626 100644
--- a/fastapi/middleware/asyncexitstack.py
+++ b/fastapi/middleware/asyncexitstack.py
@@ -10,19 +10,16 @@ class AsyncExitStackMiddleware:
self.context_name = context_name
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if AsyncExitStack:
- dependency_exception: Optional[Exception] = None
- async with AsyncExitStack() as stack:
- scope[self.context_name] = stack
- try:
- await self.app(scope, receive, send)
- except Exception as e:
- dependency_exception = e
- raise e
- if dependency_exception:
- # This exception was possibly handled by the dependency but it should
- # still bubble up so that the ServerErrorMiddleware can return a 500
- # or the ExceptionMiddleware can catch and handle any other exceptions
- raise dependency_exception
- else:
- await self.app(scope, receive, send) # pragma: no cover
+ dependency_exception: Optional[Exception] = None
+ async with AsyncExitStack() as stack:
+ scope[self.context_name] = stack
+ try:
+ await self.app(scope, receive, send)
+ except Exception as e:
+ dependency_exception = e
+ raise e
+ if dependency_exception:
+ # This exception was possibly handled by the dependency but it should
+ # still bubble up so that the ServerErrorMiddleware can return a 500
+ # or the ExceptionMiddleware can catch and handle any other exceptions
+ raise dependency_exception
diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py
index 11edfe38a..81a24f389 100644
--- a/fastapi/openapi/models.py
+++ b/fastapi/openapi/models.py
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from fastapi.logger import logger
from pydantic import AnyUrl, BaseModel, Field
+from typing_extensions import Literal
try:
import email_validator # type: ignore
@@ -298,18 +299,18 @@ class APIKeyIn(Enum):
class APIKey(SecurityBase):
- type_ = Field(SecuritySchemeType.apiKey, alias="type")
+ type_: SecuritySchemeType = Field(default=SecuritySchemeType.apiKey, alias="type")
in_: APIKeyIn = Field(alias="in")
name: str
class HTTPBase(SecurityBase):
- type_ = Field(SecuritySchemeType.http, alias="type")
+ type_: SecuritySchemeType = Field(default=SecuritySchemeType.http, alias="type")
scheme: str
class HTTPBearer(HTTPBase):
- scheme = "bearer"
+ scheme: Literal["bearer"] = "bearer"
bearerFormat: Optional[str] = None
@@ -349,12 +350,14 @@ class OAuthFlows(BaseModel):
class OAuth2(SecurityBase):
- type_ = Field(SecuritySchemeType.oauth2, alias="type")
+ type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type")
flows: OAuthFlows
class OpenIdConnect(SecurityBase):
- type_ = Field(SecuritySchemeType.openIdConnect, alias="type")
+ type_: SecuritySchemeType = Field(
+ default=SecuritySchemeType.openIdConnect, alias="type"
+ )
openIdConnectUrl: str
diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py
index 6bea7a713..5ec63af80 100644
--- a/fastapi/openapi/utils.py
+++ b/fastapi/openapi/utils.py
@@ -181,7 +181,7 @@ def get_openapi_operation_metadata(
file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
if file_name:
message += f" at {file_name}"
- warnings.warn(message)
+ warnings.warn(message, stacklevel=1)
operation_ids.add(operation_id)
operation["operationId"] = operation_id
if route.deprecated:
@@ -332,10 +332,8 @@ def get_openapi_path(
openapi_response["description"] = description
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
if (all_route_params or route.body_field) and not any(
- [
- status in operation["responses"]
- for status in [http422, "4XX", "default"]
- ]
+ status in operation["responses"]
+ for status in [http422, "4XX", "default"]
):
operation["responses"][http422] = {
"description": "Validation Error",
diff --git a/fastapi/responses.py b/fastapi/responses.py
index 88dba96e8..c0a13b755 100644
--- a/fastapi/responses.py
+++ b/fastapi/responses.py
@@ -27,8 +27,6 @@ class UJSONResponse(JSONResponse):
class ORJSONResponse(JSONResponse):
- media_type = "application/json"
-
def render(self, content: Any) -> bytes:
assert orjson is not None, "orjson must be installed to use ORJSONResponse"
return orjson.dumps(
diff --git a/fastapi/routing.py b/fastapi/routing.py
index 06c71bffa..ec8af99b3 100644
--- a/fastapi/routing.py
+++ b/fastapi/routing.py
@@ -30,7 +30,11 @@ from fastapi.dependencies.utils import (
solve_dependencies,
)
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
-from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
+from fastapi.exceptions import (
+ FastAPIError,
+ RequestValidationError,
+ WebSocketRequestValidationError,
+)
from fastapi.types import DecoratedCallable
from fastapi.utils import (
create_cloned_field,
@@ -48,15 +52,15 @@ from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
-from starlette.routing import BaseRoute, Match
-from starlette.routing import Mount as Mount # noqa
from starlette.routing import (
+ BaseRoute,
+ Match,
compile_path,
get_name,
request_response,
websocket_session,
)
-from starlette.status import WS_1008_POLICY_VIOLATION
+from starlette.routing import Mount as Mount # noqa
from starlette.types import ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket
@@ -283,7 +287,6 @@ def get_websocket_app(
)
values, errors, _, _2, _3 = solved_result
if errors:
- await websocket.close(code=WS_1008_POLICY_VIOLATION)
raise WebSocketRequestValidationError(errors)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values)
@@ -298,13 +301,21 @@ class APIWebSocketRoute(routing.WebSocketRoute):
endpoint: Callable[..., Any],
*,
name: Optional[str] = None,
+ dependencies: Optional[Sequence[params.Depends]] = None,
dependency_overrides_provider: Optional[Any] = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
+ self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
+ for depends in self.dependencies[::-1]:
+ self.dependant.dependencies.insert(
+ 0,
+ get_parameterless_sub_dependant(depends=depends, path=self.path_format),
+ )
+
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
@@ -418,10 +429,7 @@ class APIRoute(routing.Route):
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
- if dependencies:
- self.dependencies = list(dependencies)
- else:
- self.dependencies = []
+ self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
@@ -516,7 +524,7 @@ class APIRouter(routing.Router):
), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
- self.dependencies = list(dependencies or []) or []
+ self.dependencies = list(dependencies or [])
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.responses = responses or {}
@@ -690,21 +698,37 @@ class APIRouter(routing.Router):
return decorator
def add_api_websocket_route(
- self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
+ self,
+ path: str,
+ endpoint: Callable[..., Any],
+ name: Optional[str] = None,
+ *,
+ dependencies: Optional[Sequence[params.Depends]] = None,
) -> None:
+ current_dependencies = self.dependencies.copy()
+ if dependencies:
+ current_dependencies.extend(dependencies)
+
route = APIWebSocketRoute(
self.prefix + path,
endpoint=endpoint,
name=name,
+ dependencies=current_dependencies,
dependency_overrides_provider=self.dependency_overrides_provider,
)
self.routes.append(route)
def websocket(
- self, path: str, name: Optional[str] = None
+ self,
+ path: str,
+ name: Optional[str] = None,
+ *,
+ dependencies: Optional[Sequence[params.Depends]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
- self.add_api_websocket_route(path, func, name=name)
+ self.add_api_websocket_route(
+ path, func, name=name, dependencies=dependencies
+ )
return func
return decorator
@@ -744,7 +768,7 @@ class APIRouter(routing.Router):
path = getattr(r, "path") # noqa: B009
name = getattr(r, "name", "unknown")
if path is not None and not path:
- raise Exception(
+ raise FastAPIError(
f"Prefix and path cannot be both empty (path operation: {name})"
)
if responses is None:
@@ -819,8 +843,16 @@ class APIRouter(routing.Router):
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
+ current_dependencies = []
+ if dependencies:
+ current_dependencies.extend(dependencies)
+ if route.dependencies:
+ current_dependencies.extend(route.dependencies)
self.add_api_websocket_route(
- prefix + route.path, route.endpoint, name=route.name
+ prefix + route.path,
+ route.endpoint,
+ dependencies=current_dependencies,
+ name=route.name,
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py
index 61730187a..8b2c5c080 100644
--- a/fastapi/security/api_key.py
+++ b/fastapi/security/api_key.py
@@ -21,7 +21,9 @@ class APIKeyQuery(APIKeyBase):
auto_error: bool = True,
):
self.model: APIKey = APIKey(
- **{"in": APIKeyIn.query}, name=name, description=description
+ **{"in": APIKeyIn.query}, # type: ignore[arg-type]
+ name=name,
+ description=description,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
@@ -48,7 +50,9 @@ class APIKeyHeader(APIKeyBase):
auto_error: bool = True,
):
self.model: APIKey = APIKey(
- **{"in": APIKeyIn.header}, name=name, description=description
+ **{"in": APIKeyIn.header}, # type: ignore[arg-type]
+ name=name,
+ description=description,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
@@ -75,7 +79,9 @@ class APIKeyCookie(APIKeyBase):
auto_error: bool = True,
):
self.model: APIKey = APIKey(
- **{"in": APIKeyIn.cookie}, name=name, description=description
+ **{"in": APIKeyIn.cookie}, # type: ignore[arg-type]
+ name=name,
+ description=description,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py
index dc75dc9fe..938dec37c 100644
--- a/fastapi/security/oauth2.py
+++ b/fastapi/security/oauth2.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union, cast
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model
@@ -121,7 +121,9 @@ class OAuth2(SecurityBase):
description: Optional[str] = None,
auto_error: bool = True,
):
- self.model = OAuth2Model(flows=flows, description=description)
+ self.model = OAuth2Model(
+ flows=cast(OAuthFlowsModel, flows), description=description
+ )
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
@@ -148,7 +150,9 @@ class OAuth2PasswordBearer(OAuth2):
):
if not scopes:
scopes = {}
- flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
+ flows = OAuthFlowsModel(
+ password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes})
+ )
super().__init__(
flows=flows,
scheme_name=scheme_name,
@@ -185,12 +189,15 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(
- authorizationCode={
- "authorizationUrl": authorizationUrl,
- "tokenUrl": tokenUrl,
- "refreshUrl": refreshUrl,
- "scopes": scopes,
- }
+ authorizationCode=cast(
+ Any,
+ {
+ "authorizationUrl": authorizationUrl,
+ "tokenUrl": tokenUrl,
+ "refreshUrl": refreshUrl,
+ "scopes": scopes,
+ },
+ )
)
super().__init__(
flows=flows,
diff --git a/pyproject.toml b/pyproject.toml
index 3bae6a3ef..547137144 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -51,47 +51,6 @@ Homepage = "https://github.com/tiangolo/fastapi"
Documentation = "https://fastapi.tiangolo.com/"
[project.optional-dependencies]
-test = [
- "pytest >=7.1.3,<8.0.0",
- "coverage[toml] >= 6.5.0,< 8.0",
- "mypy ==0.982",
- "ruff ==0.0.138",
- "black == 23.1.0",
- "isort >=5.0.6,<6.0.0",
- "httpx >=0.23.0,<0.24.0",
- "email_validator >=1.1.1,<2.0.0",
- # TODO: once removing databases from tutorial, upgrade SQLAlchemy
- # probably when including SQLModel
- "sqlalchemy >=1.3.18,<1.4.43",
- "peewee >=3.13.3,<4.0.0",
- "databases[sqlite] >=0.3.2,<0.7.0",
- "orjson >=3.2.1,<4.0.0",
- "ujson >=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0",
- "python-multipart >=0.0.5,<0.0.7",
- "flask >=1.1.2,<3.0.0",
- "anyio[trio] >=3.2.1,<4.0.0",
- "python-jose[cryptography] >=3.3.0,<4.0.0",
- "pyyaml >=5.3.1,<7.0.0",
- "passlib[bcrypt] >=1.7.2,<2.0.0",
-
- # types
- "types-ujson ==5.7.0.1",
- "types-orjson ==3.6.2",
-]
-doc = [
- "mkdocs >=1.1.2,<2.0.0",
- "mkdocs-material >=8.1.4,<9.0.0",
- "mdx-include >=1.4.1,<2.0.0",
- "mkdocs-markdownextradata-plugin >=0.1.7,<0.3.0",
- "typer-cli >=0.0.13,<0.0.14",
- "typer[all] >=0.6.1,<0.8.0",
- "pyyaml >=5.3.1,<7.0.0",
-]
-dev = [
- "ruff ==0.0.138",
- "uvicorn[standard] >=0.12.0,<0.21.0",
- "pre-commit >=2.17.0,<3.0.0",
-]
all = [
"httpx >=0.23.0",
"jinja2 >=2.11.2",
@@ -107,10 +66,6 @@ all = [
[tool.hatch.version]
path = "fastapi/__init__.py"
-[tool.isort]
-profile = "black"
-known_third_party = ["fastapi", "pydantic", "starlette"]
-
[tool.mypy]
strict = true
@@ -166,7 +121,7 @@ select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
- # "I", # isort
+ "I", # isort
"C", # flake8-comprehensions
"B", # flake8-bugbear
]
diff --git a/requirements-docs.txt b/requirements-docs.txt
new file mode 100644
index 000000000..e9d0567ed
--- /dev/null
+++ b/requirements-docs.txt
@@ -0,0 +1,8 @@
+-e .
+mkdocs >=1.1.2,<2.0.0
+mkdocs-material >=8.1.4,<9.0.0
+mdx-include >=1.4.1,<2.0.0
+mkdocs-markdownextradata-plugin >=0.1.7,<0.3.0
+typer-cli >=0.0.13,<0.0.14
+typer[all] >=0.6.1,<0.8.0
+pyyaml >=5.3.1,<7.0.0
diff --git a/requirements-tests.txt b/requirements-tests.txt
new file mode 100644
index 000000000..3ef3c4fd9
--- /dev/null
+++ b/requirements-tests.txt
@@ -0,0 +1,25 @@
+-e .
+pytest >=7.1.3,<8.0.0
+coverage[toml] >= 6.5.0,< 8.0
+mypy ==1.3.0
+ruff ==0.0.272
+black == 23.3.0
+httpx >=0.23.0,<0.24.0
+email_validator >=1.1.1,<2.0.0
+# TODO: once removing databases from tutorial, upgrade SQLAlchemy
+# probably when including SQLModel
+sqlalchemy >=1.3.18,<1.4.43
+peewee >=3.13.3,<4.0.0
+databases[sqlite] >=0.3.2,<0.7.0
+orjson >=3.2.1,<4.0.0
+ujson >=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0
+python-multipart >=0.0.5,<0.0.7
+flask >=1.1.2,<3.0.0
+anyio[trio] >=3.2.1,<4.0.0
+python-jose[cryptography] >=3.3.0,<4.0.0
+pyyaml >=5.3.1,<7.0.0
+passlib[bcrypt] >=1.7.2,<2.0.0
+
+# types
+types-ujson ==5.7.0.1
+types-orjson ==3.6.2
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 000000000..cb9abb44a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,5 @@
+-e .[all]
+-r requirements-tests.txt
+-r requirements-docs.txt
+uvicorn[standard] >=0.12.0,<0.21.0
+pre-commit >=2.17.0,<3.0.0
diff --git a/scripts/build-docs.sh b/scripts/build-docs.sh
index 383ad3f44..ebf864afa 100755
--- a/scripts/build-docs.sh
+++ b/scripts/build-docs.sh
@@ -3,4 +3,6 @@
set -e
set -x
+# Check README.md is up to date
+python ./scripts/docs.py verify-readme
python ./scripts/docs.py build-all
diff --git a/scripts/format.sh b/scripts/format.sh
index 3ac1fead8..3fb3eb4f1 100755
--- a/scripts/format.sh
+++ b/scripts/format.sh
@@ -3,4 +3,3 @@ set -x
ruff fastapi tests docs_src scripts --fix
black fastapi tests docs_src scripts
-isort fastapi tests docs_src scripts
diff --git a/scripts/lint.sh b/scripts/lint.sh
index 0feb973a8..4db5caa96 100755
--- a/scripts/lint.sh
+++ b/scripts/lint.sh
@@ -6,4 +6,3 @@ set -x
mypy fastapi
ruff fastapi tests docs_src scripts
black fastapi tests --check
-isort fastapi tests docs_src scripts --check-only
diff --git a/scripts/test.sh b/scripts/test.sh
index 62449ea41..7d17add8f 100755
--- a/scripts/test.sh
+++ b/scripts/test.sh
@@ -3,7 +3,5 @@
set -e
set -x
-# Check README.md is up to date
-python ./scripts/docs.py verify-readme
export PYTHONPATH=./docs_src
coverage run -m pytest tests ${@}
diff --git a/tests/test_empty_router.py b/tests/test_empty_router.py
index 186ceb347..1a40cbe30 100644
--- a/tests/test_empty_router.py
+++ b/tests/test_empty_router.py
@@ -1,5 +1,6 @@
import pytest
from fastapi import APIRouter, FastAPI
+from fastapi.exceptions import FastAPIError
from fastapi.testclient import TestClient
app = FastAPI()
@@ -31,5 +32,5 @@ def test_use_empty():
def test_include_empty():
# if both include and router.path are empty - it should raise exception
- with pytest.raises(Exception):
+ with pytest.raises(FastAPIError):
app.include_router(router)
diff --git a/tests/test_ws_dependencies.py b/tests/test_ws_dependencies.py
new file mode 100644
index 000000000..ccb1c4b7d
--- /dev/null
+++ b/tests/test_ws_dependencies.py
@@ -0,0 +1,73 @@
+import json
+from typing import List
+
+from fastapi import APIRouter, Depends, FastAPI, WebSocket
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+
+def dependency_list() -> List[str]:
+ return []
+
+
+DepList = Annotated[List[str], Depends(dependency_list)]
+
+
+def create_dependency(name: str):
+ def fun(deps: DepList):
+ deps.append(name)
+
+ return Depends(fun)
+
+
+router = APIRouter(dependencies=[create_dependency("router")])
+prefix_router = APIRouter(dependencies=[create_dependency("prefix_router")])
+app = FastAPI(dependencies=[create_dependency("app")])
+
+
+@app.websocket("/", dependencies=[create_dependency("index")])
+async def index(websocket: WebSocket, deps: DepList):
+ await websocket.accept()
+ await websocket.send_text(json.dumps(deps))
+ await websocket.close()
+
+
+@router.websocket("/router", dependencies=[create_dependency("routerindex")])
+async def routerindex(websocket: WebSocket, deps: DepList):
+ await websocket.accept()
+ await websocket.send_text(json.dumps(deps))
+ await websocket.close()
+
+
+@prefix_router.websocket("/", dependencies=[create_dependency("routerprefixindex")])
+async def routerprefixindex(websocket: WebSocket, deps: DepList):
+ await websocket.accept()
+ await websocket.send_text(json.dumps(deps))
+ await websocket.close()
+
+
+app.include_router(router, dependencies=[create_dependency("router2")])
+app.include_router(
+ prefix_router, prefix="/prefix", dependencies=[create_dependency("prefix_router2")]
+)
+
+
+def test_index():
+ client = TestClient(app)
+ with client.websocket_connect("/") as websocket:
+ data = json.loads(websocket.receive_text())
+ assert data == ["app", "index"]
+
+
+def test_routerindex():
+ client = TestClient(app)
+ with client.websocket_connect("/router") as websocket:
+ data = json.loads(websocket.receive_text())
+ assert data == ["app", "router2", "router", "routerindex"]
+
+
+def test_routerprefixindex():
+ client = TestClient(app)
+ with client.websocket_connect("/prefix/") as websocket:
+ data = json.loads(websocket.receive_text())
+ assert data == ["app", "prefix_router2", "prefix_router", "routerprefixindex"]
diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py
index c312821e9..240a42bb0 100644
--- a/tests/test_ws_router.py
+++ b/tests/test_ws_router.py
@@ -1,4 +1,16 @@
-from fastapi import APIRouter, Depends, FastAPI, WebSocket
+import functools
+
+import pytest
+from fastapi import (
+ APIRouter,
+ Depends,
+ FastAPI,
+ Header,
+ WebSocket,
+ WebSocketDisconnect,
+ status,
+)
+from fastapi.middleware import Middleware
from fastapi.testclient import TestClient
router = APIRouter()
@@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket):
await websocket.close()
-app.include_router(router)
-app.include_router(prefix_router, prefix="/prefix")
-app.include_router(native_prefix_route)
+async def ws_dependency_err():
+ raise NotImplementedError()
+
+
+@router.websocket("/depends-err/")
+async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
+ pass # pragma: no cover
+
+
+async def ws_dependency_validate(x_missing: str = Header()):
+ pass # pragma: no cover
+
+
+@router.websocket("/depends-validate/")
+async def router_ws_depends_validate(
+ websocket: WebSocket, data=Depends(ws_dependency_validate)
+):
+ pass # pragma: no cover
+
+
+class CustomError(Exception):
+ pass
+
+
+@router.websocket("/custom_error/")
+async def router_ws_custom_error(websocket: WebSocket):
+ raise CustomError()
+
+
+def make_app(app=None, **kwargs):
+ app = app or FastAPI(**kwargs)
+ app.include_router(router)
+ app.include_router(prefix_router, prefix="/prefix")
+ app.include_router(native_prefix_route)
+ return app
+
+
+app = make_app(app)
def test_app():
@@ -125,3 +172,100 @@ def test_router_with_params():
assert data == "path/to/file"
data = websocket.receive_text()
assert data == "a_query_param"
+
+
+def test_wrong_uri():
+ """
+ Verify that a websocket connection to a non-existent endpoing returns in a shutdown
+ """
+ client = TestClient(app)
+ with pytest.raises(WebSocketDisconnect) as e:
+ with client.websocket_connect("/no-router/"):
+ pass # pragma: no cover
+ assert e.value.code == status.WS_1000_NORMAL_CLOSURE
+
+
+def websocket_middleware(middleware_func):
+ """
+ Helper to create a Starlette pure websocket middleware
+ """
+
+ def middleware_constructor(app):
+ @functools.wraps(app)
+ async def wrapped_app(scope, receive, send):
+ if scope["type"] != "websocket":
+ return await app(scope, receive, send) # pragma: no cover
+
+ async def call_next():
+ return await app(scope, receive, send)
+
+ websocket = WebSocket(scope, receive=receive, send=send)
+ return await middleware_func(websocket, call_next)
+
+ return wrapped_app
+
+ return middleware_constructor
+
+
+def test_depend_validation():
+ """
+ Verify that a validation in a dependency invokes the correct exception handler
+ """
+ caught = []
+
+ @websocket_middleware
+ async def catcher(websocket, call_next):
+ try:
+ return await call_next()
+ except Exception as e: # pragma: no cover
+ caught.append(e)
+ raise
+
+ myapp = make_app(middleware=[Middleware(catcher)])
+
+ client = TestClient(myapp)
+ with pytest.raises(WebSocketDisconnect) as e:
+ with client.websocket_connect("/depends-validate/"):
+ pass # pragma: no cover
+ # the validation error does produce a close message
+ assert e.value.code == status.WS_1008_POLICY_VIOLATION
+ # and no error is leaked
+ assert caught == []
+
+
+def test_depend_err_middleware():
+ """
+ Verify that it is possible to write custom WebSocket middleware to catch errors
+ """
+
+ @websocket_middleware
+ async def errorhandler(websocket: WebSocket, call_next):
+ try:
+ return await call_next()
+ except Exception as e:
+ await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))
+
+ myapp = make_app(middleware=[Middleware(errorhandler)])
+ client = TestClient(myapp)
+ with pytest.raises(WebSocketDisconnect) as e:
+ with client.websocket_connect("/depends-err/"):
+ pass # pragma: no cover
+ assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
+ assert "NotImplementedError" in e.value.reason
+
+
+def test_depend_err_handler():
+ """
+ Verify that it is possible to write custom WebSocket middleware to catch errors
+ """
+
+ async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
+ await websocket.close(1002, "foo")
+
+ myapp = make_app(exception_handlers={CustomError: custom_handler})
+ client = TestClient(myapp)
+ with pytest.raises(WebSocketDisconnect) as e:
+ with client.websocket_connect("/custom_error/"):
+ pass # pragma: no cover
+ assert e.value.code == 1002
+ assert "foo" in e.value.reason