From b0eedbb5804a6ac32e4ee8d029d462d950ff8848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 11 Sep 2024 09:45:30 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Improve=20performance=20in?= =?UTF-8?q?=20request=20body=20parsing=20with=20a=20cache=20for=20internal?= =?UTF-8?q?=20model=20fields=20(#12184)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/_compat.py | 6 ++++++ fastapi/dependencies/utils.py | 4 ++-- tests/test_compat.py | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index f940d6597..4b07b44fa 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -2,6 +2,7 @@ from collections import deque from copy import copy from dataclasses import dataclass, is_dataclass from enum import Enum +from functools import lru_cache from typing import ( Any, Callable, @@ -649,3 +650,8 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool: is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation) for sub_annotation in get_args(annotation) ) + + +@lru_cache +def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]: + return get_model_fields(model) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 6083b7319..f18eace9d 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -32,8 +32,8 @@ from fastapi._compat import ( evaluate_forwardref, field_annotation_is_scalar, get_annotation_from_field_info, + get_cached_model_fields, get_missing_field_error, - get_model_fields, is_bytes_field, is_bytes_sequence_field, is_scalar_field, @@ -810,7 +810,7 @@ async def request_body_to_args( fields_to_extract: List[ModelField] = body_fields if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_model_fields(first_field.type_) + fields_to_extract = get_cached_model_fields(first_field.type_) if isinstance(received_body, FormData): body_to_process = await _extract_form_body(fields_to_extract, received_body) diff --git a/tests/test_compat.py b/tests/test_compat.py index 270475bf3..f4a3093c5 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -5,6 +5,7 @@ from fastapi._compat import ( ModelField, Undefined, _get_model_config, + get_cached_model_fields, get_model_fields, is_bytes_sequence_annotation, is_scalar_field, @@ -102,3 +103,18 @@ def test_is_pv1_scalar_field(): fields = get_model_fields(Model) assert not is_scalar_field(fields[0]) + + +def test_get_model_fields_cached(): + class Model(BaseModel): + foo: str + + non_cached_fields = get_model_fields(Model) + non_cached_fields2 = get_model_fields(Model) + cached_fields = get_cached_model_fields(Model) + cached_fields2 = get_cached_model_fields(Model) + for f1, f2 in zip(cached_fields, cached_fields2): + assert f1 is f2 + + assert non_cached_fields is not non_cached_fields2 + assert cached_fields is cached_fields2