Browse Source

♻️ Refactor internals, cleanup unneeded Pydantic v1 specific logic (#14856)

pull/14857/head
Sebastián Ramírez 5 months ago
committed by GitHub
parent
commit
3c49346238
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 1
      fastapi/_compat/__init__.py
  2. 1
      fastapi/_compat/shared.py
  3. 10
      fastapi/_compat/v2.py
  4. 15
      fastapi/routing.py
  5. 16
      fastapi/utils.py

1
fastapi/_compat/__init__.py

@ -1,4 +1,3 @@
from .shared import PYDANTIC_V2 as PYDANTIC_V2
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1 from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
from .shared import field_annotation_is_scalar as field_annotation_is_scalar from .shared import field_annotation_is_scalar as field_annotation_is_scalar

1
fastapi/_compat/shared.py

@ -28,7 +28,6 @@ else:
) # pyright: ignore[reportAttributeAccessIssue] ) # pyright: ignore[reportAttributeAccessIssue]
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
sequence_annotation_to_type = { sequence_annotation_to_type = {

10
fastapi/_compat/v2.py

@ -500,14 +500,8 @@ def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap: def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
all_flat_models: TypeModelSet = set() flat_models = get_flat_models_from_fields(fields, known_models=set())
return get_model_name_map(flat_models)
v2_model_fields = [field for field in fields if isinstance(field, ModelField)]
v2_flat_models = get_flat_models_from_fields(v2_model_fields, known_models=set())
all_flat_models = all_flat_models.union(v2_flat_models)
model_name_map = get_model_name_map(all_flat_models)
return model_name_map
def get_flat_models_from_model( def get_flat_models_from_model(

15
fastapi/routing.py

@ -59,7 +59,6 @@ from fastapi.exceptions import (
) )
from fastapi.types import DecoratedCallable, IncEx from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import ( from fastapi.utils import (
create_cloned_field,
create_model_field, create_model_field,
generate_unique_id, generate_unique_id,
get_value_or_default, get_value_or_default,
@ -652,20 +651,8 @@ class APIRoute(routing.Route):
type_=self.response_model, type_=self.response_model,
mode="serialization", mode="serialization",
) )
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
# that doesn't have the hashed_password. But because it's a subclass, it
# would pass the validation and be returned as is.
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
self.secure_cloned_response_field: Optional[ModelField] = (
create_cloned_field(self.response_field)
)
else: else:
self.response_field = None # type: ignore self.response_field = None # type: ignore
self.secure_cloned_response_field = None
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text, # if a "form feed" character (page break) is found in the description text,
@ -720,7 +707,7 @@ class APIRoute(routing.Route):
body_field=self.body_field, body_field=self.body_field,
status_code=self.status_code, status_code=self.status_code,
response_class=self.response_class, response_class=self.response_class,
response_field=self.secure_cloned_response_field, response_field=self.response_field,
response_model_include=self.response_model_include, response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude, response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias, response_model_by_alias=self.response_model_by_alias,

16
fastapi/utils.py

@ -1,13 +1,11 @@
import re import re
import warnings import warnings
from collections.abc import MutableMapping
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Optional, Optional,
Union, Union,
) )
from weakref import WeakKeyDictionary
import fastapi import fastapi
from fastapi._compat import ( from fastapi._compat import (
@ -21,7 +19,6 @@ from fastapi._compat import (
) )
from fastapi.datastructures import DefaultPlaceholder, DefaultType from fastapi.datastructures import DefaultPlaceholder, DefaultType
from fastapi.exceptions import FastAPIDeprecationWarning, PydanticV1NotSupportedError from fastapi.exceptions import FastAPIDeprecationWarning, PydanticV1NotSupportedError
from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import Literal from typing_extensions import Literal
@ -30,11 +27,6 @@ from ._compat import v2
if TYPE_CHECKING: # pragma: nocover if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute from .routing import APIRoute
# Cache for `create_cloned_field`
_CLONED_TYPES_CACHE: MutableMapping[type[BaseModel], type[BaseModel]] = (
WeakKeyDictionary()
)
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
if status_code is None: if status_code is None:
@ -97,14 +89,6 @@ def create_model_field(
) from None ) from None
def create_cloned_field(
field: ModelField,
*,
cloned_types: Optional[MutableMapping[type[BaseModel], type[BaseModel]]] = None,
) -> ModelField:
return field
def generate_operation_id_for_path( def generate_operation_id_for_path(
*, name: str, path: str, method: str *, name: str, path: str, method: str
) -> str: # pragma: nocover ) -> str: # pragma: nocover

Loading…
Cancel
Save