Make SubFox production-ready with parallel translation and UI controls

This commit is contained in:
Eddie Nielsen 2026-03-25 11:24:54 +00:00
parent c40b8bed2b
commit 2b1d05f02c
6046 changed files with 798327 additions and 0 deletions

View file

@ -0,0 +1,396 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import os as _os
import typing as _t
from typing_extensions import override
from . import types
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given
from ._utils import file_from_path
from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
from ._models import BaseModel
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
from ._exceptions import (
APIError,
OpenAIError,
ConflictError,
NotFoundError,
APIStatusError,
RateLimitError,
APITimeoutError,
BadRequestError,
APIConnectionError,
AuthenticationError,
InternalServerError,
PermissionDeniedError,
LengthFinishReasonError,
UnprocessableEntityError,
APIResponseValidationError,
InvalidWebhookSignatureError,
ContentFilterFinishReasonError,
)
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging
from ._legacy_response import HttpxBinaryResponseContent as HttpxBinaryResponseContent
__all__ = [
"types",
"__version__",
"__title__",
"NoneType",
"Transport",
"ProxiesTypes",
"NotGiven",
"NOT_GIVEN",
"not_given",
"Omit",
"omit",
"OpenAIError",
"APIError",
"APIStatusError",
"APITimeoutError",
"APIConnectionError",
"APIResponseValidationError",
"BadRequestError",
"AuthenticationError",
"PermissionDeniedError",
"NotFoundError",
"ConflictError",
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
"LengthFinishReasonError",
"ContentFilterFinishReasonError",
"InvalidWebhookSignatureError",
"Timeout",
"RequestOptions",
"Client",
"AsyncClient",
"Stream",
"AsyncStream",
"OpenAI",
"AsyncOpenAI",
"file_from_path",
"BaseModel",
"DEFAULT_TIMEOUT",
"DEFAULT_MAX_RETRIES",
"DEFAULT_CONNECTION_LIMITS",
"DefaultHttpxClient",
"DefaultAsyncHttpxClient",
"DefaultAioHttpClient",
]
if not _t.TYPE_CHECKING:
from ._utils._resources_proxy import resources as resources
from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool
from .version import VERSION as VERSION
from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI
from .lib._old_api import *
from .lib.streaming import (
AssistantEventHandler as AssistantEventHandler,
AsyncAssistantEventHandler as AsyncAssistantEventHandler,
)
_setup_logging()
# Update the __module__ attribute for exported symbols so that
# error messages point to this module instead of the module
# it was originally defined in, e.g.
# openai._exceptions.NotFoundError -> openai.NotFoundError
__locals = locals()
for __name in __all__:
if not __name.startswith("__"):
try:
__locals[__name].__module__ = "openai"
except (TypeError, AttributeError):
# Some of our exported symbols are builtins which we can't set attributes for.
pass
# ------ Module level client ------
import typing as _t
import typing_extensions as _te
import httpx as _httpx
from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
api_key: str | None = None
organization: str | None = None
project: str | None = None
webhook_secret: str | None = None
base_url: str | _httpx.URL | None = None
timeout: float | Timeout | None = DEFAULT_TIMEOUT
max_retries: int = DEFAULT_MAX_RETRIES
default_headers: _t.Mapping[str, str] | None = None
default_query: _t.Mapping[str, object] | None = None
http_client: _httpx.Client | None = None
_ApiType = _te.Literal["openai", "azure"]
api_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get("OPENAI_API_TYPE"))
api_version: str | None = _os.environ.get("OPENAI_API_VERSION")
azure_endpoint: str | None = _os.environ.get("AZURE_OPENAI_ENDPOINT")
azure_ad_token: str | None = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
azure_ad_token_provider: _azure.AzureADTokenProvider | None = None
class _ModuleClient(OpenAI):
# Note: we have to use type: ignores here as overriding class members
# with properties is technically unsafe but it is fine for our use case
@property # type: ignore
@override
def api_key(self) -> str | None:
return api_key
@api_key.setter # type: ignore
def api_key(self, value: str | None) -> None: # type: ignore
global api_key
api_key = value
@property # type: ignore
@override
def organization(self) -> str | None:
return organization
@organization.setter # type: ignore
def organization(self, value: str | None) -> None: # type: ignore
global organization
organization = value
@property # type: ignore
@override
def project(self) -> str | None:
return project
@project.setter # type: ignore
def project(self, value: str | None) -> None: # type: ignore
global project
project = value
@property # type: ignore
@override
def webhook_secret(self) -> str | None:
return webhook_secret
@webhook_secret.setter # type: ignore
def webhook_secret(self, value: str | None) -> None: # type: ignore
global webhook_secret
webhook_secret = value
@property
@override
def base_url(self) -> _httpx.URL:
if base_url is not None:
return _httpx.URL(base_url)
return super().base_url
@base_url.setter
def base_url(self, url: _httpx.URL | str) -> None:
super().base_url = url # type: ignore[misc]
@property # type: ignore
@override
def timeout(self) -> float | Timeout | None:
return timeout
@timeout.setter # type: ignore
def timeout(self, value: float | Timeout | None) -> None: # type: ignore
global timeout
timeout = value
@property # type: ignore
@override
def max_retries(self) -> int:
return max_retries
@max_retries.setter # type: ignore
def max_retries(self, value: int) -> None: # type: ignore
global max_retries
max_retries = value
@property # type: ignore
@override
def _custom_headers(self) -> _t.Mapping[str, str] | None:
return default_headers
@_custom_headers.setter # type: ignore
def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore
global default_headers
default_headers = value
@property # type: ignore
@override
def _custom_query(self) -> _t.Mapping[str, object] | None:
return default_query
@_custom_query.setter # type: ignore
def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore
global default_query
default_query = value
@property # type: ignore
@override
def _client(self) -> _httpx.Client:
return http_client or super()._client
@_client.setter # type: ignore
def _client(self, value: _httpx.Client) -> None: # type: ignore
global http_client
http_client = value
class _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore
...
class _AmbiguousModuleClientUsageError(OpenAIError):
def __init__(self) -> None:
super().__init__(
"Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`"
)
def _has_openai_credentials() -> bool:
return _os.environ.get("OPENAI_API_KEY") is not None
def _has_azure_credentials() -> bool:
return azure_endpoint is not None or _os.environ.get("AZURE_OPENAI_API_KEY") is not None
def _has_azure_ad_credentials() -> bool:
return (
_os.environ.get("AZURE_OPENAI_AD_TOKEN") is not None
or azure_ad_token is not None
or azure_ad_token_provider is not None
)
_client: OpenAI | None = None
def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
global _client
if _client is None:
global api_type, azure_endpoint, azure_ad_token, api_version
if azure_endpoint is None:
azure_endpoint = _os.environ.get("AZURE_OPENAI_ENDPOINT")
if azure_ad_token is None:
azure_ad_token = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
if api_version is None:
api_version = _os.environ.get("OPENAI_API_VERSION")
if api_type is None:
has_openai = _has_openai_credentials()
has_azure = _has_azure_credentials()
has_azure_ad = _has_azure_ad_credentials()
if has_openai and (has_azure or has_azure_ad):
raise _AmbiguousModuleClientUsageError()
if (azure_ad_token is not None or azure_ad_token_provider is not None) and _os.environ.get(
"AZURE_OPENAI_API_KEY"
) is not None:
raise _AmbiguousModuleClientUsageError()
if has_azure or has_azure_ad:
api_type = "azure"
else:
api_type = "openai"
if api_type == "azure":
_client = _AzureModuleClient( # type: ignore
api_version=api_version,
azure_endpoint=azure_endpoint,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=http_client,
)
return _client
_client = _ModuleClient(
api_key=api_key,
organization=organization,
project=project,
webhook_secret=webhook_secret,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=http_client,
)
return _client
return _client
def _reset_client() -> None: # type: ignore[reportUnusedFunction]
global _client
_client = None
from ._module_client import (
beta as beta,
chat as chat,
audio as audio,
evals as evals,
files as files,
images as images,
models as models,
skills as skills,
videos as videos,
batches as batches,
uploads as uploads,
realtime as realtime,
webhooks as webhooks,
responses as responses,
containers as containers,
embeddings as embeddings,
completions as completions,
fine_tuning as fine_tuning,
moderations as moderations,
conversations as conversations,
vector_stores as vector_stores,
)

View file

@ -0,0 +1,3 @@
from .cli import main
main()

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,238 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
from typing_extensions import Self, Literal, TypedDict
import pydantic
from pydantic.fields import FieldInfo
from ._types import IncEx, StrBytesIntFloat
_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
# --------------- Pydantic v2, v3 compatibility ---------------
# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false
PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
if TYPE_CHECKING:
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
...
def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
...
def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
...
def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
...
def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
...
def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
...
else:
# v1 re-exports
if PYDANTIC_V1:
from pydantic.typing import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
else:
from ._utils import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
parse_date as parse_date,
is_typeddict as is_typeddict,
parse_datetime as parse_datetime,
is_literal_type as is_literal_type,
)
# refactored config
if TYPE_CHECKING:
from pydantic import ConfigDict as ConfigDict
else:
if PYDANTIC_V1:
# TODO: provide an error message here?
ConfigDict = None
else:
from pydantic import ConfigDict as ConfigDict
# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
if PYDANTIC_V1:
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
else:
return model.model_validate(value)
def field_is_required(field: FieldInfo) -> bool:
if PYDANTIC_V1:
return field.required # type: ignore
return field.is_required()
def field_get_default(field: FieldInfo) -> Any:
value = field.get_default()
if PYDANTIC_V1:
return value
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
def field_outer_type(field: FieldInfo) -> Any:
if PYDANTIC_V1:
return field.outer_type_ # type: ignore
return field.annotation
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
if PYDANTIC_V1:
return model.__config__ # type: ignore
return model.model_config
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V1:
return model.__fields__ # type: ignore
return model.model_fields
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
if PYDANTIC_V1:
return model.copy(deep=deep) # type: ignore
return model.model_copy(deep=deep)
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
if PYDANTIC_V1:
return model.json(indent=indent) # type: ignore
return model.model_dump_json(indent=indent)
class _ModelDumpKwargs(TypedDict, total=False):
by_alias: bool
def model_dump(
model: pydantic.BaseModel,
*,
exclude: IncEx | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
warnings: bool = True,
mode: Literal["json", "python"] = "python",
by_alias: bool | None = None,
) -> dict[str, Any]:
if (not PYDANTIC_V1) or hasattr(model, "model_dump"):
kwargs: _ModelDumpKwargs = {}
if by_alias is not None:
kwargs["by_alias"] = by_alias
return model.model_dump(
mode=mode,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
# warnings are not supported in Pydantic v1
warnings=True if PYDANTIC_V1 else warnings,
**kwargs,
)
return cast(
"dict[str, Any]",
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, by_alias=bool(by_alias)
),
)
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
if PYDANTIC_V1:
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
return model.model_validate(data)
def model_parse_json(model: type[_ModelT], data: str | bytes) -> _ModelT:
if PYDANTIC_V1:
return model.parse_raw(data) # pyright: ignore[reportDeprecated]
return model.model_validate_json(data)
def model_json_schema(model: type[_ModelT]) -> dict[str, Any]:
if PYDANTIC_V1:
return model.schema() # pyright: ignore[reportDeprecated]
return model.model_json_schema()
# generic models
if TYPE_CHECKING:
class GenericModel(pydantic.BaseModel): ...
else:
if PYDANTIC_V1:
import pydantic.generics
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
else:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel): ...
# cached properties
if TYPE_CHECKING:
cached_property = property
# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.
class typed_cached_property(Generic[_T]):
func: Callable[[Any], _T]
attrname: str | None
def __init__(self, func: Callable[[Any], _T]) -> None: ...
@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None: ...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
from functools import cached_property as cached_property
typed_cached_property = cached_property

View file

@ -0,0 +1,14 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
import httpx
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to"
# default timeout is 10 minutes
DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0)
DEFAULT_MAX_RETRIES = 2
DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=1000, max_keepalive_connections=100)
INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 8.0

View file

@ -0,0 +1,161 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, cast
from typing_extensions import Literal
import httpx
from ._utils import is_dict
from ._models import construct_type
if TYPE_CHECKING:
from .types.chat import ChatCompletion
__all__ = [
"BadRequestError",
"AuthenticationError",
"PermissionDeniedError",
"NotFoundError",
"ConflictError",
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
"LengthFinishReasonError",
"ContentFilterFinishReasonError",
"InvalidWebhookSignatureError",
]
class OpenAIError(Exception):
pass
class APIError(OpenAIError):
message: str
request: httpx.Request
body: object | None
"""The API response body.
If the API responded with a valid JSON structure then this property will be the
decoded result.
If it isn't a valid JSON structure then this will be the raw response.
If there was no response associated with this error then it will be `None`.
"""
code: Optional[str] = None
param: Optional[str] = None
type: Optional[str]
def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None:
super().__init__(message)
self.request = request
self.message = message
self.body = body
if is_dict(body):
self.code = cast(Any, construct_type(type_=Optional[str], value=body.get("code")))
self.param = cast(Any, construct_type(type_=Optional[str], value=body.get("param")))
self.type = cast(Any, construct_type(type_=str, value=body.get("type")))
else:
self.code = None
self.param = None
self.type = None
class APIResponseValidationError(APIError):
response: httpx.Response
status_code: int
def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None:
super().__init__(message or "Data returned by API invalid for expected schema.", response.request, body=body)
self.response = response
self.status_code = response.status_code
class APIStatusError(APIError):
"""Raised when an API response has a status code of 4xx or 5xx."""
response: httpx.Response
status_code: int
request_id: str | None
def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
super().__init__(message, response.request, body=body)
self.response = response
self.status_code = response.status_code
self.request_id = response.headers.get("x-request-id")
class APIConnectionError(APIError):
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
super().__init__(message, request, body=None)
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)
class BadRequestError(APIStatusError):
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
class AuthenticationError(APIStatusError):
status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
class PermissionDeniedError(APIStatusError):
status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
class NotFoundError(APIStatusError):
status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
class ConflictError(APIStatusError):
status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
class UnprocessableEntityError(APIStatusError):
status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
class RateLimitError(APIStatusError):
status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
class InternalServerError(APIStatusError):
pass
class LengthFinishReasonError(OpenAIError):
completion: ChatCompletion
"""The completion that caused this error.
Note: this will *not* be a complete `ChatCompletion` object when streaming as `usage`
will not be included.
"""
def __init__(self, *, completion: ChatCompletion) -> None:
msg = "Could not parse response content as the length limit was reached"
if completion.usage:
msg += f" - {completion.usage}"
super().__init__(msg)
self.completion = completion
class ContentFilterFinishReasonError(OpenAIError):
def __init__(self) -> None:
super().__init__(
f"Could not parse response content as the request was rejected by the content filter",
)
class InvalidWebhookSignatureError(ValueError):
"""Raised when a webhook signature is invalid, meaning the computed signature does not match the expected signature."""

