You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

294 lines
8.8 KiB

from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import (
Annotated,
Any,
BinaryIO,
TypeVar,
cast,
)
from annotated_doc import Doc
from pydantic import GetJsonSchemaHandler
from starlette.datastructures import URL as URL # noqa: F401
from starlette.datastructures import Address as Address # noqa: F401
from starlette.datastructures import FormData as FormData # noqa: F401
from starlette.datastructures import Headers as Headers # noqa: F401
from starlette.datastructures import QueryParams as QueryParams # noqa: F401
from starlette.datastructures import State as State # noqa: F401
from starlette.datastructures import UploadFile as StarletteUploadFile
from starlette.exceptions import HTTPException
@dataclass
class ValidationResult:
is_valid: Annotated[
bool,
Doc("Whether the file passed all validation checks."),
]
file_size: Annotated[
int | None,
Doc("The size of the file in bytes."),
] = None
content_type: Annotated[
str | None,
Doc("The content type of the request, from the headers."),
] = None
class UploadFile(StarletteUploadFile):
"""
A file uploaded in a request.
Define it as a *path operation function* (or dependency) parameter.
If you are using a regular `def` function, you can use the `upload_file.file`
attribute to access the raw standard Python file (blocking, not async), useful and
needed for non-async code.
Read more about it in the
[FastAPI docs for Request Files](https://fastapi.tiangolo.com/tutorial/request-files/).
## Example
```python
from typing import Annotated
from fastapi import FastAPI, File, UploadFile
app = FastAPI()
@app.post("/files/")
async def create_file(file: Annotated[bytes, File()]):
return {"file_size": len(file)}
@app.post("/uploadfile/")
async def create_upload_file(file: UploadFile):
return {"filename": file.filename}
```
"""
file: Annotated[
BinaryIO,
Doc("The standard Python file object (non-async)."),
]
filename: Annotated[str | None, Doc("The original file name.")]
size: Annotated[int | None, Doc("The size of the file in bytes.")]
headers: Annotated[Headers, Doc("The headers of the request.")]
content_type: Annotated[
str | None, Doc("The content type of the request, from the headers.")
]
max_size: Annotated[
int | None,
Doc(
"""
Maximum allowed file size in bytes. When set, files exceeding this size
will fail validation with a 413 error.
"""
),
] = None
allowed_content_types: Annotated[
list[str] | None,
Doc(
"""
List of allowed MIME types for the uploaded file. When set, files with
a content type not in this list will fail validation with a 415 error.
"""
),
] = None
async def write(
self,
data: Annotated[
bytes,
Doc(
"""
The bytes to write to the file.
"""
),
],
) -> None:
"""
Write some bytes to the file.
You normally wouldn't use this from a file you read in a request.
To be awaitable, compatible with async, this is run in threadpool.
"""
return await super().write(data)
async def read(
self,
size: Annotated[
int,
Doc(
"""
The number of bytes to read from the file.
"""
),
] = -1,
) -> bytes:
"""
Read some bytes from the file.
To be awaitable, compatible with async, this is run in threadpool.
"""
return await super().read(size)
async def seek(
self,
offset: Annotated[
int,
Doc(
"""
The position in bytes to seek to in the file.
"""
),
],
) -> None:
"""
Move to a position in the file.
Any next read or write will be done from that position.
To be awaitable, compatible with async, this is run in threadpool.
"""
return await super().seek(offset)
async def close(self) -> None:
"""
Close the file.
To be awaitable, compatible with async, this is run in threadpool.
"""
return await super().close()
def __init__(
self,
file: Annotated[
BinaryIO,
Doc("The file-like object to upload."),
],
*,
size: Annotated[
int | None,
Doc("The size of the file in bytes."),
] = None,
filename: Annotated[
str | None,
Doc("The original file name."),
] = None,
headers: Annotated[
Headers | None,
Doc("The headers of the request."),
] = None,
max_size: Annotated[
int | None,
Doc(
"""
Maximum allowed file size in bytes. When set, the `validate` method
will raise ``HTTPException`` 413 if the file exceeds this limit.
"""
),
] = None,
allowed_content_types: Annotated[
list[str] | None,
Doc(
"""
List of allowed MIME types. When set, the `validate` method will
raise ``HTTPException`` 415 if the file type is not in this list.
"""
),
] = None,
) -> None:
super().__init__(file=file, size=size, filename=filename, headers=headers)
self.max_size = max_size
self.allowed_content_types = allowed_content_types
async def validate(
self,
) -> ValidationResult:
"""
Validate the file against the configured size and content type constraints.
If validation passes, returns a :class:`ValidationResult` with ``is_valid``
set to ``True`` and the file's metadata.
If ``max_size`` is set and the file exceeds it, raises ``HTTPException``
with status 413.
If ``allowed_content_types`` is set and the file's content type is not in
the list, raises ``HTTPException`` with status 415.
"""
if self.max_size is not None and self.size is not None and self.size > self.max_size:
raise HTTPException(
status_code=413,
detail=f"File size {self.size} exceeds the maximum allowed size of {self.max_size} bytes",
)
if self.allowed_content_types is not None and self.content_type not in self.allowed_content_types:
raise HTTPException(
status_code=415,
detail=(
f"Content type '{self.content_type}' is not allowed. "
f"Must be one of: {self.allowed_content_types}"
),
)
return ValidationResult(is_valid=True, file_size=self.size, content_type=self.content_type)
@classmethod
def _validate(cls, __input_value: Any, _: Any) -> "UploadFile":
if not isinstance(__input_value, StarletteUploadFile):
raise ValueError(f"Expected UploadFile, received: {type(__input_value)}")
return cast(UploadFile, __input_value)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: Mapping[str, Any], handler: GetJsonSchemaHandler
) -> dict[str, Any]:
return {"type": "string", "contentMediaType": "application/octet-stream"}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type[Any], handler: Callable[[Any], Mapping[str, Any]]
) -> Mapping[str, Any]:
from ._compat.v2 import with_info_plain_validator_function
return with_info_plain_validator_function(cls._validate)
class DefaultPlaceholder:
"""
You shouldn't use this class directly.
It's used internally to recognize when a default value has been overwritten, even
if the overridden default value was truthy.
"""
def __init__(self, value: Any):
self.value = value
def __bool__(self) -> bool:
return bool(self.value)
def __eq__(self, o: object) -> bool:
return isinstance(o, DefaultPlaceholder) and o.value == self.value
DefaultType = TypeVar("DefaultType")
def Default(value: DefaultType) -> DefaultType:
"""
You shouldn't use this function directly.
It's used internally to recognize when a default value has been overwritten, even
if the overridden default value was truthy.
"""
return DefaultPlaceholder(value) # type: ignore
# Sentinel for "parameter not provided" in Param/FieldInfo.
# Typed as None to satisfy ty
_Unset = Default(None)