Browse Source

Add support for multi-file uploads (#158)

pull/163/head
Sebastián Ramírez 6 years ago
committed by GitHub
parent
commit
aad6b123f7
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 33
      docs/src/request_files/tutorial002.py
  2. 16
      docs/tutorial/request-files.md
  3. 48
      fastapi/dependencies/utils.py
  4. 7
      fastapi/routing.py
  5. 219
      tests/test_tutorial/test_request_files/test_tutorial002.py

33
docs/src/request_files/tutorial002.py

@ -0,0 +1,33 @@
from typing import List
from fastapi import FastAPI, File, UploadFile
from starlette.responses import HTMLResponse
app = FastAPI()
@app.post("/files/")
async def create_files(files: List[bytes] = File(...)):
return {"file_sizes": [len(file) for file in files]}
@app.post("/uploadfiles/")
async def create_upload_files(files: List[UploadFile] = File(...)):
return {"filenames": [file.filename for file in files]}
@app.get("/")
async def main():
content = """
<body>
<form action="/files/" enctype="multipart/form-data" method="post">
<input name="files" type="file" multiple>
<input type="submit">
</form>
<form action="/uploadfiles/" enctype="multipart/form-data" method="post">
<input name="files" type="file" multiple>
<input type="submit">
</form>
</body>
"""
return HTMLResponse(content=content)

16
docs/tutorial/request-files.md

@ -43,7 +43,7 @@ Using `UploadFile` has several advantages over `bytes`:
* It uses a "spooled" file:
* A file stored in memory up to a maximum size limit, and after passing this limit it will be stored in disk.
* This means that it will work well for large files like images, videos, large binaries, etc. All without consuming all the memory.
* This means that it will work well for large files like images, videos, large binaries, etc. without consuming all the memory.
* You can get metadata from the uploaded file.
* It has a <a href="https://docs.python.org/3/glossary.html#term-file-like-object" target="_blank">file-like</a> `async` interface.
* It exposes an actual Python <a href="https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile" target="_blank">`SpooledTemporaryFile`</a> object that you can pass directly to other libraries that expect a file-like object.
@ -107,6 +107,20 @@ The way HTML forms (`<form></form>`) sends the data to the server normally uses
This is not a limitation of **FastAPI**, it's part of the HTTP protocol.
## Multiple file uploads
It's possible to upload several files at the same time.
They would be associated to the same "form field" sent using "form data".
To use that, declare a `List` of `bytes` or `UploadFile`:
```Python hl_lines="10 15"
{!./src/request_files/tutorial002.py!}
```
You will receive, as declared, a `list` of `bytes` or `UploadFile`s.
## Recap
Use `File` to declare files to be uploaded as input parameters (as form data).

48
fastapi/dependencies/utils.py

@ -31,8 +31,8 @@ from pydantic.schema import get_annotation_from_schema
from pydantic.utils import lenient_issubclass
from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import UploadFile
from starlette.requests import Headers, QueryParams, Request
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import Request
param_supported_types = (
str,
@ -47,6 +47,10 @@ param_supported_types = (
Decimal,
)
sequence_shapes = {Shape.LIST, Shape.SET, Shape.TUPLE}
sequence_types = (list, set, tuple)
sequence_shape_to_type = {Shape.LIST: list, Shape.SET: set, Shape.TUPLE: tuple}
def get_sub_dependant(
*, param: inspect.Parameter, path: str, security_scopes: List[str] = None
@ -318,7 +322,7 @@ def request_params_to_args(
values = {}
errors = []
for field in required_params:
if field.shape in {Shape.LIST, Shape.SET, Shape.TUPLE} and isinstance(
if field.shape in sequence_shapes and isinstance(
received_params, (QueryParams, Headers)
):
value = received_params.getlist(field.alias)
@ -358,11 +362,20 @@ async def request_body_to_args(
embed = getattr(field.schema, "embed", None)
if len(required_params) == 1 and not embed:
received_body = {field.alias: received_body}
elif received_body is None:
received_body = {}
for field in required_params:
value = received_body.get(field.alias)
if value is None or (isinstance(field.schema, params.Form) and value == ""):
if field.shape in sequence_shapes and isinstance(received_body, FormData):
value = received_body.getlist(field.alias)
else:
value = received_body.get(field.alias)
if (
value is None
or (isinstance(field.schema, params.Form) and value == "")
or (
isinstance(field.schema, params.Form)
and field.shape in sequence_shapes
and len(value) == 0
)
):
if field.required:
errors.append(
ErrorWrapper(
@ -380,6 +393,15 @@ async def request_body_to_args(
and isinstance(value, UploadFile)
):
value = await value.read()
elif (
field.shape in sequence_shapes
and isinstance(field.schema, params.File)
and lenient_issubclass(field.type_, bytes)
and isinstance(value, sequence_types)
):
awaitables = [sub_value.read() for sub_value in value]
contents = await asyncio.gather(*awaitables)
value = sequence_shape_to_type[field.shape](contents)
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
@ -391,10 +413,14 @@ async def request_body_to_args(
def get_schema_compatible_field(*, field: Field) -> Field:
out_field = field
if lenient_issubclass(field.type_, UploadFile):
return Field(
use_type: type = bytes
if field.shape in sequence_shapes:
use_type = List[bytes]
out_field = Field(
name=field.name,
type_=bytes,
type_=use_type,
class_validators=field.class_validators,
model_config=field.model_config,
default=field.default,
@ -402,10 +428,10 @@ def get_schema_compatible_field(*, field: Field) -> Field:
alias=field.alias,
schema=field.schema,
)
return field
return out_field
def get_body_field(*, dependant: Dependant, name: str) -> Field:
def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]:
flat_dependant = get_flat_dependant(dependant)
if not flat_dependant.body_params:
return None

7
fastapi/routing.py

@ -53,12 +53,7 @@ def get_app(
body = None
if body_field:
if is_body_form:
raw_body = await request.form()
form_fields = {}
for field, value in raw_body.items():
form_fields[field] = value
if form_fields:
body = form_fields
body = await request.form()
else:
body_bytes = await request.body()
if body_bytes:

219
tests/test_tutorial/test_request_files/test_tutorial002.py

@ -0,0 +1,219 @@
import os
from starlette.testclient import TestClient
from request_files.tutorial002 import app
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/files/": {
"post": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"summary": "Create Files",
"operationId": "create_files_files__post",
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {"$ref": "#/components/schemas/Body_create_files"}
}
},
"required": True,
},
}
},
"/uploadfiles/": {
"post": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"summary": "Create Upload Files",
"operationId": "create_upload_files_uploadfiles__post",
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"$ref": "#/components/schemas/Body_create_upload_files"
}
}
},
"required": True,
},
}
},
"/": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Main",
"operationId": "main__get",
}
},
},
"components": {
"schemas": {
"Body_create_files": {
"title": "Body_create_files",
"required": ["files"],
"type": "object",
"properties": {
"files": {
"title": "Files",
"type": "array",
"items": {"type": "string", "format": "binary"},
}
},
},
"Body_create_upload_files": {
"title": "Body_create_upload_files",
"required": ["files"],
"type": "object",
"properties": {
"files": {
"title": "Files",
"type": "array",
"items": {"type": "string", "format": "binary"},
}
},
},
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
file_required = {
"detail": [
{
"loc": ["body", "files"],
"msg": "field required",
"type": "value_error.missing",
}
]
}
def test_post_form_no_body():
response = client.post("/files/")
assert response.status_code == 422
assert response.json() == file_required
def test_post_body_json():
response = client.post("/files/", json={"file": "Foo"})
print(response)
print(response.content)
assert response.status_code == 422
assert response.json() == file_required
def test_post_files(tmpdir):
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"<file content>")
path2 = os.path.join(tmpdir, "test2.txt")
with open(path2, "wb") as file:
file.write(b"<file content2>")
client = TestClient(app)
response = client.post(
"/files/",
files=(
("files", ("test.txt", open(path, "rb"))),
("files", ("test2.txt", open(path2, "rb"))),
),
)
assert response.status_code == 200
assert response.json() == {"file_sizes": [14, 15]}
def test_post_upload_file(tmpdir):
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"<file content>")
path2 = os.path.join(tmpdir, "test2.txt")
with open(path2, "wb") as file:
file.write(b"<file content2>")
client = TestClient(app)
response = client.post(
"/uploadfiles/",
files=(
("files", ("test.txt", open(path, "rb"))),
("files", ("test2.txt", open(path2, "rb"))),
),
)
assert response.status_code == 200
assert response.json() == {"filenames": ["test.txt", "test2.txt"]}
def test_get_root():
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
assert b"<form" in response.content
Loading…
Cancel
Save