View file

@ -0,0 +1,3 @@
from .numpy_proxy import numpy as numpy, has_numpy as has_numpy
from .pandas_proxy import pandas as pandas
from .sounddevice_proxy import sounddevice as sounddevice

View file

@ -0,0 +1,21 @@
from .._exceptions import OpenAIError
INSTRUCTIONS = """
OpenAI error:
missing `{library}`
This feature requires additional dependencies:
$ pip install openai[{extra}]
"""
def format_instructions(*, library: str, extra: str) -> str:
return INSTRUCTIONS.format(library=library, extra=extra)
class MissingDependencyError(OpenAIError):
pass

View file

@ -0,0 +1,37 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from .._utils import LazyProxy
from ._common import MissingDependencyError, format_instructions
if TYPE_CHECKING:
import numpy as numpy
NUMPY_INSTRUCTIONS = format_instructions(library="numpy", extra="voice_helpers")
class NumpyProxy(LazyProxy[Any]):
@override
def __load__(self) -> Any:
try:
import numpy
except ImportError as err:
raise MissingDependencyError(NUMPY_INSTRUCTIONS) from err
return numpy
if not TYPE_CHECKING:
numpy = NumpyProxy()
def has_numpy() -> bool:
try:
import numpy # noqa: F401 # pyright: ignore[reportUnusedImport]
except ImportError:
return False
return True

View file

@ -0,0 +1,28 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from .._utils import LazyProxy
from ._common import MissingDependencyError, format_instructions
if TYPE_CHECKING:
import pandas as pandas
PANDAS_INSTRUCTIONS = format_instructions(library="pandas", extra="datalib")
class PandasProxy(LazyProxy[Any]):
@override
def __load__(self) -> Any:
try:
import pandas
except ImportError as err:
raise MissingDependencyError(PANDAS_INSTRUCTIONS) from err
return pandas
if not TYPE_CHECKING:
pandas = PandasProxy()

View file

@ -0,0 +1,28 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from .._utils import LazyProxy
from ._common import MissingDependencyError, format_instructions
if TYPE_CHECKING:
import sounddevice as sounddevice # type: ignore
SOUNDDEVICE_INSTRUCTIONS = format_instructions(library="sounddevice", extra="voice_helpers")
class SounddeviceProxy(LazyProxy[Any]):
@override
def __load__(self) -> Any:
try:
import sounddevice # type: ignore
except ImportError as err:
raise MissingDependencyError(SOUNDDEVICE_INSTRUCTIONS) from err
return sounddevice
if not TYPE_CHECKING:
sounddevice = SounddeviceProxy()

View file

