コンテンツにスキップ

SQLAlchemySession

Bases: SessionABC

SQLAlchemy implementation of :pyclass:agents.memory.session.Session.

Source code in src/agents/extensions/memory/sqlalchemy_session.py
class SQLAlchemySession(SessionABC):
    """SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`."""

    _metadata: MetaData
    _sessions: Table
    _messages: Table

    def __init__(
        self,
        session_id: str,
        *,
        engine: AsyncEngine,
        create_tables: bool = False,
        sessions_table: str = "agent_sessions",
        messages_table: str = "agent_messages",
    ):
        """Initializes a new SQLAlchemySession.

        Args:
            session_id (str): Unique identifier for the conversation.
            engine (AsyncEngine): A pre-configured SQLAlchemy async engine. The engine
                must be created with an async driver (e.g., 'postgresql+asyncpg://',
                'mysql+aiomysql://', or 'sqlite+aiosqlite://').
            create_tables (bool, optional): Whether to automatically create the required
                tables and indexes. Defaults to False for production use. Set to True for
                development and testing when migrations aren't used.
            sessions_table (str, optional): Override the default table name for sessions if needed.
            messages_table (str, optional): Override the default table name for messages if needed.
        """
        self.session_id = session_id
        self._engine = engine
        self._lock = asyncio.Lock()

        self._metadata = MetaData()
        self._sessions = Table(
            sessions_table,
            self._metadata,
            Column("session_id", String, primary_key=True),
            Column(
                "created_at",
                TIMESTAMP(timezone=False),
                server_default=sql_text("CURRENT_TIMESTAMP"),
                nullable=False,
            ),
            Column(
                "updated_at",
                TIMESTAMP(timezone=False),
                server_default=sql_text("CURRENT_TIMESTAMP"),
                onupdate=sql_text("CURRENT_TIMESTAMP"),
                nullable=False,
            ),
        )

        self._messages = Table(
            messages_table,
            self._metadata,
            Column("id", Integer, primary_key=True, autoincrement=True),
            Column(
                "session_id",
                String,
                ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
                nullable=False,
            ),
            Column("message_data", Text, nullable=False),
            Column(
                "created_at",
                TIMESTAMP(timezone=False),
                server_default=sql_text("CURRENT_TIMESTAMP"),
                nullable=False,
            ),
            Index(
                f"idx_{messages_table}_session_time",
                "session_id",
                "created_at",
            ),
            sqlite_autoincrement=True,
        )

        # Async session factory
        self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)

        self._create_tables = create_tables

    # ---------------------------------------------------------------------
    # Convenience constructors
    # ---------------------------------------------------------------------
    @classmethod
    def from_url(
        cls,
        session_id: str,
        *,
        url: str,
        engine_kwargs: dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> SQLAlchemySession:
        """Create a session from a database URL string.

        Args:
            session_id (str): Conversation ID.
            url (str): Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db".
            engine_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
                sqlalchemy.ext.asyncio.create_async_engine.
            **kwargs: Additional keyword arguments forwarded to the main constructor
                (e.g., create_tables, custom table names, etc.).

        Returns:
            SQLAlchemySession: An instance of SQLAlchemySession connected to the specified database.
        """
        engine_kwargs = engine_kwargs or {}
        engine = create_async_engine(url, **engine_kwargs)
        return cls(session_id, engine=engine, **kwargs)

    async def _serialize_item(self, item: TResponseInputItem) -> str:
        """Serialize an item to JSON string. Can be overridden by subclasses."""
        return json.dumps(item, separators=(",", ":"))

    async def _deserialize_item(self, item: str) -> TResponseInputItem:
        """Deserialize a JSON string to an item. Can be overridden by subclasses."""
        return json.loads(item)  # type: ignore[no-any-return]

    # ------------------------------------------------------------------
    # Session protocol implementation
    # ------------------------------------------------------------------
    async def _ensure_tables(self) -> None:
        """Ensure tables are created before any database operations."""
        if self._create_tables:
            async with self._engine.begin() as conn:
                await conn.run_sync(self._metadata.create_all)
            self._create_tables = False  # Only create once

    async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
        """Retrieve the conversation history for this session.

        Args:
            limit: Maximum number of items to retrieve. If None, retrieves all items.
                   When specified, returns the latest N items in chronological order.

        Returns:
            List of input items representing the conversation history
        """
        await self._ensure_tables()
        async with self._session_factory() as sess:
            if limit is None:
                stmt = (
                    select(self._messages.c.message_data)
                    .where(self._messages.c.session_id == self.session_id)
                    .order_by(self._messages.c.created_at.asc())
                )
            else:
                stmt = (
                    select(self._messages.c.message_data)
                    .where(self._messages.c.session_id == self.session_id)
                    # Use DESC + LIMIT to get the latest N
                    # then reverse later for chronological order.
                    .order_by(self._messages.c.created_at.desc())
                    .limit(limit)
                )

            result = await sess.execute(stmt)
            rows: list[str] = [row[0] for row in result.all()]

            if limit is not None:
                rows.reverse()

            items: list[TResponseInputItem] = []
            for raw in rows:
                try:
                    items.append(await self._deserialize_item(raw))
                except json.JSONDecodeError:
                    # Skip corrupted rows
                    continue
            return items

    async def add_items(self, items: list[TResponseInputItem]) -> None:
        """Add new items to the conversation history.

        Args:
            items: List of input items to add to the history
        """
        if not items:
            return

        await self._ensure_tables()
        payload = [
            {
                "session_id": self.session_id,
                "message_data": await self._serialize_item(item),
            }
            for item in items
        ]

        async with self._session_factory() as sess:
            async with sess.begin():
                # Ensure the parent session row exists - use merge for cross-DB compatibility
                # Check if session exists
                existing = await sess.execute(
                    select(self._sessions.c.session_id).where(
                        self._sessions.c.session_id == self.session_id
                    )
                )
                if not existing.scalar_one_or_none():
                    # Session doesn't exist, create it
                    await sess.execute(
                        insert(self._sessions).values({"session_id": self.session_id})
                    )

                # Insert messages in bulk
                await sess.execute(insert(self._messages), payload)

                # Touch updated_at column
                await sess.execute(
                    update(self._sessions)
                    .where(self._sessions.c.session_id == self.session_id)
                    .values(updated_at=sql_text("CURRENT_TIMESTAMP"))
                )

    async def pop_item(self) -> TResponseInputItem | None:
        """Remove and return the most recent item from the session.

        Returns:
            The most recent item if it exists, None if the session is empty
        """
        await self._ensure_tables()
        async with self._session_factory() as sess:
            async with sess.begin():
                # Fallback for all dialects - get ID first, then delete
                subq = (
                    select(self._messages.c.id)
                    .where(self._messages.c.session_id == self.session_id)
                    .order_by(self._messages.c.created_at.desc())
                    .limit(1)
                )
                res = await sess.execute(subq)
                row_id = res.scalar_one_or_none()
                if row_id is None:
                    return None
                # Fetch data before deleting
                res_data = await sess.execute(
                    select(self._messages.c.message_data).where(self._messages.c.id == row_id)
                )
                row = res_data.scalar_one_or_none()
                await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))

                if row is None:
                    return None
                try:
                    return await self._deserialize_item(row)
                except json.JSONDecodeError:
                    return None

    async def clear_session(self) -> None:
        """Clear all items for this session."""
        await self._ensure_tables()
        async with self._session_factory() as sess:
            async with sess.begin():
                await sess.execute(
                    delete(self._messages).where(self._messages.c.session_id == self.session_id)
                )
                await sess.execute(
                    delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
                )

