diff --git a/docs/src/custom_request_and_route/tutorial001.py b/docs/src/custom_request_and_route/tutorial001.py new file mode 100644 index 000000000..cd21582e3 --- /dev/null +++ b/docs/src/custom_request_and_route/tutorial001.py @@ -0,0 +1,37 @@ +import gzip +from typing import Callable, List + +from fastapi import Body, FastAPI +from fastapi.routing import APIRoute +from starlette.requests import Request +from starlette.responses import Response + + +class GzipRequest(Request): + async def body(self) -> bytes: + if not hasattr(self, "_body"): + body = await super().body() + if "gzip" in self.headers.getlist("Content-Encoding"): + body = gzip.decompress(body) + self._body = body + return self._body + + +class GzipRoute(APIRoute): + def get_route_handler(self) -> Callable: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + request = GzipRequest(request.scope, request.receive) + return await original_route_handler(request) + + return custom_route_handler + + +app = FastAPI() +app.router.route_class = GzipRoute + + +@app.post("/sum") +async def sum_numbers(numbers: List[int] = Body(...)): + return {"sum": sum(numbers)} diff --git a/docs/src/custom_request_and_route/tutorial002.py b/docs/src/custom_request_and_route/tutorial002.py new file mode 100644 index 000000000..95cad99b1 --- /dev/null +++ b/docs/src/custom_request_and_route/tutorial002.py @@ -0,0 +1,31 @@ +from typing import Callable, List + +from fastapi import Body, FastAPI, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.routing import APIRoute +from starlette.requests import Request +from starlette.responses import Response + + +class ValidationErrorLoggingRoute(APIRoute): + def get_route_handler(self) -> Callable: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + try: + return await original_route_handler(request) + except RequestValidationError as exc: + body = await request.body() + detail = {"errors": exc.errors(), "body": body.decode()} + raise HTTPException(status_code=422, detail=detail) + + return custom_route_handler + + +app = FastAPI() +app.router.route_class = ValidationErrorLoggingRoute + + +@app.post("/") +async def sum_numbers(numbers: List[int] = Body(...)): + return sum(numbers) diff --git a/docs/src/custom_request_and_route/tutorial003.py b/docs/src/custom_request_and_route/tutorial003.py new file mode 100644 index 000000000..4497736a5 --- /dev/null +++ b/docs/src/custom_request_and_route/tutorial003.py @@ -0,0 +1,41 @@ +import time +from typing import Callable + +from fastapi import APIRouter, FastAPI +from fastapi.routing import APIRoute +from starlette.requests import Request +from starlette.responses import Response + + +class TimedRoute(APIRoute): + def get_route_handler(self) -> Callable: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + before = time.time() + response: Response = await original_route_handler(request) + duration = time.time() - before + response.headers["X-Response-Time"] = str(duration) + print(f"route duration: {duration}") + print(f"route response: {response}") + print(f"route response headers: {response.headers}") + return response + + return custom_route_handler + + +app = FastAPI() +router = APIRouter(route_class=TimedRoute) + + +@app.get("/") +async def not_timed(): + return {"message": "Not timed"} + + +@router.get("/timed") +async def timed(): + return {"message": "It's the time of my life"} + + +app.include_router(router) diff --git a/docs/tutorial/custom-request-and-route.md b/docs/tutorial/custom-request-and-route.md new file mode 100644 index 000000000..49cca992f --- /dev/null +++ b/docs/tutorial/custom-request-and-route.md @@ -0,0 +1,100 @@ +In some cases, you may want to override the logic used by the `Request` and `APIRoute` classes. + +In particular, this may be a good alternative to logic in a middleware. + +For example, if you want to read or manipulate the request body before it is processed by your application. + +!!! danger + This is an "advanced" feature. + + If you are just starting with **FastAPI** you might want to skip this section. + +## Use cases + +Some use cases include: + +* Converting non-JSON request bodies to JSON (e.g. [`msgpack`](https://msgpack.org/index.html)). +* Decompressing gzip-compressed request bodies. +* Automatically logging all request bodies. +* Accessing the request body in an exception handler. + +## Handling custom request body encodings + +Let's see how to make use of a custom `Request` subclass to decompress gzip requests. + +And an `APIRoute` subclass to use that custom request class. + +### Create a custom `GzipRequest` class + +First, we create a `GzipRequest` class, which will overwrite the `Request.body()` method to decompress the body in the presence of an appropriate header. + +If there's no `gzip` in the header, it will not try to decompress the body. + +That way, the same route class can handle gzip compressed or uncompressed requests. + +```Python hl_lines="10 11 12 13 14 15 16 17" +{!./src/custom_request_and_route/tutorial001.py!} +``` + +### Create a custom `GzipRoute` class + +Next, we create a custom subclass of `fastapi.routing.APIRoute` that will make use of the `GzipRequest`. + +This time, it will overwrite the method `APIRoute.get_route_handler()`. + +This method returns a function. And that function is what will receive a request and return a response. + +Here we use it to create a `GzipRequest` from the original request. + +```Python hl_lines="20 21 22 23 24 25 26 27 28" +{!./src/custom_request_and_route/tutorial001.py!} +``` + +!!! note "Technical Details" + A `Request` has a `request.scope` attribute, that's just a Python `dict` containing the metadata related to the request. + + A `Request` also has a `request.receive`, that's a function to "receive" the body of the request. + + The `scope` `dict` and `receive` function are both part of the ASGI specification. + + And those two things, `scope` and `receive`, are what is needed to create a new `Request` instance. + + To learn more about the `Request` check Starlette's docs about Requests. + +The only thing the function returned by `GzipRequest.get_route_handler` does differently is convert the `Request` to a `GzipRequest`. + +Doing this, our `GzipRequest` will take care of decompressing the data (if necessary) before passing it to our *path operations*. + +After that, all of the processing logic is the same. + +But because of our changes in `GzipRequest.body`, the request body will be automatically decompressed when it is loaded by **FastAPI** when needed. + +## Accessing the request body in an exception handler + +We can also use this same approach to access the request body in an exception handler. + +All we need to do is handle the request inside a `try`/`except` block: + +```Python hl_lines="15 17" +{!./src/custom_request_and_route/tutorial002.py!} +``` + +If an exception occurs, the`Request` instance will still be in scope, so we can read and make use of the request body when handling the error: + +```Python hl_lines="18 19 20" +{!./src/custom_request_and_route/tutorial002.py!} +``` + +## Custom `APIRoute` class in a router + +You can also set the `route_class` parameter of an `APIRouter`: + +```Python hl_lines="25" +{!./src/custom_request_and_route/tutorial003.py!} +``` + +In this example, the *path operations* under the `router` will use the custom `TimedRoute` class, and will have an extra `X-Response-Time` header in the response with the time it took to generate the response: + +```Python hl_lines="15 16 17 18 19" +{!./src/custom_request_and_route/tutorial003.py!} +``` diff --git a/fastapi/routing.py b/fastapi/routing.py index b0902310c..2a4e0bc8d 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -65,7 +65,7 @@ def serialize_response( return jsonable_encoder(response) -def get_app( +def get_request_handler( dependant: Dependant, body_field: Field = None, status_code: int = 200, @@ -294,19 +294,20 @@ class APIRoute(routing.Route): ) self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id) self.dependency_overrides_provider = dependency_overrides_provider - self.app = request_response( - get_app( - dependant=self.dependant, - body_field=self.body_field, - status_code=self.status_code, - response_class=self.response_class or JSONResponse, - response_field=self.secure_cloned_response_field, - response_model_include=self.response_model_include, - response_model_exclude=self.response_model_exclude, - response_model_by_alias=self.response_model_by_alias, - response_model_skip_defaults=self.response_model_skip_defaults, - dependency_overrides_provider=self.dependency_overrides_provider, - ) + self.app = request_response(self.get_route_handler()) + + def get_route_handler(self) -> Callable: + return get_request_handler( + dependant=self.dependant, + body_field=self.body_field, + status_code=self.status_code, + response_class=self.response_class or JSONResponse, + response_field=self.secure_cloned_response_field, + response_model_include=self.response_model_include, + response_model_exclude=self.response_model_exclude, + response_model_by_alias=self.response_model_by_alias, + response_model_skip_defaults=self.response_model_skip_defaults, + dependency_overrides_provider=self.dependency_overrides_provider, ) diff --git a/mkdocs.yml b/mkdocs.yml index a61deedb0..b4c2ec152 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,7 @@ nav: - GraphQL: 'tutorial/graphql.md' - WebSockets: 'tutorial/websockets.md' - 'Events: startup - shutdown': 'tutorial/events.md' + - Custom Request and APIRoute class: 'tutorial/custom-request-and-route.md' - Testing: 'tutorial/testing.md' - Testing Dependencies with Overrides: 'tutorial/testing-dependencies.md' - Debugging: 'tutorial/debugging.md' diff --git a/tests/test_tutorial/test_custom_request_and_route/__init__.py b/tests/test_tutorial/test_custom_request_and_route/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py new file mode 100644 index 000000000..2b4b474cb --- /dev/null +++ b/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py @@ -0,0 +1,34 @@ +import gzip +import json + +import pytest +from starlette.requests import Request +from starlette.testclient import TestClient + +from custom_request_and_route.tutorial001 import app + + +@app.get("/check-class") +async def check_gzip_request(request: Request): + return {"request_class": type(request).__name__} + + +client = TestClient(app) + + +@pytest.mark.parametrize("compress", [True, False]) +def test_gzip_request(compress): + n = 1000 + headers = {} + body = [1] * n + data = json.dumps(body).encode() + if compress: + data = gzip.compress(data) + headers["Content-Encoding"] = "gzip" + response = client.post("/sum", data=data, headers=headers) + assert response.json() == {"sum": n} + + +def test_request_class(): + response = client.get("/check-class") + assert response.json() == {"request_class": "GzipRequest"} diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py new file mode 100644 index 000000000..a50760b23 --- /dev/null +++ b/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py @@ -0,0 +1,27 @@ +from starlette.testclient import TestClient + +from custom_request_and_route.tutorial002 import app + +client = TestClient(app) + + +def test_endpoint_works(): + response = client.post("/", json=[1, 2, 3]) + assert response.json() == 6 + + +def test_exception_handler_body_access(): + response = client.post("/", json={"numbers": [1, 2, 3]}) + + assert response.json() == { + "detail": { + "body": '{"numbers": [1, 2, 3]}', + "errors": [ + { + "loc": ["body", "numbers"], + "msg": "value is not a valid list", + "type": "type_error.list", + } + ], + } + } diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial003.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial003.py new file mode 100644 index 000000000..bc4ccacbd --- /dev/null +++ b/tests/test_tutorial/test_custom_request_and_route/test_tutorial003.py @@ -0,0 +1,18 @@ +from starlette.testclient import TestClient + +from custom_request_and_route.tutorial003 import app + +client = TestClient(app) + + +def test_get(): + response = client.get("/") + assert response.json() == {"message": "Not timed"} + assert "X-Response-Time" not in response.headers + + +def test_get_timed(): + response = client.get("/timed") + assert response.json() == {"message": "It's the time of my life"} + assert "X-Response-Time" in response.headers + assert float(response.headers["X-Response-Time"]) > 0