@ -0,0 +1,123 @@
from __future__ import annotations
import io
import os
import pathlib
from typing import overload
from typing_extensions import TypeGuard
import anyio
from ._types import (
FileTypes,
FileContent,
RequestFiles,
HttpxFileTypes,
Base64FileInput,
HttpxFileContent,
HttpxRequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
)
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads"
) from None
@overload
def to_httpx_files(files: None) -> None: ...
@overload
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: _transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
return files
def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = pathlib.Path(file)
return (path.name, path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
def read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return file
@overload
async def async_to_httpx_files(files: None) -> None: ...
@overload
async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: await _async_transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
return files
async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = anyio.Path(file)
return (path.name, await path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], await async_read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
async def async_read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return await anyio.Path(file).read_bytes()
return file

View file

@ -0,0 +1,491 @@
from __future__ import annotations
import os
import inspect
import logging
import datetime
import functools
from typing import (
TYPE_CHECKING,
Any,
Union,
Generic,
TypeVar,
Callable,
Iterator,
AsyncIterator,
cast,
overload,
)
from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin
import anyio
import httpx
import pydantic
from ._types import NoneType
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type
from ._models import BaseModel, is_basemodel, add_request_id
from ._constants import RAW_RESPONSE_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import APIResponseValidationError
if TYPE_CHECKING:
from ._models import FinalRequestOptions
from ._base_client import BaseClient
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
log: logging.Logger = logging.getLogger(__name__)
class LegacyAPIResponse(Generic[R]):
"""This is a legacy class as it will be replaced by `APIResponse`
and `AsyncAPIResponse` in the `_response.py` file in the next major
release.
For the sync client this will mostly be the same with the exception
of `content` & `text` will be methods instead of properties. In the
async client, all methods will be async.
A migration script will be provided & the migration in general should
be smooth.
"""
_cast_to: type[R]
_client: BaseClient[Any, Any]
_parsed_by_type: dict[type[Any], Any]
_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
http_response: httpx.Response
retries_taken: int
"""The number of retries made. If no retries happened this will be `0`"""
def __init__(
self,
*,
raw: httpx.Response,
cast_to: type[R],
client: BaseClient[Any, Any],
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
options: FinalRequestOptions,
retries_taken: int = 0,
) -> None:
self._cast_to = cast_to
self._client = client
self._parsed_by_type = {}
self._stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
self.retries_taken = retries_taken
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
NOTE: For the async client: this will become a coroutine in the next major version.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
if isinstance(parsed, BaseModel):
add_request_id(parsed, self.request_id)
self._parsed_by_type[cache_key] = parsed
return cast(R, parsed)
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def content(self) -> bytes:
"""Return the binary response content.
NOTE: this will be removed in favour of `.read()` in the
next major version.
"""
return self.http_response.content
@property
def text(self) -> str:
"""Return the decoded response content.
NOTE: this will be turned into a method in the next major version.
"""
return self.http_response.text
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def is_closed(self) -> bool:
return self.http_response.is_closed
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
cast_to = to if to is not None else self._cast_to
# unwrap `TypeAlias('Name', T)` -> `T`
if is_type_alias_type(cast_to):
cast_to = cast_to.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
if cast_to and is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)
origin = get_origin(cast_to) or cast_to
if self._stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
return cast(
_T,
to(
cast_to=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
),
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_to=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_to=cast_to,
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
if cast_to is NoneType:
return cast(R, None)
response = self.http_response
if cast_to == str:
return cast(R, response.text)
if cast_to == int:
return cast(R, int(response.text))
if cast_to == float:
return cast(R, float(response.text))
if cast_to == bool:
return cast(R, response.text.lower() == "true")
if inspect.isclass(origin) and issubclass(origin, HttpxBinaryResponseContent):
return cast(R, cast_to(response)) # type: ignore
if origin == LegacyAPIResponse:
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
if inspect.isclass(
origin # pyright: ignore[reportUnknownArgumentType]
) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_to != httpx.Response:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)
if (
inspect.isclass(
origin # pyright: ignore[reportUnknownArgumentType]
)
and not issubclass(origin, BaseModel)
and issubclass(origin, pydantic.BaseModel)
):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
if (
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if not content_type.endswith("json"):
if is_basemodel(cast_to):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
body=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
@override
def __repr__(self) -> str:
return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference",
)
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "true"
kwargs["extra_headers"] = extra_headers
return cast(LegacyAPIResponse[R], func(*args, **kwargs))
return wrapped
def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[LegacyAPIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "true"
kwargs["extra_headers"] = extra_headers
return cast(LegacyAPIResponse[R], await func(*args, **kwargs))
return wrapped
class HttpxBinaryResponseContent:
response: httpx.Response
def __init__(self, response: httpx.Response) -> None:
self.response = response
@property
def content(self) -> bytes:
return self.response.content
@property
def text(self) -> str:
return self.response.text
@property
def encoding(self) -> str | None:
return self.response.encoding
@property
def charset_encoding(self) -> str | None:
return self.response.charset_encoding
def json(self, **kwargs: Any) -> Any:
return self.response.json(**kwargs)
def read(self) -> bytes:
return self.response.read()
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
return self.response.iter_bytes(chunk_size)
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
return self.response.iter_text(chunk_size)
def iter_lines(self) -> Iterator[str]:
return self.response.iter_lines()
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
return self.response.iter_raw(chunk_size)
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')`
"""
with open(file, mode="wb") as f:
for data in self.response.iter_bytes():
f.write(data)
@deprecated(
"Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead"
)
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
with open(file, mode="wb") as f:
for data in self.response.iter_bytes(chunk_size):
f.write(data)
def close(self) -> None:
return self.response.close()
async def aread(self) -> bytes:
return await self.response.aread()
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
return self.response.aiter_bytes(chunk_size)
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
return self.response.aiter_text(chunk_size)
async def aiter_lines(self) -> AsyncIterator[str]:
return self.response.aiter_lines()
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
return self.response.aiter_raw(chunk_size)
@deprecated(
"Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead"
)
async def astream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.response.aiter_bytes(chunk_size):
await f.write(data)
async def aclose(self) -> None:
return await self.response.aclose()

View file

@ -0,0 +1,915 @@
from __future__ import annotations
import os
import inspect
import weakref
from typing import (
IO,
TYPE_CHECKING,
Any,
Type,
Tuple,
Union,
Generic,
TypeVar,
Callable,
Iterable,
Optional,
AsyncIterable,
cast,
)
from datetime import date, datetime
from typing_extensions import (
List,
Unpack,
Literal,
ClassVar,
Protocol,
Required,
Sequence,
ParamSpec,
TypedDict,
TypeGuard,
final,
override,
runtime_checkable,
)
import pydantic
from pydantic.fields import FieldInfo
from ._types import (
Body,
IncEx,
Query,
ModelT,
Headers,
Timeout,
NotGiven,
AnyMapping,
HttpxRequestFiles,
)
from ._utils import (
PropertyInfo,
is_list,
is_given,
json_safe,
lru_cache,
is_mapping,
parse_date,
coerce_boolean,
parse_datetime,
strip_not_given,
extract_type_arg,
is_annotated_type,
is_type_alias_type,
strip_annotated_type,
)
from ._compat import (
PYDANTIC_V1,
ConfigDict,
GenericModel as BaseGenericModel,
get_args,
is_union,
parse_obj,
get_origin,
is_literal_type,
get_model_config,
get_model_fields,
field_get_default,
)
from ._constants import RAW_RESPONSE_HEADER
if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
__all__ = ["BaseModel", "GenericModel"]
_T = TypeVar("_T")
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
P = ParamSpec("P")
ReprArgs = Sequence[Tuple[Optional[str], Any]]
@runtime_checkable
class _ConfigProtocol(Protocol):
allow_population_by_field_name: bool
class BaseModel(pydantic.BaseModel):
if PYDANTIC_V1:
@property
@override
def model_fields_set(self) -> set[str]:
# a forwards-compat shim for pydantic v2
return self.__fields_set__ # type: ignore
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
@override
def __repr_args__(self) -> ReprArgs:
# we don't want these attributes to be included when something like `rich.print` is used
return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}]
else:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
)
if TYPE_CHECKING:
_request_id: Optional[str] = None
"""The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI.
This will **only** be set for the top-level response object, it will not be defined for nested objects. For example:
```py
completion = await client.chat.completions.create(...)
completion._request_id # req_id_xxx
completion.usage._request_id # raises `AttributeError`
```
Note: unlike other properties that use an `_` prefix, this property
*is* public. Unless documented otherwise, all other `_` prefix properties,
methods and modules are *private*.
"""
def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
mode:
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
"""
return self.model_dump(
mode=mode,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
"""
return self.model_dump_json(
indent=indent,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)
@override
def __str__(self) -> str:
# mypy complains about an invalid self arg
return f"{self.__repr_name__()}({self.__repr_str__(', ')})" # type: ignore[misc]
# Override the 'construct' method in a way that supports recursive parsing without validation.
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
@classmethod
@override
def construct( # pyright: ignore[reportIncompatibleMethodOverride]
__cls: Type[ModelT],
_fields_set: set[str] | None = None,
**values: object,
) -> ModelT:
m = __cls.__new__(__cls)
fields_values: dict[str, object] = {}
config = get_model_config(__cls)
populate_by_name = (
config.allow_population_by_field_name
if isinstance(config, _ConfigProtocol)
else config.get("populate_by_name")
)
if _fields_set is None:
_fields_set = set()
model_fields = get_model_fields(__cls)
for name, field in model_fields.items():
key = field.alias
if key is None or (key not in values and populate_by_name):
key = name
if key in values:
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
_fields_set.add(name)
else:
fields_values[name] = field_get_default(field)
extra_field_type = _get_extra_fields_type(__cls)
_extra = {}
for key, value in values.items():
if key not in model_fields:
parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value
if PYDANTIC_V1:
_fields_set.add(key)
fields_values[key] = parsed
else:
_extra[key] = parsed
object.__setattr__(m, "__dict__", fields_values)
if PYDANTIC_V1:
# init_private_attributes() does not exist in v2
m._init_private_attributes() # type: ignore
# copied from Pydantic v1's `construct()` method
object.__setattr__(m, "__fields_set__", _fields_set)
else:
# these properties are copied from Pydantic's `model_construct()` method
object.__setattr__(m, "__pydantic_private__", None)
object.__setattr__(m, "__pydantic_extra__", _extra)
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
return m
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
# because the type signatures are technically different
# although not in practice
model_construct = construct
if PYDANTIC_V1:
# we define aliases for some of the new pydantic v2 methods so
# that we can just document these methods without having to specify
# a specific pydantic version as some users may not know which
# pydantic version they are currently using
@override
def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx | None = None,
exclude: IncEx | None = None,
context: Any | None = None,
by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
Args:
mode: The mode in which `to_python` should run.
If mode is 'json', the output will only contain JSON serializable types.
If mode is 'python', the output may contain non-JSON-serializable Python objects.
include: A set of fields to include in the output.
exclude: A set of fields to exclude from the output.
context: Additional context to pass to the serializer.
by_alias: Whether to use the field's alias in the dictionary key if defined.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value.
exclude_none: Whether to exclude fields that have a value of `None`.
exclude_computed_fields: Whether to exclude computed fields.
While this can be useful for round-tripping, it is usually recommended to use the dedicated
`round_trip` parameter instead.
round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T].
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
fallback: A function to call when an unknown value is encountered. If not provided,
a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
Returns:
A dictionary representation of the model.
"""
if mode not in {"json", "python"}:
raise ValueError("mode must be either 'json' or 'python'")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
if fallback is not None:
raise ValueError("fallback is only supported in Pydantic v2")
if exclude_computed_fields != False:
raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
dumped = super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias if by_alias is not None else False,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped
@override
def model_dump_json(
self,
*,
indent: int | None = None,
ensure_ascii: bool = False,
include: IncEx | None = None,
exclude: IncEx | None = None,
context: Any | None = None,
by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> str:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
Generates a JSON representation of the model using Pydantic's `to_json` method.
Args:
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
by_alias: Whether to serialize using field aliases.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to use serialization/deserialization between JSON and class instance.
warnings: Whether to show any warnings that occurred during serialization.
Returns:
A JSON string representation of the model.
"""
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
raise ValueError("warnings is only supported in Pydantic v2")
if context is not None:
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
if fallback is not None:
raise ValueError("fallback is only supported in Pydantic v2")
if ensure_ascii != False:
raise ValueError("ensure_ascii is only supported in Pydantic v2")
if exclude_computed_fields != False:
raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
return super().json( # type: ignore[reportDeprecated]
indent=indent,
include=include,
exclude=exclude,
by_alias=by_alias if by_alias is not None else False,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
if PYDANTIC_V1:
type_ = cast(type, field.outer_type_) # type: ignore
else:
type_ = field.annotation # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
if PYDANTIC_V1:
# TODO
return None
schema = cls.__pydantic_core_schema__
if schema["type"] == "model":
fields = schema["schema"]
if fields["type"] == "model-fields":
extras = fields.get("extras_schema")
if extras and "cls" in extras:
# mypy can't narrow the type
return extras["cls"] # type: ignore[no-any-return]
return None
def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
if is_union(type_):
for variant in get_args(type_):
if is_basemodel(variant):
return True
return False
return is_basemodel_type(type_)
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
if not inspect.isclass(origin):
return False
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
def build(
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
) -> _BaseModelT:
"""Construct a BaseModel class without validation.
This is useful for cases where you need to instantiate a `BaseModel`
from an API response as this provides type-safe params which isn't supported
by helpers like `construct_type()`.
```py
build(MyModel, my_field_a="foo", my_field_b=123)
```
"""
if args:
raise TypeError(
"Received positional arguments which are not supported; Keyword arguments must be used instead",
)
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
"""Loose coercion to the expected type with construction of nested values.
Note: the returned value from this function is not guaranteed to match the
given type.
"""
return cast(_T, construct_type(value=value, type_=type_))
def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
"""Loose coercion to the expected type with construction of nested values.
If the given value does not match the expected type then it is returned as-is.
"""
# store a reference to the original type we were given before we extract any inner
# types so that we can properly resolve forward references in `TypeAliasType` annotations
original_type = None
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
if is_type_alias_type(type_):
original_type = type_ # type: ignore[unreachable]
type_ = type_.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
if metadata is not None and len(metadata) > 0:
meta: tuple[Any, ...] = tuple(metadata)
elif is_annotated_type(type_):
meta = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = tuple()
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
origin = get_origin(type_) or type_
args = get_args(type_)
if is_union(origin):
try:
return validate_type(type_=cast("type[object]", original_type or type_), value=value)
except Exception:
pass
# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)
# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
return construct_type(value=value, type_=variant)
except Exception:
continue
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
if origin == dict:
if not is_mapping(value):
return value
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
if (
not is_literal_type(type_)
and inspect.isclass(origin)
and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel))
):
if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
if is_mapping(value):
if issubclass(type_, BaseModel):
return type_.construct(**value) # type: ignore[arg-type]
return cast(Any, type_).construct(**value)
if origin == list:
if not is_list(value):
return value
inner_type = args[0] # List[inner_type]
return [construct_type(value=entry, type_=inner_type) for entry in value]
if origin == float:
if isinstance(value, int):
coerced = float(value)
if coerced != value:
return value
return coerced
return value
if type_ == datetime:
try:
return parse_datetime(value) # type: ignore
except Exception:
return value
if type_ == date:
try:
return parse_date(value) # type: ignore
except Exception:
return value
return value
@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
```py
class Foo(BaseModel):
type: Literal['foo']
```
Will result in field_name='type'
"""
field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.
```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```
Will result in field_alias_from='type_from_api'
"""
mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.
{'foo': FooVariant, 'bar': BarVariant}
"""
def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
cached = DISCRIMINATOR_CACHE.get(union)
if cached is not None:
return cached
discriminator_field_name: str | None = None
for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break
if not discriminator_field_name:
return None
mapping: dict[str, type] = {}
discriminator_alias: str | None = None
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V1:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
for entry in get_args(annotation):
if isinstance(entry, str):
mapping[entry] = variant
else:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")
field_schema = field["schema"]
if field_schema["type"] == "literal":
for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
if not mapping:
return None
details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
DISCRIMINATOR_CACHE.setdefault(union, details)
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] == "definitions":
schema = schema["schema"]
if schema["type"] != "model":
return None
schema = cast("ModelSchema", schema)
fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None
fields_schema = cast("ModelFieldsSchema", fields_schema)
field = fields_schema["fields"].get(field_name)
if not field:
return None
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
return cast(_T, parse_obj(type_, value))
return cast(_T, _validate_non_model_type(type_=type_, value=value))
def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
"""Add a pydantic config for the given type.
Note: this is a no-op on Pydantic v1.
"""
setattr(typ, "__pydantic_config__", config) # noqa: B010
def add_request_id(obj: BaseModel, request_id: str | None) -> None:
obj._request_id = request_id
# in Pydantic v1, using setattr like we do above causes the attribute
# to be included when serializing the model which we don't want in this
# case so we need to explicitly exclude it
if PYDANTIC_V1:
try:
exclude_fields = obj.__exclude_fields__ # type: ignore
except AttributeError:
cast(Any, obj).__exclude_fields__ = {"_request_id", "__exclude_fields__"}
else:
cast(Any, obj).__exclude_fields__ = {*(exclude_fields or {}), "_request_id", "__exclude_fields__"}
# our use of subclassing here causes weirdness for type checkers,
# so we just pretend that we don't subclass
if TYPE_CHECKING:
GenericModel = BaseModel
else:
class GenericModel(BaseGenericModel, BaseModel):
pass
if not PYDANTIC_V1:
from pydantic import TypeAdapter as _TypeAdapter
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
if TYPE_CHECKING:
from pydantic import TypeAdapter
else:
TypeAdapter = _CachedTypeAdapter
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)
elif not TYPE_CHECKING: # TODO: condition is weird
class RootModel(GenericModel, Generic[_T]):
"""Used as a placeholder to easily convert runtime types to a Pydantic format
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""
__root__: _T
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
model = _create_pydantic_model(type_).validate(value)
return cast(_T, model.__root__)
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
return RootModel[type_] # type: ignore
class FinalRequestOptionsInput(TypedDict, total=False):
method: Required[str]
url: Required[str]
params: Query
headers: Headers
max_retries: int
timeout: float | Timeout | None
files: HttpxRequestFiles | None
idempotency_key: str
content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None]
json_data: Body
extra_json: AnyMapping
follow_redirects: bool
synthesize_event_and_data: bool
@final
class FinalRequestOptions(pydantic.BaseModel):
method: str
url: str
params: Query = {}
headers: Union[Headers, NotGiven] = NotGiven()
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
follow_redirects: Union[bool, None] = None
synthesize_event_and_data: Optional[bool] = None
content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] = None
# It should be noted that we cannot use `json` here as that would override
# a BaseModel method in an incompatible fashion.
json_data: Union[Body, None] = None
extra_json: Union[AnyMapping, None] = None
if PYDANTIC_V1:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
arbitrary_types_allowed: bool = True
else:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
def get_max_retries(self, max_retries: int) -> int:
if isinstance(self.max_retries, NotGiven):
return max_retries
return self.max_retries
def _strip_raw_response_header(self) -> None:
if not is_given(self.headers):
return
if self.headers.get(RAW_RESPONSE_HEADER):
self.headers = {**self.headers}
self.headers.pop(RAW_RESPONSE_HEADER)
# override the `construct` method so that we can run custom transformations.
# this is necessary as we don't want to do any actual runtime type checking
# (which means we can't use validators) but we do want to ensure that `NotGiven`
# values are not present
#
# type ignore required because we're adding explicit types to `**values`
@classmethod
def construct( # type: ignore
cls,
_fields_set: set[str] | None = None,
**values: Unpack[FinalRequestOptionsInput],
) -> FinalRequestOptions:
kwargs: dict[str, Any] = {
# we unconditionally call `strip_not_given` on any value
# as it will just ignore any non-mapping types
key: strip_not_given(value)
for key, value in values.items()
}
if PYDANTIC_V1:
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
return super().model_construct(_fields_set, **kwargs)
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
model_construct = construct

View file

@ -0,0 +1,181 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing_extensions import override
if TYPE_CHECKING:
from .resources.files import Files
from .resources.images import Images
from .resources.models import Models
from .resources.videos import Videos
from .resources.batches import Batches
from .resources.beta.beta import Beta
from .resources.chat.chat import Chat
from .resources.embeddings import Embeddings
from .resources.audio.audio import Audio
from .resources.completions import Completions
from .resources.evals.evals import Evals
from .resources.moderations import Moderations
from .resources.skills.skills import Skills
from .resources.uploads.uploads import Uploads
from .resources.realtime.realtime import Realtime
from .resources.webhooks.webhooks import Webhooks
from .resources.responses.responses import Responses
from .resources.containers.containers import Containers
from .resources.fine_tuning.fine_tuning import FineTuning
from .resources.conversations.conversations import Conversations
from .resources.vector_stores.vector_stores import VectorStores
from . import _load_client
from ._utils import LazyProxy
class ChatProxy(LazyProxy["Chat"]):
@override
def __load__(self) -> Chat:
return _load_client().chat
class BetaProxy(LazyProxy["Beta"]):
@override
def __load__(self) -> Beta:
return _load_client().beta
class FilesProxy(LazyProxy["Files"]):
@override
def __load__(self) -> Files:
return _load_client().files
class AudioProxy(LazyProxy["Audio"]):
@override
def __load__(self) -> Audio:
return _load_client().audio
class EvalsProxy(LazyProxy["Evals"]):
@override
def __load__(self) -> Evals:
return _load_client().evals
class ImagesProxy(LazyProxy["Images"]):
@override
def __load__(self) -> Images:
return _load_client().images
class ModelsProxy(LazyProxy["Models"]):
@override
def __load__(self) -> Models:
return _load_client().models
class SkillsProxy(LazyProxy["Skills"]):
@override
def __load__(self) -> Skills:
return _load_client().skills
class VideosProxy(LazyProxy["Videos"]):
@override
def __load__(self) -> Videos:
return _load_client().videos
class BatchesProxy(LazyProxy["Batches"]):
@override
def __load__(self) -> Batches:
return _load_client().batches
class UploadsProxy(LazyProxy["Uploads"]):
@override
def __load__(self) -> Uploads:
return _load_client().uploads
class WebhooksProxy(LazyProxy["Webhooks"]):
@override
def __load__(self) -> Webhooks:
return _load_client().webhooks
class RealtimeProxy(LazyProxy["Realtime"]):
@override
def __load__(self) -> Realtime:
return _load_client().realtime
class ResponsesProxy(LazyProxy["Responses"]):
@override
def __load__(self) -> Responses:
return _load_client().responses
class EmbeddingsProxy(LazyProxy["Embeddings"]):
@override
def __load__(self) -> Embeddings:
return _load_client().embeddings
class ContainersProxy(LazyProxy["Containers"]):
@override
def __load__(self) -> Containers:
return _load_client().containers
class CompletionsProxy(LazyProxy["Completions"]):
@override
def __load__(self) -> Completions:
return _load_client().completions
class ModerationsProxy(LazyProxy["Moderations"]):
@override
def __load__(self) -> Moderations:
return _load_client().moderations
class FineTuningProxy(LazyProxy["FineTuning"]):
@override
def __load__(self) -> FineTuning:
return _load_client().fine_tuning
class VectorStoresProxy(LazyProxy["VectorStores"]):
@override
def __load__(self) -> VectorStores:
return _load_client().vector_stores
class ConversationsProxy(LazyProxy["Conversations"]):
@override
def __load__(self) -> Conversations:
return _load_client().conversations
chat: Chat = ChatProxy().__as_proxied__()
beta: Beta = BetaProxy().__as_proxied__()
files: Files = FilesProxy().__as_proxied__()
audio: Audio = AudioProxy().__as_proxied__()
evals: Evals = EvalsProxy().__as_proxied__()
images: Images = ImagesProxy().__as_proxied__()
models: Models = ModelsProxy().__as_proxied__()
skills: Skills = SkillsProxy().__as_proxied__()
videos: Videos = VideosProxy().__as_proxied__()
batches: Batches = BatchesProxy().__as_proxied__()
uploads: Uploads = UploadsProxy().__as_proxied__()
webhooks: Webhooks = WebhooksProxy().__as_proxied__()
realtime: Realtime = RealtimeProxy().__as_proxied__()
responses: Responses = ResponsesProxy().__as_proxied__()
embeddings: Embeddings = EmbeddingsProxy().__as_proxied__()
containers: Containers = ContainersProxy().__as_proxied__()
completions: Completions = CompletionsProxy().__as_proxied__()
moderations: Moderations = ModerationsProxy().__as_proxied__()
fine_tuning: FineTuning = FineTuningProxy().__as_proxied__()
vector_stores: VectorStores = VectorStoresProxy().__as_proxied__()
conversations: Conversations = ConversationsProxy().__as_proxied__()

View file

@ -0,0 +1,150 @@
from __future__ import annotations
from typing import Any, List, Tuple, Union, Mapping, TypeVar
from urllib.parse import parse_qs, urlencode
from typing_extensions import Literal, get_args
from ._types import NotGiven, not_given
from ._utils import flatten
_T = TypeVar("_T")
ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
NestedFormat = Literal["dots", "brackets"]
PrimitiveData = Union[str, int, float, bool, None]
# this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"]
# https://github.com/microsoft/pyright/issues/3555
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
Params = Mapping[str, Data]
class Querystring:
array_format: ArrayFormat
nested_format: NestedFormat
def __init__(
self,
*,
array_format: ArrayFormat = "repeat",
nested_format: NestedFormat = "brackets",
) -> None:
self.array_format = array_format
self.nested_format = nested_format
def parse(self, query: str) -> Mapping[str, object]:
# Note: custom format syntax is not supported yet
return parse_qs(query)
def stringify(
self,
params: Params,
*,
array_format: ArrayFormat | NotGiven = not_given,
nested_format: NestedFormat | NotGiven = not_given,
) -> str:
return urlencode(
self.stringify_items(
params,
array_format=array_format,
nested_format=nested_format,
)
)
def stringify_items(
self,
params: Params,
*,
array_format: ArrayFormat | NotGiven = not_given,
nested_format: NestedFormat | NotGiven = not_given,
) -> list[tuple[str, str]]:
opts = Options(
qs=self,
array_format=array_format,
nested_format=nested_format,
)
return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])
def _stringify_item(
self,
key: str,
value: Data,
opts: Options,
) -> list[tuple[str, str]]:
if isinstance(value, Mapping):
items: list[tuple[str, str]] = []
nested_format = opts.nested_format
for subkey, subvalue in value.items():
items.extend(
self._stringify_item(
# TODO: error if unknown format
f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]",
subvalue,
opts,
)
)
return items
if isinstance(value, (list, tuple)):
array_format = opts.array_format
if array_format == "comma":
return [
(
key,
",".join(self._primitive_value_to_str(item) for item in value if item is not None),
),
]
elif array_format == "repeat":
items = []
for item in value:
items.extend(self._stringify_item(key, item, opts))
return items
elif array_format == "indices":
raise NotImplementedError("The array indices format is not supported yet")
elif array_format == "brackets":
items = []
key = key + "[]"
for item in value:
items.extend(self._stringify_item(key, item, opts))
return items
else:
raise NotImplementedError(
f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}"
)
serialised = self._primitive_value_to_str(value)
if not serialised:
return []
return [(key, serialised)]
def _primitive_value_to_str(self, value: PrimitiveData) -> str:
# copied from httpx
if value is True:
return "true"
elif value is False:
return "false"
elif value is None:
return ""
return str(value)
_qs = Querystring()
parse = _qs.parse
stringify = _qs.stringify
stringify_items = _qs.stringify_items
class Options:
array_format: ArrayFormat
nested_format: NestedFormat
def __init__(
self,
qs: Querystring = _qs,
*,
array_format: ArrayFormat | NotGiven = not_given,
nested_format: NestedFormat | NotGiven = not_given,
) -> None:
self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format
self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format

View file

@ -0,0 +1,43 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import time
from typing import TYPE_CHECKING
import anyio
if TYPE_CHECKING:
from ._client import OpenAI, AsyncOpenAI
class SyncAPIResource:
_client: OpenAI
def __init__(self, client: OpenAI) -> None:
self._client = client
self._get = client.get
self._post = client.post
self._patch = client.patch
self._put = client.put
self._delete = client.delete
self._get_api_list = client.get_api_list
def _sleep(self, seconds: float) -> None:
time.sleep(seconds)
class AsyncAPIResource:
_client: AsyncOpenAI
def __init__(self, client: AsyncOpenAI) -> None:
self._client = client
self._get = client.get
self._post = client.post
self._patch = client.patch
self._put = client.put
self._delete = client.delete
self._get_api_list = client.get_api_list
async def _sleep(self, seconds: float) -> None:
await anyio.sleep(seconds)

View file

@ -0,0 +1,851 @@
from __future__ import annotations
import os
import inspect
import logging
import datetime
import functools
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Union,
Generic,
TypeVar,
Callable,
Iterator,
AsyncIterator,
cast,
overload,
)
from typing_extensions import Awaitable, ParamSpec, override, get_origin
import anyio
import httpx
import pydantic
from ._types import NoneType
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
from ._models import BaseModel, is_basemodel, add_request_id
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import OpenAIError, APIResponseValidationError
if TYPE_CHECKING:
from ._models import FinalRequestOptions
from ._base_client import BaseClient
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]")
log: logging.Logger = logging.getLogger(__name__)
class BaseAPIResponse(Generic[R]):
_cast_to: type[R]
_client: BaseClient[Any, Any]
_parsed_by_type: dict[type[Any], Any]
_is_sse_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
http_response: httpx.Response
retries_taken: int
"""The number of retries made. If no retries happened this will be `0`"""
def __init__(
self,
*,
raw: httpx.Response,
cast_to: type[R],
client: BaseClient[Any, Any],
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
options: FinalRequestOptions,
retries_taken: int = 0,
) -> None:
self._cast_to = cast_to
self._client = client
self._parsed_by_type = {}
self._is_sse_stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
self.retries_taken = retries_taken
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
"""Returns the httpx Request instance associated with the current response."""
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
"""Returns the URL for which the request was made."""
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
@property
def is_closed(self) -> bool:
"""Whether or not the response body has been closed.
If this is False then there is response data that has not been read yet.
You must either fully consume the response body or call `.close()`
before discarding the response to prevent resource leaks.
"""
return self.http_response.is_closed
@override
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
)
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
cast_to = to if to is not None else self._cast_to
# unwrap `TypeAlias('Name', T)` -> `T`
if is_type_alias_type(cast_to):
cast_to = cast_to.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
if cast_to and is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)
origin = get_origin(cast_to) or cast_to
if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
return cast(
_T,
to(
cast_to=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
),
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_to=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_to=cast_to,
response=self.http_response,
client=cast(Any, self._client),
options=self._options,
),
)
if cast_to is NoneType:
return cast(R, None)
response = self.http_response
if cast_to == str:
return cast(R, response.text)
if cast_to == bytes:
return cast(R, response.content)
if cast_to == int:
return cast(R, int(response.text))
if cast_to == float:
return cast(R, float(response.text))
if cast_to == bool:
return cast(R, response.text.lower() == "true")
# handle the legacy binary response case
if inspect.isclass(cast_to) and cast_to.__name__ == "HttpxBinaryResponseContent":
return cast(R, cast_to(response)) # type: ignore
if origin == APIResponse:
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_to != httpx.Response:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)
if (
inspect.isclass(
origin # pyright: ignore[reportUnknownArgumentType]
)
and not issubclass(origin, BaseModel)
and issubclass(origin, pydantic.BaseModel)
):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
if (
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if not content_type.endswith("json"):
if is_basemodel(cast_to):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
body=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
class APIResponse(BaseAPIResponse[R]):
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
self.read()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
if isinstance(parsed, BaseModel):
add_request_id(parsed, self.request_id)
self._parsed_by_type[cache_key] = parsed
return cast(R, parsed)
def read(self) -> bytes:
"""Read and return the binary response content."""
try:
return self.http_response.read()
except httpx.StreamConsumed as exc:
# The default error raised by httpx isn't very
# helpful in our case so we re-raise it with
# a different error message.
raise StreamAlreadyConsumed() from exc
def text(self) -> str:
"""Read and decode the response content into a string."""
self.read()
return self.http_response.text
def json(self) -> object:
"""Read and decode the JSON response content."""
self.read()
return self.http_response.json()
def close(self) -> None:
"""Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.http_response.close()
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This automatically handles gzip, deflate and brotli encoded responses.
"""
for chunk in self.http_response.iter_bytes(chunk_size):
yield chunk
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
"""A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
for chunk in self.http_response.iter_text(chunk_size):
yield chunk
def iter_lines(self) -> Iterator[str]:
"""Like `iter_text()` but will only yield chunks for each line"""
for chunk in self.http_response.iter_lines():
yield chunk
class AsyncAPIResponse(BaseAPIResponse[R]):
@property
def request_id(self) -> str | None:
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
@overload
async def parse(self, *, to: type[_T]) -> _T: ...
@overload
async def parse(self) -> R: ...
async def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
await self.read()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
if isinstance(parsed, BaseModel):
add_request_id(parsed, self.request_id)
self._parsed_by_type[cache_key] = parsed
return cast(R, parsed)
async def read(self) -> bytes:
"""Read and return the binary response content."""
try:
return await self.http_response.aread()
except httpx.StreamConsumed as exc:
# the default error raised by httpx isn't very
# helpful in our case so we re-raise it with
# a different error message
raise StreamAlreadyConsumed() from exc
async def text(self) -> str:
"""Read and decode the response content into a string."""
await self.read()
return self.http_response.text
async def json(self) -> object:
"""Read and decode the JSON response content."""
await self.read()
return self.http_response.json()
async def close(self) -> None:
"""Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.http_response.aclose()
async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This automatically handles gzip, deflate and brotli encoded responses.
"""
async for chunk in self.http_response.aiter_bytes(chunk_size):
yield chunk
async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
"""A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
async for chunk in self.http_response.aiter_text(chunk_size):
yield chunk
async def iter_lines(self) -> AsyncIterator[str]:
"""Like `iter_text()` but will only yield chunks for each line"""
async for chunk in self.http_response.aiter_lines():
yield chunk
class BinaryAPIResponse(APIResponse[bytes]):
"""Subclass of APIResponse providing helpers for dealing with binary data.
Note: If you want to stream the response data instead of eagerly reading it
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
with open(file, mode="wb") as f:
for data in self.iter_bytes():
f.write(data)
class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]):
"""Subclass of APIResponse providing helpers for dealing with binary data.
Note: If you want to stream the response data instead of eagerly reading it
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
async def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.iter_bytes():
await f.write(data)
class StreamedBinaryAPIResponse(APIResponse[bytes]):
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
"""Streams the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
"""
with open(file, mode="wb") as f:
for data in self.iter_bytes(chunk_size):
f.write(data)
class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]):
async def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
"""Streams the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
"""
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.iter_bytes(chunk_size):
await f.write(data)
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference",
)
class StreamAlreadyConsumed(OpenAIError):
"""
Attempted to read or stream content, but the content has already
been streamed.
This can happen if you use a method like `.iter_lines()` and then attempt
to read th entire response body afterwards, e.g.
```py
response = await client.post(...)
async for line in response.iter_lines():
... # do something with `line`
content = await response.read()
# ^ error
```
If you want this behaviour you'll need to either manually accumulate the response
content or call `await response.read()` before iterating over the stream.
"""
def __init__(self) -> None:
message = (
"Attempted to read or stream some content, but the content has "
"already been streamed. "
"This could be due to attempting to stream the response "
"content more than once."
"\n\n"
"You can fix this by manually accumulating the response content while streaming "
"or by calling `.read()` before starting to stream."
)
super().__init__(message)
class ResponseContextManager(Generic[_APIResponseT]):
"""Context manager for ensuring that a request is not made
until it is entered and that the response will always be closed
when the context manager exits
"""
def __init__(self, request_func: Callable[[], _APIResponseT]) -> None:
self._request_func = request_func
self.__response: _APIResponseT | None = None
def __enter__(self) -> _APIResponseT:
self.__response = self._request_func()
return self.__response
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__response is not None:
self.__response.close()
class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]):
"""Context manager for ensuring that a request is not made
until it is entered and that the response will always be closed
when the context manager exits
"""
def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None:
self._api_request = api_request
self.__response: _AsyncAPIResponseT | None = None
async def __aenter__(self) -> _AsyncAPIResponseT:
self.__response = await self._api_request
return self.__response
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__response is not None:
await self.__response.close()
def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support streaming and returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
kwargs["extra_headers"] = extra_headers
make_request = functools.partial(func, *args, **kwargs)
return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request))
return wrapped
def async_to_streamed_response_wrapper(
func: Callable[P, Awaitable[R]],
) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support streaming and returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
kwargs["extra_headers"] = extra_headers
make_request = func(*args, **kwargs)
return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request))
return wrapped
def to_custom_streamed_response_wrapper(
func: Callable[P, object],
response_cls: type[_APIResponseT],
) -> Callable[P, ResponseContextManager[_APIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support streaming and returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
make_request = functools.partial(func, *args, **kwargs)
return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request))
return wrapped
def async_to_custom_streamed_response_wrapper(
func: Callable[P, Awaitable[object]],
response_cls: type[_AsyncAPIResponseT],
) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support streaming and returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
make_request = func(*args, **kwargs)
return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request))
return wrapped
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
kwargs["extra_headers"] = extra_headers
return cast(APIResponse[R], func(*args, **kwargs))
return wrapped
def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
kwargs["extra_headers"] = extra_headers
return cast(AsyncAPIResponse[R], await func(*args, **kwargs))
return wrapped
def to_custom_raw_response_wrapper(
func: Callable[P, object],
response_cls: type[_APIResponseT],
) -> Callable[P, _APIResponseT]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
return cast(_APIResponseT, func(*args, **kwargs))
return wrapped
def async_to_custom_raw_response_wrapper(
func: Callable[P, Awaitable[object]],
response_cls: type[_AsyncAPIResponseT],
) -> Callable[P, Awaitable[_AsyncAPIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs))
return wrapped
def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
"""Given a type like `APIResponse[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(APIResponse[bytes]):
...
extract_response_type(MyResponse) -> bytes
```
"""
return extract_type_var_from_base(
typ,
generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)),
index=0,
)

View file

@ -0,0 +1,427 @@
# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
from __future__ import annotations
import json
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, Optional, AsyncIterator, cast
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
import httpx
from ._utils import is_mapping, extract_type_var_from_base
from ._exceptions import APIError
if TYPE_CHECKING:
from ._client import OpenAI, AsyncOpenAI
from ._models import FinalRequestOptions
_T = TypeVar("_T")
class Stream(Generic[_T]):
"""Provides the core interface to iterate over a synchronous stream response."""
response: httpx.Response
_options: Optional[FinalRequestOptions] = None
_decoder: SSEBytesDecoder
def __init__(
self,
*,
cast_to: type[_T],
response: httpx.Response,
client: OpenAI,
options: Optional[FinalRequestOptions] = None,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._options = options
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
def __next__(self) -> _T:
return self._iterator.__next__()
def __iter__(self) -> Iterator[_T]:
for item in self._iterator:
yield item
def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter_bytes(self.response.iter_bytes())
def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
try:
for sse in iterator:
if sse.data.startswith("[DONE]"):
break
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()
if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
yield process_data(
data={"data": data, "event": sse.event}
if self._options is not None and self._options.synthesize_event_and_data
else data,
cast_to=cast_to,
response=response,
)
finally:
# Ensure the response is closed even if the consumer doesn't read all data
response.close()
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.response.close()
class AsyncStream(Generic[_T]):
"""Provides the core interface to iterate over an asynchronous stream response."""
response: httpx.Response
_options: Optional[FinalRequestOptions] = None
_decoder: SSEDecoder | SSEBytesDecoder
def __init__(
self,
*,
cast_to: type[_T],
response: httpx.Response,
client: AsyncOpenAI,
options: Optional[FinalRequestOptions] = None,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._options = options
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
async def __anext__(self) -> _T:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[_T]:
async for item in self._iterator:
yield item
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse
async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
try:
async for sse in iterator:
if sse.data.startswith("[DONE]"):
break
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()
if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"
raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)
yield process_data(
data={"data": data, "event": sse.event}
if self._options is not None and self._options.synthesize_event_and_data
else data,
cast_to=cast_to,
response=response,
)
finally:
# Ensure the response is closed even if the consumer doesn't read all data
await response.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.response.aclose()
class ServerSentEvent:
def __init__(
self,
*,
event: str | None = None,
data: str | None = None,
id: str | None = None,
retry: int | None = None,
) -> None:
if data is None:
data = ""
self._id = id
self._data = data
self._event = event or None
self._retry = retry
@property
def event(self) -> str | None:
return self._event
@property
def id(self) -> str | None:
return self._id
@property
def retry(self) -> int | None:
return self._retry
@property
def data(self) -> str:
return self._data
def json(self) -> Any:
return json.loads(self.data)
@override
def __repr__(self) -> str:
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
class SSEDecoder:
_data: list[str]
_event: str | None
_retry: int | None
_last_event_id: str | None
def __init__(self) -> None:
self._event = None
self._data = []
self._last_event_id = None
self._retry = None
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
for chunk in self._iter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
sse = self.decode(line)
if sse:
yield sse
def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
data = b""
for chunk in iterator:
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
async for chunk in self._aiter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
sse = self.decode(line)
if sse:
yield sse
async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
data = b""
async for chunk in iterator:
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
def decode(self, line: str) -> ServerSentEvent | None:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
if not line:
if not self._event and not self._data and not self._last_event_id and self._retry is None:
return None
sse = ServerSentEvent(
event=self._event,
data="\n".join(self._data),
id=self._last_event_id,
retry=self._retry,
)
# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = None
self._data = []
self._retry = None
return sse
if line.startswith(":"):
return None
fieldname, _, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if fieldname == "event":
self._event = value
elif fieldname == "data":
self._data.append(value)
elif fieldname == "id":
if "\0" in value:
pass
else:
self._last_event_id = value
elif fieldname == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.
return None
@runtime_checkable
class SSEBytesDecoder(Protocol):
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
...
def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
...
def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ
return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
def extract_stream_chunk_type(
stream_cls: type,
*,
failure_message: str | None = None,
) -> type:
"""Given a type like `Stream[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyStream(Stream[bytes]):
...
extract_stream_chunk_type(MyStream) -> bytes
```
"""
from ._base_client import Stream, AsyncStream
return extract_type_var_from_base(
stream_cls,
index=0,
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
failure_message=failure_message,
)

View file

@ -0,0 +1,275 @@
from __future__ import annotations
from os import PathLike
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
TypeVar,
Callable,
Iterable,
Iterator,
Optional,
Sequence,
AsyncIterable,
)
from typing_extensions import (
Set,
Literal,
Protocol,
TypeAlias,
TypedDict,
SupportsIndex,
overload,
override,
runtime_checkable,
)
import httpx
import pydantic
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
if TYPE_CHECKING:
from ._models import BaseModel
from ._response import APIResponse, AsyncAPIResponse
from ._legacy_response import HttpxBinaryResponseContent
Transport = BaseTransport
AsyncTransport = AsyncBaseTransport
Query = Mapping[str, object]
Body = object
AnyMapping = Mapping[str, object]
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")
# Approximates httpx internal ProxiesTypes and RequestFiles types
# while adding support for `PathLike` instances
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
ProxiesTypes = Union[str, Proxy, ProxiesDict]
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
# Used for sending raw binary data / streaming data in request bodies
# e.g. for file uploads without multipart encoding
BinaryTypes = Union[bytes, bytearray, IO[bytes], Iterable[bytes]]
AsyncBinaryTypes = Union[bytes, bytearray, IO[bytes], AsyncIterable[bytes]]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
# duplicate of the above but without our custom file support
HttpxFileContent = Union[IO[bytes], bytes]
HttpxFileTypes = Union[
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
Tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
# where ResponseT includes `None`. In order to support directly
# passing `None`, overloads would have to be defined for every
# method that uses `ResponseT` which would lead to an unacceptable
# amount of code duplication and make it unreadable. See _base_client.py
# for example usage.
#
# This unfortunately means that you will either have
# to import this type and pass it explicitly:
#
# from openai import NoneType
# client.get('/foo', cast_to=NoneType)
#
# or build it yourself:
#
# client.get('/foo', cast_to=type(None))
if TYPE_CHECKING:
NoneType: Type[None]
else:
NoneType = type(None)
class RequestOptions(TypedDict, total=False):
headers: Headers
max_retries: int
timeout: float | Timeout | None
params: Query
extra_json: AnyMapping
idempotency_key: str
follow_redirects: bool
synthesize_event_and_data: bool
# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
For parameters with a meaningful None value, we need to distinguish between
the user explicitly passing None, and the user not passing the parameter at
all.
User code shouldn't need to use not_given directly.
For example:
```py
def create(timeout: Timeout | None | NotGiven = not_given): ...
create(timeout=1) # 1s timeout
create(timeout=None) # No timeout
create() # Default timeout behavior
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
not_given = NotGiven()
# for backwards compatibility:
NOT_GIVEN = NotGiven()
class Omit:
"""
To explicitly omit something from being sent in a request, use `omit`.
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post("/upload/files", files={"file": b"my raw file content"})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={"Content-Type": "multipart/form-data"})
# instead you can remove the default `application/json` header by passing omit
client.post(..., headers={"Content-Type": omit})
```
"""
def __bool__(self) -> Literal[False]:
return False
omit = Omit()
Omittable = Union[_T, Omit]
@runtime_checkable
class ModelBuilderProtocol(Protocol):
@classmethod
def build(
cls: type[_T],
*,
response: Response,
data: object,
) -> _T: ...
Headers = Mapping[str, Union[str, Omit]]
class HeadersLikeProtocol(Protocol):
def get(self, __key: str) -> str | None: ...
HeadersLike = Union[Headers, HeadersLikeProtocol]
ResponseT = TypeVar(
"ResponseT",
bound=Union[
object,
str,
None,
"BaseModel",
List[Any],
Dict[str, Any],
Response,
ModelBuilderProtocol,
"APIResponse[Any]",
"AsyncAPIResponse[Any]",
"HttpxBinaryResponseContent",
],
)
StrBytesIntFloat = Union[str, bytes, int, float]
# Note: copied from Pydantic
# https://github.com/pydantic/pydantic/blob/6f31f8f68ef011f84357330186f603ff295312fd/pydantic/main.py#L79
IncEx: TypeAlias = Union[Set[int], Set[str], Mapping[int, Union["IncEx", bool]], Mapping[str, Union["IncEx", bool]]]
PostParser = Callable[[Any], Any]
@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""
__orig_bases__: tuple[_GenericAlias]
class _GenericAlias(Protocol):
__origin__: type[object]
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
follow_redirects: bool
_T_co = TypeVar("_T_co", covariant=True)
if TYPE_CHECKING:
# This works because str.__contains__ does not accept object (either in typeshed or at runtime)
# https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285
#
# Note: index() and count() methods are intentionally omitted to allow pyright to properly
# infer TypedDict types when dict literals are used in lists assigned to SequenceNotStr.
class SequenceNotStr(Protocol[_T_co]):
@overload
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
@overload
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
def __contains__(self, value: object, /) -> bool: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T_co]: ...
def __reversed__(self) -> Iterator[_T_co]: ...
else:
# just point this to a normal `Sequence` at runtime to avoid having to special case
# deserializing our custom sequence type
SequenceNotStr = Sequence

View file

@ -0,0 +1,67 @@
from ._logs import SensitiveHeadersFilter as SensitiveHeadersFilter
from ._sync import asyncify as asyncify
from ._proxy import LazyProxy as LazyProxy
from ._utils import (
flatten as flatten,
is_dict as is_dict,
is_list as is_list,
is_given as is_given,
is_tuple as is_tuple,
json_safe as json_safe,
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
is_iterable as is_iterable,
is_sequence as is_sequence,
coerce_float as coerce_float,
is_mapping_t as is_mapping_t,
removeprefix as removeprefix,
removesuffix as removesuffix,
extract_files as extract_files,
is_sequence_t as is_sequence_t,
required_args as required_args,
coerce_boolean as coerce_boolean,
coerce_integer as coerce_integer,
file_from_path as file_from_path,
is_azure_client as is_azure_client,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
maybe_coerce_float as maybe_coerce_float,
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
is_async_azure_client as is_async_azure_client,
)
from ._compat import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
from ._typing import (
is_list_type as is_list_type,
is_union_type as is_union_type,
extract_type_arg as extract_type_arg,
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
is_sequence_type as is_sequence_type,
is_annotated_type as is_annotated_type,
is_type_alias_type as is_type_alias_type,
strip_annotated_type as strip_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
)
from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator
from ._transform import (
PropertyInfo as PropertyInfo,
transform as transform,
async_transform as async_transform,
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)
from ._reflection import (
function_has_argument as function_has_argument,
assert_signatures_in_sync as assert_signatures_in_sync,
)
from ._datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime

View file

@ -0,0 +1,45 @@
from __future__ import annotations
import sys
import typing_extensions
from typing import Any, Type, Union, Literal, Optional
from datetime import date, datetime
from typing_extensions import get_args as _get_args, get_origin as _get_origin
from .._types import StrBytesIntFloat
from ._datetime_parse import parse_date as _parse_date, parse_datetime as _parse_datetime
_LITERAL_TYPES = {Literal, typing_extensions.Literal}
def get_args(tp: type[Any]) -> tuple[Any, ...]:
return _get_args(tp)
def get_origin(tp: type[Any]) -> type[Any] | None:
return _get_origin(tp)
def is_union(tp: Optional[Type[Any]]) -> bool:
if sys.version_info < (3, 10):
return tp is Union # type: ignore[comparison-overlap]
else:
import types
return tp is Union or tp is types.UnionType # type: ignore[comparison-overlap]
def is_typeddict(tp: Type[Any]) -> bool:
return typing_extensions.is_typeddict(tp)
def is_literal_type(tp: Type[Any]) -> bool:
return get_origin(tp) in _LITERAL_TYPES
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
return _parse_date(value)
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
return _parse_datetime(value)

View file

@ -0,0 +1,136 @@
"""
This file contains code from https://github.com/pydantic/pydantic/blob/main/pydantic/v1/datetime_parse.py
without the Pydantic v1 specific errors.
"""
from __future__ import annotations
import re
from typing import Dict, Union, Optional
from datetime import date, datetime, timezone, timedelta
from .._types import StrBytesIntFloat
date_expr = r"(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})"
time_expr = (
r"(?P<hour>\d{1,2}):(?P<minute>\d{1,2})"
r"(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?"
r"(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$"
)
date_re = re.compile(f"{date_expr}$")
datetime_re = re.compile(f"{date_expr}[T ]{time_expr}")
EPOCH = datetime(1970, 1, 1)
# if greater than this, the number is in ms, if less than or equal it's in seconds
# (in seconds this is 11th October 2603, in ms it's 20th August 1970)
MS_WATERSHED = int(2e10)
# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
MAX_NUMBER = int(3e20)
def _get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
if isinstance(value, (int, float)):
return value
try:
return float(value)
except ValueError:
return None
except TypeError:
raise TypeError(f"invalid type; expected {native_expected_type}, string, bytes, int or float") from None
def _from_unix_seconds(seconds: Union[int, float]) -> datetime:
if seconds > MAX_NUMBER:
return datetime.max
elif seconds < -MAX_NUMBER:
return datetime.min
while abs(seconds) > MS_WATERSHED:
seconds /= 1000
dt = EPOCH + timedelta(seconds=seconds)
return dt.replace(tzinfo=timezone.utc)
def _parse_timezone(value: Optional[str]) -> Union[None, int, timezone]:
if value == "Z":
return timezone.utc
elif value is not None:
offset_mins = int(value[-2:]) if len(value) > 3 else 0
offset = 60 * int(value[1:3]) + offset_mins
if value[0] == "-":
offset = -offset
return timezone(timedelta(minutes=offset))
else:
return None
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
"""
Parse a datetime/int/float/string and return a datetime.datetime.
This function supports time zone offsets. When the input contains one,
the output uses a timezone with a fixed offset from UTC.
Raise ValueError if the input is well formatted but not a valid datetime.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, datetime):
return value
number = _get_numeric(value, "datetime")
if number is not None:
return _from_unix_seconds(number)
if isinstance(value, bytes):
value = value.decode()
assert not isinstance(value, (float, int))
match = datetime_re.match(value)
if match is None:
raise ValueError("invalid datetime format")
kw = match.groupdict()
if kw["microsecond"]:
kw["microsecond"] = kw["microsecond"].ljust(6, "0")
tzinfo = _parse_timezone(kw.pop("tzinfo"))
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
kw_["tzinfo"] = tzinfo
return datetime(**kw_) # type: ignore
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
"""
Parse a date/int/float/string and return a datetime.date.
Raise ValueError if the input is well formatted but not a valid date.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, date):
if isinstance(value, datetime):
return value.date()
else:
return value
number = _get_numeric(value, "date")
if number is not None:
return _from_unix_seconds(number).date()
if isinstance(value, bytes):
value = value.decode()
assert not isinstance(value, (float, int))
match = date_re.match(value)
if match is None:
raise ValueError("invalid date format")
kw = {k: int(v) for k, v in match.groupdict().items()}
try:
return date(**kw)
except ValueError:
raise ValueError("invalid date format") from None

View file

@ -0,0 +1,35 @@
import json
from typing import Any
from datetime import datetime
from typing_extensions import override
import pydantic
from .._compat import model_dump
def openapi_dumps(obj: Any) -> bytes:
"""
Serialize an object to UTF-8 encoded JSON bytes.
Extends the standard json.dumps with support for additional types
commonly used in the SDK, such as `datetime`, `pydantic.BaseModel`, etc.
"""
return json.dumps(
obj,
cls=_CustomEncoder,
# Uses the same defaults as httpx's JSON serialization
ensure_ascii=False,
separators=(",", ":"),
allow_nan=False,
).encode()
class _CustomEncoder(json.JSONEncoder):
@override
def default(self, o: Any) -> Any:
if isinstance(o, datetime):
return o.isoformat()
if isinstance(o, pydantic.BaseModel):
return model_dump(o, exclude_unset=True, mode="json", by_alias=True)
return super().default(o)

View file

@ -0,0 +1,42 @@
import os
import logging
from typing_extensions import override
from ._utils import is_dict
logger: logging.Logger = logging.getLogger("openai")
httpx_logger: logging.Logger = logging.getLogger("httpx")
SENSITIVE_HEADERS = {"api-key", "authorization"}
def _basic_config() -> None:
# e.g. [2023-10-05 14:12:26 - openai._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar "200 OK"
logging.basicConfig(
format="[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def setup_logging() -> None:
env = os.environ.get("OPENAI_LOG")
if env == "debug":
_basic_config()
logger.setLevel(logging.DEBUG)
httpx_logger.setLevel(logging.DEBUG)
elif env == "info":
_basic_config()
logger.setLevel(logging.INFO)
httpx_logger.setLevel(logging.INFO)
class SensitiveHeadersFilter(logging.Filter):
@override
def filter(self, record: logging.LogRecord) -> bool:
if is_dict(record.args) and "headers" in record.args and is_dict(record.args["headers"]):
headers = record.args["headers"] = {**record.args["headers"]}
for header in headers:
if str(header).lower() in SENSITIVE_HEADERS:
headers[header] = "<redacted>"
return True

View file

@ -0,0 +1,65 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Iterable, cast
from typing_extensions import override
T = TypeVar("T")
class LazyProxy(Generic[T], ABC):
"""Implements data methods to pretend that an instance is another instance.
This includes forwarding attribute access and other methods.
"""
# Note: we have to special case proxies that themselves return proxies
# to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz`
def __getattr__(self, attr: str) -> object:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied # pyright: ignore
return getattr(proxied, attr)
@override
def __repr__(self) -> str:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied.__class__.__name__
return repr(self.__get_proxied__())
@override
def __str__(self) -> str:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied.__class__.__name__
return str(proxied)
@override
def __dir__(self) -> Iterable[str]:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return []
return proxied.__dir__()
@property # type: ignore
@override
def __class__(self) -> type: # pyright: ignore
try:
proxied = self.__get_proxied__()
except Exception:
return type(self)
if issubclass(type(proxied), LazyProxy):
return type(proxied)
return proxied.__class__
def __get_proxied__(self) -> T:
return self.__load__()
def __as_proxied__(self) -> T:
"""Helper method that returns the current proxy, typed as the loaded object"""
return cast(T, self)
@abstractmethod
def __load__(self) -> T: ...

