Browse Source

Add support for WebSockets with dependencies, params, etc #166 (#178)

pull/262/head
James Kaplan 6 years ago
committed by Sebastián Ramírez
parent
commit
b087246f26
  1. 0
      docs/src/websockets/__init__.py
  2. 3
      docs/src/websockets/tutorial001.py
  3. 78
      docs/src/websockets/tutorial002.py
  4. 39
      docs/tutorial/websockets.md
  5. 12
      fastapi/applications.py
  6. 2
      fastapi/dependencies/models.py
  7. 9
      fastapi/dependencies/utils.py
  8. 57
      fastapi/routing.py
  9. 0
      tests/test_tutorial/test_websockets/__init__.py
  10. 25
      tests/test_tutorial/test_websockets/test_tutorial001.py
  11. 83
      tests/test_tutorial/test_websockets/test_tutorial002.py
  12. 14
      tests/test_ws_router.py

0
docs/src/websockets/__init__.py

3
docs/src/websockets/tutorial001.py

@ -44,10 +44,9 @@ async def get():
return HTMLResponse(html) return HTMLResponse(html)
@app.websocket_route("/ws") @app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}") await websocket.send_text(f"Message text was: {data}")
await websocket.close()

78
docs/src/websockets/tutorial002.py

@ -0,0 +1,78 @@
from fastapi import Cookie, Depends, FastAPI, Header
from starlette.responses import HTMLResponse
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocket
app = FastAPI()
html = """
<!DOCTYPE html>
<html>
<head>
<title>Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<form action="" onsubmit="sendMessage(event)">
<label>Item ID: <input type="text" id="itemId" autocomplete="off" value="foo"/></label>
<button onclick="connect(event)">Connect</button>
<br>
<label>Message: <input type="text" id="messageText" autocomplete="off"/></label>
<button>Send</button>
</form>
<ul id='messages'>
</ul>
<script>
var ws = null;
function connect(event) {
var input = document.getElementById("itemId")
ws = new WebSocket("ws://localhost:8000/items/" + input.value + "/ws");
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
var content = document.createTextNode(event.data)
message.appendChild(content)
messages.appendChild(message)
};
}
function sendMessage(event) {
var input = document.getElementById("messageText")
ws.send(input.value)
input.value = ''
event.preventDefault()
}
</script>
</body>
</html>
"""
@app.get("/")
async def get():
return HTMLResponse(html)
async def get_cookie_or_client(
websocket: WebSocket, session: str = Cookie(None), x_client: str = Header(None)
):
if session is None and x_client is None:
await websocket.close(code=WS_1008_POLICY_VIOLATION)
return session or x_client
@app.websocket("/items/{item_id}/ws")
async def websocket_endpoint(
websocket: WebSocket,
item_id: int,
q: str = None,
cookie_or_client: str = Depends(get_cookie_or_client),
):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(
f"Session Cookie or X-Client Header value is: {cookie_or_client}"
)
if q is not None:
await websocket.send_text(f"Query parameter q is: {q}")
await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}")

39
docs/tutorial/websockets.md

