diff --git a/docs/src/websockets/__init__.py b/docs/src/websockets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs/src/websockets/tutorial001.py b/docs/src/websockets/tutorial001.py index 2713550d3..3adfd49c1 100644 --- a/docs/src/websockets/tutorial001.py +++ b/docs/src/websockets/tutorial001.py @@ -44,10 +44,9 @@ async def get(): return HTMLResponse(html) -@app.websocket_route("/ws") +@app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() while True: data = await websocket.receive_text() await websocket.send_text(f"Message text was: {data}") - await websocket.close() diff --git a/docs/src/websockets/tutorial002.py b/docs/src/websockets/tutorial002.py new file mode 100644 index 000000000..f57b927f8 --- /dev/null +++ b/docs/src/websockets/tutorial002.py @@ -0,0 +1,78 @@ +from fastapi import Cookie, Depends, FastAPI, Header +from starlette.responses import HTMLResponse +from starlette.status import WS_1008_POLICY_VIOLATION +from starlette.websockets import WebSocket + +app = FastAPI() + +html = """ + + + + Chat + + +

WebSocket Chat

+
+ + +
+ + +
+ + + + +""" + + +@app.get("/") +async def get(): + return HTMLResponse(html) + + +async def get_cookie_or_client( + websocket: WebSocket, session: str = Cookie(None), x_client: str = Header(None) +): + if session is None and x_client is None: + await websocket.close(code=WS_1008_POLICY_VIOLATION) + return session or x_client + + +@app.websocket("/items/{item_id}/ws") +async def websocket_endpoint( + websocket: WebSocket, + item_id: int, + q: str = None, + cookie_or_client: str = Depends(get_cookie_or_client), +): + await websocket.accept() + while True: + data = await websocket.receive_text() + await websocket.send_text( + f"Session Cookie or X-Client Header value is: {cookie_or_client}" + ) + if q is not None: + await websocket.send_text(f"Query parameter q is: {q}") + await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}") diff --git a/docs/tutorial/websockets.md b/docs/tutorial/websockets.md index 9bdb39a32..16bba8ee3 100644 --- a/docs/tutorial/websockets.md +++ b/docs/tutorial/websockets.md @@ -27,9 +27,9 @@ But it's the simplest way to focus on the server-side of WebSockets and have a w {!./src/websockets/tutorial001.py!} ``` -## Create a `websocket_route` +## Create a `websocket` -In your **FastAPI** application, create a `websocket_route`: +In your **FastAPI** application, create a `websocket`: ```Python hl_lines="3 47 48" {!./src/websockets/tutorial001.py!} @@ -38,15 +38,6 @@ In your **FastAPI** application, create a `websocket_route`: !!! tip In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function. - That is not required, but it's recommended as it will provide you completion and checks inside the function. - - -!!! info - This `websocket_route` we are using comes directly from Starlette. - - That's why the naming convention is not the same as with other API path operations (`get`, `post`, etc). - - ## Await for messages and send messages In your WebSocket route you can `await` for messages and send messages. @@ -57,6 +48,32 @@ In your WebSocket route you can `await` for messages and send messages. You can receive and send binary, text, and JSON data. +## Using `Depends` and others + +In WebSocket endpoints you can import from `fastapi` and use: + +* `Depends` +* `Security` +* `Cookie` +* `Header` +* `Path` +* `Query` + +They work the same way as for other FastAPI endpoints/*path operations*: + +```Python hl_lines="55 56 57 58 59 60 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78" +{!./src/websockets/tutorial002.py!} +``` + +!!! info + In a WebSocket it doesn't really make sense to raise an `HTTPException`. So it's better to close the WebSocket connection directly. + + You can use a closing code from the valid codes defined in the specification. + + In the future, there will be a `WebSocketException` that you will be able to `raise` from anywhere, and add exception handlers for it. It depends on the PR #527 in Starlette. + +## More info + To learn more about the options, check Starlette's documentation for: * Applications (`websocket_route`). diff --git a/fastapi/applications.py b/fastapi/applications.py index dd5633dd2..7041e91d6 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -203,6 +203,18 @@ class FastAPI(Starlette): return decorator + def add_api_websocket_route( + self, path: str, endpoint: Callable, name: str = None + ) -> None: + self.router.add_api_websocket_route(path, endpoint, name=name) + + def websocket(self, path: str, name: str = None) -> Callable: + def decorator(func: Callable) -> Callable: + self.add_api_websocket_route(path, func, name=name) + return func + + return decorator + def include_router( self, router: routing.APIRouter, diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 8bba5e369..67eb094e8 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -26,6 +26,7 @@ class Dependant: name: str = None, call: Callable = None, request_param_name: str = None, + websocket_param_name: str = None, background_tasks_param_name: str = None, security_scopes_param_name: str = None, security_scopes: List[str] = None, @@ -38,6 +39,7 @@ class Dependant: self.dependencies = dependencies or [] self.security_requirements = security_schemes or [] self.request_param_name = request_param_name + self.websocket_param_name = websocket_param_name self.background_tasks_param_name = background_tasks_param_name self.security_scopes = security_scopes self.security_scopes_param_name = security_scopes_param_name diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0530fd209..194187f28 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -33,6 +33,7 @@ from starlette.background import BackgroundTasks from starlette.concurrency import run_in_threadpool from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.requests import Request +from starlette.websockets import WebSocket param_supported_types = ( str, @@ -184,6 +185,8 @@ def get_dependant( ) elif lenient_issubclass(param.annotation, Request): dependant.request_param_name = param_name + elif lenient_issubclass(param.annotation, WebSocket): + dependant.websocket_param_name = param_name elif lenient_issubclass(param.annotation, BackgroundTasks): dependant.background_tasks_param_name = param_name elif lenient_issubclass(param.annotation, SecurityScopes): @@ -279,7 +282,7 @@ def is_coroutine_callable(call: Callable) -> bool: async def solve_dependencies( *, - request: Request, + request: Union[Request, WebSocket], dependant: Dependant, body: Dict[str, Any] = None, background_tasks: BackgroundTasks = None, @@ -326,8 +329,10 @@ async def solve_dependencies( ) values.update(body_values) errors.extend(body_errors) - if dependant.request_param_name: + if dependant.request_param_name and isinstance(request, Request): values[dependant.request_param_name] = request + elif dependant.websocket_param_name and isinstance(request, WebSocket): + values[dependant.websocket_param_name] = request if dependant.background_tasks_param_name: if background_tasks is None: background_tasks = BackgroundTasks() diff --git a/fastapi/routing.py b/fastapi/routing.py index ef8d9bed2..c902bb2ad 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,6 +1,7 @@ import asyncio import inspect import logging +import re from typing import Any, Callable, Dict, List, Optional, Type, Union from fastapi import params @@ -21,8 +22,14 @@ from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.routing import compile_path, get_name, request_response -from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY +from starlette.routing import ( + compile_path, + get_name, + request_response, + websocket_session, +) +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION +from starlette.websockets import WebSocket def serialize_response(*, field: Field = None, response: Response) -> Any: @@ -97,6 +104,35 @@ def get_app( return app +def get_websocket_app(dependant: Dependant) -> Callable: + async def app(websocket: WebSocket) -> None: + values, errors, _ = await solve_dependencies( + request=websocket, dependant=dependant + ) + if errors: + await websocket.close(code=WS_1008_POLICY_VIOLATION) + errors_out = ValidationError(errors) + raise HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors() + ) + assert dependant.call is not None, "dependant.call must me a function" + await dependant.call(**values) + + return app + + +class APIWebSocketRoute(routing.WebSocketRoute): + def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None: + self.path = path + self.endpoint = endpoint + self.name = get_name(endpoint) if name is None else name + self.dependant = get_dependant(path=path, call=self.endpoint) + self.app = websocket_session(get_websocket_app(dependant=self.dependant)) + regex = "^" + path + "$" + regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex) + self.path_regex, self.path_format, self.param_convertors = compile_path(path) + + class APIRoute(routing.Route): def __init__( self, @@ -281,6 +317,19 @@ class APIRouter(routing.Router): return decorator + def add_api_websocket_route( + self, path: str, endpoint: Callable, name: str = None + ) -> None: + route = APIWebSocketRoute(path, endpoint=endpoint, name=name) + self.routes.append(route) + + def websocket(self, path: str, name: str = None) -> Callable: + def decorator(func: Callable) -> Callable: + self.add_api_websocket_route(path, func, name=name) + return func + + return decorator + def include_router( self, router: "APIRouter", @@ -326,6 +375,10 @@ class APIRouter(routing.Router): include_in_schema=route.include_in_schema, name=route.name, ) + elif isinstance(route, APIWebSocketRoute): + self.add_api_websocket_route( + prefix + route.path, route.endpoint, name=route.name + ) elif isinstance(route, routing.WebSocketRoute): self.add_websocket_route( prefix + route.path, route.endpoint, name=route.name diff --git a/tests/test_tutorial/test_websockets/__init__.py b/tests/test_tutorial/test_websockets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_tutorial/test_websockets/test_tutorial001.py b/tests/test_tutorial/test_websockets/test_tutorial001.py new file mode 100644 index 000000000..e886140f8 --- /dev/null +++ b/tests/test_tutorial/test_websockets/test_tutorial001.py @@ -0,0 +1,25 @@ +import pytest +from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect +from websockets.tutorial001 import app + +client = TestClient(app) + + +def test_main(): + response = client.get("/") + assert response.status_code == 200 + assert b"" in response.content + + +def test_websocket(): + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect("/ws") as websocket: + message = "Message one" + websocket.send_text(message) + data = websocket.receive_text() + assert data == f"Message text was: {message}" + message = "Message two" + websocket.send_text(message) + data = websocket.receive_text() + assert data == f"Message text was: {message}" diff --git a/tests/test_tutorial/test_websockets/test_tutorial002.py b/tests/test_tutorial/test_websockets/test_tutorial002.py new file mode 100644 index 000000000..063f83c84 --- /dev/null +++ b/tests/test_tutorial/test_websockets/test_tutorial002.py @@ -0,0 +1,83 @@ +import pytest +from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect +from websockets.tutorial002 import app + +client = TestClient(app) + + +def test_main(): + response = client.get("/") + assert response.status_code == 200 + assert b"" in response.content + + +def test_websocket_with_cookie(): + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect( + "/items/1/ws", cookies={"session": "fakesession"} + ) as websocket: + message = "Message one" + websocket.send_text(message) + data = websocket.receive_text() + assert data == "Session Cookie or X-Client Header value is: fakesession" + data = websocket.receive_text() + assert data == f"Message text was: {message}, for item ID: 1" + message = "Message two" + websocket.send_text(message) + data = websocket.receive_text() + assert data == "Session Cookie or X-Client Header value is: fakesession" + data = websocket.receive_text() + assert data == f"Message text was: {message}, for item ID: 1" + + +def test_websocket_with_header(): + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect( + "/items/2/ws", headers={"X-Client": "xmen"} + ) as websocket: + message = "Message one" + websocket.send_text(message) + data = websocket.receive_text() + assert data == "Session Cookie or X-Client Header value is: xmen" + data = websocket.receive_text() + assert data == f"Message text was: {message}, for item ID: 2" + message = "Message two" + websocket.send_text(message) + data = websocket.receive_text() + assert data == "Session Cookie or X-Client Header value is: xmen" + data = websocket.receive_text() + assert data == f"Message text was: {message}, for item ID: 2" + + +def test_websocket_with_header_and_query(): + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect( + "/items/2/ws?q=baz", headers={"X-Client": "xmen"} + ) as websocket: + message = "Message one" + websocket.send_text(message) + data = websocket.receive_text() + assert data == "Session Cookie or X-Client Header value is: xmen" + data = websocket.receive_text() + assert data == "Query parameter q is: baz" + data = websocket.receive_text() + assert data == f"Message text was: {message}, for item ID: 2" + message = "Message two" + websocket.send_text(message) + data = websocket.receive_text() + assert data == "Session Cookie or X-Client Header value is: xmen" + data = websocket.receive_text() + assert data == "Query parameter q is: baz" + data = websocket.receive_text() + assert data == f"Message text was: {message}, for item ID: 2" + + +def test_websocket_no_credentials(): + with pytest.raises(WebSocketDisconnect): + client.websocket_connect("/items/2/ws") + + +def test_websocket_invalid_data(): + with pytest.raises(WebSocketDisconnect): + client.websocket_connect("/items/foo/ws", headers={"X-Client": "xmen"}) diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index d3c69ca1f..019d7cbc7 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -28,6 +28,13 @@ async def routerprefixindex(websocket: WebSocket): await websocket.close() +@router.websocket("/router2") +async def routerindex(websocket: WebSocket): + await websocket.accept() + await websocket.send_text("Hello, router!") + await websocket.close() + + app.include_router(router) app.include_router(prefix_router, prefix="/prefix") @@ -51,3 +58,10 @@ def test_prefix_router(): with client.websocket_connect("/prefix/") as websocket: data = websocket.receive_text() assert data == "Hello, router with prefix!" + + +def test_router2(): + client = TestClient(app) + with client.websocket_connect("/router2") as websocket: + data = websocket.receive_text() + assert data == "Hello, router!"