View file

@ -0,0 +1,45 @@
from __future__ import annotations
import inspect
from typing import Any, Callable
def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
"""Returns whether or not the given function has a specific parameter"""
sig = inspect.signature(func)
return arg_name in sig.parameters
def assert_signatures_in_sync(
source_func: Callable[..., Any],
check_func: Callable[..., Any],
*,
exclude_params: set[str] = set(),
description: str = "",
) -> None:
"""Ensure that the signature of the second function matches the first."""
check_sig = inspect.signature(check_func)
source_sig = inspect.signature(source_func)
errors: list[str] = []
for name, source_param in source_sig.parameters.items():
if name in exclude_params:
continue
custom_param = check_sig.parameters.get(name)
if not custom_param:
errors.append(f"the `{name}` param is missing")
continue
if custom_param.annotation != source_param.annotation:
errors.append(
f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}"
)
continue
if errors:
raise AssertionError(
f"{len(errors)} errors encountered when comparing signatures{description}:\n\n" + "\n\n".join(errors)
)

View file

@ -0,0 +1,24 @@
from __future__ import annotations
from typing import Any
from typing_extensions import override
from ._proxy import LazyProxy
class ResourcesProxy(LazyProxy[Any]):
"""A proxy for the `openai.resources` module.
This is used so that we can lazily import `openai.resources` only when
needed *and* so that users can just import `openai` and reference `openai.resources`
"""
@override
def __load__(self) -> Any:
import importlib
mod = importlib.import_module("openai.resources")
return mod
resources = ResourcesProxy().__as_proxied__()

