"""Asyncmy database configuration."""
import inspect
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from weakref import WeakSet
import asyncmy
from mypy_extensions import mypyc_attr
from typing_extensions import NotRequired
from sqlspec.adapters.asyncmy._typing import (
AsyncmyConnection,
AsyncmyCursor,
AsyncmyDictCursor,
AsyncmyPool,
AsyncmyRawCursor,
AsyncmySessionContext,
)
from sqlspec.adapters.asyncmy.core import apply_driver_features, default_statement_config
from sqlspec.adapters.asyncmy.driver import AsyncmyDriver, AsyncmyExceptionHandler
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs
from sqlspec.driver._async import AsyncPoolConnectionContext, AsyncPoolSessionFactory
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config
if TYPE_CHECKING:
import ssl
from collections.abc import Awaitable, Callable, Mapping
from types import TracebackType
from sqlspec.core import StatementConfig
from sqlspec.observability import ObservabilityConfig
__all__ = ("AsyncmyConfig", "AsyncmyConnectionParams", "AsyncmyDriverFeatures", "AsyncmyPoolParams", "AsyncmySSLParams")
_ASYNCMY_POOL_ONLY_KEYS = frozenset(("minsize", "maxsize", "pool_recycle"))
_ASYNCMY_POOL_KEYS = _ASYNCMY_POOL_ONLY_KEYS | {"echo"}
_ASYNCMY_LOCAL_INFILE_GATE = "allow_local_infile"
def _get_asyncmy_connect_parameter_names() -> "frozenset[str]":
try:
return frozenset(inspect.signature(asyncmy.connect).parameters)
except (TypeError, ValueError):
return frozenset()
_ASYNCMY_CONNECT_PARAMETER_NAMES = _get_asyncmy_connect_parameter_names()
class AsyncmySSLParams(TypedDict):
"""Asyncmy TLS parameters."""
ca: NotRequired[str]
capath: NotRequired[str]
cert: NotRequired[str]
key: NotRequired[str]
cipher: NotRequired[str]
check_hostname: NotRequired[bool]
verify_mode: NotRequired[bool | int | str]
class AsyncmyConnectionParams(TypedDict):
"""Asyncmy connection parameters."""
host: NotRequired[str]
user: NotRequired[str]
password: NotRequired[str]
database: NotRequired[str]
db: NotRequired[str]
port: NotRequired[int]
unix_socket: NotRequired[str]
charset: NotRequired[str]
connect_timeout: NotRequired[int | float]
read_default_file: NotRequired[str]
read_default_group: NotRequired[str]
autocommit: NotRequired[bool]
allow_local_infile: NotRequired[bool]
local_infile: NotRequired[bool]
ssl: NotRequired["AsyncmySSLParams | ssl.SSLContext | dict[str, Any]"]
sql_mode: NotRequired[str]
init_command: NotRequired[str]
auth_plugin_map: NotRequired["dict[str | bytes, type[Any]]"]
binary_prefix: NotRequired[bool]
client_flag: NotRequired[int]
conv: NotRequired["dict[Any, Any]"]
cursor_class: NotRequired[type["AsyncmyRawCursor"] | type["AsyncmyDictCursor"]]
cursor_cls: NotRequired[type["AsyncmyRawCursor"] | type["AsyncmyDictCursor"]]
max_allowed_packet: NotRequired[int]
program_name: NotRequired[str]
read_timeout: NotRequired[int | float]
server_public_key: NotRequired[str | bytes]
use_unicode: NotRequired[bool]
write_timeout: NotRequired[int | float]
extra: NotRequired["dict[str, Any]"]
class AsyncmyPoolParams(AsyncmyConnectionParams):
"""Asyncmy pool parameters."""
minsize: NotRequired[int]
maxsize: NotRequired[int]
echo: NotRequired[bool]
pool_recycle: NotRequired[int]
def _normalize_asyncmy_connection_config(connection_config: "Mapping[str, Any] | None") -> "dict[str, Any]":
"""Normalize SQLSpec asyncmy config keys before storing them."""
config = normalize_connection_config(connection_config)
if "cursor_class" in config:
cursor_class = config.pop("cursor_class")
existing_cursor_cls = config.get("cursor_cls")
if existing_cursor_cls is not None and existing_cursor_cls is not cursor_class:
msg = "Asyncmy connection_config received conflicting 'cursor_cls' and legacy 'cursor_class' values."
raise ImproperConfigurationError(msg)
config["cursor_cls"] = cursor_class
allow_local_infile = bool(config.pop(_ASYNCMY_LOCAL_INFILE_GATE, False))
local_infile = bool(config.get("local_infile", False))
if local_infile and not allow_local_infile:
msg = "Asyncmy local_infile=True requires allow_local_infile=True because LOAD DATA LOCAL INFILE can read client files."
raise ImproperConfigurationError(msg)
config["local_infile"] = bool(local_infile and allow_local_infile)
return config
def _split_asyncmy_pool_config(connection_config: "Mapping[str, Any]") -> "tuple[dict[str, Any], dict[str, Any]]":
"""Split pool constructor settings from connection settings."""
pool_kwargs: dict[str, Any] = {}
connection_kwargs: dict[str, Any] = {}
for key, value in connection_config.items():
if value is None:
continue
if key in _ASYNCMY_POOL_KEYS:
pool_kwargs[key] = value
continue
if key == "write_timeout" and key not in _ASYNCMY_CONNECT_PARAMETER_NAMES:
continue
connection_kwargs[key] = value
return pool_kwargs, connection_kwargs
def _build_asyncmy_pool_config(connection_config: "Mapping[str, Any]") -> "dict[str, Any]":
pool_kwargs, connection_kwargs = _split_asyncmy_pool_config(connection_config)
return {**connection_kwargs, **pool_kwargs}
class AsyncmyDriverFeatures(TypedDict):
"""Asyncmy driver feature flags.
MySQL/MariaDB handle JSON natively, but custom serializers can be provided
for specialized use cases.
json_serializer: Custom JSON serializer function.
Defaults to sqlspec.utils.serializers.to_json.
Use for performance (orjson) or custom encoding.
json_deserializer: Custom JSON deserializer function.
Defaults to sqlspec.utils.serializers.from_json.
Use for performance (orjson) or custom decoding.
on_connection_create: Async callback executed when a connection is acquired from pool.
Receives the raw asyncmy connection for low-level driver configuration.
Called exactly once per physical connection using WeakSet tracking.
enable_events: Enable database event channel support.
Defaults to True when extension_config["events"] is configured.
Provides pub/sub capabilities via table-backed queue (MySQL/MariaDB have no native pub/sub).
Requires extension_config["events"] for migration setup.
events_backend: Event channel backend selection.
Only option: "table_queue" (durable table-backed queue with retries and exactly-once delivery).
MySQL/MariaDB do not have native pub/sub, so table_queue is the only backend.
Defaults to "table_queue".
"""
json_serializer: NotRequired["Callable[[Any], str]"]
json_deserializer: NotRequired["Callable[[str], Any]"]
on_connection_create: "NotRequired[Callable[[AsyncmyConnection], Awaitable[None]]]"
enable_events: NotRequired[bool]
events_backend: NotRequired[str]
class _AsyncmySessionFactory(AsyncPoolSessionFactory):
__slots__ = ("_ctx",)
def __init__(self, config: "AsyncmyConfig") -> None:
super().__init__(config)
self._ctx: Any | None = None
async def acquire_connection(self) -> "AsyncmyConnection":
pool = self._config.connection_instance
if pool is None:
pool = await self._config.create_pool()
self._config.connection_instance = pool
ctx = pool.acquire()
self._ctx = ctx
connection = cast("AsyncmyConnection", await ctx.__aenter__())
await self._config._ensure_connection_initialized(connection) # pyright: ignore[reportPrivateUsage]
return connection
async def release_connection(self, _conn: "AsyncmyConnection", **kwargs: Any) -> None:
if self._ctx is not None:
await self._ctx.__aexit__(None, None, None)
self._ctx = None
class AsyncmyConnectionContext(AsyncPoolConnectionContext):
"""Async context manager for Asyncmy connections."""
__slots__ = ("_ctx",)
def __init__(self, config: "AsyncmyConfig") -> None:
super().__init__(config)
self._ctx: Any = None
async def __aenter__(self) -> AsyncmyConnection:
pool = self._config.connection_instance
if pool is None:
pool = await self._config.create_pool()
self._config.connection_instance = pool
ctx = pool.acquire()
self._ctx = ctx
connection = cast("AsyncmyConnection", await ctx.__aenter__())
await self._config._ensure_connection_initialized(connection) # pyright: ignore[reportPrivateUsage]
return connection
async def __aexit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None"
) -> bool | None:
if self._ctx:
return cast("bool | None", await self._ctx.__aexit__(exc_type, exc_val, exc_tb))
return None
[docs]
@mypyc_attr(native_class=False)
class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", AsyncmyDriver]): # pyright: ignore
"""Configuration for Asyncmy database connections."""
driver_type: ClassVar[type[AsyncmyDriver]] = AsyncmyDriver
connection_type: "ClassVar[type[Any]]" = cast("type[Any]", AsyncmyConnection)
supports_transactional_ddl: ClassVar[bool] = False
supports_native_arrow_export: ClassVar[bool] = True
supports_native_parquet_export: ClassVar[bool] = True
supports_native_arrow_import: ClassVar[bool] = True
supports_native_parquet_import: ClassVar[bool] = True
supports_native_row_streaming: ClassVar[bool] = True
_connection_context_class: "ClassVar[type[AsyncmyConnectionContext]]" = AsyncmyConnectionContext
_session_factory_class: "ClassVar[type[_AsyncmySessionFactory]]" = _AsyncmySessionFactory
_session_context_class: "ClassVar[type[AsyncmySessionContext]]" = AsyncmySessionContext
_default_statement_config = default_statement_config
[docs]
def __init__(
self,
*,
connection_config: "AsyncmyPoolParams | dict[str, Any] | None" = None,
connection_instance: "AsyncmyPool | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "AsyncmyDriverFeatures | dict[str, Any] | None" = None,
bind_key: "str | None" = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
**kwargs: Any,
) -> None:
"""Initialize Asyncmy configuration.
Args:
connection_config: Connection and pool configuration parameters
connection_instance: Existing pool instance to use
migration_config: Migration configuration
statement_config: Statement configuration override
driver_features: Driver feature configuration (TypedDict or dict)
bind_key: Optional unique identifier for this configuration
extension_config: Extension-specific configuration
observability_config: Adapter-level observability overrides for lifecycle hooks and observers
**kwargs: Additional keyword arguments
"""
connection_config = _normalize_asyncmy_connection_config(connection_config)
connection_config.setdefault("host", "localhost")
connection_config.setdefault("port", 3306)
statement_config = statement_config or default_statement_config
statement_config, driver_features = apply_driver_features(statement_config, driver_features)
# Extract user connection hook before storing driver_features
features_dict = dict(driver_features) if driver_features else {}
self._user_connection_hook: Callable[[AsyncmyConnection], Awaitable[None]] | None = features_dict.pop(
"on_connection_create", None
)
# Track initialized connections to ensure callback runs exactly once per physical connection
self._initialized_connections: WeakSet[Any] = WeakSet()
super().__init__(
connection_config=connection_config,
connection_instance=connection_instance,
migration_config=migration_config,
statement_config=statement_config,
driver_features=features_dict,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
**kwargs,
)
async def _create_pool(self) -> "AsyncmyPool":
"""Create the actual async connection pool.
MySQL/MariaDB handle JSON types natively without requiring connection-level
type handlers. JSON serialization is handled via type_coercion_map in the
driver's statement_config (see driver.py).
Future driver_features can be added here if needed.
"""
return cast("AsyncmyPool", await asyncmy.create_pool(**_build_asyncmy_pool_config(self.connection_config)))
async def _ensure_connection_initialized(self, connection: "AsyncmyConnection") -> None:
"""Ensure connection callback has been called exactly once for this connection.
Uses WeakSet tracking to ensure the callback runs once per physical connection.
"""
if self._user_connection_hook is None:
return
if connection not in self._initialized_connections:
await self._user_connection_hook(connection)
self._initialized_connections.add(connection)
async def _close_pool(self) -> None:
"""Close the actual async connection pool."""
if self.connection_instance:
self.connection_instance.close()
await self.connection_instance.wait_closed()
self.connection_instance = None
[docs]
async def create_connection(self) -> AsyncmyConnection:
"""Create a single async connection (not from pool).
Returns:
An Asyncmy connection instance.
"""
pool = self.connection_instance
if pool is None:
pool = await self.create_pool()
self.connection_instance = pool
connection = cast("AsyncmyConnection", await pool.acquire())
await self._ensure_connection_initialized(connection)
return connection
[docs]
async def provide_pool(self, *args: Any, **kwargs: Any) -> "AsyncmyPool":
"""Provide async pool instance.
Returns:
The async connection pool.
"""
if not self.connection_instance:
self.connection_instance = await self.create_pool()
return self.connection_instance
[docs]
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for Asyncmy types.
Returns:
Dictionary mapping type names to types.
"""
namespace = super().get_signature_namespace()
namespace.update({
"AsyncmyConnectionContext": AsyncmyConnectionContext,
"AsyncmyConnection": AsyncmyConnection,
"AsyncmyConnectionParams": AsyncmyConnectionParams,
"AsyncmyCursor": AsyncmyCursor,
"AsyncmyDictCursor": AsyncmyDictCursor,
"AsyncmyDriver": AsyncmyDriver,
"AsyncmyDriverFeatures": AsyncmyDriverFeatures,
"AsyncmyExceptionHandler": AsyncmyExceptionHandler,
"AsyncmyPool": AsyncmyPool,
"AsyncmyPoolParams": AsyncmyPoolParams,
"AsyncmyRawCursor": AsyncmyRawCursor,
"AsyncmySessionContext": AsyncmySessionContext,
})
return namespace
[docs]
def get_event_runtime_hints(self) -> "EventRuntimeHints":
"""Return queue polling defaults for Asyncmy adapters."""
return EventRuntimeHints(poll_interval=0.25, lease_seconds=5, select_for_update=True, skip_locked=True)