diff --git a/docs/src/websockets/__init__.py b/docs/src/websockets/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/docs/src/websockets/tutorial001.py b/docs/src/websockets/tutorial001.py
index 2713550d3..3adfd49c1 100644
--- a/docs/src/websockets/tutorial001.py
+++ b/docs/src/websockets/tutorial001.py
@@ -44,10 +44,9 @@ async def get():
return HTMLResponse(html)
-@app.websocket_route("/ws")
+@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}")
- await websocket.close()
diff --git a/docs/src/websockets/tutorial002.py b/docs/src/websockets/tutorial002.py
new file mode 100644
index 000000000..f57b927f8
--- /dev/null
+++ b/docs/src/websockets/tutorial002.py
@@ -0,0 +1,78 @@
+from fastapi import Cookie, Depends, FastAPI, Header
+from starlette.responses import HTMLResponse
+from starlette.status import WS_1008_POLICY_VIOLATION
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+html = """
+
+
+
+ Chat
+
+
+ WebSocket Chat
+
+
+
+
+
+"""
+
+
+@app.get("/")
+async def get():
+ return HTMLResponse(html)
+
+
+async def get_cookie_or_client(
+ websocket: WebSocket, session: str = Cookie(None), x_client: str = Header(None)
+):
+ if session is None and x_client is None:
+ await websocket.close(code=WS_1008_POLICY_VIOLATION)
+ return session or x_client
+
+
+@app.websocket("/items/{item_id}/ws")
+async def websocket_endpoint(
+ websocket: WebSocket,
+ item_id: int,
+ q: str = None,
+ cookie_or_client: str = Depends(get_cookie_or_client),
+):
+ await websocket.accept()
+ while True:
+ data = await websocket.receive_text()
+ await websocket.send_text(
+ f"Session Cookie or X-Client Header value is: {cookie_or_client}"
+ )
+ if q is not None:
+ await websocket.send_text(f"Query parameter q is: {q}")
+ await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}")
diff --git a/docs/tutorial/websockets.md b/docs/tutorial/websockets.md
index 9bdb39a32..16bba8ee3 100644
--- a/docs/tutorial/websockets.md
+++ b/docs/tutorial/websockets.md
@@ -27,9 +27,9 @@ But it's the simplest way to focus on the server-side of WebSockets and have a w
{!./src/websockets/tutorial001.py!}
```
-## Create a `websocket_route`
+## Create a `websocket`
-In your **FastAPI** application, create a `websocket_route`:
+In your **FastAPI** application, create a `websocket`:
```Python hl_lines="3 47 48"
{!./src/websockets/tutorial001.py!}
@@ -38,15 +38,6 @@ In your **FastAPI** application, create a `websocket_route`:
!!! tip
In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function.
- That is not required, but it's recommended as it will provide you completion and checks inside the function.
-
-
-!!! info
- This `websocket_route` we are using comes directly from Starlette.
-
- That's why the naming convention is not the same as with other API path operations (`get`, `post`, etc).
-
-
## Await for messages and send messages
In your WebSocket route you can `await` for messages and send messages.
@@ -57,6 +48,32 @@ In your WebSocket route you can `await` for messages and send messages.
You can receive and send binary, text, and JSON data.
+## Using `Depends` and others
+
+In WebSocket endpoints you can import from `fastapi` and use:
+
+* `Depends`
+* `Security`
+* `Cookie`
+* `Header`
+* `Path`
+* `Query`
+
+They work the same way as for other FastAPI endpoints/*path operations*:
+
+```Python hl_lines="55 56 57 58 59 60 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78"
+{!./src/websockets/tutorial002.py!}
+```
+
+!!! info
+ In a WebSocket it doesn't really make sense to raise an `HTTPException`. So it's better to close the WebSocket connection directly.
+
+ You can use a closing code from the valid codes defined in the specification.
+
+ In the future, there will be a `WebSocketException` that you will be able to `raise` from anywhere, and add exception handlers for it. It depends on the PR #527 in Starlette.
+
+## More info
+
To learn more about the options, check Starlette's documentation for:
* Applications (`websocket_route`).
diff --git a/fastapi/applications.py b/fastapi/applications.py
index dd5633dd2..7041e91d6 100644
--- a/fastapi/applications.py
+++ b/fastapi/applications.py
@@ -203,6 +203,18 @@ class FastAPI(Starlette):
return decorator
+ def add_api_websocket_route(
+ self, path: str, endpoint: Callable, name: str = None
+ ) -> None:
+ self.router.add_api_websocket_route(path, endpoint, name=name)
+
+ def websocket(self, path: str, name: str = None) -> Callable:
+ def decorator(func: Callable) -> Callable:
+ self.add_api_websocket_route(path, func, name=name)
+ return func
+
+ return decorator
+
def include_router(
self,
router: routing.APIRouter,
diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py
index 8bba5e369..67eb094e8 100644
--- a/fastapi/dependencies/models.py
+++ b/fastapi/dependencies/models.py
@@ -26,6 +26,7 @@ class Dependant:
name: str = None,
call: Callable = None,
request_param_name: str = None,
+ websocket_param_name: str = None,
background_tasks_param_name: str = None,
security_scopes_param_name: str = None,
security_scopes: List[str] = None,
@@ -38,6 +39,7 @@ class Dependant:
self.dependencies = dependencies or []
self.security_requirements = security_schemes or []
self.request_param_name = request_param_name
+ self.websocket_param_name = websocket_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes
self.security_scopes_param_name = security_scopes_param_name
diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py
index 0530fd209..194187f28 100644
--- a/fastapi/dependencies/utils.py
+++ b/fastapi/dependencies/utils.py
@@ -33,6 +33,7 @@ from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import Request
+from starlette.websockets import WebSocket
param_supported_types = (
str,
@@ -184,6 +185,8 @@ def get_dependant(
)
elif lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param_name
+ elif lenient_issubclass(param.annotation, WebSocket):
+ dependant.websocket_param_name = param_name
elif lenient_issubclass(param.annotation, BackgroundTasks):
dependant.background_tasks_param_name = param_name
elif lenient_issubclass(param.annotation, SecurityScopes):
@@ -279,7 +282,7 @@ def is_coroutine_callable(call: Callable) -> bool:
async def solve_dependencies(
*,
- request: Request,
+ request: Union[Request, WebSocket],
dependant: Dependant,
body: Dict[str, Any] = None,
background_tasks: BackgroundTasks = None,
@@ -326,8 +329,10 @@ async def solve_dependencies(
)
values.update(body_values)
errors.extend(body_errors)
- if dependant.request_param_name:
+ if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
+ elif dependant.websocket_param_name and isinstance(request, WebSocket):
+ values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
diff --git a/fastapi/routing.py b/fastapi/routing.py
index ef8d9bed2..c902bb2ad 100644
--- a/fastapi/routing.py
+++ b/fastapi/routing.py
@@ -1,6 +1,7 @@
import asyncio
import inspect
import logging
+import re
from typing import Any, Callable, Dict, List, Optional, Type, Union
from fastapi import params
@@ -21,8 +22,14 @@ from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
-from starlette.routing import compile_path, get_name, request_response
-from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+from starlette.routing import (
+ compile_path,
+ get_name,
+ request_response,
+ websocket_session,
+)
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
+from starlette.websockets import WebSocket
def serialize_response(*, field: Field = None, response: Response) -> Any:
@@ -97,6 +104,35 @@ def get_app(
return app
+def get_websocket_app(dependant: Dependant) -> Callable:
+ async def app(websocket: WebSocket) -> None:
+ values, errors, _ = await solve_dependencies(
+ request=websocket, dependant=dependant
+ )
+ if errors:
+ await websocket.close(code=WS_1008_POLICY_VIOLATION)
+ errors_out = ValidationError(errors)
+ raise HTTPException(
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
+ )
+ assert dependant.call is not None, "dependant.call must me a function"
+ await dependant.call(**values)
+
+ return app
+
+
+class APIWebSocketRoute(routing.WebSocketRoute):
+ def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
+ self.path = path
+ self.endpoint = endpoint
+ self.name = get_name(endpoint) if name is None else name
+ self.dependant = get_dependant(path=path, call=self.endpoint)
+ self.app = websocket_session(get_websocket_app(dependant=self.dependant))
+ regex = "^" + path + "$"
+ regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
+ self.path_regex, self.path_format, self.param_convertors = compile_path(path)
+
+
class APIRoute(routing.Route):
def __init__(
self,
@@ -281,6 +317,19 @@ class APIRouter(routing.Router):
return decorator
+ def add_api_websocket_route(
+ self, path: str, endpoint: Callable, name: str = None
+ ) -> None:
+ route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
+ self.routes.append(route)
+
+ def websocket(self, path: str, name: str = None) -> Callable:
+ def decorator(func: Callable) -> Callable:
+ self.add_api_websocket_route(path, func, name=name)
+ return func
+
+ return decorator
+
def include_router(
self,
router: "APIRouter",
@@ -326,6 +375,10 @@ class APIRouter(routing.Router):
include_in_schema=route.include_in_schema,
name=route.name,
)
+ elif isinstance(route, APIWebSocketRoute):
+ self.add_api_websocket_route(
+ prefix + route.path, route.endpoint, name=route.name
+ )
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name
diff --git a/tests/test_tutorial/test_websockets/__init__.py b/tests/test_tutorial/test_websockets/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/test_tutorial/test_websockets/test_tutorial001.py b/tests/test_tutorial/test_websockets/test_tutorial001.py
new file mode 100644
index 000000000..e886140f8
--- /dev/null
+++ b/tests/test_tutorial/test_websockets/test_tutorial001.py
@@ -0,0 +1,25 @@
+import pytest
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocketDisconnect
+from websockets.tutorial001 import app
+
+client = TestClient(app)
+
+
+def test_main():
+ response = client.get("/")
+ assert response.status_code == 200
+ assert b"" in response.content
+
+
+def test_websocket():
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect("/ws") as websocket:
+ message = "Message one"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}"
+ message = "Message two"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}"
diff --git a/tests/test_tutorial/test_websockets/test_tutorial002.py b/tests/test_tutorial/test_websockets/test_tutorial002.py
new file mode 100644
index 000000000..063f83c84
--- /dev/null
+++ b/tests/test_tutorial/test_websockets/test_tutorial002.py
@@ -0,0 +1,83 @@
+import pytest
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocketDisconnect
+from websockets.tutorial002 import app
+
+client = TestClient(app)
+
+
+def test_main():
+ response = client.get("/")
+ assert response.status_code == 200
+ assert b"" in response.content
+
+
+def test_websocket_with_cookie():
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect(
+ "/items/1/ws", cookies={"session": "fakesession"}
+ ) as websocket:
+ message = "Message one"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: fakesession"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 1"
+ message = "Message two"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: fakesession"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 1"
+
+
+def test_websocket_with_header():
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect(
+ "/items/2/ws", headers={"X-Client": "xmen"}
+ ) as websocket:
+ message = "Message one"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: xmen"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 2"
+ message = "Message two"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: xmen"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 2"
+
+
+def test_websocket_with_header_and_query():
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect(
+ "/items/2/ws?q=baz", headers={"X-Client": "xmen"}
+ ) as websocket:
+ message = "Message one"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: xmen"
+ data = websocket.receive_text()
+ assert data == "Query parameter q is: baz"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 2"
+ message = "Message two"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: xmen"
+ data = websocket.receive_text()
+ assert data == "Query parameter q is: baz"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 2"
+
+
+def test_websocket_no_credentials():
+ with pytest.raises(WebSocketDisconnect):
+ client.websocket_connect("/items/2/ws")
+
+
+def test_websocket_invalid_data():
+ with pytest.raises(WebSocketDisconnect):
+ client.websocket_connect("/items/foo/ws", headers={"X-Client": "xmen"})
diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py
index d3c69ca1f..019d7cbc7 100644
--- a/tests/test_ws_router.py
+++ b/tests/test_ws_router.py
@@ -28,6 +28,13 @@ async def routerprefixindex(websocket: WebSocket):
await websocket.close()
+@router.websocket("/router2")
+async def routerindex(websocket: WebSocket):
+ await websocket.accept()
+ await websocket.send_text("Hello, router!")
+ await websocket.close()
+
+
app.include_router(router)
app.include_router(prefix_router, prefix="/prefix")
@@ -51,3 +58,10 @@ def test_prefix_router():
with client.websocket_connect("/prefix/") as websocket:
data = websocket.receive_text()
assert data == "Hello, router with prefix!"
+
+
+def test_router2():
+ client = TestClient(app)
+ with client.websocket_connect("/router2") as websocket:
+ data = websocket.receive_text()
+ assert data == "Hello, router!"