View file

@ -0,0 +1,12 @@
from typing import Any
from typing_extensions import Iterator, AsyncIterator
def consume_sync_iterator(iterator: Iterator[Any]) -> None:
for _ in iterator:
...
async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
async for _ in iterator:
...

View file

@ -0,0 +1,58 @@
from __future__ import annotations
import asyncio
import functools
from typing import TypeVar, Callable, Awaitable
from typing_extensions import ParamSpec
import anyio
import sniffio
import anyio.to_thread
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
async def to_thread(
func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
) -> T_Retval:
if sniffio.current_async_library() == "asyncio":
return await asyncio.to_thread(func, *args, **kwargs)
return await anyio.to_thread.run_sync(
functools.partial(func, *args, **kwargs),
)
# inspired by `asyncer`, https://github.com/tiangolo/asyncer
def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
positional and keyword arguments.
Usage:
```python
def blocking_func(arg1, arg2, kwarg1=None):
# blocking code
return result
result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1)
```
## Arguments
`function`: a blocking regular callable (e.g. a function)
## Return
An async function that takes the same positional and keyword arguments as the
original one, that when called runs the same original function in a thread worker
and returns the result.
"""
async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
return await to_thread(function, *args, **kwargs)
return wrapper

View file

@ -0,0 +1,457 @@
from __future__ import annotations
import io
import base64
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
import anyio
import pydantic
from ._utils import (
is_list,
is_given,
lru_cache,
is_mapping,
is_iterable,
is_sequence,
)
from .._files import is_base64_file_input
from ._compat import get_origin, is_typeddict
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
is_iterable_type,
is_required_type,
is_sequence_type,
is_annotated_type,
strip_annotated_type,
)
_T = TypeVar("_T")
# TODO: support for drilling globals() and locals()
# TODO: ensure works correctly with forward references in all cases
PropertyFormat = Literal["iso8601", "base64", "custom"]
class PropertyInfo:
"""Metadata class to be used in Annotated types to provide information about a given type.
For example:
class MyParams(TypedDict):
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
"""
alias: str | None
format: PropertyFormat | None
format_template: str | None
discriminator: str | None
def __init__(
self,
*,
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
self.discriminator = discriminator
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
def maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `transform()` that allows `None` to be passed.
See `transform()` for more details.
"""
if data is None:
return None
return transform(data, expected_type)
# Wrapper over _transform_recursive providing fake types
def transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
@lru_cache(maxsize=8096)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
"""
if is_required_type(type_):
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
type_ = get_args(type_)[0]
if is_annotated_type(type_):
return type_
return None
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.
Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
# no `Annotated` definition for this type, no transformation needed
return key
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
return annotation.alias
return key
def _no_transform_needed(annotation: type) -> bool:
return annotation == float or annotation == int
def _transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
from .._compat import model_dump
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)
if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
# Sequence[T]
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json", exclude=getattr(data, "__api_exclude__", None))
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return _format_data(data, annotation.format, annotation.format_template)
return data
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
def _transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
if not is_given(value):
# we don't need to include omitted values here as they'll
# be stripped out before the request is sent anyway
continue
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
return result
async def async_maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `async_transform()` that allows `None` to be passed.
See `async_transform()` for more details.
"""
if data is None:
return None
return await async_transform(data, expected_type)
async def async_transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
async def _async_transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
from .._compat import model_dump
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)
if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
# Sequence[T]
or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json")
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return data
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
async def _async_transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
if not is_given(value):
# we don't need to include omitted values here as they'll
# be stripped out before the request is sent anyway
continue
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result
@lru_cache(maxsize=8096)
def get_type_hints(
obj: Any,
globalns: dict[str, Any] | None = None,
localns: Mapping[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]:
return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)

