From 4aebae3f24ababe8d421c6271919cf715f253e3a Mon Sep 17 00:00:00 2001 From: Matthew Batema Date: Fri, 11 Apr 2025 10:04:49 -0700 Subject: [PATCH] Fix downcast from Starlette UploadFile to FastAPI UploadFile --- fastapi/datastructures.py | 12 +++++ fastapi/routing.py | 16 ++++++- tests/test_request_uploadfile_type.py | 63 +++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 tests/test_request_uploadfile_type.py diff --git a/fastapi/datastructures.py b/fastapi/datastructures.py index cf8406b0f..47412ab1f 100644 --- a/fastapi/datastructures.py +++ b/fastapi/datastructures.py @@ -5,6 +5,7 @@ from typing import ( Dict, Iterable, Optional, + Self, Type, TypeVar, cast, @@ -72,6 +73,17 @@ class UploadFile(StarletteUploadFile): Optional[str], Doc("The content type of the request, from the headers.") ] + @classmethod + def from_starlette( + cls: type[Self], starlette_uploadfile: StarletteUploadFile + ) -> Self: + return cls( + file=starlette_uploadfile.file, + size=starlette_uploadfile.size, + filename=starlette_uploadfile.filename, + headers=starlette_uploadfile.headers, + ) + async def write( self, data: Annotated[ diff --git a/fastapi/routing.py b/fastapi/routing.py index 54c75a027..546e105fc 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -31,7 +31,7 @@ from fastapi._compat import ( _normalize_errors, lenient_issubclass, ) -from fastapi.datastructures import Default, DefaultPlaceholder +from fastapi.datastructures import Default, DefaultPlaceholder, UploadFile from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import ( _should_embed_body_fields, @@ -60,6 +60,7 @@ from fastapi.utils import ( from pydantic import BaseModel from starlette import routing from starlette.concurrency import run_in_threadpool +from starlette.datastructures import UploadFile as StarletteUploadFile from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -209,6 +210,19 @@ async def run_endpoint_function( # facilitate profiling endpoints, since inner functions are harder to profile. assert dependant.call is not None, "dependant.call must be a function" + # Convert all Starlette UploadFiles to FastAPI UploadFiles + for key, value in values.items(): + if isinstance(value, StarletteUploadFile) and not isinstance(value, UploadFile): + values[key] = UploadFile.from_starlette(value) + elif isinstance(value, list): + values[key] = [ + UploadFile.from_starlette(item) + if isinstance(item, StarletteUploadFile) + and not isinstance(item, UploadFile) + else item + for item in value + ] + if is_coroutine: return await dependant.call(**values) else: diff --git a/tests/test_request_uploadfile_type.py b/tests/test_request_uploadfile_type.py new file mode 100644 index 000000000..adeca0ed6 --- /dev/null +++ b/tests/test_request_uploadfile_type.py @@ -0,0 +1,63 @@ +import io +from typing import Any + +from fastapi import FastAPI, File, UploadFile +from fastapi.testclient import TestClient +from starlette.datastructures import UploadFile as StarletteUploadFile + +app = FastAPI() + + +@app.post("/uploadfile") +async def uploadfile(uploadfile: UploadFile = File(...)) -> dict[str, Any]: + return { + "filename": uploadfile.filename, + "is_fastapi_uploadfile": isinstance(uploadfile, UploadFile), + "is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile), + "class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}", + } + + +@app.post("/uploadfiles") +async def uploadfiles( + uploadfiles: list[UploadFile] = File(...), +) -> list[dict[str, Any]]: + return [ + { + "filename": uploadfile.filename, + "is_fastapi_uploadfile": isinstance(uploadfile, UploadFile), + "is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile), + "class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}", + } + for uploadfile in uploadfiles + ] + + +def test_uploadfile_type() -> None: + client = TestClient(app) + files = {"uploadfile": ("example.txt", io.BytesIO(b"test content"), "text/plain")} + response = client.post("/uploadfile/", files=files) + data = response.json() + + assert data["filename"] == "example.txt" + assert data["is_fastapi_uploadfile"] is True + assert data["is_starlette_uploadfile"] is True + assert data["class"].startswith("fastapi.") + + +def test_uploadfiles_type() -> None: + client = TestClient(app) + files = [ + ("uploadfiles", ("example.txt", io.BytesIO(b"test content"), "text/plain")) + ] + response = client.post("/uploadfiles/", files=files) + files_data = response.json() + + assert len(files_data) == 1 + + data = files_data[0] + + assert data["filename"] == "example.txt" + assert data["is_fastapi_uploadfile"] is True + assert data["is_starlette_uploadfile"] is True + assert data["class"].startswith("fastapi.")