__init__

__init__(
    session_id: str,
    *,
    engine: AsyncEngine,
    create_tables: bool = False,
    sessions_table: str = "agent_sessions",
    messages_table: str = "agent_messages",
)

Initializes a new SQLAlchemySession.

Parameters:

Name Type Description Default
session_id str

Unique identifier for the conversation.

required
engine AsyncEngine

A pre-configured SQLAlchemy async engine. The engine must be created with an async driver (e.g., 'postgresql+asyncpg://', 'mysql+aiomysql://', or 'sqlite+aiosqlite://').

required
create_tables bool

Whether to automatically create the required tables and indexes. Defaults to False for production use. Set to True for development and testing when migrations aren't used.

False
sessions_table str

Override the default table name for sessions if needed.

'agent_sessions'
messages_table str

Override the default table name for messages if needed.

'agent_messages'
Source code in src/agents/extensions/memory/sqlalchemy_session.py
def __init__(
    self,
    session_id: str,
    *,
    engine: AsyncEngine,
    create_tables: bool = False,
    sessions_table: str = "agent_sessions",
    messages_table: str = "agent_messages",
):
    """Initializes a new SQLAlchemySession.

    Args:
        session_id (str): Unique identifier for the conversation.
        engine (AsyncEngine): A pre-configured SQLAlchemy async engine. The engine
            must be created with an async driver (e.g., 'postgresql+asyncpg://',
            'mysql+aiomysql://', or 'sqlite+aiosqlite://').
        create_tables (bool, optional): Whether to automatically create the required
            tables and indexes. Defaults to False for production use. Set to True for
            development and testing when migrations aren't used.
        sessions_table (str, optional): Override the default table name for sessions if needed.
        messages_table (str, optional): Override the default table name for messages if needed.
    """
    self.session_id = session_id
    self._engine = engine
    self._lock = asyncio.Lock()

    self._metadata = MetaData()
    self._sessions = Table(
        sessions_table,
        self._metadata,
        Column("session_id", String, primary_key=True),
        Column(
            "created_at",
            TIMESTAMP(timezone=False),
            server_default=sql_text("CURRENT_TIMESTAMP"),
            nullable=False,
        ),
        Column(
            "updated_at",
            TIMESTAMP(timezone=False),
            server_default=sql_text("CURRENT_TIMESTAMP"),
            onupdate=sql_text("CURRENT_TIMESTAMP"),
            nullable=False,
        ),
    )

    self._messages = Table(
        messages_table,
        self._metadata,
        Column("id", Integer, primary_key=True, autoincrement=True),
        Column(
            "session_id",
            String,
            ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
            nullable=False,
        ),
        Column("message_data", Text, nullable=False),
        Column(
            "created_at",
            TIMESTAMP(timezone=False),
            server_default=sql_text("CURRENT_TIMESTAMP"),
            nullable=False,
        ),
        Index(
            f"idx_{messages_table}_session_time",
            "session_id",
            "created_at",
        ),
        sqlite_autoincrement=True,
    )

    # Async session factory
    self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)

    self._create_tables = create_tables