View file

@ -0,0 +1,156 @@
from __future__ import annotations
import sys
import typing
import typing_extensions
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc
from typing_extensions import (
TypeIs,
Required,
Annotated,
get_args,
get_origin,
)
from ._utils import lru_cache
from .._types import InheritsGeneric
from ._compat import is_union as _is_union
def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated
def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list
def is_sequence_type(typ: type) -> bool:
origin = get_origin(typ) or typ
return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence
def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
return origin == Iterable or origin == _c_abc.Iterable
def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))
def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required
def is_typevar(typ: type) -> bool:
# type ignore is required because type checkers
# think this expression will always return False
return type(typ) == TypeVar # type: ignore
_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
if sys.version_info >= (3, 12):
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
"""Return whether the provided argument is an instance of `TypeAliasType`.
```python
type Int = int
is_type_alias_type(Int)
# > True
Str = TypeAliasType("Str", str)
is_type_alias_type(Str)
# > True
```
"""
return isinstance(tp, _TYPE_ALIAS_TYPES)
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
@lru_cache(maxsize=8096)
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
return typ
def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
def extract_type_var_from_base(
typ: type,
*,
generic_bases: tuple[type, ...],
index: int,
failure_message: str | None = None,
) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(Foo[bytes]):
...
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```
And where a generic subclass is given:
```py
_T = TypeVar('_T')
class MyResponse(Foo[_T]):
...
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains]
# we're given the class directly
return extract_type_arg(typ, index)
# if a subclass is given
# ---
# this is needed as __orig_bases__ is not present in the typeshed stubs
# because it is intended to be for internal use only, however there does
# not seem to be a way to resolve generic TypeVars for inherited subclasses
# without using it.
if isinstance(cls, InheritsGeneric):
target_base_class: Any | None = None
for base in cls.__orig_bases__:
if base.__origin__ in generic_bases:
target_base_class = base
break
if target_base_class is None:
raise RuntimeError(
"Could not find the generic base class;\n"
"This should never happen;\n"
f"Does {cls} inherit from one of {generic_bases} ?"
)
extracted = extract_type_arg(target_base_class, index)
if is_typevar(extracted):
# If the extracted type argument is itself a type variable
# then that means the subclass itself is generic, so we have
# to resolve the type argument from the class itself, not
# the base class.
#
# Note: if there is more than 1 type argument, the subclass could
# change the ordering of the type arguments, this is not currently
# supported.
return extract_type_arg(typ, index)
return extracted
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")

