diff --git a/fastapi/applications.py b/fastapi/applications.py index 3306aab3d..c21087911 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -38,6 +38,7 @@ class FastAPI(Starlette): version: str = "0.1.0", openapi_url: Optional[str] = "/openapi.json", openapi_tags: Optional[List[Dict[str, Any]]] = None, + servers: Optional[List[Dict[str, Union[str, Any]]]] = None, default_response_class: Type[Response] = JSONResponse, docs_url: Optional[str] = "/docs", redoc_url: Optional[str] = "/redoc", @@ -70,6 +71,7 @@ class FastAPI(Starlette): self.title = title self.description = description self.version = version + self.servers = servers self.openapi_url = openapi_url self.openapi_tags = openapi_tags # TODO: remove when discarding the openapi_prefix parameter @@ -106,6 +108,7 @@ class FastAPI(Starlette): routes=self.routes, openapi_prefix=openapi_prefix, tags=self.openapi_tags, + servers=self.servers, ) return self.openapi_schema diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py index a7c4460fa..13dc59f18 100644 --- a/fastapi/openapi/models.py +++ b/fastapi/openapi/models.py @@ -63,7 +63,7 @@ class ServerVariable(BaseModel): class Server(BaseModel): - url: AnyUrl + url: Union[AnyUrl, str] description: Optional[str] = None variables: Optional[Dict[str, ServerVariable]] = None diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index b6221ca20..5a0c89a89 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -86,7 +86,7 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L def get_openapi_operation_parameters( *, all_route_params: Sequence[ModelField], - model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str] + model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], ) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: @@ -112,7 +112,7 @@ def get_openapi_operation_parameters( def get_openapi_operation_request_body( *, body_field: Optional[ModelField], - model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str] + model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], ) -> Optional[Dict]: if not body_field: return None @@ -318,12 +318,15 @@ def get_openapi( description: str = None, routes: Sequence[BaseRoute], openapi_prefix: str = "", - tags: Optional[List[Dict[str, Any]]] = None + tags: Optional[List[Dict[str, Any]]] = None, + servers: Optional[List[Dict[str, Union[str, Any]]]] = None, ) -> Dict: info = {"title": title, "version": version} if description: info["description"] = description output: Dict[str, Any] = {"openapi": openapi_version, "info": info} + if servers: + output["servers"] = servers components: Dict[str, Dict] = {} paths: Dict[str, Dict] = {} flat_models = get_flat_models_from_routes(routes) diff --git a/tests/test_openapi_servers.py b/tests/test_openapi_servers.py new file mode 100644 index 000000000..a210154f6 --- /dev/null +++ b/tests/test_openapi_servers.py @@ -0,0 +1,60 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient + +app = FastAPI( + servers=[ + {"url": "/", "description": "Default, relative server"}, + { + "url": "http://staging.localhost.tiangolo.com:8000", + "description": "Staging but actually localhost still", + }, + {"url": "https://prod.example.com"}, + ] +) + + +@app.get("/foo") +def foo(): + return {"message": "Hello World"} + + +client = TestClient(app) + + +openapi_schema = { + "openapi": "3.0.2", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "servers": [ + {"url": "/", "description": "Default, relative server"}, + { + "url": "http://staging.localhost.tiangolo.com:8000", + "description": "Staging but actually localhost still", + }, + {"url": "https://prod.example.com"}, + ], + "paths": { + "/foo": { + "get": { + "summary": "Foo", + "operationId": "foo_foo_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + } + } + }, +} + + +def test_openapi_servers(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == openapi_schema + + +def test_app(): + response = client.get("/foo") + assert response.status_code == 200, response.text