from_url classmethod

from_url(
    session_id: str,
    *,
    url: str,
    engine_kwargs: dict[str, Any] | None = None,
    **kwargs: Any,
) -> SQLAlchemySession

Create a session from a database URL string.

Parameters:

Name Type Description Default
session_id str

Conversation ID.

required
url str

Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db".

required
engine_kwargs dict[str, Any] | None

Additional keyword arguments forwarded to sqlalchemy.ext.asyncio.create_async_engine.

None
**kwargs Any

Additional keyword arguments forwarded to the main constructor (e.g., create_tables, custom table names, etc.).

{}

Returns:

Name Type Description
SQLAlchemySession SQLAlchemySession

An instance of SQLAlchemySession connected to the specified database.

Source code in src/agents/extensions/memory/sqlalchemy_session.py
@classmethod
def from_url(
    cls,
    session_id: str,
    *,
    url: str,
    engine_kwargs: dict[str, Any] | None = None,
    **kwargs: Any,
) -> SQLAlchemySession:
    """Create a session from a database URL string.

    Args:
        session_id (str): Conversation ID.
        url (str): Any SQLAlchemy async URL, e.g. "postgresql+asyncpg://user:pass@host/db".
        engine_kwargs (dict[str, Any] | None): Additional keyword arguments forwarded to
            sqlalchemy.ext.asyncio.create_async_engine.
        **kwargs: Additional keyword arguments forwarded to the main constructor
            (e.g., create_tables, custom table names, etc.).

    Returns:
        SQLAlchemySession: An instance of SQLAlchemySession connected to the specified database.
    """
    engine_kwargs = engine_kwargs or {}
    engine = create_async_engine(url, **engine_kwargs)
    return cls(session_id, engine=engine, **kwargs)

get_items async

get_items(
    limit: int | None = None,
) -> list[TResponseInputItem]

Retrieve the conversation history for this session.

Parameters:

Name Type Description Default
limit int | None

Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order.

None

Returns:

Type Description
list[TResponseInputItem]