View file

@ -0,0 +1,437 @@
from __future__ import annotations
import os
import re
import inspect
import functools
from typing import (
TYPE_CHECKING,
Any,
Tuple,
Mapping,
TypeVar,
Callable,
Iterable,
Sequence,
cast,
overload,
)
from pathlib import Path
from datetime import date, datetime
from typing_extensions import TypeGuard
import sniffio
from .._types import Omit, NotGiven, FileTypes, HeadersLike
_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
if TYPE_CHECKING:
from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
def extract_files(
# TODO: this needs to take Dict but variance issues.....
# create protocol type ?
query: Mapping[str, object],
*,
paths: Sequence[Sequence[str]],
) -> list[tuple[str, FileTypes]]:
"""Recursively extract files from the given dictionary based on specified paths.
A path may look like this ['foo', 'files', '<array>', 'data'].
Note: this mutates the given dictionary.
"""
files: list[tuple[str, FileTypes]] = []
for path in paths:
files.extend(_extract_items(query, path, index=0, flattened_key=None))
return files
def _extract_items(
obj: object,
path: Sequence[str],
*,
index: int,
flattened_key: str | None,
) -> list[tuple[str, FileTypes]]:
try:
key = path[index]
except IndexError:
if not is_given(obj):
# no value was provided - we can safely ignore
return []
# cyclical import
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
assert flattened_key is not None
if is_list(obj):
files: list[tuple[str, FileTypes]] = []
for entry in obj:
assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
files.append((flattened_key + "[]", cast(FileTypes, entry)))
return files
assert_is_file_content(obj, key=flattened_key)
return [(flattened_key, cast(FileTypes, obj))]
index += 1
if is_dict(obj):
try:
# We are at the last entry in the path so we must remove the field
if (len(path)) == index:
item = obj.pop(key)
else:
item = obj[key]
except KeyError:
# Key was not present in the dictionary, this is not indicative of an error
# as the given path may not point to a required field. We also do not want
# to enforce required fields as the API may differ from the spec in some cases.
return []
if flattened_key is None:
flattened_key = key
else:
flattened_key += f"[{key}]"
return _extract_items(
item,
path,
index=index,
flattened_key=flattened_key,
)
elif is_list(obj):
if key != "<array>":
return []
return flatten(
[
_extract_items(
item,
path,
index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
)
for item in obj
]
)
# Something unexpected was passed, just ignore it.
return []
def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
return not isinstance(obj, NotGiven) and not isinstance(obj, Omit)
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
# care about the contained types we can safely use `object` in its place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
# `is_*_t` is for when you're narrowing a known union type to a specific subset
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
return isinstance(obj, tuple)
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
return isinstance(obj, tuple)
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
return isinstance(obj, Sequence)
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
return isinstance(obj, Sequence)
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
return isinstance(obj, Mapping)
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
return isinstance(obj, Mapping)
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)
def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
return isinstance(obj, Iterable)
def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
- mappings, e.g. `dict`
- list
This is done for performance reasons.
"""
if is_mapping(item):
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
if is_list(item):
return cast(_T, [deepcopy_minimal(entry) for entry in item])
return item
# copied from https://github.com/Rapptz/RoboDanny
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
size = len(seq)
if size == 0:
return ""
if size == 1:
return seq[0]
if size == 2:
return f"{seq[0]} {final} {seq[1]}"
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
def quote(string: str) -> str:
"""Add single quotation marks around the given string. Does *not* do any escaping."""
return f"'{string}'"
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Example usage:
```py
@overload
def foo(*, a: str) -> str: ...
@overload
def foo(*, b: bool) -> str: ...
# This enforces the same constraints that a static type checker would
# i.e. that either a or b must be passed to the function
@required_args(["a"], ["b"])
def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
```
"""
def inner(func: CallableT) -> CallableT:
params = inspect.signature(func).parameters
positional = [
name
for name, param in params.items()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
}
]
@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i, _ in enumerate(args):
try:
given_params.add(positional[i])
except IndexError:
raise TypeError(
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
) from None
for key in kwargs.keys():
given_params.add(key)
for variant in variants:
matches = all((param in given_params for param in variant))
if matches:
break
else: # no break
if len(variants) > 1:
variations = human_join(
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
assert len(variants) > 0
# TODO: this error message is not deterministic
missing = list(set(variants[0]) - given_params)
if len(missing) > 1:
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
else:
msg = f"Missing required argument: {quote(missing[0])}"
raise TypeError(msg)
return func(*args, **kwargs)
return wrapper # type: ignore
return inner
_K = TypeVar("_K")
_V = TypeVar("_V")
@overload
def strip_not_given(obj: None) -> None: ...
@overload
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
@overload
def strip_not_given(obj: object) -> object: ...
def strip_not_given(obj: object | None) -> object:
"""Remove all top-level keys where their values are instances of `NotGiven`"""
if obj is None:
return None
if not is_mapping(obj):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
def coerce_integer(val: str) -> int:
return int(val, base=10)
def coerce_float(val: str) -> float:
return float(val)
def coerce_boolean(val: str) -> bool:
return val == "true" or val == "1" or val == "on"
def maybe_coerce_integer(val: str | None) -> int | None:
if val is None:
return None
return coerce_integer(val)
def maybe_coerce_float(val: str | None) -> float | None:
if val is None:
return None
return coerce_float(val)
def maybe_coerce_boolean(val: str | None) -> bool | None:
if val is None:
return None
return coerce_boolean(val)
def removeprefix(string: str, prefix: str) -> str:
"""Remove a prefix from a string.
Backport of `str.removeprefix` for Python < 3.9
"""
if string.startswith(prefix):
return string[len(prefix) :]
return string
def removesuffix(string: str, suffix: str) -> str:
"""Remove a suffix from a string.
Backport of `str.removesuffix` for Python < 3.9
"""
if string.endswith(suffix):
return string[: -len(suffix)]
return string
def file_from_path(path: str) -> FileTypes:
contents = Path(path).read_bytes()
file_name = os.path.basename(path)
return (file_name, contents)
def get_required_header(headers: HeadersLike, header: str) -> str:
lower_header = header.lower()
if is_mapping_t(headers):
# mypy doesn't understand the type narrowing here
for k, v in headers.items(): # type: ignore
if k.lower() == lower_header and isinstance(v, str):
return v
# to deal with the case where the header looks like Stainless-Event-Id
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
value = headers.get(normalized_header)
if value:
return value
raise ValueError(f"Could not find {header} header")
def get_async_library() -> str:
try:
return sniffio.current_async_library()
except Exception:
return "false"
def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
"""A version of functools.lru_cache that retains the type signature
for the wrapped function arguments.
"""
wrapper = functools.lru_cache( # noqa: TID251
maxsize=maxsize,
)
return cast(Any, wrapper) # type: ignore[no-any-return]
def json_safe(data: object) -> object:
"""Translates a mapping / sequence recursively in the same fashion
as `pydantic` v2's `model_dump(mode="json")`.
"""
if is_mapping(data):
return {json_safe(key): json_safe(value) for key, value in data.items()}
if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
return [json_safe(item) for item in data]
if isinstance(data, (datetime, date)):
return data.isoformat()
return data
def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
from ..lib.azure import AzureOpenAI
return isinstance(client, AzureOpenAI)
def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
from ..lib.azure import AsyncAzureOpenAI
return isinstance(client, AsyncAzureOpenAI)

View file

@ -0,0 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "openai"
__version__ = "2.29.0" # x-release-please-version

View file

@ -0,0 +1 @@
from ._cli import main as main

View file

@ -0,0 +1 @@
from ._main import register_commands as register_commands

View file

@ -0,0 +1,17 @@
from __future__ import annotations
from argparse import ArgumentParser
from . import chat, audio, files, image, models, completions, fine_tuning
def register_commands(parser: ArgumentParser) -> None:
subparsers = parser.add_subparsers(help="All API subcommands")
chat.register(subparsers)
image.register(subparsers)
audio.register(subparsers)
files.register(subparsers)
models.register(subparsers)
completions.register(subparsers)
fine_tuning.register(subparsers)

View file

@ -0,0 +1,108 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any, Optional, cast
from argparse import ArgumentParser
from .._utils import get_client, print_model
from ..._types import omit
from .._models import BaseModel
from .._progress import BufferReader
from ...types.audio import Transcription
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
# transcriptions
sub = subparser.add_parser("audio.transcriptions.create")
# Required
sub.add_argument("-m", "--model", type=str, default="whisper-1")
sub.add_argument("-f", "--file", type=str, required=True)
# Optional
sub.add_argument("--response-format", type=str)
sub.add_argument("--language", type=str)
sub.add_argument("-t", "--temperature", type=float)
sub.add_argument("--prompt", type=str)
sub.set_defaults(func=CLIAudio.transcribe, args_model=CLITranscribeArgs)
# translations
sub = subparser.add_parser("audio.translations.create")
# Required
sub.add_argument("-f", "--file", type=str, required=True)
# Optional
sub.add_argument("-m", "--model", type=str, default="whisper-1")
sub.add_argument("--response-format", type=str)
# TODO: doesn't seem to be supported by the API
# sub.add_argument("--language", type=str)
sub.add_argument("-t", "--temperature", type=float)
sub.add_argument("--prompt", type=str)
sub.set_defaults(func=CLIAudio.translate, args_model=CLITranslationArgs)
class CLITranscribeArgs(BaseModel):
model: str
file: str
response_format: Optional[str] = None
language: Optional[str] = None
temperature: Optional[float] = None
prompt: Optional[str] = None
class CLITranslationArgs(BaseModel):
model: str
file: str
response_format: Optional[str] = None
language: Optional[str] = None
temperature: Optional[float] = None
prompt: Optional[str] = None
class CLIAudio:
@staticmethod
def transcribe(args: CLITranscribeArgs) -> None:
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
model = cast(
"Transcription | str",
get_client().audio.transcriptions.create(
file=(args.file, buffer_reader),
model=args.model,
language=args.language or omit,
temperature=args.temperature or omit,
prompt=args.prompt or omit,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
response_format=cast(Any, args.response_format),
),
)
if isinstance(model, str):
sys.stdout.write(model + "\n")
else:
print_model(model)
@staticmethod
def translate(args: CLITranslationArgs) -> None:
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
model = cast(
"Transcription | str",
get_client().audio.translations.create(
file=(args.file, buffer_reader),
model=args.model,
temperature=args.temperature or omit,
prompt=args.prompt or omit,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
response_format=cast(Any, args.response_format),
),
)
if isinstance(model, str):
sys.stdout.write(model + "\n")
else:
print_model(model)

View file

@ -0,0 +1,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from . import completions
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
completions.register(subparser)

View file

@ -0,0 +1,160 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, List, Optional, cast
from argparse import ArgumentParser
from typing_extensions import Literal, NamedTuple
from ..._utils import get_client
from ..._models import BaseModel
from ...._streaming import Stream
from ....types.chat import (
ChatCompletionRole,
ChatCompletionChunk,
CompletionCreateParams,
)
from ....types.chat.completion_create_params import (
CompletionCreateParamsStreaming,
CompletionCreateParamsNonStreaming,
)
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("chat.completions.create")
sub._action_groups.pop()
req = sub.add_argument_group("required arguments")
opt = sub.add_argument_group("optional arguments")
req.add_argument(
"-g",
"--message",
action="append",
nargs=2,
metavar=("ROLE", "CONTENT"),
help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
required=True,
)
req.add_argument(
"-m",
"--model",
help="The model to use.",
required=True,
)
opt.add_argument(
"-n",
"--n",
help="How many completions to generate for the conversation.",
type=int,
)
opt.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int)
opt.add_argument(
"-t",
"--temperature",
help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
Mutually exclusive with `top_p`.""",
type=float,
)
opt.add_argument(
"-P",
"--top_p",
help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
Mutually exclusive with `temperature`.""",
type=float,
)
opt.add_argument(
"--stop",
help="A stop sequence at which to stop generating tokens for the message.",
)
opt.add_argument("--stream", help="Stream messages as they're ready.", action="store_true")
sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)
class CLIMessage(NamedTuple):
role: ChatCompletionRole
content: str
class CLIChatCompletionCreateArgs(BaseModel):
message: List[CLIMessage]
model: str
n: Optional[int] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
stop: Optional[str] = None
stream: bool = False
class CLIChatCompletion:
@staticmethod
def create(args: CLIChatCompletionCreateArgs) -> None:
params: CompletionCreateParams = {
"model": args.model,
"messages": [
{"role": cast(Literal["user"], message.role), "content": message.content} for message in args.message
],
# type checkers are not good at inferring union types so we have to set stream afterwards
"stream": False,
}
if args.temperature is not None:
params["temperature"] = args.temperature
if args.stop is not None:
params["stop"] = args.stop
if args.top_p is not None:
params["top_p"] = args.top_p
if args.n is not None:
params["n"] = args.n
if args.stream:
params["stream"] = args.stream # type: ignore
if args.max_tokens is not None:
params["max_tokens"] = args.max_tokens
if args.stream:
return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))
return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))
@staticmethod
def _create(params: CompletionCreateParamsNonStreaming) -> None:
completion = get_client().chat.completions.create(**params)
should_print_header = len(completion.choices) > 1
for choice in completion.choices:
if should_print_header:
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
content = choice.message.content if choice.message.content is not None else "None"
sys.stdout.write(content)
if should_print_header or not content.endswith("\n"):
sys.stdout.write("\n")
sys.stdout.flush()
@staticmethod
def _stream_create(params: CompletionCreateParamsStreaming) -> None:
# cast is required for mypy
stream = cast( # pyright: ignore[reportUnnecessaryCast]
Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)
)
for chunk in stream:
should_print_header = len(chunk.choices) > 1
for choice in chunk.choices:
if should_print_header:
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
content = choice.delta.content or ""
sys.stdout.write(content)
if should_print_header:
sys.stdout.write("\n")
sys.stdout.flush()
sys.stdout.write("\n")

