Make SubFox production-ready with parallel translation and UI controls

This commit is contained in:
Eddie Nielsen 2026-03-25 11:24:54 +00:00
parent c40b8bed2b
commit 2b1d05f02c
6046 changed files with 798327 additions and 0 deletions

View file

@ -0,0 +1,436 @@
---
name: fastapi
description: FastAPI best practices and conventions. Use when working with FastAPI APIs and Pydantic models for them. Keeps FastAPI code clean and up to date with the latest features and patterns, updated with new versions. Write new code or refactor and update old code.
---
# FastAPI
Official FastAPI skill to write code with best practices, keeping up to date with new versions and features.
## Use the `fastapi` CLI
Run the development server on localhost with reload:
```bash
fastapi dev
```
Run the production server:
```bash
fastapi run
```
### Add an entrypoint in `pyproject.toml`
FastAPI CLI will read the entrypoint in `pyproject.toml` to know where the FastAPI app is declared.
```toml
[tool.fastapi]
entrypoint = "my_app.main:app"
```
### Use `fastapi` with a path
When adding the entrypoint to `pyproject.toml` is not possible, or the user explicitly asks not to, or it's running an independent small app, you can pass the app file path to the `fastapi` command:
```bash
fastapi dev my_app/main.py
```
Prefer to set the entrypoint in `pyproject.toml` when possible.
## Use `Annotated`
Always prefer the `Annotated` style for parameter and dependency declarations.
It keeps the function signatures working in other contexts, respects the types, allows reusability.
### In Parameter Declarations
Use `Annotated` for parameter declarations, including `Path`, `Query`, `Header`, etc.:
```python
from typing import Annotated
from fastapi import FastAPI, Path, Query
app = FastAPI()
@app.get("/items/{item_id}")
async def read_item(
item_id: Annotated[int, Path(ge=1, description="The item ID")],
q: Annotated[str | None, Query(max_length=50)] = None,
):
return {"message": "Hello World"}
```
instead of:
```python
# DO NOT DO THIS
@app.get("/items/{item_id}")
async def read_item(
item_id: int = Path(ge=1, description="The item ID"),
q: str | None = Query(default=None, max_length=50),
):
return {"message": "Hello World"}
```
### For Dependencies
Use `Annotated` for dependencies with `Depends()`.
Unless asked not to, create a new type alias for the dependency to allow re-using it.
```python
from typing import Annotated
from fastapi import Depends, FastAPI
app = FastAPI()
def get_current_user():
return {"username": "johndoe"}
CurrentUserDep = Annotated[dict, Depends(get_current_user)]
@app.get("/items/")
async def read_item(current_user: CurrentUserDep):
return {"message": "Hello World"}
```
instead of:
```python
# DO NOT DO THIS
@app.get("/items/")
async def read_item(current_user: dict = Depends(get_current_user)):
return {"message": "Hello World"}
```
## Do not use Ellipsis for *path operations* or Pydantic models
Do not use `...` as a default value for required parameters, it's not needed and not recommended.
Do this, without Ellipsis (`...`):
```python
from typing import Annotated
from fastapi import FastAPI, Query
from pydantic import BaseModel, Field
class Item(BaseModel):
name: str
description: str | None = None
price: float = Field(gt=0)
app = FastAPI()
@app.post("/items/")
async def create_item(item: Item, project_id: Annotated[int, Query()]): ...
```
instead of this:
```python
# DO NOT DO THIS
class Item(BaseModel):
name: str = ...
description: str | None = None
price: float = Field(..., gt=0)
app = FastAPI()
@app.post("/items/")
async def create_item(item: Item, project_id: Annotated[int, Query(...)]): ...
```
## Return Type or Response Model
When possible, include a return type. It will be used to validate, filter, document, and serialize the response.
```python
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
description: str | None = None
@app.get("/items/me")
async def get_item() -> Item:
return Item(name="Plumbus", description="All-purpose home device")
```
**Important**: Return types or response models are what filter data ensuring no sensitive information is exposed. And they are used to serialize data with Pydantic (in Rust), this is the main idea that can increase response performance.
The return type doesn't have to be a Pydantic model, it could be a different type, like a list of integers, or a dict, etc.
### When to use `response_model` instead
If the return type is not the same as the type that you want to use to validate, filter, or serialize, use the `response_model` parameter on the decorator instead.
```python
from typing import Any
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
description: str | None = None
@app.get("/items/me", response_model=Item)
async def get_item() -> Any:
return {"name": "Foo", "description": "A very nice Item"}
```
This can be particularly useful when filtering data to expose only the public fields and avoid exposing sensitive information.
```python
from typing import Any
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class InternalItem(BaseModel):
name: str
description: str | None = None
secret_key: str
class Item(BaseModel):
name: str
description: str | None = None
@app.get("/items/me", response_model=Item)
async def get_item() -> Any:
item = InternalItem(
name="Foo", description="A very nice Item", secret_key="supersecret"
)
return item
```
## Performance
Do not use `ORJSONResponse` or `UJSONResponse`, they are deprecated.
Instead, declare a return type or response model. Pydantic will handle the data serialization on the Rust side.
## Including Routers
When declaring routers, prefer to add router level parameters like prefix, tags, etc. to the router itself, instead of in `include_router()`.
Do this:
```python
from fastapi import APIRouter, FastAPI
app = FastAPI()
router = APIRouter(prefix="/items", tags=["items"])
@router.get("/")
async def list_items():
return []
# In main.py
app.include_router(router)
```
instead of this:
```python
# DO NOT DO THIS
from fastapi import APIRouter, FastAPI
app = FastAPI()
router = APIRouter()
@router.get("/")
async def list_items():
return []
# In main.py
app.include_router(router, prefix="/items", tags=["items"])
```
There could be exceptions, but try to follow this convention.
Apply shared dependencies at the router level via `dependencies=[Depends(...)]`.
## Dependency Injection
See [the dependency injection reference](references/dependencies.md) for detailed patterns including `yield` with `scope`, and class dependencies.
Use dependencies when the logic can't be declared in Pydantic validation, depends on external resources, needs cleanup (with `yield`), or is shared across endpoints.
Apply shared dependencies at the router level via `dependencies=[Depends(...)]`.
## Async vs Sync *path operations*
Use `async` *path operations* only when fully certain that the logic called inside is compatible with async and await (it's called with `await`) or that doesn't block.
```python
from fastapi import FastAPI
app = FastAPI()
# Use async def when calling async code
@app.get("/async-items/")
async def read_async_items():
data = await some_async_library.fetch_items()
return data
# Use plain def when calling blocking/sync code or when in doubt
@app.get("/items/")
def read_items():
data = some_blocking_library.fetch_items()
return data
```
In case of doubt, or by default, use regular `def` functions, those will be run in a threadpool so they don't block the event loop.
The same rules apply to dependencies.
Make sure blocking code is not run inside of `async` functions. The logic will work, but will damage the performance heavily.
When needing to mix blocking and async code, see Asyncer in [the other tools reference](references/other-tools.md).
## Streaming (JSON Lines, SSE, bytes)
See [the streaming reference](references/streaming.md) for JSON Lines, Server-Sent Events (`EventSourceResponse`, `ServerSentEvent`), and byte streaming (`StreamingResponse`) patterns.
## Tooling
See [the other tools reference](references/other-tools.md) for details on uv, Ruff, ty for package management, linting, type checking, formatting, etc.
## Other Libraries
See [the other tools reference](references/other-tools.md) for details on other libraries:
* Asyncer for handling async and await, concurrency, mixing async and blocking code, prefer it over AnyIO or asyncio.
* SQLModel for working with SQL databases, prefer it over SQLAlchemy.
* HTTPX for interacting with HTTP (other APIs), prefer it over Requests.
## Do not use Pydantic RootModels
Do not use Pydantic `RootModel`, instead use regular type annotations with `Annotated` and Pydantic validation utilities.
For example, for a list with validations you could do:
```python
from typing import Annotated
from fastapi import Body, FastAPI
from pydantic import Field
app = FastAPI()
@app.post("/items/")
async def create_items(items: Annotated[list[int], Field(min_length=1), Body()]):
return items
```
instead of:
```python
# DO NOT DO THIS
from typing import Annotated
from fastapi import FastAPI
from pydantic import Field, RootModel
app = FastAPI()
class ItemList(RootModel[Annotated[list[int], Field(min_length=1)]]):
pass
@app.post("/items/")
async def create_items(items: ItemList):
return items
```
FastAPI supports these type annotations and will create a Pydantic `TypeAdapter` for them, so that types can work as normally and there's no need for the custom logic and types in RootModels.
## Use one HTTP operation per function
Don't mix HTTP operations in a single function, having one function per HTTP operation helps separate concerns and organize the code.
Do this:
```python
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
@app.get("/items/")
async def list_items():
return []
@app.post("/items/")
async def create_item(item: Item):
return item
```
instead of this:
```python
# DO NOT DO THIS
from fastapi import FastAPI, Request
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
@app.api_route("/items/", methods=["GET", "POST"])
async def handle_items(request: Request):
if request.method == "GET":
return []
```

View file

@ -0,0 +1,142 @@
# Dependency Injection
Use dependencies when:
* They can't be declared in Pydantic validation and require additional logic
* The logic depends on external resources or could block in any other way
* Other dependencies need their results (it's a sub-dependency)
* The logic can be shared by multiple endpoints to do things like error early, authentication, etc.
* They need to handle cleanup (e.g., DB sessions, file handles), using dependencies with `yield`
* Their logic needs input data from the request, like headers, query parameters, etc.
## Dependencies with `yield` and `scope`
When using dependencies with `yield`, they can have a `scope` that defines when the exit code is run.
Use the default scope `"request"` to run the exit code after the response is sent back.
```python
from typing import Annotated
from fastapi import Depends, FastAPI
app = FastAPI()
def get_db():
db = DBSession()
try:
yield db
finally:
db.close()
DBDep = Annotated[DBSession, Depends(get_db)]
@app.get("/items/")
async def read_items(db: DBDep):
return db.query(Item).all()
```
Use the scope `"function"` when they should run the exit code after the response data is generated but before the response is sent back to the client.
```python
from typing import Annotated
from fastapi import Depends, FastAPI
app = FastAPI()
def get_username():
try:
yield "Rick"
finally:
print("Cleanup up before response is sent")
UserNameDep = Annotated[str, Depends(get_username, scope="function")]
@app.get("/users/me")
def get_user_me(username: UserNameDep):
return username
```
## Class Dependencies
Avoid creating class dependencies when possible.
If a class is needed, instead create a regular function dependency that returns a class instance.
Do this:
```python
from dataclasses import dataclass
from typing import Annotated
from fastapi import Depends, FastAPI
app = FastAPI()
@dataclass
class DatabasePaginator:
offset: int = 0
limit: int = 100
q: str | None = None
def get_page(self) -> dict:
# Simulate a page of data
return {
"offset": self.offset,
"limit": self.limit,
"q": self.q,
"items": [],
}
def get_db_paginator(
offset: int = 0, limit: int = 100, q: str | None = None
) -> DatabasePaginator:
return DatabasePaginator(offset=offset, limit=limit, q=q)
PaginatorDep = Annotated[DatabasePaginator, Depends(get_db_paginator)]
@app.get("/items/")
async def read_items(paginator: PaginatorDep):
return paginator.get_page()
```
instead of this:
```python
# DO NOT DO THIS
from typing import Annotated
from fastapi import Depends, FastAPI
app = FastAPI()
class DatabasePaginator:
def __init__(self, offset: int = 0, limit: int = 100, q: str | None = None):
self.offset = offset
self.limit = limit
self.q = q
def get_page(self) -> dict:
# Simulate a page of data
return {
"offset": self.offset,
"limit": self.limit,
"q": self.q,
"items": [],
}
@app.get("/items/")
async def read_items(paginator: Annotated[DatabasePaginator, Depends()]):
return paginator.get_page()
```

View file

@ -0,0 +1,76 @@
# Other Tools
## uv
If uv is available, use it to manage dependencies.
## Ruff
If Ruff is available, use it to lint and format the code. Consider enabling the FastAPI rules.
## ty
If ty is available, use it to check types.
## Asyncer
When needing to run blocking code inside of async functions, or async code inside of blocking functions, suggest using Asyncer.
Prefer it over AnyIO or asyncio.
Install:
```bash
uv add asyncer
```
Run blocking sync code inside of async with `asyncify()`:
```python
from asyncer import asyncify
from fastapi import FastAPI
app = FastAPI()
def do_blocking_work(name: str) -> str:
# Some blocking I/O operation
return f"Hello {name}"
@app.get("/items/")
async def read_items():
result = await asyncify(do_blocking_work)(name="World")
return {"message": result}
```
And run async code inside of blocking sync code with `syncify()`:
```python
from asyncer import syncify
from fastapi import FastAPI
app = FastAPI()
async def do_async_work(name: str) -> str:
return f"Hello {name}"
@app.get("/items/")
def read_items():
result = syncify(do_async_work)(name="World")
return {"message": result}
```
## SQLModel for SQL databases
When working with SQL databases, prefer using SQLModel as it is integrated with Pydantic and will allow declaring data validation with the same models.
Prefer it over SQLAlchemy.
## HTTPX
Use HTTPX for handling HTTP communication (e.g. with other APIs). It support sync and async usage.
Prefer it over Requests.

View file

@ -0,0 +1,105 @@
# Streaming
## Stream JSON Lines
To stream JSON Lines, declare the return type and use `yield` to return the data.
```python
@app.get("/items/stream")
async def stream_items() -> AsyncIterable[Item]:
for item in items:
yield item
```
## Server-Sent Events (SSE)
To stream Server-Sent Events, use `response_class=EventSourceResponse` and `yield` items from the endpoint.
Plain objects are automatically JSON-serialized as `data:` fields, declare the return type so the serialization is done by Pydantic:
```python
from collections.abc import AsyncIterable
from fastapi import FastAPI
from fastapi.sse import EventSourceResponse
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
price: float
@app.get("/items/stream", response_class=EventSourceResponse)
async def stream_items() -> AsyncIterable[Item]:
yield Item(name="Plumbus", price=32.99)
yield Item(name="Portal Gun", price=999.99)
```
For full control over SSE fields (`event`, `id`, `retry`, `comment`), yield `ServerSentEvent` instances:
```python
from collections.abc import AsyncIterable
from fastapi import FastAPI
from fastapi.sse import EventSourceResponse, ServerSentEvent
app = FastAPI()
@app.get("/events", response_class=EventSourceResponse)
async def stream_events() -> AsyncIterable[ServerSentEvent]:
yield ServerSentEvent(data={"status": "started"}, event="status", id="1")
yield ServerSentEvent(data={"progress": 50}, event="progress", id="2")
```
Use `raw_data` instead of `data` to send pre-formatted strings without JSON encoding:
```python
yield ServerSentEvent(raw_data="plain text line", event="log")
```
## Stream bytes
To stream bytes, declare a `response_class=` of `StreamingResponse` or a sub-class, and use `yield` to return the data.
```python
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from app.utils import read_image
app = FastAPI()
class PNGStreamingResponse(StreamingResponse):
media_type = "image/png"
@app.get("/image", response_class=PNGStreamingResponse)
def stream_image_no_async_no_annotation():
with read_image() as image_file:
yield from image_file
```
prefer this over returning a `StreamingResponse` directly:
```python
# DO NOT DO THIS
import anyio
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from app.utils import read_image
app = FastAPI()
class PNGStreamingResponse(StreamingResponse):
media_type = "image/png"
@app.get("/")
async def main():
return PNGStreamingResponse(read_image())
```

View file

@ -0,0 +1,25 @@
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
__version__ = "0.135.2"
from starlette import status as status
from .applications import FastAPI as FastAPI
from .background import BackgroundTasks as BackgroundTasks
from .datastructures import UploadFile as UploadFile
from .exceptions import HTTPException as HTTPException
from .exceptions import WebSocketException as WebSocketException
from .param_functions import Body as Body
from .param_functions import Cookie as Cookie
from .param_functions import Depends as Depends
from .param_functions import File as File
from .param_functions import Form as Form
from .param_functions import Header as Header
from .param_functions import Path as Path
from .param_functions import Query as Query
from .param_functions import Security as Security
from .requests import Request as Request
from .responses import Response as Response
from .routing import APIRouter as APIRouter
from .websockets import WebSocket as WebSocket
from .websockets import WebSocketDisconnect as WebSocketDisconnect

View file

@ -0,0 +1,3 @@
from fastapi.cli import main
main()

View file

@ -0,0 +1,40 @@
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
from .shared import (
field_annotation_is_scalar_sequence as field_annotation_is_scalar_sequence,
)
from .shared import field_annotation_is_sequence as field_annotation_is_sequence
from .shared import (
is_bytes_or_nonable_bytes_annotation as is_bytes_or_nonable_bytes_annotation,
)
from .shared import is_bytes_sequence_annotation as is_bytes_sequence_annotation
from .shared import is_pydantic_v1_model_instance as is_pydantic_v1_model_instance
from .shared import (
is_uploadfile_or_nonable_uploadfile_annotation as is_uploadfile_or_nonable_uploadfile_annotation,
)
from .shared import (
is_uploadfile_sequence_annotation as is_uploadfile_sequence_annotation,
)
from .shared import lenient_issubclass as lenient_issubclass
from .shared import sequence_types as sequence_types
from .shared import value_is_sequence as value_is_sequence
from .v2 import ModelField as ModelField
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .v2 import RequiredParam as RequiredParam
from .v2 import Undefined as Undefined
from .v2 import Url as Url
from .v2 import copy_field_info as copy_field_info
from .v2 import create_body_model as create_body_model
from .v2 import evaluate_forwardref as evaluate_forwardref # ty: ignore[deprecated]
from .v2 import get_cached_model_fields as get_cached_model_fields
from .v2 import get_definitions as get_definitions
from .v2 import get_flat_models_from_fields as get_flat_models_from_fields
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import get_model_name_map as get_model_name_map
from .v2 import get_schema_from_model_field as get_schema_from_model_field
from .v2 import is_scalar_field as is_scalar_field
from .v2 import serialize_sequence_value as serialize_sequence_value
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,
)

View file

@ -0,0 +1,214 @@
import types
import typing
import warnings
from collections import deque
from collections.abc import Mapping, Sequence
from dataclasses import is_dataclass
from typing import (
Annotated,
Any,
TypeGuard,
TypeVar,
Union,
get_args,
get_origin,
)
from fastapi.types import UnionType
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
_T = TypeVar("_T")
# Copy from Pydantic: pydantic/_internal/_typing_extra.py
WithArgsTypes: tuple[Any, ...] = (
typing._GenericAlias, # type: ignore[attr-defined]
types.GenericAlias,
types.UnionType,
) # pyright: ignore[reportAttributeAccessIssue]
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
sequence_annotation_to_type = {
Sequence: list,
list: list,
tuple: tuple,
set: set,
frozenset: frozenset,
deque: deque,
}
sequence_types: tuple[type[Any], ...] = tuple(sequence_annotation_to_type.keys())
# Copy of Pydantic: pydantic/_internal/_utils.py with added TypeGuard
def lenient_issubclass(
cls: Any, class_or_tuple: type[_T] | tuple[type[_T], ...] | None
) -> TypeGuard[type[_T]]:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
except TypeError: # pragma: no cover
if isinstance(cls, WithArgsTypes):
return False
raise # pragma: no cover
def _annotation_is_sequence(annotation: type[Any] | None) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types)
def field_annotation_is_sequence(annotation: type[Any] | None) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if field_annotation_is_sequence(arg):
return True
return False
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
get_origin(annotation)
)
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes))
def _annotation_is_complex(annotation: type[Any] | None) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
def field_annotation_is_complex(annotation: type[Any] | None) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
if origin is Annotated:
return field_annotation_is_complex(get_args(annotation)[0])
return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
or hasattr(origin, "__pydantic_core_schema__")
or hasattr(origin, "__get_pydantic_core_schema__")
)
def field_annotation_is_scalar(annotation: Any) -> bool:
# handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: type[Any] | None) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
for arg in get_args(annotation):
if field_annotation_is_scalar_sequence(arg):
at_least_one_scalar_sequence = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_sequence
return field_annotation_is_sequence(annotation) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, bytes):
return True
return False
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, UploadFile):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, UploadFile):
return True
return False
def is_bytes_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_bytes_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_bytes_or_nonable_bytes_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_uploadfile_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_pydantic_v1_model_instance(obj: Any) -> bool:
# TODO: remove this function once the required version of Pydantic fully
# removes pydantic.v1
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from pydantic import v1
except ImportError: # pragma: no cover
return False
return isinstance(obj, v1.BaseModel)
def is_pydantic_v1_model_class(cls: Any) -> bool:
# TODO: remove this function once the required version of Pydantic fully
# removes pydantic.v1
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from pydantic import v1
except ImportError: # pragma: no cover
return False
return lenient_issubclass(cls, v1.BaseModel)
def annotation_is_pydantic_v1(annotation: Any) -> bool:
if is_pydantic_v1_model_class(annotation):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if is_pydantic_v1_model_class(arg):
return True
if field_annotation_is_sequence(annotation):
for sub_annotation in get_args(annotation):
if annotation_is_pydantic_v1(sub_annotation):
return True
return False