List of input items representing the conversation history

Source code in src/agents/extensions/memory/sqlalchemy_session.py
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
    """Retrieve the conversation history for this session.

    Args:
        limit: Maximum number of items to retrieve. If None, retrieves all items.
               When specified, returns the latest N items in chronological order.

    Returns:
        List of input items representing the conversation history
    """
    await self._ensure_tables()
    async with self._session_factory() as sess:
        if limit is None:
            stmt = (
                select(self._messages.c.message_data)
                .where(self._messages.c.session_id == self.session_id)
                .order_by(self._messages.c.created_at.asc())
            )
        else:
            stmt = (
                select(self._messages.c.message_data)
                .where(self._messages.c.session_id == self.session_id)
                # Use DESC + LIMIT to get the latest N
                # then reverse later for chronological order.
                .order_by(self._messages.c.created_at.desc())
                .limit(limit)
            )

        result = await sess.execute(stmt)
        rows: list[str] = [row[0] for row in result.all()]

        if limit is not None:
            rows.reverse()

        items: list[TResponseInputItem] = []
        for raw in rows:
            try:
                items.append(await self._deserialize_item(raw))
            except json.JSONDecodeError:
                # Skip corrupted rows
                continue
        return items

add_items async

add_items(items: list[TResponseInputItem]) -> None

Add new items to the conversation history.

Parameters:

Name Type Description Default
items list[TResponseInputItem]

List of input items to add to the history

required
Source code in src/agents/extensions/memory/sqlalchemy_session.py
async def add_items(self, items: list[TResponseInputItem]) -> None:
    """Add new items to the conversation history.

    Args:
        items: List of input items to add to the history
    """
    if not items:
        return

    await self._ensure_tables()
    payload = [
        {
            "session_id": self.session_id,
            "message_data": await self._serialize_item(item),
        }
        for item in items
    ]

    async with self._session_factory() as sess:
        async with sess.begin():
            # Ensure the parent session row exists - use merge for cross-DB compatibility
            # Check if session exists
            existing = await sess.execute(
                select(self._sessions.c.session_id).where(
                    self._sessions.c.session_id == self.session_id
                )
            )
            if not existing.scalar_one_or_none():
                # Session doesn't exist, create it
                await sess.execute(
                    insert(self._sessions).values({"session_id": self.session_id})
                )

            # Insert messages in bulk
            await sess.execute(insert(self._messages), payload)

            # Touch updated_at column
            await sess.execute(
                update(self._sessions)
                .where(self._sessions.c.session_id == self.session_id)
                .values(updated_at=sql_text("CURRENT_TIMESTAMP"))
            )

pop_item async

pop_item() -> TResponseInputItem | None

Remove and return the most recent item from the session.

Returns:

Type Description
TResponseInputItem | None

The most recent item if it exists, None if the session is empty

Source code in src/agents/extensions/memory/sqlalchemy_session.py
async def pop_item(self) -> TResponseInputItem | None:
    """Remove and return the most recent item from the session.

    Returns:
        The most recent item if it exists, None if the session is empty
    """
    await self._ensure_tables()
    async with self._session_factory() as sess:
        async with sess.begin():
            # Fallback for all dialects - get ID first, then delete
            subq = (
                select(self._messages.c.id)
                .where(self._messages.c.session_id == self.session_id)
                .order_by(self._messages.c.created_at.desc())
                .limit(1)
            )
            res = await sess.execute(subq)
            row_id = res.scalar_one_or_none()
            if row_id is None:
                return None
            # Fetch data before deleting
            res_data = await sess.execute(
                select(self._messages.c.message_data).where(self._messages.c.id == row_id)
            )
            row = res_data.scalar_one_or_none()
            await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))

            if row is None:
                return None
            try:
                return await self._deserialize_item(row)
            except json.JSONDecodeError:
                return None

clear_session async

clear_session() -> None

Clear all items for this session.

Source code in src/agents/extensions/memory/sqlalchemy_session.py
async def clear_session(self) -> None:
    """Clear all items for this session."""
    await self._ensure_tables()
    async with self._session_factory() as sess:
        async with sess.begin():
            await sess.execute(
                delete(self._messages).where(self._messages.c.session_id == self.session_id)
            )
            await sess.execute(
                delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
            )