committed by
Sebastián Ramírez
10 changed files with 304 additions and 14 deletions
@ -0,0 +1,37 @@ |
|||||
|
import gzip |
||||
|
from typing import Callable, List |
||||
|
|
||||
|
from fastapi import Body, FastAPI |
||||
|
from fastapi.routing import APIRoute |
||||
|
from starlette.requests import Request |
||||
|
from starlette.responses import Response |
||||
|
|
||||
|
|
||||
|
class GzipRequest(Request): |
||||
|
async def body(self) -> bytes: |
||||
|
if not hasattr(self, "_body"): |
||||
|
body = await super().body() |
||||
|
if "gzip" in self.headers.getlist("Content-Encoding"): |
||||
|
body = gzip.decompress(body) |
||||
|
self._body = body |
||||
|
return self._body |
||||
|
|
||||
|
|
||||
|
class GzipRoute(APIRoute): |
||||
|
def get_route_handler(self) -> Callable: |
||||
|
original_route_handler = super().get_route_handler() |
||||
|
|
||||
|
async def custom_route_handler(request: Request) -> Response: |
||||
|
request = GzipRequest(request.scope, request.receive) |
||||
|
return await original_route_handler(request) |
||||
|
|
||||
|
return custom_route_handler |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
app.router.route_class = GzipRoute |
||||
|
|
||||
|
|
||||
|
@app.post("/sum") |
||||
|
async def sum_numbers(numbers: List[int] = Body(...)): |
||||
|
return {"sum": sum(numbers)} |
@ -0,0 +1,31 @@ |
|||||
|
from typing import Callable, List |
||||
|
|
||||
|
from fastapi import Body, FastAPI, HTTPException |
||||
|
from fastapi.exceptions import RequestValidationError |
||||
|
from fastapi.routing import APIRoute |
||||
|
from starlette.requests import Request |
||||
|
from starlette.responses import Response |
||||
|
|
||||
|
|
||||
|
class ValidationErrorLoggingRoute(APIRoute): |
||||
|
def get_route_handler(self) -> Callable: |
||||
|
original_route_handler = super().get_route_handler() |
||||
|
|
||||
|
async def custom_route_handler(request: Request) -> Response: |
||||
|
try: |
||||
|
return await original_route_handler(request) |
||||
|
except RequestValidationError as exc: |
||||
|
body = await request.body() |
||||
|
detail = {"errors": exc.errors(), "body": body.decode()} |
||||
|
raise HTTPException(status_code=422, detail=detail) |
||||
|
|
||||
|
return custom_route_handler |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
app.router.route_class = ValidationErrorLoggingRoute |
||||
|
|
||||
|
|
||||
|
@app.post("/") |
||||
|
async def sum_numbers(numbers: List[int] = Body(...)): |
||||
|
return sum(numbers) |
@ -0,0 +1,41 @@ |
|||||
|
import time |
||||
|
from typing import Callable |
||||
|
|
||||
|
from fastapi import APIRouter, FastAPI |
||||
|
from fastapi.routing import APIRoute |
||||
|
from starlette.requests import Request |
||||
|
from starlette.responses import Response |
||||
|
|
||||
|
|
||||
|
class TimedRoute(APIRoute): |
||||
|
def get_route_handler(self) -> Callable: |
||||
|
original_route_handler = super().get_route_handler() |
||||
|
|
||||
|
async def custom_route_handler(request: Request) -> Response: |
||||
|
before = time.time() |
||||
|
response: Response = await original_route_handler(request) |
||||
|
duration = time.time() - before |
||||
|
response.headers["X-Response-Time"] = str(duration) |
||||
|
print(f"route duration: {duration}") |
||||
|
print(f"route response: {response}") |
||||
|
print(f"route response headers: {response.headers}") |
||||
|
return response |
||||
|
|
||||
|
return custom_route_handler |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
router = APIRouter(route_class=TimedRoute) |
||||
|
|
||||
|
|
||||
|
@app.get("/") |
||||
|
async def not_timed(): |
||||
|
return {"message": "Not timed"} |
||||
|
|
||||
|
|
||||
|
@router.get("/timed") |
||||
|
async def timed(): |
||||
|
return {"message": "It's the time of my life"} |
||||
|
|
||||
|
|
||||
|
app.include_router(router) |
@ -0,0 +1,100 @@ |
|||||
|
In some cases, you may want to override the logic used by the `Request` and `APIRoute` classes. |
||||
|
|
||||
|
In particular, this may be a good alternative to logic in a middleware. |
||||
|
|
||||
|
For example, if you want to read or manipulate the request body before it is processed by your application. |
||||
|
|
||||
|
!!! danger |
||||
|
This is an "advanced" feature. |
||||
|
|
||||
|
If you are just starting with **FastAPI** you might want to skip this section. |
||||
|
|
||||
|
## Use cases |
||||
|
|
||||
|
Some use cases include: |
||||
|
|
||||
|
* Converting non-JSON request bodies to JSON (e.g. [`msgpack`](https://msgpack.org/index.html)). |
||||
|
* Decompressing gzip-compressed request bodies. |
||||
|
* Automatically logging all request bodies. |
||||
|
* Accessing the request body in an exception handler. |
||||
|
|
||||
|
## Handling custom request body encodings |
||||
|
|
||||
|
Let's see how to make use of a custom `Request` subclass to decompress gzip requests. |
||||
|
|
||||
|
And an `APIRoute` subclass to use that custom request class. |
||||
|
|
||||
|
### Create a custom `GzipRequest` class |
||||
|
|
||||
|
First, we create a `GzipRequest` class, which will overwrite the `Request.body()` method to decompress the body in the presence of an appropriate header. |
||||
|
|
||||
|
If there's no `gzip` in the header, it will not try to decompress the body. |
||||
|
|
||||
|
That way, the same route class can handle gzip compressed or uncompressed requests. |
||||
|
|
||||
|
```Python hl_lines="10 11 12 13 14 15 16 17" |
||||
|
{!./src/custom_request_and_route/tutorial001.py!} |
||||
|
``` |
||||
|
|
||||
|
### Create a custom `GzipRoute` class |
||||
|
|
||||
|
Next, we create a custom subclass of `fastapi.routing.APIRoute` that will make use of the `GzipRequest`. |
||||
|
|
||||
|
This time, it will overwrite the method `APIRoute.get_route_handler()`. |
||||
|
|
||||
|
This method returns a function. And that function is what will receive a request and return a response. |
||||
|
|
||||
|
Here we use it to create a `GzipRequest` from the original request. |
||||
|
|
||||
|
```Python hl_lines="20 21 22 23 24 25 26 27 28" |
||||
|
{!./src/custom_request_and_route/tutorial001.py!} |
||||
|
``` |
||||
|
|
||||
|
!!! note "Technical Details" |
||||
|
A `Request` has a `request.scope` attribute, that's just a Python `dict` containing the metadata related to the request. |
||||
|
|
||||
|
A `Request` also has a `request.receive`, that's a function to "receive" the body of the request. |
||||
|
|
||||
|
The `scope` `dict` and `receive` function are both part of the ASGI specification. |
||||
|
|
||||
|
And those two things, `scope` and `receive`, are what is needed to create a new `Request` instance. |
||||
|
|
||||
|
To learn more about the `Request` check <a href="https://www.starlette.io/requests/" target="_blank">Starlette's docs about Requests</a>. |
||||
|
|
||||
|
The only thing the function returned by `GzipRequest.get_route_handler` does differently is convert the `Request` to a `GzipRequest`. |
||||
|
|
||||
|
Doing this, our `GzipRequest` will take care of decompressing the data (if necessary) before passing it to our *path operations*. |
||||
|
|
||||
|
After that, all of the processing logic is the same. |
||||
|
|
||||
|
But because of our changes in `GzipRequest.body`, the request body will be automatically decompressed when it is loaded by **FastAPI** when needed. |
||||
|
|
||||
|
## Accessing the request body in an exception handler |
||||
|
|
||||
|
We can also use this same approach to access the request body in an exception handler. |
||||
|
|
||||
|
All we need to do is handle the request inside a `try`/`except` block: |
||||
|
|
||||
|
```Python hl_lines="15 17" |
||||
|
{!./src/custom_request_and_route/tutorial002.py!} |
||||
|
``` |
||||
|
|
||||
|
If an exception occurs, the`Request` instance will still be in scope, so we can read and make use of the request body when handling the error: |
||||
|
|
||||
|
```Python hl_lines="18 19 20" |
||||
|
{!./src/custom_request_and_route/tutorial002.py!} |
||||
|
``` |
||||
|
|
||||
|
## Custom `APIRoute` class in a router |
||||
|
|
||||
|
You can also set the `route_class` parameter of an `APIRouter`: |
||||
|
|
||||
|
```Python hl_lines="25" |
||||
|
{!./src/custom_request_and_route/tutorial003.py!} |
||||
|
``` |
||||
|
|
||||
|
In this example, the *path operations* under the `router` will use the custom `TimedRoute` class, and will have an extra `X-Response-Time` header in the response with the time it took to generate the response: |
||||
|
|
||||
|
```Python hl_lines="15 16 17 18 19" |
||||
|
{!./src/custom_request_and_route/tutorial003.py!} |
||||
|
``` |
@ -0,0 +1,34 @@ |
|||||
|
import gzip |
||||
|
import json |
||||
|
|
||||
|
import pytest |
||||
|
from starlette.requests import Request |
||||
|
from starlette.testclient import TestClient |
||||
|
|
||||
|
from custom_request_and_route.tutorial001 import app |
||||
|
|
||||
|
|
||||
|
@app.get("/check-class") |
||||
|
async def check_gzip_request(request: Request): |
||||
|
return {"request_class": type(request).__name__} |
||||
|
|
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
@pytest.mark.parametrize("compress", [True, False]) |
||||
|
def test_gzip_request(compress): |
||||
|
n = 1000 |
||||
|
headers = {} |
||||
|
body = [1] * n |
||||
|
data = json.dumps(body).encode() |
||||
|
if compress: |
||||
|
data = gzip.compress(data) |
||||
|
headers["Content-Encoding"] = "gzip" |
||||
|
response = client.post("/sum", data=data, headers=headers) |
||||
|
assert response.json() == {"sum": n} |
||||
|
|
||||
|
|
||||
|
def test_request_class(): |
||||
|
response = client.get("/check-class") |
||||
|
assert response.json() == {"request_class": "GzipRequest"} |
@ -0,0 +1,27 @@ |
|||||
|
from starlette.testclient import TestClient |
||||
|
|
||||
|
from custom_request_and_route.tutorial002 import app |
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_endpoint_works(): |
||||
|
response = client.post("/", json=[1, 2, 3]) |
||||
|
assert response.json() == 6 |
||||
|
|
||||
|
|
||||
|
def test_exception_handler_body_access(): |
||||
|
response = client.post("/", json={"numbers": [1, 2, 3]}) |
||||
|
|
||||
|
assert response.json() == { |
||||
|
"detail": { |
||||
|
"body": '{"numbers": [1, 2, 3]}', |
||||
|
"errors": [ |
||||
|
{ |
||||
|
"loc": ["body", "numbers"], |
||||
|
"msg": "value is not a valid list", |
||||
|
"type": "type_error.list", |
||||
|
} |
||||
|
], |
||||
|
} |
||||
|
} |
@ -0,0 +1,18 @@ |
|||||
|
from starlette.testclient import TestClient |
||||
|
|
||||
|
from custom_request_and_route.tutorial003 import app |
||||
|
|
||||
|
client = TestClient(app) |
||||
|
|
||||
|
|
||||
|
def test_get(): |
||||
|
response = client.get("/") |
||||
|
assert response.json() == {"message": "Not timed"} |
||||
|
assert "X-Response-Time" not in response.headers |
||||
|
|
||||
|
|
||||
|
def test_get_timed(): |
||||
|
response = client.get("/timed") |
||||
|
assert response.json() == {"message": "It's the time of my life"} |
||||
|
assert "X-Response-Time" in response.headers |
||||
|
assert float(response.headers["X-Response-Time"]) > 0 |
Loading…
Reference in new issue