@ -27,9 +27,9 @@ But it's the simplest way to focus on the server-side of WebSockets and have a w
{!./src/websockets/tutorial001.py!} {!./src/websockets/tutorial001.py!}
``` ```
## Create a `websocket_route` ## Create a `websocket`
In your **FastAPI** application, create a `websocket_route`: In your **FastAPI** application, create a `websocket`:
```Python hl_lines="3 47 48" ```Python hl_lines="3 47 48"
{!./src/websockets/tutorial001.py!} {!./src/websockets/tutorial001.py!}
@ -38,15 +38,6 @@ In your **FastAPI** application, create a `websocket_route`:
!!! tip !!! tip
In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function. In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function.
That is not required, but it's recommended as it will provide you completion and checks inside the function.
!!! info
This `websocket_route` we are using comes directly from <a href="https://www.starlette.io/applications/" target="_blank">Starlette</a>.
That's why the naming convention is not the same as with other API path operations (`get`, `post`, etc).
## Await for messages and send messages ## Await for messages and send messages
In your WebSocket route you can `await` for messages and send messages. In your WebSocket route you can `await` for messages and send messages.
@ -57,6 +48,32 @@ In your WebSocket route you can `await` for messages and send messages.
You can receive and send binary, text, and JSON data. You can receive and send binary, text, and JSON data.
## Using `Depends` and others
In WebSocket endpoints you can import from `fastapi` and use:
* `Depends`
* `Security`
* `Cookie`
* `Header`
* `Path`
* `Query`
They work the same way as for other FastAPI endpoints/*path operations*:
```Python hl_lines="55 56 57 58 59 60 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78"
{!./src/websockets/tutorial002.py!}
```
!!! info
In a WebSocket it doesn't really make sense to raise an `HTTPException`. So it's better to close the WebSocket connection directly.
You can use a closing code from the <a href="https://tools.ietf.org/html/rfc6455#section-7.4.1" target="_blank">valid codes defined in the specification</a>.
In the future, there will be a `WebSocketException` that you will be able to `raise` from anywhere, and add exception handlers for it. It depends on the <a href="https://github.com/encode/starlette/pull/527" target="_blank">PR #527</a> in Starlette.
## More info
To learn more about the options, check Starlette's documentation for: To learn more about the options, check Starlette's documentation for:
* <a href="https://www.starlette.io/applications/" target="_blank">Applications (`websocket_route`)</a>. * <a href="https://www.starlette.io/applications/" target="_blank">Applications (`websocket_route`)</a>.

12
fastapi/applications.py

@ -203,6 +203,18 @@ class FastAPI(Starlette):
return decorator return decorator
def add_api_websocket_route(
self, path: str, endpoint: Callable, name: str = None
) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)
def websocket(self, path: str, name: str = None) -> Callable:
def decorator(func: Callable) -> Callable:
self.add_api_websocket_route(path, func, name=name)
return func
return decorator
def include_router( def include_router(
self, self,
router: routing.APIRouter, router: routing.APIRouter,

2
fastapi/dependencies/models.py

@ -26,6 +26,7 @@ class Dependant:
name: str = None, name: str = None,
call: Callable = None, call: Callable = None,
request_param_name: str = None, request_param_name: str = None,
websocket_param_name: str = None,
background_tasks_param_name: str = None, background_tasks_param_name: str = None,
security_scopes_param_name: str = None, security_scopes_param_name: str = None,
security_scopes: List[str] = None, security_scopes: List[str] = None,
@ -38,6 +39,7 @@ class Dependant:
self.dependencies = dependencies or [] self.dependencies = dependencies or []
self.security_requirements = security_schemes or [] self.security_requirements = security_schemes or []
self.request_param_name = request_param_name self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.background_tasks_param_name = background_tasks_param_name self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes self.security_scopes = security_scopes
self.security_scopes_param_name = security_scopes_param_name self.security_scopes_param_name = security_scopes_param_name

9
fastapi/dependencies/utils.py

@ -33,6 +33,7 @@ from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import Request from starlette.requests import Request
from starlette.websockets import WebSocket
param_supported_types = ( param_supported_types = (
str, str,
@ -184,6 +185,8 @@ def get_dependant(
) )
elif lenient_issubclass(param.annotation, Request): elif lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param_name dependant.request_param_name = param_name
elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param_name
elif lenient_issubclass(param.annotation, BackgroundTasks): elif lenient_issubclass(param.annotation, BackgroundTasks):
dependant.background_tasks_param_name = param_name dependant.background_tasks_param_name = param_name
elif lenient_issubclass(param.annotation, SecurityScopes): elif lenient_issubclass(param.annotation, SecurityScopes):
@ -279,7 +282,7 @@ def is_coroutine_callable(call: Callable) -> bool:
async def solve_dependencies( async def solve_dependencies(
*, *,
request: Request, request: Union[Request, WebSocket],
dependant: Dependant, dependant: Dependant,
body: Dict[str, Any] = None, body: Dict[str, Any] = None,
background_tasks: BackgroundTasks = None, background_tasks: BackgroundTasks = None,
@ -326,8 +329,10 @@ async def solve_dependencies(
) )
values.update(body_values) values.update(body_values)
errors.extend(body_errors) errors.extend(body_errors)
if dependant.request_param_name: if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name: if dependant.background_tasks_param_name:
if background_tasks is None: if background_tasks is None:
background_tasks = BackgroundTasks() background_tasks = BackgroundTasks()

57
fastapi/routing.py

@ -1,6 +1,7 @@
import asyncio import asyncio
import inspect import inspect
import logging import logging
import re
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
from fastapi import params from fastapi import params
@ -21,8 +22,14 @@ from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.routing import compile_path, get_name, request_response from starlette.routing import (
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY compile_path,
get_name,
request_response,
websocket_session,
)
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocket
def serialize_response(*, field: Field = None, response: Response) -> Any: def serialize_response(*, field: Field = None, response: Response) -> Any:
@ -97,6 +104,35 @@ def get_app(
return app return app
def get_websocket_app(dependant: Dependant) -> Callable:
async def app(websocket: WebSocket) -> None:
values, errors, _ = await solve_dependencies(
request=websocket, dependant=dependant
)
if errors:
await websocket.close(code=WS_1008_POLICY_VIOLATION)
errors_out = ValidationError(errors)
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
)
assert dependant.call is not None, "dependant.call must me a function"
await dependant.call(**values)
return app
class APIWebSocketRoute(routing.WebSocketRoute):
def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.dependant = get_dependant(path=path, call=self.endpoint)
self.app = websocket_session(get_websocket_app(dependant=self.dependant))
regex = "^" + path + "$"
regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
class APIRoute(routing.Route): class APIRoute(routing.Route):
def __init__( def __init__(
self, self,
@ -281,6 +317,19 @@ class APIRouter(routing.Router):
return decorator return decorator
def add_api_websocket_route(
self, path: str, endpoint: Callable, name: str = None
) -> None:
route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
self.routes.append(route)
def websocket(self, path: str, name: str = None) -> Callable:
def decorator(func: Callable) -> Callable:
self.add_api_websocket_route(path, func, name=name)
return func
return decorator
def include_router( def include_router(
self, self,
router: "APIRouter", router: "APIRouter",
@ -326,6 +375,10 @@ class APIRouter(routing.Router):
include_in_schema=route.include_in_schema, include_in_schema=route.include_in_schema,
name=route.name, name=route.name,
) )
elif isinstance(route, APIWebSocketRoute):
self.add_api_websocket_route(
prefix + route.path, route.endpoint, name=route.name
)
elif isinstance(route, routing.WebSocketRoute): elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route( self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name prefix + route.path, route.endpoint, name=route.name

0
tests/test_tutorial/test_websockets/__init__.py

25
tests/test_tutorial/test_websockets/test_tutorial001.py

@ -0,0 +1,25 @@
import pytest
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
from websockets.tutorial001 import app
client = TestClient(app)
def test_main():
response = client.get("/")
assert response.status_code == 200
assert b"<!DOCTYPE html>" in response.content
def test_websocket():
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws") as websocket:
message = "Message one"
websocket.send_text(message)
data = websocket.receive_text()
assert data == f"Message text was: {message}"
message = "Message two"
websocket.send_text(message)
data = websocket.receive_text()
assert data == f"Message text was: {message}"

83
tests/test_tutorial/test_websockets/test_tutorial002.py

@ -0,0 +1,83 @@
import pytest
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
from websockets.tutorial002 import app
client = TestClient(app)
def test_main():
response = client.get("/")
assert response.status_code == 200
assert b"<!DOCTYPE html>" in response.content
def test_websocket_with_cookie():
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(
"/items/1/ws", cookies={"session": "fakesession"}
) as websocket:
message = "Message one"
websocket.send_text(message)
data = websocket.receive_text()
assert data == "Session Cookie or X-Client Header value is: fakesession"
data = websocket.receive_text()
assert data == f"Message text was: {message}, for item ID: 1"
message = "Message two"
websocket.send_text(message)
data = websocket.receive_text()
assert data == "Session Cookie or X-Client Header value is: fakesession"
data = websocket.receive_text()
assert data == f"Message text was: {message}, for item ID: 1"
def test_websocket_with_header():
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(
"/items/2/ws", headers={"X-Client": "xmen"}
) as websocket:
message = "Message one"
websocket.send_text(message)
data = websocket.receive_text()
assert data == "Session Cookie or X-Client Header value is: xmen"
data = websocket.receive_text()
assert data == f"Message text was: {message}, for item ID: 2"
message = "Message two"
websocket.send_text(message)
data = websocket.receive_text()
assert data == "Session Cookie or X-Client Header value is: xmen"
data = websocket.receive_text()
assert data == f"Message text was: {message}, for item ID: 2"
def test_websocket_with_header_and_query():
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(
"/items/2/ws?q=baz", headers={"X-Client": "xmen"}
) as websocket:
message = "Message one"
websocket.send_text(message)
data = websocket.receive_text()
assert data == "Session Cookie or X-Client Header value is: xmen"
data = websocket.receive_text()
assert data == "Query parameter q is: baz"
data = websocket.receive_text()
assert data == f"Message text was: {message}, for item ID: 2"
message = "Message two"
websocket.send_text(message)
data = websocket.receive_text()
assert data == "Session Cookie or X-Client Header value is: xmen"
data = websocket.receive_text()
assert data == "Query parameter q is: baz"
data = websocket.receive_text()
assert data == f"Message text was: {message}, for item ID: 2"
def test_websocket_no_credentials():
with pytest.raises(WebSocketDisconnect):
client.websocket_connect("/items/2/ws")
def test_websocket_invalid_data():
with pytest.raises(WebSocketDisconnect):
client.websocket_connect("/items/foo/ws", headers={"X-Client": "xmen"})

14
tests/test_ws_router.py

@ -28,6 +28,13 @@ async def routerprefixindex(websocket: WebSocket):
await websocket.close() await websocket.close()
@router.websocket("/router2")
async def routerindex(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Hello, router!")
await websocket.close()
app.include_router(router) app.include_router(router)
app.include_router(prefix_router, prefix="/prefix") app.include_router(prefix_router, prefix="/prefix")
@ -51,3 +58,10 @@ def test_prefix_router():
with client.websocket_connect("/prefix/") as websocket: with client.websocket_connect("/prefix/") as websocket:
data = websocket.receive_text() data = websocket.receive_text()
assert data == "Hello, router with prefix!" assert data == "Hello, router with prefix!"
def test_router2():
client = TestClient(app)
with client.websocket_connect("/router2") as websocket:
data = websocket.receive_text()
assert data == "Hello, router!"

Loading…
Cancel
Save