View file

@ -0,0 +1,173 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Optional, cast
from argparse import ArgumentParser
from functools import partial
from openai.types.completion import Completion
from .._utils import get_client
from ..._types import Omittable, omit
from ..._utils import is_given
from .._errors import CLIError
from .._models import BaseModel
from ..._streaming import Stream
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("completions.create")
# Required
sub.add_argument(
"-m",
"--model",
help="The model to use",
required=True,
)
# Optional
sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
sub.add_argument("--stream", help="Stream tokens as they're ready.", action="store_true")
sub.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate", type=int)
sub.add_argument(
"-t",
"--temperature",
help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
Mutually exclusive with `top_p`.""",
type=float,
)
sub.add_argument(
"-P",
"--top_p",
help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
Mutually exclusive with `temperature`.""",
type=float,
)
sub.add_argument(
"-n",
"--n",
help="How many sub-completions to generate for each prompt.",
type=int,
)
sub.add_argument(
"--logprobs",
help="Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
type=int,
)
sub.add_argument(
"--best_of",
help="Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.",
type=int,
)
sub.add_argument(
"--echo",
help="Echo back the prompt in addition to the completion",
action="store_true",
)
sub.add_argument(
"--frequency_penalty",
help="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
type=float,
)
sub.add_argument(
"--presence_penalty",
help="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
type=float,
)
sub.add_argument("--suffix", help="The suffix that comes after a completion of inserted text.")
sub.add_argument("--stop", help="A stop sequence at which to stop generating tokens.")
sub.add_argument(
"--user",
help="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.",
)
# TODO: add support for logit_bias
sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)
class CLICompletionCreateArgs(BaseModel):
model: str
stream: bool = False
prompt: Optional[str] = None
n: Omittable[int] = omit
stop: Omittable[str] = omit
user: Omittable[str] = omit
echo: Omittable[bool] = omit
suffix: Omittable[str] = omit
best_of: Omittable[int] = omit
top_p: Omittable[float] = omit
logprobs: Omittable[int] = omit
max_tokens: Omittable[int] = omit
temperature: Omittable[float] = omit
presence_penalty: Omittable[float] = omit
frequency_penalty: Omittable[float] = omit
class CLICompletions:
@staticmethod
def create(args: CLICompletionCreateArgs) -> None:
if is_given(args.n) and args.n > 1 and args.stream:
raise CLIError("Can't stream completions with n>1 with the current CLI")
make_request = partial(
get_client().completions.create,
n=args.n,
echo=args.echo,
stop=args.stop,
user=args.user,
model=args.model,
top_p=args.top_p,
prompt=args.prompt,
suffix=args.suffix,
best_of=args.best_of,
logprobs=args.logprobs,
max_tokens=args.max_tokens,
temperature=args.temperature,
presence_penalty=args.presence_penalty,
frequency_penalty=args.frequency_penalty,
)
if args.stream:
return CLICompletions._stream_create(
# mypy doesn't understand the `partial` function but pyright does
cast(Stream[Completion], make_request(stream=True)) # pyright: ignore[reportUnnecessaryCast]
)
return CLICompletions._create(make_request())
@staticmethod
def _create(completion: Completion) -> None:
should_print_header = len(completion.choices) > 1
for choice in completion.choices:
if should_print_header:
sys.stdout.write("===== Completion {} =====\n".format(choice.index))
sys.stdout.write(choice.text)
if should_print_header or not choice.text.endswith("\n"):
sys.stdout.write("\n")
sys.stdout.flush()
@staticmethod
def _stream_create(stream: Stream[Completion]) -> None:
for completion in stream:
should_print_header = len(completion.choices) > 1
for choice in sorted(completion.choices, key=lambda c: c.index):
if should_print_header:
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
sys.stdout.write(choice.text)
if should_print_header:
sys.stdout.write("\n")
sys.stdout.flush()
sys.stdout.write("\n")

View file

@ -0,0 +1,80 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
from argparse import ArgumentParser
from .._utils import get_client, print_model
from .._models import BaseModel
from .._progress import BufferReader
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("files.create")
sub.add_argument(
"-f",
"--file",
required=True,
help="File to upload",
)
sub.add_argument(
"-p",
"--purpose",
help="Why are you uploading this file? (see https://platform.openai.com/docs/api-reference/ for purposes)",
required=True,
)
sub.set_defaults(func=CLIFile.create, args_model=CLIFileCreateArgs)
sub = subparser.add_parser("files.retrieve")
sub.add_argument("-i", "--id", required=True, help="The files ID")
sub.set_defaults(func=CLIFile.get, args_model=CLIFileCreateArgs)
sub = subparser.add_parser("files.delete")
sub.add_argument("-i", "--id", required=True, help="The files ID")
sub.set_defaults(func=CLIFile.delete, args_model=CLIFileCreateArgs)
sub = subparser.add_parser("files.list")
sub.set_defaults(func=CLIFile.list)
class CLIFileIDArgs(BaseModel):
id: str
class CLIFileCreateArgs(BaseModel):
file: str
purpose: str
class CLIFile:
@staticmethod
def create(args: CLIFileCreateArgs) -> None:
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
file = get_client().files.create(
file=(args.file, buffer_reader),
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
purpose=cast(Any, args.purpose),
)
print_model(file)
@staticmethod
def get(args: CLIFileIDArgs) -> None:
file = get_client().files.retrieve(file_id=args.id)
print_model(file)
@staticmethod
def delete(args: CLIFileIDArgs) -> None:
file = get_client().files.delete(file_id=args.id)
print_model(file)
@staticmethod
def list() -> None:
files = get_client().files.list()
for file in files:
print_model(file)

View file

@ -0,0 +1,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from . import jobs
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
jobs.register(subparser)

View file

@ -0,0 +1,170 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING
from argparse import ArgumentParser
from ..._utils import get_client, print_model
from ...._types import Omittable, omit
from ...._utils import is_given
from ..._models import BaseModel
from ....pagination import SyncCursorPage
from ....types.fine_tuning import (
FineTuningJob,
FineTuningJobEvent,
)
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("fine_tuning.jobs.create")
sub.add_argument(
"-m",
"--model",
help="The model to fine-tune.",
required=True,
)
sub.add_argument(
"-F",
"--training-file",
help="The training file to fine-tune the model on.",
required=True,
)
sub.add_argument(
"-H",
"--hyperparameters",
help="JSON string of hyperparameters to use for fine-tuning.",
type=str,
)
sub.add_argument(
"-s",
"--suffix",
help="A suffix to add to the fine-tuned model name.",
)
sub.add_argument(
"-V",
"--validation-file",
help="The validation file to use for fine-tuning.",
)
sub.set_defaults(func=CLIFineTuningJobs.create, args_model=CLIFineTuningJobsCreateArgs)
sub = subparser.add_parser("fine_tuning.jobs.retrieve")
sub.add_argument(
"-i",
"--id",
help="The ID of the fine-tuning job to retrieve.",
required=True,
)
sub.set_defaults(func=CLIFineTuningJobs.retrieve, args_model=CLIFineTuningJobsRetrieveArgs)
sub = subparser.add_parser("fine_tuning.jobs.list")
sub.add_argument(
"-a",
"--after",
help="Identifier for the last job from the previous pagination request. If provided, only jobs created after this job will be returned.",
)
sub.add_argument(
"-l",
"--limit",
help="Number of fine-tuning jobs to retrieve.",
type=int,
)
sub.set_defaults(func=CLIFineTuningJobs.list, args_model=CLIFineTuningJobsListArgs)
sub = subparser.add_parser("fine_tuning.jobs.cancel")
sub.add_argument(
"-i",
"--id",
help="The ID of the fine-tuning job to cancel.",
required=True,
)
sub.set_defaults(func=CLIFineTuningJobs.cancel, args_model=CLIFineTuningJobsCancelArgs)
sub = subparser.add_parser("fine_tuning.jobs.list_events")
sub.add_argument(
"-i",
"--id",
help="The ID of the fine-tuning job to list events for.",
required=True,
)
sub.add_argument(
"-a",
"--after",
help="Identifier for the last event from the previous pagination request. If provided, only events created after this event will be returned.",
)
sub.add_argument(
"-l",
"--limit",
help="Number of fine-tuning job events to retrieve.",
type=int,
)
sub.set_defaults(func=CLIFineTuningJobs.list_events, args_model=CLIFineTuningJobsListEventsArgs)
class CLIFineTuningJobsCreateArgs(BaseModel):
model: str
training_file: str
hyperparameters: Omittable[str] = omit
suffix: Omittable[str] = omit
validation_file: Omittable[str] = omit
class CLIFineTuningJobsRetrieveArgs(BaseModel):
id: str
class CLIFineTuningJobsListArgs(BaseModel):
after: Omittable[str] = omit
limit: Omittable[int] = omit
class CLIFineTuningJobsCancelArgs(BaseModel):
id: str
class CLIFineTuningJobsListEventsArgs(BaseModel):
id: str
after: Omittable[str] = omit
limit: Omittable[int] = omit
class CLIFineTuningJobs:
@staticmethod
def create(args: CLIFineTuningJobsCreateArgs) -> None:
hyperparameters = json.loads(str(args.hyperparameters)) if is_given(args.hyperparameters) else omit
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.create(
model=args.model,
training_file=args.training_file,
hyperparameters=hyperparameters,
suffix=args.suffix,
validation_file=args.validation_file,
)
print_model(fine_tuning_job)
@staticmethod
def retrieve(args: CLIFineTuningJobsRetrieveArgs) -> None:
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.retrieve(fine_tuning_job_id=args.id)
print_model(fine_tuning_job)
@staticmethod
def list(args: CLIFineTuningJobsListArgs) -> None:
fine_tuning_jobs: SyncCursorPage[FineTuningJob] = get_client().fine_tuning.jobs.list(
after=args.after or omit, limit=args.limit or omit
)
print_model(fine_tuning_jobs)
@staticmethod
def cancel(args: CLIFineTuningJobsCancelArgs) -> None:
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.cancel(fine_tuning_job_id=args.id)
print_model(fine_tuning_job)
@staticmethod
def list_events(args: CLIFineTuningJobsListEventsArgs) -> None:
fine_tuning_job_events: SyncCursorPage[FineTuningJobEvent] = get_client().fine_tuning.jobs.list_events(
fine_tuning_job_id=args.id,
after=args.after or omit,
limit=args.limit or omit,
)
print_model(fine_tuning_job_events)

View file

@ -0,0 +1,139 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
from argparse import ArgumentParser
from .._utils import get_client, print_model
from ..._types import Omit, Omittable, omit
from .._models import BaseModel
from .._progress import BufferReader
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("images.generate")
sub.add_argument("-m", "--model", type=str)
sub.add_argument("-p", "--prompt", type=str, required=True)
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
sub.add_argument("--response-format", type=str, default="url")
sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)
sub = subparser.add_parser("images.edit")
sub.add_argument("-m", "--model", type=str)
sub.add_argument("-p", "--prompt", type=str, required=True)
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument(
"-I",
"--image",
type=str,
required=True,
help="Image to modify. Should be a local path and a PNG encoded image.",
)
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
sub.add_argument("--response-format", type=str, default="url")
sub.add_argument(
"-M",
"--mask",
type=str,
required=False,
help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
)
sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)
sub = subparser.add_parser("images.create_variation")
sub.add_argument("-m", "--model", type=str)
sub.add_argument("-n", "--num-images", type=int, default=1)
sub.add_argument(
"-I",
"--image",
type=str,
required=True,
help="Image to modify. Should be a local path and a PNG encoded image.",
)
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
sub.add_argument("--response-format", type=str, default="url")
sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)
class CLIImageCreateArgs(BaseModel):
prompt: str
num_images: int
size: str
response_format: str
model: Omittable[str] = omit
class CLIImageCreateVariationArgs(BaseModel):
image: str
num_images: int
size: str
response_format: str
model: Omittable[str] = omit
class CLIImageEditArgs(BaseModel):
image: str
num_images: int
size: str
response_format: str
prompt: str
mask: Omittable[str] = omit
model: Omittable[str] = omit
class CLIImage:
@staticmethod
def create(args: CLIImageCreateArgs) -> None:
image = get_client().images.generate(
model=args.model,
prompt=args.prompt,
n=args.num_images,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
size=cast(Any, args.size),
response_format=cast(Any, args.response_format),
)
print_model(image)
@staticmethod
def create_variation(args: CLIImageCreateVariationArgs) -> None:
with open(args.image, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
image = get_client().images.create_variation(
model=args.model,
image=("image", buffer_reader),
n=args.num_images,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
size=cast(Any, args.size),
response_format=cast(Any, args.response_format),
)
print_model(image)
@staticmethod
def edit(args: CLIImageEditArgs) -> None:
with open(args.image, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")
if isinstance(args.mask, Omit):
mask: Omittable[BufferReader] = omit
else:
with open(args.mask, "rb") as file_reader:
mask = BufferReader(file_reader.read(), desc="Mask progress")
image = get_client().images.edit(
model=args.model,
prompt=args.prompt,
image=("image", buffer_reader),
n=args.num_images,
mask=("mask", mask) if not isinstance(mask, Omit) else mask,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
size=cast(Any, args.size),
response_format=cast(Any, args.response_format),
)
print_model(image)

Some files were not shown because too many files have changed in this diff Show more