Browse Source

Fix downcast from Starlette UploadFile to FastAPI UploadFile

pull/13605/head
Matthew Batema 4 months ago
committed by Matthew Batema
parent
commit
4aebae3f24
  1. 12
      fastapi/datastructures.py
  2. 16
      fastapi/routing.py
  3. 63
      tests/test_request_uploadfile_type.py

12
fastapi/datastructures.py

@ -5,6 +5,7 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
Optional, Optional,
Self,
Type, Type,
TypeVar, TypeVar,
cast, cast,
@ -72,6 +73,17 @@ class UploadFile(StarletteUploadFile):
Optional[str], Doc("The content type of the request, from the headers.") Optional[str], Doc("The content type of the request, from the headers.")
] ]
@classmethod
def from_starlette(
cls: type[Self], starlette_uploadfile: StarletteUploadFile
) -> Self:
return cls(
file=starlette_uploadfile.file,
size=starlette_uploadfile.size,
filename=starlette_uploadfile.filename,
headers=starlette_uploadfile.headers,
)
async def write( async def write(
self, self,
data: Annotated[ data: Annotated[

16
fastapi/routing.py

@ -31,7 +31,7 @@ from fastapi._compat import (
_normalize_errors, _normalize_errors,
lenient_issubclass, lenient_issubclass,
) )
from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.datastructures import Default, DefaultPlaceholder, UploadFile
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import ( from fastapi.dependencies.utils import (
_should_embed_body_fields, _should_embed_body_fields,
@ -60,6 +60,7 @@ from fastapi.utils import (
from pydantic import BaseModel from pydantic import BaseModel
from starlette import routing from starlette import routing
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.datastructures import UploadFile as StarletteUploadFile
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
@ -209,6 +210,19 @@ async def run_endpoint_function(
# facilitate profiling endpoints, since inner functions are harder to profile. # facilitate profiling endpoints, since inner functions are harder to profile.
assert dependant.call is not None, "dependant.call must be a function" assert dependant.call is not None, "dependant.call must be a function"
# Convert all Starlette UploadFiles to FastAPI UploadFiles
for key, value in values.items():
if isinstance(value, StarletteUploadFile) and not isinstance(value, UploadFile):
values[key] = UploadFile.from_starlette(value)
elif isinstance(value, list):
values[key] = [
UploadFile.from_starlette(item)
if isinstance(item, StarletteUploadFile)
and not isinstance(item, UploadFile)
else item
for item in value
]
if is_coroutine: if is_coroutine:
return await dependant.call(**values) return await dependant.call(**values)
else: else:

63
tests/test_request_uploadfile_type.py

@ -0,0 +1,63 @@
import io
from typing import Any
from fastapi import FastAPI, File, UploadFile
from fastapi.testclient import TestClient
from starlette.datastructures import UploadFile as StarletteUploadFile
app = FastAPI()
@app.post("/uploadfile")
async def uploadfile(uploadfile: UploadFile = File(...)) -> dict[str, Any]:
return {
"filename": uploadfile.filename,
"is_fastapi_uploadfile": isinstance(uploadfile, UploadFile),
"is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile),
"class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}",
}
@app.post("/uploadfiles")
async def uploadfiles(
uploadfiles: list[UploadFile] = File(...),
) -> list[dict[str, Any]]:
return [
{
"filename": uploadfile.filename,
"is_fastapi_uploadfile": isinstance(uploadfile, UploadFile),
"is_starlette_uploadfile": isinstance(uploadfile, StarletteUploadFile),
"class": f"{uploadfile.__class__.__module__}.{uploadfile.__class__.__name__}",
}
for uploadfile in uploadfiles
]
def test_uploadfile_type() -> None:
client = TestClient(app)
files = {"uploadfile": ("example.txt", io.BytesIO(b"test content"), "text/plain")}
response = client.post("/uploadfile/", files=files)
data = response.json()
assert data["filename"] == "example.txt"
assert data["is_fastapi_uploadfile"] is True
assert data["is_starlette_uploadfile"] is True
assert data["class"].startswith("fastapi.")
def test_uploadfiles_type() -> None:
client = TestClient(app)
files = [
("uploadfiles", ("example.txt", io.BytesIO(b"test content"), "text/plain"))
]
response = client.post("/uploadfiles/", files=files)
files_data = response.json()
assert len(files_data) == 1
data = files_data[0]
assert data["filename"] == "example.txt"
assert data["is_fastapi_uploadfile"] is True
assert data["is_starlette_uploadfile"] is True
assert data["class"].startswith("fastapi.")
Loading…
Cancel
Save