View file

@ -0,0 +1,480 @@
import re
import warnings
from collections.abc import Sequence
from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
from typing import (
Annotated,
Any,
Literal,
Union,
cast,
get_args,
get_origin,
)
from fastapi._compat import lenient_issubclass, shared
from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined] # ty: ignore[unused-ignore-comment]
GetJsonSchemaHandler as GetJsonSchemaHandler,
)
from pydantic._internal._typing_extra import eval_type_lenient # ty: ignore[deprecated]
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.json_schema import GenerateJsonSchema as _GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined
from pydantic_core import Url as Url
from pydantic_core.core_schema import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
RequiredParam = PydanticUndefined
Undefined = PydanticUndefined
evaluate_forwardref = eval_type_lenient # ty: ignore[deprecated]
class GenerateJsonSchema(_GenerateJsonSchema):
# TODO: remove when this is merged (or equivalent): https://github.com/pydantic/pydantic/pull/12841
# and dropping support for any version of Pydantic before that one (so, in a very long time)
def bytes_schema(self, schema: CoreSchema) -> JsonSchemaValue:
json_schema = {"type": "string", "contentMediaType": "application/octet-stream"}
bytes_mode = (
self._config.ser_json_bytes
if self.mode == "serialization"
else self._config.val_json_bytes
)
if bytes_mode == "base64":
json_schema["contentEncoding"] = "base64"
self.update_with_validations(json_schema, schema, self.ValidationsMapping.bytes)
return json_schema
# TODO: remove when dropping support for Pydantic < v2.12.3
_Attrs = {
"default": ...,
"default_factory": None,
"alias": None,
"alias_priority": None,
"validation_alias": None,
"serialization_alias": None,
"title": None,
"field_title_generator": None,
"description": None,
"examples": None,
"exclude": None,
"exclude_if": None,
"discriminator": None,
"deprecated": None,
"json_schema_extra": None,
"frozen": None,
"validate_default": None,
"repr": True,
"init": None,
"init_var": None,
"kw_only": None,
}
# TODO: remove when dropping support for Pydantic < v2.12.3
def asdict(field_info: FieldInfo) -> dict[str, Any]:
attributes = {}
for attr in _Attrs:
value = getattr(field_info, attr, Undefined)
if value is not Undefined:
attributes[attr] = value
return {
"annotation": field_info.annotation,
"metadata": field_info.metadata,
"attributes": attributes,
}
@dataclass
class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
config: ConfigDict | None = None
@property
def alias(self) -> str:
a = self.field_info.alias
return a if a is not None else self.name
@property
def validation_alias(self) -> str | None:
va = self.field_info.validation_alias
if isinstance(va, str) and va:
return va
return None
@property
def serialization_alias(self) -> str | None:
sa = self.field_info.serialization_alias
return sa or None
@property
def default(self) -> Any:
return self.get_default()
def __post_init__(self) -> None:
with warnings.catch_warnings():
# Pydantic >= 2.12.0 warns about field specific metadata that is unused
# (e.g. `TypeAdapter(Annotated[int, Field(alias='b')])`). In some cases, we
# end up building the type adapter from a model field annotation so we
# need to ignore the warning:
if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 12):
from pydantic.warnings import UnsupportedFieldAttributeWarning
warnings.simplefilter(
"ignore", category=UnsupportedFieldAttributeWarning
)
# TODO: remove after setting the min Pydantic to v2.12.3
# that adds asdict(), and use self.field_info.asdict() instead
field_dict = asdict(self.field_info)
annotated_args = (
field_dict["annotation"],
*field_dict["metadata"],
# this FieldInfo needs to be created again so that it doesn't include
# the old field info metadata and only the rest of the attributes
Field(**field_dict["attributes"]),
)
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[annotated_args], # ty: ignore[invalid-type-form]
config=self.config,
)
def get_default(self) -> Any:
if self.field_info.is_required():
return Undefined
return self.field_info.get_default(call_default_factory=True)
def validate(
self,
value: Any,
values: dict[str, Any] = {}, # noqa: B006
*,
loc: tuple[int | str, ...] = (),
) -> tuple[Any, list[dict[str, Any]]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
[],
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def serialize_json(
self,
value: Any,
*,
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> bytes:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
# This uses Pydantic's dump_json() which serializes directly to JSON
# bytes in one pass (via Rust), avoiding the intermediate Python dict
# step of dump_python(mode="json") + json.dumps().
return self._type_adapter.dump_json(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from
# ModelField to its JSON Schema.
return id(self)
def _has_computed_fields(field: ModelField) -> bool:
computed_fields = field._type_adapter.core_schema.get("schema", {}).get(
"computed_fields", []
)
return len(computed_fields) > 0
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> dict[str, Any]:
override_mode: Literal["validation"] | None = (
None
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
)
field_alias = (
(field.validation_alias or field.alias)
if field.mode == "validation"
else (field.serialization_alias or field.alias)
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = field.field_info.title or field_alias.title().replace(
"_", " "
)
return json_schema
def get_definitions(
*,
fields: Sequence[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, dict[str, Any]],
]:
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
validation_fields = [field for field in fields if field.mode == "validation"]
serialization_fields = [field for field in fields if field.mode == "serialization"]
flat_validation_models = get_flat_models_from_fields(
validation_fields, known_models=set()
)
flat_serialization_models = get_flat_models_from_fields(
serialization_fields, known_models=set()
)
flat_validation_model_fields = [
ModelField(
field_info=FieldInfo(annotation=model),
name=model.__name__,
mode="validation",
)
for model in flat_validation_models
]
flat_serialization_model_fields = [
ModelField(
field_info=FieldInfo(annotation=model),
name=model.__name__,
mode="serialization",
)
for model in flat_serialization_models
]
flat_model_fields = flat_validation_model_fields + flat_serialization_model_fields
input_types = {f.field_info.annotation for f in fields}
unique_flat_model_fields = {
f for f in flat_model_fields if f.field_info.annotation not in input_types
}
inputs = [
(
field,
(
field.mode
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
),
field._type_adapter.core_schema,
)
for field in list(fields) + list(unique_flat_model_fields)
]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
for item_def in cast(dict[str, dict[str, Any]], definitions).values():
if "description" in item_def:
item_description = cast(str, item_def["description"]).split("\f")[0]
item_def["description"] = item_description
# definitions: dict[DefsRef, dict[str, Any]]
# but mypy complains about general str in other places that are not declared as
# DefsRef, although DefsRef is just str:
# DefsRef = NewType('DefsRef', str)
# So, a cast to simplify the types here
return field_mapping, cast(dict[str, dict[str, Any]], definitions)
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return shared.field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
cls = type(field_info)
merged_field_info = cls.from_annotation(annotation)
new_field_info = copy(field_info)
new_field_info.metadata = merged_field_info.metadata
new_field_info.annotation = merged_field_info.annotation
return new_field_info
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
if origin_type is Union or origin_type is UnionType: # Handle optional sequences
union_args = get_args(field.field_info.annotation)
for union_arg in union_args:
if union_arg is type(None):
continue
origin_type = get_origin(union_arg) or union_arg
break
assert issubclass(origin_type, shared.sequence_types) # type: ignore[arg-type]
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index]
def get_missing_field_error(loc: tuple[int | str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
def get_model_fields(model: type[BaseModel]) -> list[ModelField]:
model_fields: list[ModelField] = []
for name, field_info in model.model_fields.items():
type_ = field_info.annotation
if lenient_issubclass(type_, (BaseModel, dict)) or is_dataclass(type_):
model_config = None
else:
model_config = model.model_config
model_fields.append(
ModelField(
field_info=field_info,
name=name,
config=model_config,
)
)
return model_fields
@lru_cache
def get_cached_model_fields(model: type[BaseModel]) -> list[ModelField]:
return get_model_fields(model)
# Duplicate of several schema functions from Pydantic v1 to make them compatible with
# Pydantic v2 and allow mixing the models
TypeModelOrEnum = type["BaseModel"] | type[Enum]
TypeModelSet = set[TypeModelOrEnum]
def normalize_name(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str]:
name_model_map = {}
for model in unique_models:
model_name = normalize_name(model.__name__)
name_model_map[model_name] = model
return {v: k for k, v in name_model_map.items()}
def get_flat_models_from_model(
model: type["BaseModel"], known_models: TypeModelSet | None = None
) -> TypeModelSet:
known_models = known_models or set()
fields = get_model_fields(model)
get_flat_models_from_fields(fields, known_models=known_models)
return known_models
def get_flat_models_from_annotation(
annotation: Any, known_models: TypeModelSet
) -> TypeModelSet:
origin = get_origin(annotation)
if origin is not None:
for arg in get_args(annotation):
if lenient_issubclass(arg, (BaseModel, Enum)):
if arg not in known_models:
known_models.add(arg) # type: ignore[arg-type] # ty: ignore[unused-ignore-comment]
if lenient_issubclass(arg, BaseModel):
get_flat_models_from_model(arg, known_models=known_models)
else:
get_flat_models_from_annotation(arg, known_models=known_models)
return known_models
def get_flat_models_from_field(
field: ModelField, known_models: TypeModelSet
) -> TypeModelSet:
field_type = field.field_info.annotation
if lenient_issubclass(field_type, BaseModel):
if field_type in known_models:
return known_models
known_models.add(field_type)
get_flat_models_from_model(field_type, known_models=known_models)
elif lenient_issubclass(field_type, Enum):
known_models.add(field_type)
else:
get_flat_models_from_annotation(field_type, known_models=known_models)
return known_models
def get_flat_models_from_fields(
fields: Sequence[ModelField], known_models: TypeModelSet
) -> TypeModelSet:
for field in fields:
get_flat_models_from_field(field, known_models=known_models)
return known_models
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: tuple[str | int, ...]
) -> list[dict[str, Any]]:
updated_loc_errors: list[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())} for err in errors
]
return updated_loc_errors

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,61 @@
from collections.abc import Callable
from typing import Annotated, Any
from annotated_doc import Doc
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from typing_extensions import ParamSpec
P = ParamSpec("P")
class BackgroundTasks(StarletteBackgroundTasks):
"""
A collection of background tasks that will be called after a response has been
sent to the client.
Read more about it in the
[FastAPI docs for Background Tasks](https://fastapi.tiangolo.com/tutorial/background-tasks/).
## Example
```python
from fastapi import BackgroundTasks, FastAPI
app = FastAPI()
def write_notification(email: str, message=""):
with open("log.txt", mode="w") as email_file:
content = f"notification for {email}: {message}"
email_file.write(content)
@app.post("/send-notification/{email}")
async def send_notification(email: str, background_tasks: BackgroundTasks):
background_tasks.add_task(write_notification, email, message="some notification")
return {"message": "Notification sent in the background"}
```
"""
def add_task(
self,
func: Annotated[
Callable[P, Any],
Doc(
"""
The function to call after the response is sent.
It can be a regular `def` function or an `async def` function.
"""
),
],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Add a function to be called in the background after the response is sent.
Read more about it in the
[FastAPI docs for Background Tasks](https://fastapi.tiangolo.com/tutorial/background-tasks/).
"""
return super().add_task(func, *args, **kwargs)

View file

@ -0,0 +1,13 @@
try:
from fastapi_cli.cli import main as cli_main
except ImportError: # pragma: no cover
cli_main = None # type: ignore
def main() -> None:
if not cli_main: # type: ignore[truthy-function] # ty: ignore[unused-ignore-comment]
message = 'To use the fastapi command, please install "fastapi[standard]":\n\n\tpip install "fastapi[standard]"\n'
print(message)
raise RuntimeError(message) # noqa: B904
cli_main()

View file

@ -0,0 +1,41 @@
from collections.abc import AsyncGenerator
from contextlib import AbstractContextManager
from contextlib import asynccontextmanager as asynccontextmanager
from typing import TypeVar
import anyio.to_thread
from anyio import CapacityLimiter
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette.concurrency import ( # noqa
run_until_first_complete as run_until_first_complete,
)
_T = TypeVar("_T")
@asynccontextmanager
async def contextmanager_in_threadpool(
cm: AbstractContextManager[_T],
) -> AsyncGenerator[_T, None]:
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself
# has its own internal pool (e.g. a database connection pool)
# to avoid this we let __exit__ run without a capacity limit
# since we're creating a new limiter for each call, any non-zero limit
# works (1 is arbitrary)
exit_limiter = CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
ok = bool(
await anyio.to_thread.run_sync(
cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
)
)
if not ok:
raise e
else:
await anyio.to_thread.run_sync(
cm.__exit__, None, None, None, limiter=exit_limiter
)

View file

@ -0,0 +1,186 @@
from collections.abc import Callable, Mapping
from typing import (
Annotated,
Any,
BinaryIO,
TypeVar,
cast,
)
from annotated_doc import Doc
from pydantic import GetJsonSchemaHandler
from starlette.datastructures import URL as URL # noqa: F401
from starlette.datastructures import Address as Address # noqa: F401
from starlette.datastructures import FormData as FormData # noqa: F401
from starlette.datastructures import Headers as Headers # noqa: F401
from starlette.datastructures import QueryParams as QueryParams # noqa: F401
from starlette.datastructures import State as State # noqa: F401
from starlette.datastructures import UploadFile as StarletteUploadFile
class UploadFile(StarletteUploadFile):
"""
A file uploaded in a request.
Define it as a *path operation function* (or dependency) parameter.
If you are using a regular `def` function, you can use the `upload_file.file`
attribute to access the raw standard Python file (blocking, not async), useful and
needed for non-async code.
Read more about it in the
[FastAPI docs for Request Files](https://fastapi.tiangolo.com/tutorial/request-files/).
## Example
```python
from typing import Annotated
from fastapi import FastAPI, File, UploadFile
app = FastAPI()
@app.post("/files/")
async def create_file(file: Annotated[bytes, File()]):
return {"file_size": len(file)}
@app.post("/uploadfile/")
async def create_upload_file(file: UploadFile):
return {"filename": file.filename}
```
"""
file: Annotated[
BinaryIO,
Doc("The standard Python file object (non-async)."),
]
filename: Annotated[str | None, Doc("The original file name.")]
size: Annotated[int | None, Doc("The size of the file in bytes.")]
headers: Annotated[Headers, Doc("The headers of the request.")]
content_type: Annotated[
str | None, Doc("The content type of the request, from the headers.")
]
async def write(
self,
data: Annotated[
bytes,
Doc(
"""
The bytes to write to the file.
"""
),
],
) -> None:
"""
Write some bytes to the file.
You normally wouldn't use this from a file you read in a request.
To be awaitable, compatible with async, this is run in threadpool.
"""
return await super().write(data)
async def read(
self,
size: Annotated[
int,
Doc(
"""
The number of bytes to read from the file.
"""
),
] = -1,
) -> bytes:
"""
Read some bytes from the file.
To be awaitable, compatible with async, this is run in threadpool.
"""
return await super().read(size)
async def seek(
self,
offset: Annotated[
int,
Doc(
"""
The position in bytes to seek to in the file.
"""
),
],
) -> None:
"""
Move to a position in the file.
Any next read or write will be done from that position.
To be awaitable, compatible with async, this is run in threadpool.
"""
return await super().seek(offset)
async def close(self) -> None:
"""
Close the file.
To be awaitable, compatible with async, this is run in threadpool.
"""
return await super().close()
@classmethod
def _validate(cls, __input_value: Any, _: Any) -> "UploadFile":
if not isinstance(__input_value, StarletteUploadFile):
raise ValueError(f"Expected UploadFile, received: {type(__input_value)}")
return cast(UploadFile, __input_value)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: Mapping[str, Any], handler: GetJsonSchemaHandler
) -> dict[str, Any]:
return {"type": "string", "contentMediaType": "application/octet-stream"}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type[Any], handler: Callable[[Any], Mapping[str, Any]]
) -> Mapping[str, Any]:
from ._compat.v2 import with_info_plain_validator_function
return with_info_plain_validator_function(cls._validate)
class DefaultPlaceholder:
"""
You shouldn't use this class directly.
It's used internally to recognize when a default value has been overwritten, even
if the overridden default value was truthy.
"""
def __init__(self, value: Any):
self.value = value
def __bool__(self) -> bool:
return bool(self.value)
def __eq__(self, o: object) -> bool:
return isinstance(o, DefaultPlaceholder) and o.value == self.value
DefaultType = TypeVar("DefaultType")
def Default(value: DefaultType) -> DefaultType:
"""
You shouldn't use this function directly.
It's used internally to recognize when a default value has been overwritten, even
if the overridden default value was truthy.
"""
return DefaultPlaceholder(value) # type: ignore
# Sentinel for "parameter not provided" in Param/FieldInfo.
# Typed as None to satisfy ty
_Unset = Default(None)

View file

@ -0,0 +1,193 @@
import inspect
import sys
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import cached_property, partial
from typing import Any, Literal
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
from fastapi.types import DependencyCacheKey
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
def _unwrapped_call(call: Callable[..., Any] | None) -> Any:
if call is None:
return call # pragma: no cover
unwrapped = inspect.unwrap(_impartial(call))
return unwrapped
def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
while isinstance(func, partial):
func = func.func
return func
@dataclass
class Dependant:
path_params: list[ModelField] = field(default_factory=list)
query_params: list[ModelField] = field(default_factory=list)
header_params: list[ModelField] = field(default_factory=list)
cookie_params: list[ModelField] = field(default_factory=list)
body_params: list[ModelField] = field(default_factory=list)
dependencies: list["Dependant"] = field(default_factory=list)
name: str | None = None
call: Callable[..., Any] | None = None
request_param_name: str | None = None
websocket_param_name: str | None = None
http_connection_param_name: str | None = None
response_param_name: str | None = None
background_tasks_param_name: str | None = None
security_scopes_param_name: str | None = None
own_oauth_scopes: list[str] | None = None
parent_oauth_scopes: list[str] | None = None
use_cache: bool = True
path: str | None = None
scope: Literal["function", "request"] | None = None
@cached_property
def oauth_scopes(self) -> list[str]:
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
# This doesn't use a set to preserve order, just in case
for scope in self.own_oauth_scopes or []:
if scope not in scopes:
scopes.append(scope)
return scopes
@cached_property
def cache_key(self) -> DependencyCacheKey:
scopes_for_cache = (
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
)
return (
self.call,
scopes_for_cache,
self.computed_scope or "",
)
@cached_property
def _uses_scopes(self) -> bool:
if self.own_oauth_scopes:
return True
if self.security_scopes_param_name is not None:
return True
if self._is_security_scheme:
return True
for sub_dep in self.dependencies:
if sub_dep._uses_scopes:
return True
return False
@cached_property
def _is_security_scheme(self) -> bool:
if self.call is None:
return False # pragma: no cover
unwrapped = _unwrapped_call(self.call)
return isinstance(unwrapped, SecurityBase)
# Mainly to get the type of SecurityBase, but it's the same self.call
@cached_property
def _security_scheme(self) -> SecurityBase:
unwrapped = _unwrapped_call(self.call)
assert isinstance(unwrapped, SecurityBase)
return unwrapped
@cached_property
def _security_dependencies(self) -> list["Dependant"]:
security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
return security_deps
@cached_property
def is_gen_callable(self) -> bool:
if self.call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(self.call)
) or inspect.isgeneratorfunction(_unwrapped_call(self.call)):
return True
if inspect.isclass(_unwrapped_call(self.call)):
return False
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(dunder_call)
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(dunder_unwrapped_call)
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)):
return True
return False
@cached_property
def is_async_gen_callable(self) -> bool:
if self.call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(self.call)
) or inspect.isasyncgenfunction(_unwrapped_call(self.call)):
return True
if inspect.isclass(_unwrapped_call(self.call)):
return False
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(dunder_call)
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(dunder_unwrapped_call)
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)):
return True
return False
@cached_property
def is_coroutine_callable(self) -> bool:
if self.call is None:
return False # pragma: no cover
if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction(
_impartial(self.call)
):
return True
if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction(
_unwrapped_call(self.call)
):
return True
if inspect.isclass(_unwrapped_call(self.call)):
return False
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction(
_unwrapped_call(dunder_call)
):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if iscoroutinefunction(
_impartial(dunder_unwrapped_call)
) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)):
return True
return False
@cached_property
def computed_scope(self) -> str | None:
if self.scope:
return self.scope
if self.is_gen_callable or self.is_async_gen_callable:
return "request"
return None

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,347 @@
import dataclasses
import datetime
from collections import defaultdict, deque
from collections.abc import Callable
from decimal import Decimal
from enum import Enum
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from pathlib import Path, PurePath
from re import Pattern
from types import GeneratorType
from typing import Annotated, Any
from uuid import UUID
from annotated_doc import Doc
from fastapi.exceptions import PydanticV1NotSupportedError
from fastapi.types import IncEx
from pydantic import BaseModel
from pydantic.color import Color # ty: ignore[deprecated]
from pydantic.networks import AnyUrl, NameEmail
from pydantic.types import SecretBytes, SecretStr
from pydantic_core import PydanticUndefinedType
from ._compat import (
Url,
is_pydantic_v1_model_instance,
)
# Taken from Pydantic v1 as is
def isoformat(o: datetime.date | datetime.time) -> str:
return o.isoformat()
# Adapted from Pydantic v1
# TODO: pv2 should this return strings instead?
def decimal_encoder(dec_value: Decimal) -> int | float:
"""
Encodes a Decimal as int if there's no exponent, otherwise float
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
where an integer (but not int typed) is used. Encoding this as a float
results in failed round-tripping between encode and parse.
Our Id type is a prime example of this.
>>> decimal_encoder(Decimal("1.0"))
1.0
>>> decimal_encoder(Decimal("1"))
1
>>> decimal_encoder(Decimal("NaN"))
nan
"""
exponent = dec_value.as_tuple().exponent
if isinstance(exponent, int) and exponent >= 0:
return int(dec_value)
else:
return float(dec_value)
ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str, # ty: ignore[deprecated]
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder,
Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
IPv6Address: str,
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
Url: str,
AnyUrl: str,
}
def generate_encoders_by_class_tuples(
type_encoder_map: dict[Any, Callable[[Any], Any]],
) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
encoders_by_class_tuples[encoder] += (type_,)
return encoders_by_class_tuples
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
def jsonable_encoder(
obj: Annotated[
Any,
Doc(
"""
The input object to convert to JSON.
"""
),
],
include: Annotated[
IncEx | None,
Doc(
"""
Pydantic's `include` parameter, passed to Pydantic models to set the
fields to include.
"""
),
] = None,
exclude: Annotated[
IncEx | None,
Doc(
"""
Pydantic's `exclude` parameter, passed to Pydantic models to set the
fields to exclude.
"""
),
] = None,
by_alias: Annotated[
bool,
Doc(
"""
Pydantic's `by_alias` parameter, passed to Pydantic models to define if
the output should use the alias names (when provided) or the Python
attribute names. In an API, if you set an alias, it's probably because you
want to use it in the result, so you probably want to leave this set to
`True`.
"""
),
] = True,
exclude_unset: Annotated[
bool,
Doc(
"""
Pydantic's `exclude_unset` parameter, passed to Pydantic models to define
if it should exclude from the output the fields that were not explicitly
set (and that only had their default values).
"""
),
] = False,
exclude_defaults: Annotated[
bool,
Doc(
"""
Pydantic's `exclude_defaults` parameter, passed to Pydantic models to define
if it should exclude from the output the fields that had the same default
value, even when they were explicitly set.
"""
),
] = False,
exclude_none: Annotated[
bool,
Doc(
"""
Pydantic's `exclude_none` parameter, passed to Pydantic models to define
if it should exclude from the output any fields that have a `None` value.
"""
),
] = False,
custom_encoder: Annotated[
dict[Any, Callable[[Any], Any]] | None,
Doc(
"""
Pydantic's `custom_encoder` parameter, passed to Pydantic models to define
a custom encoder.
"""
),
] = None,
sqlalchemy_safe: Annotated[
bool,
Doc(
"""
Exclude from the output any fields that start with the name `_sa`.
This is mainly a hack for compatibility with SQLAlchemy objects, they
store internal SQLAlchemy-specific state in attributes named with `_sa`,
and those objects can't (and shouldn't be) serialized to JSON.
"""
),
] = True,
) -> Any:
"""
Convert any object to something that can be encoded in JSON.
This is used internally by FastAPI to make sure anything you return can be
encoded as JSON before it is sent to the client.
You can also use it yourself, for example to convert objects before saving them
in a database that supports only JSON.
Read more about it in the
[FastAPI docs for JSON Compatible Encoder](https://fastapi.tiangolo.com/tutorial/encoder/).
"""
custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder_instance in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder_instance(obj)
if include is not None and not isinstance(include, (set, dict)):
include = set(include) # type: ignore[assignment] # ty: ignore[unused-ignore-comment]
if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude) # type: ignore[assignment] # ty: ignore[unused-ignore-comment]
if isinstance(obj, BaseModel):
obj_dict = obj.model_dump(
mode="json",
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
assert not isinstance(obj, type)
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
obj_dict,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, PydanticUndefinedType):
return None
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())
if include is not None:
allowed_keys &= set(include)
if exclude is not None:
allowed_keys -= set(exclude)
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and key in allowed_keys
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
if is_pydantic_v1_model_instance(obj):
raise PydanticV1NotSupportedError(
"pydantic.v1 models are no longer supported by FastAPI."
f" Please update the model {obj!r}."
)
try:
data = dict(obj)
except Exception as e:
errors: list[Exception] = []
errors.append(e)
try:
data = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors) from e
return jsonable_encoder(
data,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)

View file

@ -0,0 +1,34 @@
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.utils import is_body_allowed_for_status_code
from fastapi.websockets import WebSocket
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import WS_1008_POLICY_VIOLATION
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
headers = getattr(exc, "headers", None)
if not is_body_allowed_for_status_code(exc.status_code):
return Response(status_code=exc.status_code, headers=headers)
return JSONResponse(
{"detail": exc.detail}, status_code=exc.status_code, headers=headers
)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
return JSONResponse(
status_code=422,
content={"detail": jsonable_encoder(exc.errors())},
)
async def websocket_request_validation_exception_handler(
websocket: WebSocket, exc: WebSocketRequestValidationError
) -> None:
await websocket.close(
code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors())
)

View file

@ -0,0 +1,256 @@
from collections.abc import Mapping, Sequence
from typing import Annotated, Any, TypedDict
from annotated_doc import Doc
from pydantic import BaseModel, create_model
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.exceptions import WebSocketException as StarletteWebSocketException
class EndpointContext(TypedDict, total=False):
function: str
path: str
file: str
line: int
class HTTPException(StarletteHTTPException):
"""
An HTTP exception you can raise in your own code to show errors to the client.
This is for client errors, invalid authentication, invalid data, etc. Not for server
errors in your code.
Read more about it in the
[FastAPI docs for Handling Errors](https://fastapi.tiangolo.com/tutorial/handling-errors/).
## Example
```python
from fastapi import FastAPI, HTTPException
app = FastAPI()
items = {"foo": "The Foo Wrestlers"}
@app.get("/items/{item_id}")
async def read_item(item_id: str):
if item_id not in items:
raise HTTPException(status_code=404, detail="Item not found")
return {"item": items[item_id]}
```
"""
def __init__(
self,
status_code: Annotated[
int,
Doc(
"""
HTTP status code to send to the client.
Read more about it in the
[FastAPI docs for Handling Errors](https://fastapi.tiangolo.com/tutorial/handling-errors/#use-httpexception)
"""
),
],
detail: Annotated[
Any,
Doc(
"""
Any data to be sent to the client in the `detail` key of the JSON
response.
Read more about it in the
[FastAPI docs for Handling Errors](https://fastapi.tiangolo.com/tutorial/handling-errors/#use-httpexception)
"""
),
] = None,
headers: Annotated[
Mapping[str, str] | None,
Doc(
"""
Any headers to send to the client in the response.
Read more about it in the
[FastAPI docs for Handling Errors](https://fastapi.tiangolo.com/tutorial/handling-errors/#add-custom-headers)
"""
),
] = None,
) -> None:
super().__init__(status_code=status_code, detail=detail, headers=headers)
class WebSocketException(StarletteWebSocketException):
"""
A WebSocket exception you can raise in your own code to show errors to the client.
This is for client errors, invalid authentication, invalid data, etc. Not for server
errors in your code.
Read more about it in the
[FastAPI docs for WebSockets](https://fastapi.tiangolo.com/advanced/websockets/).
## Example
```python
from typing import Annotated
from fastapi import (
Cookie,
FastAPI,
WebSocket,
WebSocketException,
status,
)
app = FastAPI()
@app.websocket("/items/{item_id}/ws")
async def websocket_endpoint(
*,
websocket: WebSocket,
session: Annotated[str | None, Cookie()] = None,
item_id: str,
):
if session is None:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Session cookie is: {session}")
await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}")
```
"""
def __init__(
self,
code: Annotated[
int,
Doc(
"""
A closing code from the
[valid codes defined in the specification](https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1).
"""
),
],
reason: Annotated[
str | None,
Doc(
"""
The reason to close the WebSocket connection.
It is UTF-8-encoded data. The interpretation of the reason is up to the
application, it is not specified by the WebSocket specification.
It could contain text that could be human-readable or interpretable
by the client code, etc.
"""
),
] = None,
) -> None:
super().__init__(code=code, reason=reason)
RequestErrorModel: type[BaseModel] = create_model("Request")
WebSocketErrorModel: type[BaseModel] = create_model("WebSocket")
class FastAPIError(RuntimeError):
"""
A generic, FastAPI-specific error.
"""
class DependencyScopeError(FastAPIError):
"""
A dependency declared that it depends on another dependency with an invalid
(narrower) scope.
"""
class ValidationException(Exception):
def __init__(
self,
errors: Sequence[Any],
*,
endpoint_ctx: EndpointContext | None = None,
) -> None:
self._errors = errors
self.endpoint_ctx = endpoint_ctx
ctx = endpoint_ctx or {}
self.endpoint_function = ctx.get("function")
self.endpoint_path = ctx.get("path")
self.endpoint_file = ctx.get("file")
self.endpoint_line = ctx.get("line")
def errors(self) -> Sequence[Any]:
return self._errors
def _format_endpoint_context(self) -> str:
if not (self.endpoint_file and self.endpoint_line and self.endpoint_function):
if self.endpoint_path:
return f"\n Endpoint: {self.endpoint_path}"
return ""
context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}'
if self.endpoint_path:
context += f"\n {self.endpoint_path}"
return context
def __str__(self) -> str:
message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n"
for err in self._errors:
message += f" {err}\n"
message += self._format_endpoint_context()
return message.rstrip()
class RequestValidationError(ValidationException):
def __init__(
self,
errors: Sequence[Any],
*,
body: Any = None,
endpoint_ctx: EndpointContext | None = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body
class WebSocketRequestValidationError(ValidationException):
def __init__(
self,
errors: Sequence[Any],
*,
endpoint_ctx: EndpointContext | None = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
class ResponseValidationError(ValidationException):
def __init__(
self,
errors: Sequence[Any],
*,
body: Any = None,
endpoint_ctx: EndpointContext | None = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body
class PydanticV1NotSupportedError(FastAPIError):
"""
A pydantic.v1 model is used, which is no longer supported.
"""
class FastAPIDeprecationWarning(UserWarning):
"""
A custom deprecation warning as DeprecationWarning is ignored
Ref: https://sethmlarson.dev/deprecations-via-warnings-dont-work-for-python-libraries
"""

View file

@ -0,0 +1,3 @@
import logging
logger = logging.getLogger("fastapi")

View file

@ -0,0 +1 @@
from starlette.middleware import Middleware as Middleware

View file

@ -0,0 +1,18 @@
from contextlib import AsyncExitStack
from starlette.types import ASGIApp, Receive, Scope, Send
# Used mainly to close files after the request is done, dependencies are closed
# in their own AsyncExitStack
class AsyncExitStackMiddleware:
def __init__(
self, app: ASGIApp, context_name: str = "fastapi_middleware_astack"
) -> None:
self.app = app
self.context_name = context_name
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with AsyncExitStack() as stack:
scope[self.context_name] = stack
await self.app(scope, receive, send)

View file

@ -0,0 +1 @@
from starlette.middleware.cors import CORSMiddleware as CORSMiddleware # noqa

View file

@ -0,0 +1 @@
from starlette.middleware.gzip import GZipMiddleware as GZipMiddleware # noqa

View file

@ -0,0 +1,3 @@
from starlette.middleware.httpsredirect import ( # noqa
HTTPSRedirectMiddleware as HTTPSRedirectMiddleware,
)

View file

@ -0,0 +1,3 @@
from starlette.middleware.trustedhost import ( # noqa
TrustedHostMiddleware as TrustedHostMiddleware,
)

View file

@ -0,0 +1,3 @@
from starlette.middleware.wsgi import (
WSGIMiddleware as WSGIMiddleware,
) # pragma: no cover # noqa

View file

@ -0,0 +1,3 @@
METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}
REF_PREFIX = "#/components/schemas/"
REF_TEMPLATE = "#/components/schemas/{model}"

View file

@ -0,0 +1,389 @@
import json
from typing import Annotated, Any
from annotated_doc import Doc
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse
def _html_safe_json(value: Any) -> str:
"""Serialize a value to JSON with HTML special characters escaped.
This prevents injection when the JSON is embedded inside a <script> tag.
"""
return (
json.dumps(value)
.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
)
swagger_ui_default_parameters: Annotated[
dict[str, Any],
Doc(
"""
Default configurations for Swagger UI.
You can use it as a template to add any other configurations needed.
"""
),
] = {
"dom_id": "#swagger-ui",
"layout": "BaseLayout",
"deepLinking": True,
"showExtensions": True,
"showCommonExtensions": True,
}
def get_swagger_ui_html(
*,
openapi_url: Annotated[
str,
Doc(
"""
The OpenAPI URL that Swagger UI should load and use.
This is normally done automatically by FastAPI using the default URL
`/openapi.json`.
Read more about it in the
[FastAPI docs for Conditional OpenAPI](https://fastapi.tiangolo.com/how-to/conditional-openapi/#conditional-openapi-from-settings-and-env-vars)
"""
),
],
title: Annotated[
str,
Doc(
"""
The HTML `<title>` content, normally shown in the browser tab.
Read more about it in the
[FastAPI docs for Custom Docs UI Static Assets](https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/)
"""
),
],
swagger_js_url: Annotated[
str,
Doc(
"""
The URL to use to load the Swagger UI JavaScript.
It is normally set to a CDN URL.
Read more about it in the
[FastAPI docs for Custom Docs UI Static Assets](https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/)
"""
),
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
swagger_css_url: Annotated[
str,
Doc(
"""
The URL to use to load the Swagger UI CSS.
It is normally set to a CDN URL.
Read more about it in the
[FastAPI docs for Custom Docs UI Static Assets](https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/)
"""
),
] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
swagger_favicon_url: Annotated[
str,
Doc(
"""
The URL of the favicon to use. It is normally shown in the browser tab.
"""
),
] = "https://fastapi.tiangolo.com/img/favicon.png",
oauth2_redirect_url: Annotated[
str | None,
Doc(
"""
The OAuth2 redirect URL, it is normally automatically handled by FastAPI.
Read more about it in the
[FastAPI docs for Custom Docs UI Static Assets](https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/)
"""
),
] = None,
init_oauth: Annotated[
dict[str, Any] | None,
Doc(
"""
A dictionary with Swagger UI OAuth2 initialization configurations.
Read more about the available configuration options in the
[Swagger UI docs](https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/).
"""
),
] = None,
swagger_ui_parameters: Annotated[
dict[str, Any] | None,
Doc(
"""
Configuration parameters for Swagger UI.
It defaults to [swagger_ui_default_parameters][fastapi.openapi.docs.swagger_ui_default_parameters].
Read more about it in the
[FastAPI docs about how to Configure Swagger UI](https://fastapi.tiangolo.com/how-to/configure-swagger-ui/).
"""
),
] = None,
) -> HTMLResponse:
"""
Generate and return the HTML that loads Swagger UI for the interactive
API docs (normally served at `/docs`).
You would only call this function yourself if you needed to override some parts,
for example the URLs to use to load Swagger UI's JavaScript and CSS.
Read more about it in the
[FastAPI docs for Configure Swagger UI](https://fastapi.tiangolo.com/how-to/configure-swagger-ui/)
and the [FastAPI docs for Custom Docs UI Static Assets (Self-Hosting)](https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/).
"""
current_swagger_ui_parameters = swagger_ui_default_parameters.copy()
if swagger_ui_parameters:
current_swagger_ui_parameters.update(swagger_ui_parameters)
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link type="text/css" rel="stylesheet" href="{swagger_css_url}">
<link rel="shortcut icon" href="{swagger_favicon_url}">
<title>{title}</title>
</head>
<body>
<div id="swagger-ui">
</div>
<script src="{swagger_js_url}"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({{
url: '{openapi_url}',
"""
for key, value in current_swagger_ui_parameters.items():
html += f"{_html_safe_json(key)}: {_html_safe_json(jsonable_encoder(value))},\n"
if oauth2_redirect_url:
html += f"oauth2RedirectUrl: window.location.origin + '{oauth2_redirect_url}',"
html += """
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
})"""
if init_oauth:
html += f"""
ui.initOAuth({_html_safe_json(jsonable_encoder(init_oauth))})
"""
html += """
</script>
</body>
</html>
"""
return HTMLResponse(html)
def get_redoc_html(
*,
openapi_url: Annotated[
str,
Doc(
"""
The OpenAPI URL that ReDoc should load and use.
This is normally done automatically by FastAPI using the default URL
`/openapi.json`.
Read more about it in the
[FastAPI docs for Conditional OpenAPI](https://fastapi.tiangolo.com/how-to/conditional-openapi/#conditional-openapi-from-settings-and-env-vars)
"""
),
],
title: Annotated[
str,
Doc(
"""
The HTML `<title>` content, normally shown in the browser tab.
Read more about it in the
[FastAPI docs for Custom Docs UI Static Assets](https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/)
"""
),
],
redoc_js_url: Annotated[
str,
Doc(
"""
The URL to use to load the ReDoc JavaScript.
It is normally set to a CDN URL.
Read more about it in the
[FastAPI docs for Custom Docs UI Static Assets](https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/)
"""
),
] = "https://cdn.jsdelivr.net/npm/redoc@2/bundles/redoc.standalone.js",
redoc_favicon_url: Annotated[
str,
Doc(
"""
The URL of the favicon to use. It is normally shown in the browser tab.
"""
),
] = "https://fastapi.tiangolo.com/img/favicon.png",
with_google_fonts: Annotated[
bool,
Doc(
"""
Load and use Google Fonts.
"""
),
] = True,
) -> HTMLResponse:
"""
Generate and return the HTML response that loads ReDoc for the alternative
API docs (normally served at `/redoc`).
You would only call this function yourself if you needed to override some parts,
for example the URLs to use to load ReDoc's JavaScript and CSS.
Read more about it in the
[FastAPI docs for Custom Docs UI Static Assets (Self-Hosting)](https://fastapi.tiangolo.com/how-to/custom-docs-ui-assets/).
"""
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>{title}</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
"""
if with_google_fonts:
html += """
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
"""
html += f"""
<link rel="shortcut icon" href="{redoc_favicon_url}">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {{
margin: 0;
padding: 0;
}}
</style>
</head>
<body>
<noscript>
ReDoc requires Javascript to function. Please enable it to browse the documentation.
</noscript>
<redoc spec-url="{openapi_url}"></redoc>
<script src="{redoc_js_url}"> </script>
</body>
</html>
"""
return HTMLResponse(html)
def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse:
"""
Generate the HTML response with the OAuth2 redirection for Swagger UI.
You normally don't need to use or change this.
"""
# copied from https://github.com/swagger-api/swagger-ui/blob/v4.14.0/dist/oauth2-redirect.html
html = """
<!doctype html>
<html lang="en-US">
<head>
<title>Swagger UI: OAuth2 Redirect</title>
</head>
<body>
<script>
'use strict';
function run () {
var oauth2 = window.opener.swaggerUIRedirectOauth2;
var sentState = oauth2.state;
var redirectUrl = oauth2.redirectUrl;
var isValid, qp, arr;
if (/code|token|error/.test(window.location.hash)) {
qp = window.location.hash.substring(1).replace('?', '&');
} else {
qp = location.search.substring(1);
}
arr = qp.split("&");
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
qp = qp ? JSON.parse('{' + arr.join() + '}',
function (key, value) {
return key === "" ? value : decodeURIComponent(value);
}
) : {};
isValid = qp.state === sentState;
if ((
oauth2.auth.schema.get("flow") === "accessCode" ||
oauth2.auth.schema.get("flow") === "authorizationCode" ||
oauth2.auth.schema.get("flow") === "authorization_code"
) && !oauth2.auth.code) {
if (!isValid) {
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "warning",
message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server."
});
}
if (qp.code) {
delete oauth2.state;
oauth2.auth.code = qp.code;
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
} else {
let oauthErrorMsg;
if (qp.error) {
oauthErrorMsg = "["+qp.error+"]: " +
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
(qp.error_uri ? "More info: "+qp.error_uri : "");
}
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "error",
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server."
});
}
} else {
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
}
window.close();
}
if (document.readyState !== 'loading') {
run();
} else {
document.addEventListener('DOMContentLoaded', function () {
run();
});
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html)

View file

@ -0,0 +1,435 @@
from collections.abc import Callable, Iterable, Mapping
from enum import Enum
from typing import Annotated, Any, Literal, Optional, Union
from fastapi._compat import with_info_plain_validator_function
from fastapi.logger import logger
from pydantic import (
AnyUrl,
BaseModel,
Field,
GetJsonSchemaHandler,
)
from typing_extensions import TypedDict
from typing_extensions import deprecated as typing_deprecated
try:
import email_validator
assert email_validator # make autoflake ignore the unused import
from pydantic import EmailStr
except ImportError: # pragma: no cover
class EmailStr(str): # type: ignore # ty: ignore[unused-ignore-comment]
@classmethod
def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
yield cls.validate
@classmethod
def validate(cls, v: Any) -> str:
logger.warning(
"email-validator not installed, email fields will be treated as str.\n"
"To install, run: pip install email-validator"
)
return str(v)
@classmethod
def _validate(cls, __input_value: Any, _: Any) -> str:
logger.warning(
"email-validator not installed, email fields will be treated as str.\n"
"To install, run: pip install email-validator"
)
return str(__input_value)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: Mapping[str, Any], handler: GetJsonSchemaHandler
) -> dict[str, Any]:
return {"type": "string", "format": "email"}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type[Any], handler: Callable[[Any], Mapping[str, Any]]
) -> Mapping[str, Any]:
return with_info_plain_validator_function(cls._validate)
class BaseModelWithConfig(BaseModel):
model_config = {"extra": "allow"}
class Contact(BaseModelWithConfig):
name: str | None = None
url: AnyUrl | None = None
email: EmailStr | None = None
class License(BaseModelWithConfig):
name: str
identifier: str | None = None
url: AnyUrl | None = None
class Info(BaseModelWithConfig):
title: str
summary: str | None = None
description: str | None = None
termsOfService: str | None = None
contact: Contact | None = None
license: License | None = None
version: str
class ServerVariable(BaseModelWithConfig):
enum: Annotated[list[str] | None, Field(min_length=1)] = None
default: str
description: str | None = None
class Server(BaseModelWithConfig):
url: AnyUrl | str
description: str | None = None
variables: dict[str, ServerVariable] | None = None
class Reference(BaseModel):
ref: str = Field(alias="$ref")
class Discriminator(BaseModel):
propertyName: str
mapping: dict[str, str] | None = None
class XML(BaseModelWithConfig):
name: str | None = None
namespace: str | None = None
prefix: str | None = None
attribute: bool | None = None
wrapped: bool | None = None
class ExternalDocumentation(BaseModelWithConfig):
description: str | None = None
url: AnyUrl
# Ref JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation#name-type
SchemaType = Literal[
"array", "boolean", "integer", "null", "number", "object", "string"
]
class Schema(BaseModelWithConfig):
# Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu
# Core Vocabulary
schema_: str | None = Field(default=None, alias="$schema")
vocabulary: str | None = Field(default=None, alias="$vocabulary")
id: str | None = Field(default=None, alias="$id")
anchor: str | None = Field(default=None, alias="$anchor")
dynamicAnchor: str | None = Field(default=None, alias="$dynamicAnchor")
ref: str | None = Field(default=None, alias="$ref")
dynamicRef: str | None = Field(default=None, alias="$dynamicRef")
defs: dict[str, "SchemaOrBool"] | None = Field(default=None, alias="$defs")
comment: str | None = Field(default=None, alias="$comment")
# Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-a-vocabulary-for-applying-s
# A Vocabulary for Applying Subschemas
allOf: list["SchemaOrBool"] | None = None
anyOf: list["SchemaOrBool"] | None = None
oneOf: list["SchemaOrBool"] | None = None
not_: Optional["SchemaOrBool"] = Field(default=None, alias="not")
if_: Optional["SchemaOrBool"] = Field(default=None, alias="if")
then: Optional["SchemaOrBool"] = None
else_: Optional["SchemaOrBool"] = Field(default=None, alias="else")
dependentSchemas: dict[str, "SchemaOrBool"] | None = None
prefixItems: list["SchemaOrBool"] | None = None
items: Optional["SchemaOrBool"] = None
contains: Optional["SchemaOrBool"] = None
properties: dict[str, "SchemaOrBool"] | None = None
patternProperties: dict[str, "SchemaOrBool"] | None = None
additionalProperties: Optional["SchemaOrBool"] = None
propertyNames: Optional["SchemaOrBool"] = None
unevaluatedItems: Optional["SchemaOrBool"] = None
unevaluatedProperties: Optional["SchemaOrBool"] = None
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural
# A Vocabulary for Structural Validation
type: SchemaType | list[SchemaType] | None = None
enum: list[Any] | None = None
const: Any | None = None
multipleOf: float | None = Field(default=None, gt=0)
maximum: float | None = None
exclusiveMaximum: float | None = None
minimum: float | None = None
exclusiveMinimum: float | None = None
maxLength: int | None = Field(default=None, ge=0)
minLength: int | None = Field(default=None, ge=0)
pattern: str | None = None
maxItems: int | None = Field(default=None, ge=0)
minItems: int | None = Field(default=None, ge=0)
uniqueItems: bool | None = None
maxContains: int | None = Field(default=None, ge=0)
minContains: int | None = Field(default=None, ge=0)
maxProperties: int | None = Field(default=None, ge=0)
minProperties: int | None = Field(default=None, ge=0)
required: list[str] | None = None
dependentRequired: dict[str, set[str]] | None = None
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-vocabularies-for-semantic-c
# Vocabularies for Semantic Content With "format"
format: str | None = None
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-the-conten
# A Vocabulary for the Contents of String-Encoded Data
contentEncoding: str | None = None
contentMediaType: str | None = None
contentSchema: Optional["SchemaOrBool"] = None
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-basic-meta
# A Vocabulary for Basic Meta-Data Annotations
title: str | None = None
description: str | None = None
default: Any | None = None
deprecated: bool | None = None
readOnly: bool | None = None
writeOnly: bool | None = None
examples: list[Any] | None = None
# Ref: OpenAPI 3.1.0: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#schema-object
# Schema Object
discriminator: Discriminator | None = None
xml: XML | None = None
externalDocs: ExternalDocumentation | None = None
example: Annotated[
Any | None,
typing_deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = None
# Ref: https://json-schema.org/draft/2020-12/json-schema-core.html#name-json-schema-documents
# A JSON Schema MUST be an object or a boolean.
SchemaOrBool = Schema | bool
class Example(TypedDict, total=False):
summary: str | None
description: str | None
value: Any | None
externalValue: AnyUrl | None
__pydantic_config__ = {"extra": "allow"} # type: ignore[misc]
class ParameterInType(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
class Encoding(BaseModelWithConfig):
contentType: str | None = None
headers: dict[str, Union["Header", Reference]] | None = None
style: str | None = None
explode: bool | None = None
allowReserved: bool | None = None
class MediaType(BaseModelWithConfig):
schema_: Schema | Reference | None = Field(default=None, alias="schema")
example: Any | None = None
examples: dict[str, Example | Reference] | None = None
encoding: dict[str, Encoding] | None = None
class ParameterBase(BaseModelWithConfig):
description: str | None = None
required: bool | None = None
deprecated: bool | None = None
# Serialization rules for simple scenarios
style: str | None = None
explode: bool | None = None
allowReserved: bool | None = None
schema_: Schema | Reference | None = Field(default=None, alias="schema")
example: Any | None = None
examples: dict[str, Example | Reference] | None = None
# Serialization rules for more complex scenarios
content: dict[str, MediaType] | None = None
class Parameter(ParameterBase):
name: str
in_: ParameterInType = Field(alias="in")
class Header(ParameterBase):
pass
class RequestBody(BaseModelWithConfig):
description: str | None = None
content: dict[str, MediaType]
required: bool | None = None
class Link(BaseModelWithConfig):
operationRef: str | None = None
operationId: str | None = None
parameters: dict[str, Any | str] | None = None
requestBody: Any | str | None = None
description: str | None = None
server: Server | None = None
class Response(BaseModelWithConfig):
description: str
headers: dict[str, Header | Reference] | None = None
content: dict[str, MediaType] | None = None
links: dict[str, Link | Reference] | None = None
class Operation(BaseModelWithConfig):
tags: list[str] | None = None
summary: str | None = None
description: str | None = None
externalDocs: ExternalDocumentation | None = None
operationId: str | None = None
parameters: list[Parameter | Reference] | None = None
requestBody: RequestBody | Reference | None = None
# Using Any for Specification Extensions
responses: dict[str, Response | Any] | None = None
callbacks: dict[str, dict[str, "PathItem"] | Reference] | None = None
deprecated: bool | None = None
security: list[dict[str, list[str]]] | None = None
servers: list[Server] | None = None
class PathItem(BaseModelWithConfig):
ref: str | None = Field(default=None, alias="$ref")
summary: str | None = None
description: str | None = None
get: Operation | None = None
put: Operation | None = None
post: Operation | None = None
delete: Operation | None = None
options: Operation | None = None
head: Operation | None = None
patch: Operation | None = None
trace: Operation | None = None
servers: list[Server] | None = None
parameters: list[Parameter | Reference] | None = None
class SecuritySchemeType(Enum):
apiKey = "apiKey"
http = "http"
oauth2 = "oauth2"
openIdConnect = "openIdConnect"
class SecurityBase(BaseModelWithConfig):
type_: SecuritySchemeType = Field(alias="type")
description: str | None = None
class APIKeyIn(Enum):
query = "query"
header = "header"
cookie = "cookie"
class APIKey(SecurityBase):
type_: SecuritySchemeType = Field(default=SecuritySchemeType.apiKey, alias="type")
in_: APIKeyIn = Field(alias="in")
name: str
class HTTPBase(SecurityBase):
type_: SecuritySchemeType = Field(default=SecuritySchemeType.http, alias="type")
scheme: str
class HTTPBearer(HTTPBase):
scheme: Literal["bearer"] = "bearer"
bearerFormat: str | None = None
class OAuthFlow(BaseModelWithConfig):
refreshUrl: str | None = None
scopes: dict[str, str] = {}
class OAuthFlowImplicit(OAuthFlow):
authorizationUrl: str
class OAuthFlowPassword(OAuthFlow):
tokenUrl: str
class OAuthFlowClientCredentials(OAuthFlow):
tokenUrl: str
class OAuthFlowAuthorizationCode(OAuthFlow):
authorizationUrl: str
tokenUrl: str
class OAuthFlows(BaseModelWithConfig):
implicit: OAuthFlowImplicit | None = None
password: OAuthFlowPassword | None = None
clientCredentials: OAuthFlowClientCredentials | None = None
authorizationCode: OAuthFlowAuthorizationCode | None = None
class OAuth2(SecurityBase):
type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type")
flows: OAuthFlows
class OpenIdConnect(SecurityBase):
type_: SecuritySchemeType = Field(
default=SecuritySchemeType.openIdConnect, alias="type"
)
openIdConnectUrl: str
SecurityScheme = APIKey | HTTPBase | OAuth2 | OpenIdConnect | HTTPBearer
class Components(BaseModelWithConfig):
schemas: dict[str, Schema | Reference] | None = None
responses: dict[str, Response | Reference] | None = None
parameters: dict[str, Parameter | Reference] | None = None
examples: dict[str, Example | Reference] | None = None
requestBodies: dict[str, RequestBody | Reference] | None = None
headers: dict[str, Header | Reference] | None = None
securitySchemes: dict[str, SecurityScheme | Reference] | None = None
links: dict[str, Link | Reference] | None = None
# Using Any for Specification Extensions
callbacks: dict[str, dict[str, PathItem] | Reference | Any] | None = None
pathItems: dict[str, PathItem | Reference] | None = None
class Tag(BaseModelWithConfig):
name: str
description: str | None = None
externalDocs: ExternalDocumentation | None = None
class OpenAPI(BaseModelWithConfig):
openapi: str
info: Info
jsonSchemaDialect: str | None = None
servers: list[Server] | None = None
# Using Any for Specification Extensions
paths: dict[str, PathItem | Any] | None = None
webhooks: dict[str, PathItem | Reference] | None = None
components: Components | None = None
security: list[dict[str, list[str]]] | None = None
tags: list[Tag] | None = None
externalDocs: ExternalDocumentation | None = None
Schema.model_rebuild()
Operation.model_rebuild()
Encoding.model_rebuild()

View file

@ -0,0 +1,606 @@
import copy
import http.client
import inspect
import warnings
from collections.abc import Sequence
from typing import Any, Literal, cast
from fastapi import routing
from fastapi._compat import (
ModelField,
get_definitions,
get_flat_models_from_fields,
get_model_name_map,
get_schema_from_model_field,
lenient_issubclass,
)
from fastapi.datastructures import DefaultPlaceholder, _Unset
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
_get_flat_fields_from_params,
get_flat_dependant,
get_flat_params,
get_validation_alias,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, ParamTypes
from fastapi.responses import Response
from fastapi.sse import _SSE_EVENT_SCHEMA
from fastapi.types import ModelNameMap
from fastapi.utils import (
deep_dict_update,
generate_operation_id_for_path,
is_body_allowed_for_status_code,
)
from pydantic import BaseModel
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
validation_error_definition = {
"title": "ValidationError",
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
"input": {"title": "Input"},
"ctx": {"title": "Context", "type": "object"},
},
"required": ["loc", "msg", "type"],
}
validation_error_response_definition = {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": REF_PREFIX + "ValidationError"},
}
},
}
status_code_ranges: dict[str, str] = {
"1XX": "Information",
"2XX": "Success",
"3XX": "Redirection",
"4XX": "Client Error",
"5XX": "Server Error",
"DEFAULT": "Default Response",
}
def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
security_definitions = {}
# Use a dict to merge scopes for same security scheme
operation_security_dict: dict[str, list[str]] = {}
for security_dependency in flat_dependant._security_dependencies:
security_definition = jsonable_encoder(
security_dependency._security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_name = security_dependency._security_scheme.scheme_name
security_definitions[security_name] = security_definition
# Merge scopes for the same security scheme
if security_name not in operation_security_dict:
operation_security_dict[security_name] = []
for scope in security_dependency.oauth_scopes or []:
if scope not in operation_security_dict[security_name]:
operation_security_dict[security_name].append(scope)
operation_security = [
{name: scopes} for name, scopes in operation_security_dict.items()
]
return security_definitions, operation_security
def _get_openapi_operation_parameters(
*,
dependant: Dependant,
model_name_map: ModelNameMap,
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> list[dict[str, Any]]:
parameters = []
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
parameter_groups = [
(ParamTypes.path, path_params),
(ParamTypes.query, query_params),
(ParamTypes.header, header_params),
(ParamTypes.cookie, cookie_params),
]
default_convert_underscores = True
if len(flat_dependant.header_params) == 1:
first_field = flat_dependant.header_params[0]
if lenient_issubclass(first_field.field_info.annotation, BaseModel):
default_convert_underscores = getattr(
first_field.field_info, "convert_underscores", True
)
for param_type, param_group in parameter_groups:
for param in param_group:
field_info = param.field_info
# field_info = cast(Param, field_info)
if not getattr(field_info, "include_in_schema", True):
continue
param_schema = get_schema_from_model_field(
field=param,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
name = get_validation_alias(param)
convert_underscores = getattr(
param.field_info,
"convert_underscores",
default_convert_underscores,
)
if (
param_type == ParamTypes.header
and name == param.name
and convert_underscores
):
name = param.name.replace("_", "-")
parameter = {
"name": name,
"in": param_type.value,
"required": param.field_info.is_required(),
"schema": param_schema,
}
if field_info.description:
parameter["description"] = field_info.description
openapi_examples = getattr(field_info, "openapi_examples", None)
example = getattr(field_info, "example", None)
if openapi_examples:
parameter["examples"] = jsonable_encoder(openapi_examples)
elif example is not _Unset:
parameter["example"] = jsonable_encoder(example)
if getattr(field_info, "deprecated", None):
parameter["deprecated"] = True
parameters.append(parameter)
return parameters
def get_openapi_operation_request_body(
*,
body_field: ModelField | None,
model_name_map: ModelNameMap,
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> dict[str, Any] | None:
if not body_field:
return None
assert isinstance(body_field, ModelField)
body_schema = get_schema_from_model_field(
field=body_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.field_info.is_required()
request_body_oai: dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_media_content: dict[str, Any] = {"schema": body_schema}
if field_info.openapi_examples:
request_media_content["examples"] = jsonable_encoder(
field_info.openapi_examples
)
elif field_info.example is not _Unset:
request_media_content["example"] = jsonable_encoder(field_info.example)
request_body_oai["content"] = {request_media_type: request_media_content}
return request_body_oai
def generate_operation_id(
*, route: routing.APIRoute, method: str
) -> str: # pragma: nocover
warnings.warn(
message="fastapi.openapi.utils.generate_operation_id() was deprecated, "
"it is not used internally, and will be removed soon",
category=FastAPIDeprecationWarning,
stacklevel=2,
)
if route.operation_id:
return route.operation_id
path: str = route.path_format
return generate_operation_id_for_path(name=route.name, path=path, method=method)
def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
if route.summary:
return route.summary
return route.name.replace("_", " ").title()
def get_openapi_operation_metadata(
*, route: routing.APIRoute, method: str, operation_ids: set[str]
) -> dict[str, Any]:
operation: dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
operation["summary"] = generate_operation_summary(route=route, method=method)
if route.description:
operation["description"] = route.description
operation_id = route.operation_id or route.unique_id
if operation_id in operation_ids:
endpoint_name = getattr(route.endpoint, "__name__", "<unnamed_endpoint>")
message = f"Duplicate Operation ID {operation_id} for function {endpoint_name}"
file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
if file_name:
message += f" at {file_name}"
warnings.warn(message, stacklevel=1)
operation_ids.add(operation_id)
operation["operationId"] = operation_id
if route.deprecated:
operation["deprecated"] = route.deprecated
return operation
def get_openapi_path(
*,
route: routing.APIRoute,
operation_ids: set[str],
model_name_map: ModelNameMap,
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
path = {}
security_schemes: dict[str, Any] = {}
definitions: dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
if isinstance(route.response_class, DefaultPlaceholder):
current_response_class: type[Response] = route.response_class.value
else:
current_response_class = route.response_class
assert current_response_class, "A response class is needed to generate OpenAPI"
route_response_media_type: str | None = current_response_class.media_type
if route.include_in_schema:
for method in route.methods:
operation = get_openapi_operation_metadata(
route=route, method=method, operation_ids=operation_ids
)
parameters: list[dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
)
if operation_security:
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(security_definitions)
operation_parameters = _get_openapi_operation_parameters(
dependant=route.dependant,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
parameters.extend(operation_parameters)
if parameters:
all_parameters = {
(param["in"], param["name"]): param for param in parameters
}
required_parameters = {
(param["in"], param["name"]): param
for param in parameters
if param.get("required")
}
# Make sure required definitions of the same parameter take precedence
# over non-required definitions
all_parameters.update(required_parameters)
operation["parameters"] = list(all_parameters.values())
if method in METHODS_WITH_BODY:
request_body_oai = get_openapi_operation_request_body(
body_field=route.body_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
if request_body_oai:
operation["requestBody"] = request_body_oai
if route.callbacks:
callbacks = {}
for callback in route.callbacks:
if isinstance(callback, routing.APIRoute):
(
cb_path,
cb_security_schemes,
cb_definitions,
) = get_openapi_path(
route=callback,
operation_ids=operation_ids,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks
if route.status_code is not None:
status_code = str(route.status_code)
else:
# It would probably make more sense for all response classes to have an
# explicit default status_code, and to extract it from them, instead of
# doing this inspection tricks, that would probably be in the future
# TODO: probably make status_code a default class attribute for all
# responses in Starlette
response_signature = inspect.signature(current_response_class.__init__)
status_code_param = response_signature.parameters.get("status_code")
if status_code_param is not None:
if isinstance(status_code_param.default, int):
status_code = str(status_code_param.default)
operation.setdefault("responses", {}).setdefault(status_code, {})[
"description"
] = route.response_description
if is_body_allowed_for_status_code(route.status_code):
# Check for JSONL streaming (generator endpoints)
if route.is_json_stream:
jsonl_content: dict[str, Any] = {}
if route.stream_item_field:
item_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
jsonl_content["itemSchema"] = item_schema
else:
jsonl_content["itemSchema"] = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["application/jsonl"] = jsonl_content
elif route.is_sse_stream:
sse_content: dict[str, Any] = {}
item_schema = copy.deepcopy(_SSE_EVENT_SCHEMA)
if route.stream_item_field:
content_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
item_schema["required"] = ["data"]
item_schema["properties"]["data"] = {
"type": "string",
"contentMediaType": "application/json",
"contentSchema": content_schema,
}
sse_content["itemSchema"] = item_schema
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["text/event-stream"] = sse_content
elif route_response_media_type:
response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse):
if route.response_field:
response_schema = get_schema_from_model_field(
field=route.response_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(
route_response_media_type, {}
)["schema"] = response_schema
if route.responses:
operation_responses = operation.setdefault("responses", {})
for (
additional_status_code,
additional_response,
) in route.responses.items():
process_response = copy.deepcopy(additional_response)
process_response.pop("model", None)
status_code_key = str(additional_status_code).upper()
if status_code_key == "DEFAULT":
status_code_key = "default"
openapi_response = operation_responses.setdefault(
status_code_key, {}
)
assert isinstance(process_response, dict), (
"An additional response must be a dict"
)
field = route.response_fields.get(additional_status_code)
additional_field_schema: dict[str, Any] | None = None
if field:
additional_field_schema = get_schema_from_model_field(
field=field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
media_type = route_response_media_type or "application/json"
additional_schema = (
process_response.setdefault("content", {})
.setdefault(media_type, {})
.setdefault("schema", {})
)
deep_dict_update(additional_schema, additional_field_schema)
status_text: str | None = status_code_ranges.get(
str(additional_status_code).upper()
) or http.client.responses.get(int(additional_status_code))
description = (
process_response.get("description")
or openapi_response.get("description")
or status_text
or "Additional Response"
)
deep_dict_update(openapi_response, process_response)
openapi_response["description"] = description
http422 = "422"
all_route_params = get_flat_params(route.dependant)
if (all_route_params or route.body_field) and not any(
status in operation["responses"]
for status in [http422, "4XX", "default"]
):
operation["responses"][http422] = {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
}
},
}
if "ValidationError" not in definitions:
definitions.update(
{
"ValidationError": validation_error_definition,
"HTTPValidationError": validation_error_response_definition,
}
)
if route.openapi_extra:
deep_dict_update(operation, route.openapi_extra)
path[method.lower()] = operation
return path, security_schemes, definitions
def get_fields_from_routes(
routes: Sequence[BaseRoute],
) -> list[ModelField]:
body_fields_from_routes: list[ModelField] = []
responses_from_routes: list[ModelField] = []
request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = []
for route in routes:
if not isinstance(route, routing.APIRoute):
continue
if route.include_in_schema:
if route.body_field:
assert isinstance(route.body_field, ModelField), (
"A request body must be a Pydantic Field"
)
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.stream_item_field:
responses_from_routes.append(route.stream_item_field)
if route.callbacks:
callback_flat_models.extend(get_fields_from_routes(route.callbacks))
params = get_flat_params(route.dependant)
request_fields_from_routes.extend(params)
flat_models = callback_flat_models + list(
body_fields_from_routes + responses_from_routes + request_fields_from_routes
)
return flat_models
def get_openapi(
*,
title: str,
version: str,
openapi_version: str = "3.1.0",
summary: str | None = None,
description: str | None = None,
routes: Sequence[BaseRoute],
webhooks: Sequence[BaseRoute] | None = None,
tags: list[dict[str, Any]] | None = None,
servers: list[dict[str, str | Any]] | None = None,
terms_of_service: str | None = None,
contact: dict[str, str | Any] | None = None,
license_info: dict[str, str | Any] | None = None,
separate_input_output_schemas: bool = True,
external_docs: dict[str, Any] | None = None,
) -> dict[str, Any]:
info: dict[str, Any] = {"title": title, "version": version}
if summary:
info["summary"] = summary
if description:
info["description"] = description
if terms_of_service:
info["termsOfService"] = terms_of_service
if contact:
info["contact"] = contact
if license_info:
info["license"] = license_info
output: dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: dict[str, dict[str, Any]] = {}
paths: dict[str, dict[str, Any]] = {}
webhook_paths: dict[str, dict[str, Any]] = {}
operation_ids: set[str] = set()
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
flat_models = get_flat_models_from_fields(all_fields, known_models=set())
model_name_map = get_model_name_map(flat_models)
field_mapping, definitions = get_definitions(
fields=all_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
for route in routes or []:
if isinstance(route, routing.APIRoute):
result = get_openapi_path(
route=route,
operation_ids=operation_ids,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
if result:
path, security_schemes, path_definitions = result
if path:
paths.setdefault(route.path_format, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(
security_schemes
)
if path_definitions:
definitions.update(path_definitions)
for webhook in webhooks or []:
if isinstance(webhook, routing.APIRoute):
result = get_openapi_path(
route=webhook,
operation_ids=operation_ids,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
if result:
path, security_schemes, path_definitions = result
if path:
webhook_paths.setdefault(webhook.path_format, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(
security_schemes
)
if path_definitions:
definitions.update(path_definitions)
if definitions:
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
if components:
output["components"] = components
output["paths"] = paths
if webhook_paths:
output["webhooks"] = webhook_paths
if tags:
output["tags"] = tags
if external_docs:
output["externalDocs"] = external_docs
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore # ty: ignore[unused-ignore-comment]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,754 @@
import warnings
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Annotated, Any, Literal
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.openapi.models import Example
from pydantic import AliasChoices, AliasPath
from pydantic.fields import FieldInfo
from typing_extensions import deprecated
from ._compat import (
Undefined,
)
from .datastructures import _Unset
class ParamTypes(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
class Param(FieldInfo): # type: ignore[misc]
in_: ParamTypes
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | AliasPath | AliasChoices | None = None,
serialization_alias: str | None = None,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
regex: Annotated[
str | None,
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
example: Annotated[
Any | None,
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: dict[str, Example] | None = None,
deprecated: deprecated | str | bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
category=FastAPIDeprecationWarning,
stacklevel=4,
)
self.example = example
self.include_in_schema = include_in_schema
self.openapi_examples = openapi_examples
kwargs = dict(
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
**extra,
)
if examples is not None:
kwargs["examples"] = examples
if regex is not None:
warnings.warn(
"`regex` has been deprecated, please use `pattern` instead",
category=FastAPIDeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
kwargs["deprecated"] = deprecated
if serialization_alias in (_Unset, None) and isinstance(alias, str):
serialization_alias = alias
if validation_alias in (_Unset, None):
validation_alias = alias
kwargs.update(
{
"annotation": annotation,
"alias_priority": alias_priority,
"validation_alias": validation_alias,
"serialization_alias": serialization_alias,
"strict": strict,
"json_schema_extra": current_json_schema_extra,
}
)
kwargs["pattern"] = pattern or regex
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
class Path(Param): # type: ignore[misc] # ty: ignore[unused-ignore-comment]
in_ = ParamTypes.path
def __init__(
self,
default: Any = ...,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | AliasPath | AliasChoices | None = None,
serialization_alias: str | None = None,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
regex: Annotated[
str | None,
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
example: Annotated[
Any | None,
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: dict[str, Example] | None = None,
deprecated: deprecated | str | bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
assert default is ..., "Path parameters cannot have a default value"
self.in_ = self.in_
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Query(Param): # type: ignore[misc] # ty: ignore[unused-ignore-comment]
in_ = ParamTypes.query
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | AliasPath | AliasChoices | None = None,
serialization_alias: str | None = None,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
regex: Annotated[
str | None,
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
example: Annotated[
Any | None,
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: dict[str, Example] | None = None,
deprecated: deprecated | str | bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Header(Param): # type: ignore[misc] # ty: ignore[unused-ignore-comment]
in_ = ParamTypes.header
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | AliasPath | AliasChoices | None = None,
serialization_alias: str | None = None,
convert_underscores: bool = True,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
regex: Annotated[
str | None,
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
example: Annotated[
Any | None,
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: dict[str, Example] | None = None,
deprecated: deprecated | str | bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
self.convert_underscores = convert_underscores
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Cookie(Param): # type: ignore[misc] # ty: ignore[unused-ignore-comment]
in_ = ParamTypes.cookie
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | AliasPath | AliasChoices | None = None,
serialization_alias: str | None = None,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
regex: Annotated[
str | None,
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
example: Annotated[
Any | None,
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: dict[str, Example] | None = None,
deprecated: deprecated | str | bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Body(FieldInfo): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
embed: bool | None = None,
media_type: str = "application/json",
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | AliasPath | AliasChoices | None = None,
serialization_alias: str | None = None,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
regex: Annotated[
str | None,
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
example: Annotated[
Any | None,
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: dict[str, Example] | None = None,
deprecated: deprecated | str | bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
self.embed = embed
self.media_type = media_type
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
category=FastAPIDeprecationWarning,
stacklevel=4,
)
self.example = example
self.include_in_schema = include_in_schema
self.openapi_examples = openapi_examples
kwargs = dict(
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
**extra,
)
if examples is not None:
kwargs["examples"] = examples
if regex is not None:
warnings.warn(
"`regex` has been deprecated, please use `pattern` instead",
category=FastAPIDeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
kwargs["deprecated"] = deprecated
if serialization_alias in (_Unset, None) and isinstance(alias, str):
serialization_alias = alias
if validation_alias in (_Unset, None):
validation_alias = alias
kwargs.update(
{
"annotation": annotation,
"alias_priority": alias_priority,
"validation_alias": validation_alias,
"serialization_alias": serialization_alias,
"strict": strict,
"json_schema_extra": current_json_schema_extra,
}
)
kwargs["pattern"] = pattern or regex
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
class Form(Body): # type: ignore[misc] # ty: ignore[unused-ignore-comment]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
media_type: str = "application/x-www-form-urlencoded",
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | AliasPath | AliasChoices | None = None,
serialization_alias: str | None = None,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
regex: Annotated[
str | None,
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
example: Annotated[
Any | None,
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: dict[str, Example] | None = None,
deprecated: deprecated | str | bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
media_type=media_type,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class File(Form): # type: ignore[misc] # ty: ignore[unused-ignore-comment]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Callable[[], Any] | None = _Unset,
annotation: Any | None = None,
media_type: str = "multipart/form-data",
alias: str | None = None,
alias_priority: int | None = _Unset,
validation_alias: str | AliasPath | AliasChoices | None = None,
serialization_alias: str | None = None,
title: str | None = None,
description: str | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
regex: Annotated[
str | None,
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: str | None = None,
strict: bool | None = _Unset,
multiple_of: float | None = _Unset,
allow_inf_nan: bool | None = _Unset,
max_digits: int | None = _Unset,
decimal_places: int | None = _Unset,
examples: list[Any] | None = None,
example: Annotated[
Any | None,
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: dict[str, Example] | None = None,
deprecated: deprecated | str | bool | None = None,
include_in_schema: bool = True,
json_schema_extra: dict[str, Any] | None = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
media_type=media_type,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
@dataclass(frozen=True)
class Depends:
dependency: Callable[..., Any] | None = None
use_cache: bool = True
scope: Literal["function", "request"] | None = None
@dataclass(frozen=True)
class Security(Depends):
scopes: Sequence[str] | None = None

View file

@ -0,0 +1,2 @@
from starlette.requests import HTTPConnection as HTTPConnection # noqa: F401
from starlette.requests import Request as Request # noqa: F401

View file

@ -0,0 +1,85 @@
from typing import Any
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.sse import EventSourceResponse as EventSourceResponse # noqa
from starlette.responses import FileResponse as FileResponse # noqa
from starlette.responses import HTMLResponse as HTMLResponse # noqa
from starlette.responses import JSONResponse as JSONResponse # noqa
from starlette.responses import PlainTextResponse as PlainTextResponse # noqa
from starlette.responses import RedirectResponse as RedirectResponse # noqa
from starlette.responses import Response as Response # noqa
from starlette.responses import StreamingResponse as StreamingResponse # noqa
from typing_extensions import deprecated
try:
import ujson
except ImportError: # pragma: nocover
ujson = None # type: ignore
try:
import orjson
except ImportError: # pragma: nocover
orjson = None # type: ignore
@deprecated(
"UJSONResponse is deprecated, FastAPI now serializes data directly to JSON "
"bytes via Pydantic when a return type or response model is set, which is "
"faster and doesn't need a custom response class. Read more in the FastAPI "
"docs: https://fastapi.tiangolo.com/advanced/custom-response/#orjson-or-response-model "
"and https://fastapi.tiangolo.com/tutorial/response-model/",
category=FastAPIDeprecationWarning,
stacklevel=2,
)
class UJSONResponse(JSONResponse):
"""JSON response using the ujson library to serialize data to JSON.
**Deprecated**: `UJSONResponse` is deprecated. FastAPI now serializes data
directly to JSON bytes via Pydantic when a return type or response model is
set, which is faster and doesn't need a custom response class.
Read more in the
[FastAPI docs for Custom Response](https://fastapi.tiangolo.com/advanced/custom-response/#orjson-or-response-model)
and the
[FastAPI docs for Response Model](https://fastapi.tiangolo.com/tutorial/response-model/).
**Note**: `ujson` is not included with FastAPI and must be installed
separately, e.g. `pip install ujson`.
"""
def render(self, content: Any) -> bytes:
assert ujson is not None, "ujson must be installed to use UJSONResponse"
return ujson.dumps(content, ensure_ascii=False).encode("utf-8")
@deprecated(
"ORJSONResponse is deprecated, FastAPI now serializes data directly to JSON "
"bytes via Pydantic when a return type or response model is set, which is "
"faster and doesn't need a custom response class. Read more in the FastAPI "
"docs: https://fastapi.tiangolo.com/advanced/custom-response/#orjson-or-response-model "
"and https://fastapi.tiangolo.com/tutorial/response-model/",
category=FastAPIDeprecationWarning,
stacklevel=2,
)
class ORJSONResponse(JSONResponse):
"""JSON response using the orjson library to serialize data to JSON.
**Deprecated**: `ORJSONResponse` is deprecated. FastAPI now serializes data
directly to JSON bytes via Pydantic when a return type or response model is
set, which is faster and doesn't need a custom response class.
Read more in the
[FastAPI docs for Custom Response](https://fastapi.tiangolo.com/advanced/custom-response/#orjson-or-response-model)
and the
[FastAPI docs for Response Model](https://fastapi.tiangolo.com/tutorial/response-model/).
**Note**: `orjson` is not included with FastAPI and must be installed
separately, e.g. `pip install orjson`.
"""
def render(self, content: Any) -> bytes:
assert orjson is not None, "orjson must be installed to use ORJSONResponse"
return orjson.dumps(
content, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,15 @@
from .api_key import APIKeyCookie as APIKeyCookie
from .api_key import APIKeyHeader as APIKeyHeader
from .api_key import APIKeyQuery as APIKeyQuery
from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials
from .http import HTTPBasic as HTTPBasic
from .http import HTTPBasicCredentials as HTTPBasicCredentials
from .http import HTTPBearer as HTTPBearer
from .http import HTTPDigest as HTTPDigest
from .oauth2 import OAuth2 as OAuth2
from .oauth2 import OAuth2AuthorizationCodeBearer as OAuth2AuthorizationCodeBearer
from .oauth2 import OAuth2PasswordBearer as OAuth2PasswordBearer
from .oauth2 import OAuth2PasswordRequestForm as OAuth2PasswordRequestForm
from .oauth2 import OAuth2PasswordRequestFormStrict as OAuth2PasswordRequestFormStrict
from .oauth2 import SecurityScopes as SecurityScopes
from .open_id_connect_url import OpenIdConnect as OpenIdConnect

View file

@ -0,0 +1,320 @@
from typing import Annotated
from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED
class APIKeyBase(SecurityBase):
model: APIKey
def __init__(
self,
location: APIKeyIn,
name: str,
description: str | None,
scheme_name: str | None,
auto_error: bool,
):
self.auto_error = auto_error
self.model: APIKey = APIKey(
**{"in": location}, # ty: ignore[invalid-argument-type]
name=name,
description=description,
)
self.scheme_name = scheme_name or self.__class__.__name__
def make_not_authenticated_error(self) -> HTTPException:
"""
The WWW-Authenticate header is not standardized for API Key authentication but
the HTTP specification requires that an error of 401 "Unauthorized" must
include a WWW-Authenticate header.
Ref: https://datatracker.ietf.org/doc/html/rfc9110#name-401-unauthorized
For this, this method sends a custom challenge `APIKey`.
"""
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "APIKey"},
)
def check_api_key(self, api_key: str | None) -> str | None:
if not api_key:
if self.auto_error:
raise self.make_not_authenticated_error()
return None
return api_key
class APIKeyQuery(APIKeyBase):
"""
API key authentication using a query parameter.
This defines the name of the query parameter that should be provided in the request
with the API key and integrates that into the OpenAPI documentation. It extracts
the key value sent in the query parameter automatically and provides it as the
dependency result. But it doesn't define how to send that API key to the client.
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
The dependency result will be a string containing the key value.
## Example
```python
from fastapi import Depends, FastAPI
from fastapi.security import APIKeyQuery
app = FastAPI()
query_scheme = APIKeyQuery(name="api_key")
@app.get("/items/")
async def read_items(api_key: str = Depends(query_scheme)):
return {"api_key": api_key}
```
"""
def __init__(
self,
*,
name: Annotated[
str,
Doc("Query parameter name."),
],
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if the query parameter is not provided, `APIKeyQuery` will
automatically cancel the request and send the client an error.
If `auto_error` is set to `False`, when the query parameter is not
available, instead of erroring out, the dependency result will be
`None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, in a query
parameter or in an HTTP Bearer token).
"""
),
] = True,
):
super().__init__(
location=APIKeyIn.query,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
async def __call__(self, request: Request) -> str | None:
api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key)
class APIKeyHeader(APIKeyBase):
"""
API key authentication using a header.
This defines the name of the header that should be provided in the request with
the API key and integrates that into the OpenAPI documentation. It extracts
the key value sent in the header automatically and provides it as the dependency
result. But it doesn't define how to send that key to the client.
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
The dependency result will be a string containing the key value.
## Example
```python
from fastapi import Depends, FastAPI
from fastapi.security import APIKeyHeader
app = FastAPI()
header_scheme = APIKeyHeader(name="x-key")
@app.get("/items/")
async def read_items(key: str = Depends(header_scheme)):
return {"key": key}
```
"""
def __init__(
self,
*,
name: Annotated[str, Doc("Header name.")],
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if the header is not provided, `APIKeyHeader` will
automatically cancel the request and send the client an error.
If `auto_error` is set to `False`, when the header is not available,
instead of erroring out, the dependency result will be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, in a header or
in an HTTP Bearer token).
"""
),
] = True,
):
super().__init__(
location=APIKeyIn.header,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
async def __call__(self, request: Request) -> str | None:
api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key)
class APIKeyCookie(APIKeyBase):
"""
API key authentication using a cookie.
This defines the name of the cookie that should be provided in the request with
the API key and integrates that into the OpenAPI documentation. It extracts
the key value sent in the cookie automatically and provides it as the dependency
result. But it doesn't define how to set that cookie.
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
The dependency result will be a string containing the key value.
## Example
```python
from fastapi import Depends, FastAPI
from fastapi.security import APIKeyCookie
app = FastAPI()
cookie_scheme = APIKeyCookie(name="session")
@app.get("/items/")
async def read_items(session: str = Depends(cookie_scheme)):
return {"session": session}
```
"""
def __init__(
self,
*,
name: Annotated[str, Doc("Cookie name.")],
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if the cookie is not provided, `APIKeyCookie` will
automatically cancel the request and send the client an error.
If `auto_error` is set to `False`, when the cookie is not available,
instead of erroring out, the dependency result will be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, in a cookie or
in an HTTP Bearer token).
"""
),
] = True,
):
super().__init__(
location=APIKeyIn.cookie,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
async def __call__(self, request: Request) -> str | None:
api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key)

View file

@ -0,0 +1,6 @@
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
class SecurityBase:
model: SecurityBaseModel
scheme_name: str

View file

@ -0,0 +1,417 @@
import binascii
from base64 import b64decode
from typing import Annotated
from annotated_doc import Doc
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED
class HTTPBasicCredentials(BaseModel):
"""
The HTTP Basic credentials given as the result of using `HTTPBasic` in a
dependency.
Read more about it in the
[FastAPI docs for HTTP Basic Auth](https://fastapi.tiangolo.com/advanced/security/http-basic-auth/).
"""
username: Annotated[str, Doc("The HTTP Basic username.")]
password: Annotated[str, Doc("The HTTP Basic password.")]
class HTTPAuthorizationCredentials(BaseModel):
"""
The HTTP authorization credentials in the result of using `HTTPBearer` or
`HTTPDigest` in a dependency.
The HTTP authorization header value is split by the first space.
The first part is the `scheme`, the second part is the `credentials`.
For example, in an HTTP Bearer token scheme, the client will send a header
like:
```
Authorization: Bearer deadbeef12346
```
In this case:
* `scheme` will have the value `"Bearer"`
* `credentials` will have the value `"deadbeef12346"`
"""
scheme: Annotated[
str,
Doc(
"""
The HTTP authorization scheme extracted from the header value.
"""
),
]
credentials: Annotated[
str,
Doc(
"""
The HTTP authorization credentials extracted from the header value.
"""
),
]
class HTTPBase(SecurityBase):
model: HTTPBaseModel
def __init__(
self,
*,
scheme: str,
scheme_name: str | None = None,
description: str | None = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme=scheme, description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_authenticate_headers(self) -> dict[str, str]:
return {"WWW-Authenticate": f"{self.model.scheme.title()}"}
def make_not_authenticated_error(self) -> HTTPException:
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers=self.make_authenticate_headers(),
)
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPBasic(HTTPBase):
"""
HTTP Basic authentication.
Ref: https://datatracker.ietf.org/doc/html/rfc7617
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
The dependency result will be an `HTTPBasicCredentials` object containing the
`username` and the `password`.
Read more about it in the
[FastAPI docs for HTTP Basic Auth](https://fastapi.tiangolo.com/advanced/security/http-basic-auth/).
## Example
```python
from typing import Annotated
from fastapi import Depends, FastAPI
from fastapi.security import HTTPBasic, HTTPBasicCredentials
app = FastAPI()
security = HTTPBasic()
@app.get("/users/me")
def read_current_user(credentials: Annotated[HTTPBasicCredentials, Depends(security)]):
return {"username": credentials.username, "password": credentials.password}
```
"""
def __init__(
self,
*,
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
realm: Annotated[
str | None,
Doc(
"""
HTTP Basic authentication realm.
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if the HTTP Basic authentication is not provided (a
header), `HTTPBasic` will automatically cancel the request and send the
client an error.
If `auto_error` is set to `False`, when the HTTP Basic authentication
is not available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, in HTTP Basic
authentication or in an HTTP Bearer token).
"""
),
] = True,
):
self.model = HTTPBaseModel(scheme="basic", description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.realm = realm
self.auto_error = auto_error
def make_authenticate_headers(self) -> dict[str, str]:
if self.realm:
return {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore
self, request: Request
) -> HTTPBasicCredentials | None:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "basic":
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
try:
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error) as e:
raise self.make_not_authenticated_error() from e
username, separator, password = data.partition(":")
if not separator:
raise self.make_not_authenticated_error()
return HTTPBasicCredentials(username=username, password=password)
class HTTPBearer(HTTPBase):
"""
HTTP Bearer token authentication.
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
The dependency result will be an `HTTPAuthorizationCredentials` object containing
the `scheme` and the `credentials`.
## Example
```python
from typing import Annotated
from fastapi import Depends, FastAPI
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
app = FastAPI()
security = HTTPBearer()
@app.get("/users/me")
def read_current_user(
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]
):
return {"scheme": credentials.scheme, "credentials": credentials.credentials}
```
"""
def __init__(
self,
*,
bearerFormat: Annotated[str | None, Doc("Bearer token format.")] = None,
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if the HTTP Bearer token is not provided (in an
`Authorization` header), `HTTPBearer` will automatically cancel the
request and send the client an error.
If `auto_error` is set to `False`, when the HTTP Bearer token
is not available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, in an HTTP
Bearer token or in a cookie).
"""
),
] = True,
):
self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
if scheme.lower() != "bearer":
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
class HTTPDigest(HTTPBase):
"""
HTTP Digest authentication.
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
but it doesn't implement the full Digest scheme, you would need to subclass it
and implement it in your code.
Ref: https://datatracker.ietf.org/doc/html/rfc7616
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
The dependency result will be an `HTTPAuthorizationCredentials` object containing
the `scheme` and the `credentials`.
## Example
```python
from typing import Annotated
from fastapi import Depends, FastAPI
from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest
app = FastAPI()
security = HTTPDigest()
@app.get("/users/me")
def read_current_user(
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]
):
return {"scheme": credentials.scheme, "credentials": credentials.credentials}
```
"""
def __init__(
self,
*,
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if the HTTP Digest is not provided, `HTTPDigest` will
automatically cancel the request and send the client an error.
If `auto_error` is set to `False`, when the HTTP Digest is not
available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, in HTTP
Digest or in a cookie).
"""
),
] = True,
):
self.model = HTTPBaseModel(scheme="digest", description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
if scheme.lower() != "digest":
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

View file

@ -0,0 +1,693 @@
from typing import Annotated, Any, cast
from annotated_doc import Doc
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED
class OAuth2PasswordRequestForm:
"""
This is a dependency class to collect the `username` and `password` as form data
for an OAuth2 password flow.
The OAuth2 specification dictates that for a password flow the data should be
collected using form data (instead of JSON) and that it should have the specific
fields `username` and `password`.
All the initialization parameters are extracted from the request.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
## Example
```python
from typing import Annotated
from fastapi import Depends, FastAPI
from fastapi.security import OAuth2PasswordRequestForm
app = FastAPI()
@app.post("/login")
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
data = {}
data["scopes"] = []
for scope in form_data.scopes:
data["scopes"].append(scope)
if form_data.client_id:
data["client_id"] = form_data.client_id
if form_data.client_secret:
data["client_secret"] = form_data.client_secret
return data
```
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
You could have custom internal logic to separate it by colon characters (`:`) or
similar, and get the two parts `items` and `read`. Many applications do that to
group and organize permissions, you could do it as well in your application, just
know that it is application specific, it's not part of the specification.
"""
def __init__(
self,
*,
grant_type: Annotated[
str | None,
Form(pattern="^password$"),
Doc(
"""
The OAuth2 spec says it is required and MUST be the fixed string
"password". Nevertheless, this dependency class is permissive and
allows not passing it. If you want to enforce it, use instead the
`OAuth2PasswordRequestFormStrict` dependency.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
] = None,
username: Annotated[
str,
Form(),
Doc(
"""
`username` string. The OAuth2 spec requires the exact field name
`username`.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
],
password: Annotated[
str,
Form(json_schema_extra={"format": "password"}),
Doc(
"""
`password` string. The OAuth2 spec requires the exact field name
`password`.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
],
scope: Annotated[
str,
Form(),
Doc(
"""
A single string with actually several scopes separated by spaces. Each
scope is also a string.
For example, a single string with:
```python
"items:read items:write users:read profile openid"
````
would represent the scopes:
* `items:read`
* `items:write`
* `users:read`
* `profile`
* `openid`
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
] = "",
client_id: Annotated[
str | None,
Form(),
Doc(
"""
If there's a `client_id`, it can be sent as part of the form fields.
But the OAuth2 specification recommends sending the `client_id` and
`client_secret` (if any) using HTTP Basic auth.
"""
),
] = None,
client_secret: Annotated[
str | None,
Form(json_schema_extra={"format": "password"}),
Doc(
"""
If there's a `client_password` (and a `client_id`), they can be sent
as part of the form fields. But the OAuth2 specification recommends
sending the `client_id` and `client_secret` (if any) using HTTP Basic
auth.
"""
),
] = None,
):
self.grant_type = grant_type
self.username = username
self.password = password
self.scopes = scope.split()
self.client_id = client_id
self.client_secret = client_secret
class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
"""
This is a dependency class to collect the `username` and `password` as form data
for an OAuth2 password flow.
The OAuth2 specification dictates that for a password flow the data should be
collected using form data (instead of JSON) and that it should have the specific
fields `username` and `password`.
All the initialization parameters are extracted from the request.
The only difference between `OAuth2PasswordRequestFormStrict` and
`OAuth2PasswordRequestForm` is that `OAuth2PasswordRequestFormStrict` requires the
client to send the form field `grant_type` with the value `"password"`, which
is required in the OAuth2 specification (it seems that for no particular reason),
while for `OAuth2PasswordRequestForm` `grant_type` is optional.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
## Example
```python
from typing import Annotated
from fastapi import Depends, FastAPI
from fastapi.security import OAuth2PasswordRequestForm
app = FastAPI()
@app.post("/login")
def login(form_data: Annotated[OAuth2PasswordRequestFormStrict, Depends()]):
data = {}
data["scopes"] = []
for scope in form_data.scopes:
data["scopes"].append(scope)
if form_data.client_id:
data["client_id"] = form_data.client_id
if form_data.client_secret:
data["client_secret"] = form_data.client_secret
return data
```
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
You could have custom internal logic to separate it by colon characters (`:`) or
similar, and get the two parts `items` and `read`. Many applications do that to
group and organize permissions, you could do it as well in your application, just
know that it is application specific, it's not part of the specification.
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
This dependency is strict about it. If you want to be permissive, use instead the
OAuth2PasswordRequestForm dependency class.
username: username string. The OAuth2 spec requires the exact field name "username".
password: password string. The OAuth2 spec requires the exact field name "password".
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
"items:read items:write users:read profile openid"
client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
using HTTP Basic auth, as: client_id:client_secret
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
using HTTP Basic auth, as: client_id:client_secret
"""
def __init__(
self,
grant_type: Annotated[
str,
Form(pattern="^password$"),
Doc(
"""
The OAuth2 spec says it is required and MUST be the fixed string
"password". This dependency is strict about it. If you want to be
permissive, use instead the `OAuth2PasswordRequestForm` dependency
class.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
],
username: Annotated[
str,
Form(),
Doc(
"""
`username` string. The OAuth2 spec requires the exact field name
`username`.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
],
password: Annotated[
str,
Form(),
Doc(
"""
`password` string. The OAuth2 spec requires the exact field name
`password`.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
],
scope: Annotated[
str,
Form(),
Doc(
"""
A single string with actually several scopes separated by spaces. Each
scope is also a string.
For example, a single string with:
```python
"items:read items:write users:read profile openid"
````
would represent the scopes:
* `items:read`
* `items:write`
* `users:read`
* `profile`
* `openid`
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
] = "",
client_id: Annotated[
str | None,
Form(),
Doc(
"""
If there's a `client_id`, it can be sent as part of the form fields.
But the OAuth2 specification recommends sending the `client_id` and
`client_secret` (if any) using HTTP Basic auth.
"""
),
] = None,
client_secret: Annotated[
str | None,
Form(),
Doc(
"""
If there's a `client_password` (and a `client_id`), they can be sent
as part of the form fields. But the OAuth2 specification recommends
sending the `client_id` and `client_secret` (if any) using HTTP Basic
auth.
"""
),
] = None,
):
super().__init__(
grant_type=grant_type,
username=username,
password=password,
scope=scope,
client_id=client_id,
client_secret=client_secret,
)
class OAuth2(SecurityBase):
"""
This is the base class for OAuth2 authentication, an instance of it would be used
as a dependency. All other OAuth2 classes inherit from it and customize it for
each OAuth2 flow.
You normally would not create a new class inheriting from it but use one of the
existing subclasses, and maybe compose them if you want to support multiple flows.
Read more about it in the
[FastAPI docs for Security](https://fastapi.tiangolo.com/tutorial/security/).
"""
def __init__(
self,
*,
flows: Annotated[
OAuthFlowsModel | dict[str, dict[str, Any]],
Doc(
"""
The dictionary of OAuth2 flows.
"""
),
] = OAuthFlowsModel(),
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
If `auto_error` is set to `False`, when the HTTP Authorization header
is not available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, with OAuth2
or in a cookie).
"""
),
] = True,
):
self.model = OAuth2Model(
flows=cast(OAuthFlowsModel, flows), description=description
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_not_authenticated_error(self) -> HTTPException:
"""
The OAuth 2 specification doesn't define the challenge that should be used,
because a `Bearer` token is not really the only option to authenticate.
But declaring any other authentication challenge would be application-specific
as it's not defined in the specification.
For practical reasons, this method uses the `Bearer` challenge by default, as
it's probably the most common one.
If you are implementing an OAuth2 authentication scheme other than the provided
ones in FastAPI (based on bearer tokens), you might want to override this.
Ref: https://datatracker.ietf.org/doc/html/rfc6749
"""
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
async def __call__(self, request: Request) -> str | None:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
return authorization
class OAuth2PasswordBearer(OAuth2):
"""
OAuth2 flow for authentication using a bearer token obtained with a password.
An instance of it would be used as a dependency.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
def __init__(
self,
tokenUrl: Annotated[
str,
Doc(
"""
The URL to obtain the OAuth2 token. This would be the *path operation*
that has `OAuth2PasswordRequestForm` as a dependency.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
],
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
scopes: Annotated[
dict[str, str] | None,
Doc(
"""
The OAuth2 scopes that would be required by the *path operations* that
use this dependency.
Read more about it in the
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
If `auto_error` is set to `False`, when the HTTP Authorization header
is not available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, with OAuth2
or in a cookie).
"""
),
] = True,
refreshUrl: Annotated[
str | None,
Doc(
"""
The URL to refresh the token and obtain a new one.
"""
),
] = None,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(
password=cast(
Any,
{
"tokenUrl": tokenUrl,
"refreshUrl": refreshUrl,
"scopes": scopes,
},
)
)
super().__init__(
flows=flows,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
async def __call__(self, request: Request) -> str | None:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
return param
class OAuth2AuthorizationCodeBearer(OAuth2):
"""
OAuth2 flow for authentication using a bearer token obtained with an OAuth2 code
flow. An instance of it would be used as a dependency.
"""
def __init__(
self,
authorizationUrl: str,
tokenUrl: Annotated[
str,
Doc(
"""
The URL to obtain the OAuth2 token.
"""
),
],
refreshUrl: Annotated[
str | None,
Doc(
"""
The URL to refresh the token and obtain a new one.
"""
),
] = None,
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
scopes: Annotated[
dict[str, str] | None,
Doc(
"""
The OAuth2 scopes that would be required by the *path operations* that
use this dependency.
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
If `auto_error` is set to `False`, when the HTTP Authorization header
is not available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, with OAuth2
or in a cookie).
"""
),
] = True,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(
authorizationCode=cast(
Any,
{
"authorizationUrl": authorizationUrl,
"tokenUrl": tokenUrl,
"refreshUrl": refreshUrl,
"scopes": scopes,
},
)
)
super().__init__(
flows=flows,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
async def __call__(self, request: Request) -> str | None:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None # pragma: nocover
return param
class SecurityScopes:
"""
This is a special class that you can define in a parameter in a dependency to
obtain the OAuth2 scopes required by all the dependencies in the same chain.
This way, multiple dependencies can have different scopes, even when used in the
same *path operation*. And with this, you can access all the scopes required in
all those dependencies in a single place.
Read more about it in the
[FastAPI docs for OAuth2 scopes](https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/).
"""
def __init__(
self,
scopes: Annotated[
list[str] | None,
Doc(
"""
This will be filled by FastAPI.
"""
),
] = None,
):
self.scopes: Annotated[
list[str],
Doc(
"""
The list of all the scopes required by dependencies.
"""
),
] = scopes or []
self.scope_str: Annotated[
str,
Doc(
"""
All the scopes required by all the dependencies in a single string
separated by spaces, as defined in the OAuth2 specification.
"""
),
] = " ".join(self.scopes)

View file

@ -0,0 +1,94 @@
from typing import Annotated
from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED
class OpenIdConnect(SecurityBase):
"""
OpenID Connect authentication class. An instance of it would be used as a
dependency.
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
but it doesn't implement the full OpenIdConnect scheme, for example, it doesn't use
the OpenIDConnect URL. You would need to subclass it and implement it in your
code.
"""
def __init__(
self,
*,
openIdConnectUrl: Annotated[
str,
Doc(
"""
The OpenID Connect URL.
"""
),
],
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if no HTTP Authorization header is provided, required for
OpenID Connect authentication, it will automatically cancel the request
and send the client an error.
If `auto_error` is set to `False`, when the HTTP Authorization header
is not available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, with OpenID
Connect or in a cookie).
"""
),
] = True,
):
self.model = OpenIdConnectModel(
openIdConnectUrl=openIdConnectUrl, description=description
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_not_authenticated_error(self) -> HTTPException:
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
async def __call__(self, request: Request) -> str | None:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
return authorization

View file

@ -0,0 +1,7 @@
def get_authorization_scheme_param(
authorization_header_value: str | None,
) -> tuple[str, str]:
if not authorization_header_value:
return "", ""
scheme, _, param = authorization_header_value.partition(" ")
return scheme, param.strip()

View file

@ -0,0 +1,222 @@
from typing import Annotated, Any
from annotated_doc import Doc
from pydantic import AfterValidator, BaseModel, Field, model_validator
from starlette.responses import StreamingResponse
# Canonical SSE event schema matching the OpenAPI 3.2 spec
# (Section 4.14.4 "Special Considerations for Server-Sent Events")
_SSE_EVENT_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"data": {"type": "string"},
"event": {"type": "string"},
"id": {"type": "string"},
"retry": {"type": "integer", "minimum": 0},
},
}
class EventSourceResponse(StreamingResponse):
"""Streaming response with `text/event-stream` media type.
Use as `response_class=EventSourceResponse` on a *path operation* that uses `yield`
to enable Server Sent Events (SSE) responses.
Works with **any HTTP method** (`GET`, `POST`, etc.), which makes it compatible
with protocols like MCP that stream SSE over `POST`.
The actual encoding logic lives in the FastAPI routing layer. This class
serves mainly as a marker and sets the correct `Content-Type`.
"""
media_type = "text/event-stream"
def _check_id_no_null(v: str | None) -> str | None:
if v is not None and "\0" in v:
raise ValueError("SSE 'id' must not contain null characters")
return v
class ServerSentEvent(BaseModel):
"""Represents a single Server-Sent Event.
When `yield`ed from a *path operation function* that uses
`response_class=EventSourceResponse`, each `ServerSentEvent` is encoded
into the [SSE wire format](https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream)
(`text/event-stream`).
If you yield a plain object (dict, Pydantic model, etc.) instead, it is
automatically JSON-encoded and sent as the `data:` field.
All `data` values **including plain strings** are JSON-serialized.
For example, `data="hello"` produces `data: "hello"` on the wire (with
quotes).
"""
data: Annotated[
Any,
Doc(
"""
The event payload.
Can be any JSON-serializable value: a Pydantic model, dict, list,
string, number, etc. It is **always** serialized to JSON: strings
are quoted (`"hello"` becomes `data: "hello"` on the wire).
Mutually exclusive with `raw_data`.
"""
),
] = None
raw_data: Annotated[
str | None,
Doc(
"""
Raw string to send as the `data:` field **without** JSON encoding.
Use this when you need to send pre-formatted text, HTML fragments,
CSV lines, or any non-JSON payload. The string is placed directly
into the `data:` field as-is.
Mutually exclusive with `data`.
"""
),
] = None
event: Annotated[
str | None,
Doc(
"""
Optional event type name.
Maps to `addEventListener(event, ...)` on the browser. When omitted,
the browser dispatches on the generic `message` event.
"""
),
] = None
id: Annotated[
str | None,
AfterValidator(_check_id_no_null),
Doc(
"""
Optional event ID.
The browser sends this value back as the `Last-Event-ID` header on
automatic reconnection. **Must not contain null (`\\0`) characters.**
"""
),
] = None
retry: Annotated[
int | None,
Field(ge=0),
Doc(
"""
Optional reconnection time in **milliseconds**.
Tells the browser how long to wait before reconnecting after the
connection is lost. Must be a non-negative integer.
"""
),
] = None
comment: Annotated[
str | None,
Doc(
"""
Optional comment line(s).
Comment lines start with `:` in the SSE wire format and are ignored by
`EventSource` clients. Useful for keep-alive pings to prevent
proxy/load-balancer timeouts.
"""
),
] = None
@model_validator(mode="after")
def _check_data_exclusive(self) -> "ServerSentEvent":
if self.data is not None and self.raw_data is not None:
raise ValueError(
"Cannot set both 'data' and 'raw_data' on the same "
"ServerSentEvent. Use 'data' for JSON-serialized payloads "
"or 'raw_data' for pre-formatted strings."
)
return self
def format_sse_event(
*,
data_str: Annotated[
str | None,
Doc(
"""
Pre-serialized data string to use as the `data:` field.
"""
),
] = None,
event: Annotated[
str | None,
Doc(
"""
Optional event type name (`event:` field).
"""
),
] = None,
id: Annotated[
str | None,
Doc(
"""
Optional event ID (`id:` field).
"""
),
] = None,
retry: Annotated[
int | None,
Doc(
"""
Optional reconnection time in milliseconds (`retry:` field).
"""
),
] = None,
comment: Annotated[
str | None,
Doc(
"""
Optional comment line(s) (`:` prefix).
"""
),
] = None,
) -> bytes:
"""Build SSE wire-format bytes from **pre-serialized** data.
The result always ends with `\n\n` (the event terminator).
"""
lines: list[str] = []
if comment is not None:
for line in comment.splitlines():
lines.append(f": {line}")
if event is not None:
lines.append(f"event: {event}")
if data_str is not None:
for line in data_str.splitlines():
lines.append(f"data: {line}")
if id is not None:
lines.append(f"id: {id}")
if retry is not None:
lines.append(f"retry: {retry}")
lines.append("")
lines.append("")
return "\n".join(lines).encode("utf-8")
# Keep-alive comment, per the SSE spec recommendation
KEEPALIVE_COMMENT = b": ping\n\n"
# Seconds between keep-alive pings when a generator is idle.
# Private but importable so tests can monkeypatch it.
_PING_INTERVAL: float = 15.0

View file

@ -0,0 +1 @@
from starlette.staticfiles import StaticFiles as StaticFiles # noqa

View file

@ -0,0 +1 @@
from starlette.templating import Jinja2Templates as Jinja2Templates # noqa

View file

@ -0,0 +1 @@
from starlette.testclient import TestClient as TestClient # noqa

View file

@ -0,0 +1,12 @@
import types
from collections.abc import Callable
from enum import Enum
from typing import Any, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import IncEx as IncEx
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
UnionType = getattr(types, "UnionType", Union)
ModelNameMap = dict[type[BaseModel] | type[Enum], str]
DependencyCacheKey = tuple[Callable[..., Any] | None, tuple[str, ...], str]

View file

@ -0,0 +1,136 @@
import re
import warnings
from typing import (
TYPE_CHECKING,
Any,
Literal,
)
import fastapi
from fastapi._compat import (
ModelField,
PydanticSchemaGenerationError,
Undefined,
annotation_is_pydantic_v1,
)
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from fastapi.exceptions import FastAPIDeprecationWarning, PydanticV1NotSupportedError
from pydantic.fields import FieldInfo
from ._compat import v2
if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute
def is_body_allowed_for_status_code(status_code: int | str | None) -> bool:
if status_code is None:
return True
# Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1
if status_code in {
"default",
"1XX",
"2XX",
"3XX",
"4XX",
"5XX",
}:
return True
current_status_code = int(status_code)
return not (current_status_code < 200 or current_status_code in {204, 205, 304})
def get_path_param_names(path: str) -> set[str]:
return set(re.findall("{(.*?)}", path))
_invalid_args_message = (
"Invalid args for response field! Hint: "
"check that {type_} is a valid Pydantic field type. "
"If you are using a return type annotation that is not a valid Pydantic "
"field (e.g. Union[Response, dict, None]) you can disable generating the "
"response model from the type annotation with the path operation decorator "
"parameter response_model=None. Read more: "
"https://fastapi.tiangolo.com/tutorial/response-model/"
)
def create_model_field(
name: str,
type_: Any,
default: Any | None = Undefined,
field_info: FieldInfo | None = None,
alias: str | None = None,
mode: Literal["validation", "serialization"] = "validation",
) -> ModelField:
if annotation_is_pydantic_v1(type_):
raise PydanticV1NotSupportedError(
"pydantic.v1 models are no longer supported by FastAPI."
f" Please update the response model {type_!r}."
)
field_info = field_info or FieldInfo(annotation=type_, default=default, alias=alias)
try:
return v2.ModelField(mode=mode, name=name, field_info=field_info)
except PydanticSchemaGenerationError:
raise fastapi.exceptions.FastAPIError(
_invalid_args_message.format(type_=type_)
) from None
def generate_operation_id_for_path(
*, name: str, path: str, method: str
) -> str: # pragma: nocover
warnings.warn(
message="fastapi.utils.generate_operation_id_for_path() was deprecated, "
"it is not used internally, and will be removed soon",
category=FastAPIDeprecationWarning,
stacklevel=2,
)
operation_id = f"{name}{path}"
operation_id = re.sub(r"\W", "_", operation_id)
operation_id = f"{operation_id}_{method.lower()}"
return operation_id
def generate_unique_id(route: "APIRoute") -> str:
operation_id = f"{route.name}{route.path_format}"
operation_id = re.sub(r"\W", "_", operation_id)
assert route.methods
operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
return operation_id
def deep_dict_update(main_dict: dict[Any, Any], update_dict: dict[Any, Any]) -> None:
for key, value in update_dict.items():
if (
key in main_dict
and isinstance(main_dict[key], dict)
and isinstance(value, dict)
):
deep_dict_update(main_dict[key], value)
elif (
key in main_dict
and isinstance(main_dict[key], list)
and isinstance(update_dict[key], list)
):
main_dict[key] = main_dict[key] + update_dict[key]
else:
main_dict[key] = value
def get_value_or_default(
first_item: DefaultPlaceholder | DefaultType,
*extra_items: DefaultPlaceholder | DefaultType,
) -> DefaultPlaceholder | DefaultType:
"""
Pass items or `DefaultPlaceholder`s by descending priority.
The first one to _not_ be a `DefaultPlaceholder` will be returned.
Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
"""
items = (first_item,) + extra_items
for item in items:
if not isinstance(item, DefaultPlaceholder):
return item
return first_item

Some files were not shown because too many files have changed in this diff Show more