diff --git a/docs/src/request_files/tutorial002.py b/docs/src/request_files/tutorial002.py new file mode 100644 index 000000000..bc665b259 --- /dev/null +++ b/docs/src/request_files/tutorial002.py @@ -0,0 +1,33 @@ +from typing import List + +from fastapi import FastAPI, File, UploadFile +from starlette.responses import HTMLResponse + +app = FastAPI() + + +@app.post("/files/") +async def create_files(files: List[bytes] = File(...)): + return {"file_sizes": [len(file) for file in files]} + + +@app.post("/uploadfiles/") +async def create_upload_files(files: List[UploadFile] = File(...)): + return {"filenames": [file.filename for file in files]} + + +@app.get("/") +async def main(): + content = """ + +
+ + +
+
+ + +
+ + """ + return HTMLResponse(content=content) diff --git a/docs/tutorial/request-files.md b/docs/tutorial/request-files.md index bc73dd25e..2e0c2f802 100644 --- a/docs/tutorial/request-files.md +++ b/docs/tutorial/request-files.md @@ -43,7 +43,7 @@ Using `UploadFile` has several advantages over `bytes`: * It uses a "spooled" file: * A file stored in memory up to a maximum size limit, and after passing this limit it will be stored in disk. -* This means that it will work well for large files like images, videos, large binaries, etc. All without consuming all the memory. +* This means that it will work well for large files like images, videos, large binaries, etc. without consuming all the memory. * You can get metadata from the uploaded file. * It has a file-like `async` interface. * It exposes an actual Python `SpooledTemporaryFile` object that you can pass directly to other libraries that expect a file-like object. @@ -107,6 +107,20 @@ The way HTML forms (`
`) sends the data to the server normally uses This is not a limitation of **FastAPI**, it's part of the HTTP protocol. +## Multiple file uploads + +It's possible to upload several files at the same time. + +They would be associated to the same "form field" sent using "form data". + +To use that, declare a `List` of `bytes` or `UploadFile`: + +```Python hl_lines="10 15" +{!./src/request_files/tutorial002.py!} +``` + +You will receive, as declared, a `list` of `bytes` or `UploadFile`s. + ## Recap Use `File` to declare files to be uploaded as input parameters (as form data). diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 4cf737d67..c9f618132 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -31,8 +31,8 @@ from pydantic.schema import get_annotation_from_schema from pydantic.utils import lenient_issubclass from starlette.background import BackgroundTasks from starlette.concurrency import run_in_threadpool -from starlette.datastructures import UploadFile -from starlette.requests import Headers, QueryParams, Request +from starlette.datastructures import FormData, Headers, QueryParams, UploadFile +from starlette.requests import Request param_supported_types = ( str, @@ -47,6 +47,10 @@ param_supported_types = ( Decimal, ) +sequence_shapes = {Shape.LIST, Shape.SET, Shape.TUPLE} +sequence_types = (list, set, tuple) +sequence_shape_to_type = {Shape.LIST: list, Shape.SET: set, Shape.TUPLE: tuple} + def get_sub_dependant( *, param: inspect.Parameter, path: str, security_scopes: List[str] = None @@ -318,7 +322,7 @@ def request_params_to_args( values = {} errors = [] for field in required_params: - if field.shape in {Shape.LIST, Shape.SET, Shape.TUPLE} and isinstance( + if field.shape in sequence_shapes and isinstance( received_params, (QueryParams, Headers) ): value = received_params.getlist(field.alias) @@ -358,11 +362,20 @@ async def request_body_to_args( embed = getattr(field.schema, "embed", None) if len(required_params) == 1 and not embed: received_body = {field.alias: received_body} - elif received_body is None: - received_body = {} for field in required_params: - value = received_body.get(field.alias) - if value is None or (isinstance(field.schema, params.Form) and value == ""): + if field.shape in sequence_shapes and isinstance(received_body, FormData): + value = received_body.getlist(field.alias) + else: + value = received_body.get(field.alias) + if ( + value is None + or (isinstance(field.schema, params.Form) and value == "") + or ( + isinstance(field.schema, params.Form) + and field.shape in sequence_shapes + and len(value) == 0 + ) + ): if field.required: errors.append( ErrorWrapper( @@ -380,6 +393,15 @@ async def request_body_to_args( and isinstance(value, UploadFile) ): value = await value.read() + elif ( + field.shape in sequence_shapes + and isinstance(field.schema, params.File) + and lenient_issubclass(field.type_, bytes) + and isinstance(value, sequence_types) + ): + awaitables = [sub_value.read() for sub_value in value] + contents = await asyncio.gather(*awaitables) + value = sequence_shape_to_type[field.shape](contents) v_, errors_ = field.validate(value, values, loc=("body", field.alias)) if isinstance(errors_, ErrorWrapper): errors.append(errors_) @@ -391,10 +413,14 @@ async def request_body_to_args( def get_schema_compatible_field(*, field: Field) -> Field: + out_field = field if lenient_issubclass(field.type_, UploadFile): - return Field( + use_type: type = bytes + if field.shape in sequence_shapes: + use_type = List[bytes] + out_field = Field( name=field.name, - type_=bytes, + type_=use_type, class_validators=field.class_validators, model_config=field.model_config, default=field.default, @@ -402,10 +428,10 @@ def get_schema_compatible_field(*, field: Field) -> Field: alias=field.alias, schema=field.schema, ) - return field + return out_field -def get_body_field(*, dependant: Dependant, name: str) -> Field: +def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]: flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None diff --git a/fastapi/routing.py b/fastapi/routing.py index 2bdf46ddc..a078662d8 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -53,12 +53,7 @@ def get_app( body = None if body_field: if is_body_form: - raw_body = await request.form() - form_fields = {} - for field, value in raw_body.items(): - form_fields[field] = value - if form_fields: - body = form_fields + body = await request.form() else: body_bytes = await request.body() if body_bytes: diff --git a/tests/test_tutorial/test_request_files/test_tutorial002.py b/tests/test_tutorial/test_request_files/test_tutorial002.py new file mode 100644 index 000000000..e6b7ba479 --- /dev/null +++ b/tests/test_tutorial/test_request_files/test_tutorial002.py @@ -0,0 +1,219 @@ +import os + +from starlette.testclient import TestClient + +from request_files.tutorial002 import app + +client = TestClient(app) + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "Fast API", "version": "0.1.0"}, + "paths": { + "/files/": { + "post": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Create Files", + "operationId": "create_files_files__post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": {"$ref": "#/components/schemas/Body_create_files"} + } + }, + "required": True, + }, + } + }, + "/uploadfiles/": { + "post": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "summary": "Create Upload Files", + "operationId": "create_upload_files_uploadfiles__post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_create_upload_files" + } + } + }, + "required": True, + }, + } + }, + "/": { + "get": { + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "summary": "Main", + "operationId": "main__get", + } + }, + }, + "components": { + "schemas": { + "Body_create_files": { + "title": "Body_create_files", + "required": ["files"], + "type": "object", + "properties": { + "files": { + "title": "Files", + "type": "array", + "items": {"type": "string", "format": "binary"}, + } + }, + }, + "Body_create_upload_files": { + "title": "Body_create_upload_files", + "required": ["files"], + "type": "object", + "properties": { + "files": { + "title": "Files", + "type": "array", + "items": {"type": "string", "format": "binary"}, + } + }, + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + } + }, +} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200 + assert response.json() == openapi_schema + + +file_required = { + "detail": [ + { + "loc": ["body", "files"], + "msg": "field required", + "type": "value_error.missing", + } + ] +} + + +def test_post_form_no_body(): + response = client.post("/files/") + assert response.status_code == 422 + assert response.json() == file_required + + +def test_post_body_json(): + response = client.post("/files/", json={"file": "Foo"}) + print(response) + print(response.content) + assert response.status_code == 422 + assert response.json() == file_required + + +def test_post_files(tmpdir): + path = os.path.join(tmpdir, "test.txt") + with open(path, "wb") as file: + file.write(b"") + path2 = os.path.join(tmpdir, "test2.txt") + with open(path2, "wb") as file: + file.write(b"") + + client = TestClient(app) + response = client.post( + "/files/", + files=( + ("files", ("test.txt", open(path, "rb"))), + ("files", ("test2.txt", open(path2, "rb"))), + ), + ) + assert response.status_code == 200 + assert response.json() == {"file_sizes": [14, 15]} + + +def test_post_upload_file(tmpdir): + path = os.path.join(tmpdir, "test.txt") + with open(path, "wb") as file: + file.write(b"") + path2 = os.path.join(tmpdir, "test2.txt") + with open(path2, "wb") as file: + file.write(b"") + + client = TestClient(app) + response = client.post( + "/uploadfiles/", + files=( + ("files", ("test.txt", open(path, "rb"))), + ("files", ("test2.txt", open(path2, "rb"))), + ), + ) + assert response.status_code == 200 + assert response.json() == {"filenames": ["test.txt", "test2.txt"]} + + +def test_get_root(): + client = TestClient(app) + response = client.get("/") + assert response.status_code == 200 + assert b"