Skip to content

refactor: align types #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import json
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any, List, Optional, Tuple, cast
from typing import Any, Dict, Iterator, List, Optional, Tuple, cast

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Expand Down Expand Up @@ -42,19 +41,21 @@ def __init__(
redis_url: Optional[str] = None,
*,
redis_client: Optional[Redis] = None,
connection_args: Optional[dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
)

def configure_client(
self,
redis_url: Optional[str] = None,
redis_client: Optional[Redis] = None,
connection_args: Optional[dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Configure the Redis client."""
self._owns_its_client = redis_client is None
Expand Down Expand Up @@ -395,7 +396,8 @@ def from_conn_string(
redis_url: Optional[str] = None,
*,
redis_client: Optional[Redis] = None,
connection_args: Optional[dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
) -> Iterator[RedisSaver]:
"""Create a new RedisSaver instance."""
saver: Optional[RedisSaver] = None
Expand All @@ -404,6 +406,7 @@ def from_conn_string(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
)

yield saver
Expand All @@ -414,7 +417,7 @@ def from_conn_string(

def get_channel_values(
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""Retrieve channel_values dictionary with properly constructed message objects."""
storage_safe_thread_id = to_storage_safe_id(thread_id)
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
Expand Down
19 changes: 11 additions & 8 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import asyncio
import json
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import partial
from types import TracebackType
from typing import Any, List, Optional, Sequence, Tuple, Type, cast
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Type, cast

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Expand Down Expand Up @@ -42,7 +41,7 @@
async def _write_obj_tx(
pipe: Pipeline,
key: str,
write_obj: dict[str, Any],
write_obj: Dict[str, Any],
upsert_case: bool,
) -> None:
exists: int = await pipe.exists(key)
Expand Down Expand Up @@ -73,20 +72,22 @@ def __init__(
redis_url: Optional[str] = None,
*,
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
)
self.loop = asyncio.get_running_loop()

def configure_client(
self,
redis_url: Optional[str] = None,
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Configure the Redis client."""
self._owns_its_client = redis_client is None
Expand Down Expand Up @@ -706,18 +707,20 @@ async def from_conn_string(
redis_url: Optional[str] = None,
*,
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[AsyncRedisSaver]:
async with cls(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
) as saver:
yield saver

async def aget_channel_values(
self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = ""
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""Retrieve channel_values dictionary with properly constructed message objects."""
storage_safe_thread_id = to_storage_safe_id(thread_id)
storage_safe_checkpoint_ns = to_storage_safe_str(checkpoint_ns)
Expand Down Expand Up @@ -767,7 +770,7 @@ async def aget_channel_values(

async def _aload_pending_sends(
self, thread_id: str, checkpoint_ns: str = "", parent_checkpoint_id: str = ""
) -> list[tuple[str, bytes]]:
) -> List[Tuple[str, bytes]]:
"""Load pending sends for a parent checkpoint.

Args:
Expand Down
60 changes: 49 additions & 11 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import json
import random
from abc import abstractmethod
from collections.abc import Sequence
from typing import Any, Generic, List, Optional, cast
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, cast

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
Expand Down Expand Up @@ -100,12 +99,16 @@ def __init__(
redis_url: Optional[str] = None,
*,
redis_client: Optional[RedisClientType] = None,
connection_args: Optional[dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
ttl: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(serde=JsonPlusRedisSerializer())
if redis_url is None and redis_client is None:
raise ValueError("Either redis_url or redis_client must be provided")

# Store TTL configuration
self.ttl_config = ttl

self.configure_client(
redis_url=redis_url,
redis_client=redis_client,
Expand All @@ -128,7 +131,7 @@ def configure_client(
self,
redis_url: Optional[str] = None,
redis_client: Optional[RedisClientType] = None,
connection_args: Optional[dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Configure the Redis client."""
pass
Expand Down Expand Up @@ -180,11 +183,46 @@ def setup(self) -> None:
self.checkpoint_blobs_index.create(overwrite=False)
self.checkpoint_writes_index.create(overwrite=False)

def _apply_ttl_to_keys(
self,
main_key: str,
related_keys: Optional[List[str]] = None,
ttl_minutes: Optional[float] = None,
) -> Any:
"""Apply Redis native TTL to keys.

Args:
main_key: The primary Redis key
related_keys: Additional Redis keys that should expire at the same time
ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided

Returns:
Result of the Redis operation
"""
if ttl_minutes is None:
# Check if there's a default TTL in config
if self.ttl_config and "default_ttl" in self.ttl_config:
ttl_minutes = self.ttl_config.get("default_ttl")

if ttl_minutes is not None:
ttl_seconds = int(ttl_minutes * 60)
pipeline = self._redis.pipeline()

# Set TTL for main key
pipeline.expire(main_key, ttl_seconds)

# Set TTL for related keys
if related_keys:
for key in related_keys:
pipeline.expire(key, ttl_seconds)

return pipeline.execute()

def _load_checkpoint(
self,
checkpoint: dict[str, Any],
channel_values: dict[str, Any],
pending_sends: list[Any],
checkpoint: Dict[str, Any],
channel_values: Dict[str, Any],
pending_sends: List[Any],
) -> Checkpoint:
if not checkpoint:
return {}
Expand Down Expand Up @@ -218,7 +256,7 @@ def _load_blobs(self, blob_values: dict[str, Any]) -> dict[str, Any]:
if v["type"] != "empty"
}

def _get_type_and_blob(self, value: Any) -> tuple[str, Optional[bytes]]:
def _get_type_and_blob(self, value: Any) -> Tuple[str, Optional[bytes]]:
"""Helper to get type and blob from a value."""
t, b = self.serde.dumps_typed(value)
return t, b
Expand All @@ -227,9 +265,9 @@ def _dump_blobs(
self,
thread_id: str,
checkpoint_ns: str,
values: dict[str, Any],
values: Dict[str, Any],
versions: ChannelVersions,
) -> list[tuple[str, dict[str, Any]]]:
) -> List[Tuple[str, Dict[str, Any]]]:
"""Convert blob data for Redis storage."""
if not versions:
return []
Expand Down Expand Up @@ -337,7 +375,7 @@ def _decode_blob(self, blob: str) -> bytes:
# Handle both malformed base64 data and incorrect input types
return blob.encode() if isinstance(blob, str) else blob

def _load_writes_from_redis(self, write_key: str) -> list[tuple[str, str, Any]]:
def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]:
"""Load writes from Redis JSON storage by key."""
if not write_key:
return []
Expand Down