Skip to content

feat(redis): add TTL support for Redis checkpoint storage (#27) #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,25 @@ with ShallowRedisSaver.from_conn_string("redis://localhost:6379") as checkpointe
# ... rest of the implementation follows similar pattern
```

## Redis Checkpoint TTL Support

Both Redis checkpoint savers and stores support Time-To-Live (TTL) functionality for automatic key expiration:

```python
# Configure TTL for checkpoint savers
ttl_config = {
"default_ttl": 60, # Default TTL in minutes
"refresh_on_read": True, # Refresh TTL when checkpoint is read
}

# Use with any checkpoint saver implementation
with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as checkpointer:
checkpointer.setup()
# Use the checkpointer...
```

This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up.

## Redis Stores

Redis Stores provide a persistent key-value store with optional vector search capabilities.
Expand All @@ -169,9 +188,19 @@ index_config = {
"fields": ["text"], # Fields to index
}

with RedisStore.from_conn_string("redis://localhost:6379", index=index_config) as store:
# With TTL configuration
ttl_config = {
"default_ttl": 60, # Default TTL in minutes
"refresh_on_read": True, # Refresh TTL when store entries are read
}

with RedisStore.from_conn_string(
"redis://localhost:6379",
index=index_config,
ttl=ttl_config
) as store:
store.setup()
# Use the store with vector search capabilities...
# Use the store with vector search and TTL capabilities...
```

### Async Implementation
Expand All @@ -180,7 +209,16 @@ with RedisStore.from_conn_string("redis://localhost:6379", index=index_config) a
from langgraph.store.redis.aio import AsyncRedisStore

async def main():
async with AsyncRedisStore.from_conn_string("redis://localhost:6379") as store:
# TTL also works with async implementations
ttl_config = {
"default_ttl": 60, # Default TTL in minutes
"refresh_on_read": True, # Refresh TTL when store entries are read
}

async with AsyncRedisStore.from_conn_string(
"redis://localhost:6379",
ttl=ttl_config
) as store:
await store.setup()
# Use the store asynchronously...

Expand Down Expand Up @@ -235,6 +273,16 @@ For Redis Stores with vector search:
1. **Store Index**: Main key-value store
2. **Vector Index**: Optional vector embeddings for similarity search

### TTL Implementation

Both Redis checkpoint savers and stores leverage Redis's native key expiration:

- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command
- **Automatic Cleanup**: Redis automatically removes expired keys
- **Configurable Default TTL**: Set a default TTL for all keys in minutes
- **TTL Refresh on Read**: Optionally refresh TTL when keys are accessed
- **Applied to All Related Keys**: TTL is applied to all related keys (checkpoint, blobs, writes)

## Contributing

We welcome contributions! Here's how you can help:
Expand Down
50 changes: 42 additions & 8 deletions langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,16 @@ def put(
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"] # type: ignore

# Create the checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
storage_safe_thread_id,
storage_safe_checkpoint_ns,
storage_safe_checkpoint_id,
)

self.checkpoints_index.load(
[checkpoint_data],
keys=[
BaseRedisSaver._make_redis_checkpoint_key(
storage_safe_thread_id,
storage_safe_checkpoint_ns,
storage_safe_checkpoint_id,
)
],
keys=[checkpoint_key],
)

# Store blob values.
Expand All @@ -272,10 +273,16 @@ def put(
new_versions,
)

blob_keys = []
if blobs:
# Unzip the list of tuples into separate lists for keys and data
keys, data = zip(*blobs)
self.checkpoint_blobs_index.load(list(data), keys=list(keys))
blob_keys = list(keys)
self.checkpoint_blobs_index.load(list(data), keys=blob_keys)

# Apply TTL to checkpoint and blob keys if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
self._apply_ttl_to_keys(checkpoint_key, blob_keys)

return next_config

Expand Down Expand Up @@ -332,6 +339,33 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])

# If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys
if self.ttl_config and self.ttl_config.get("refresh_on_read"):
# Get the checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
to_storage_safe_id(doc_thread_id),
to_storage_safe_str(doc_checkpoint_ns),
to_storage_safe_id(doc_checkpoint_id),
)

# Get all blob keys related to this checkpoint
from langgraph.checkpoint.redis.base import (
CHECKPOINT_BLOB_PREFIX,
CHECKPOINT_WRITE_PREFIX,
)

# Get the blob keys
blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*"
blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)]

# Also get checkpoint write keys that should have the same TTL
write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*"
write_keys = [key.decode() for key in self._redis.keys(write_key_pattern)]

# Apply TTL to checkpoint, blob keys, and write keys
all_related_keys = blob_keys + write_keys
self._apply_ttl_to_keys(checkpoint_key, all_related_keys)

# Fetch channel_values
channel_values = self.get_channel_values(
thread_id=doc_thread_id,
Expand Down
69 changes: 69 additions & 0 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,45 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])

# If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys
if self.ttl_config and self.ttl_config.get("refresh_on_read"):
# Get the checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
to_storage_safe_id(doc_thread_id),
to_storage_safe_str(doc_checkpoint_ns),
to_storage_safe_id(doc_checkpoint_id),
)

# Get all blob keys related to this checkpoint
from langgraph.checkpoint.redis.base import (
CHECKPOINT_BLOB_PREFIX,
CHECKPOINT_WRITE_PREFIX,
)

# Get the blob keys
blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*"
blob_keys = await self._redis.keys(blob_key_pattern)
blob_keys = [key.decode() for key in blob_keys]

# Also get checkpoint write keys that should have the same TTL
write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*"
write_keys = await self._redis.keys(write_key_pattern)
write_keys = [key.decode() for key in write_keys]

# Apply TTL to checkpoint, blob keys, and write keys
ttl_minutes = self.ttl_config.get("default_ttl")
if ttl_minutes is not None:
ttl_seconds = int(ttl_minutes * 60)
pipeline = self._redis.pipeline()
pipeline.expire(checkpoint_key, ttl_seconds)

# Combine blob keys and write keys for TTL refresh
all_related_keys = blob_keys + write_keys
for key in all_related_keys:
pipeline.expire(key, ttl_seconds)

await pipeline.execute()

# Fetch channel_values
channel_values = await self.aget_channel_values(
thread_id=doc_thread_id,
Expand Down Expand Up @@ -476,6 +515,22 @@ async def aput(
# Execute all operations atomically
await pipeline.execute()

# Apply TTL to checkpoint and blob keys if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
all_keys = (
[checkpoint_key] + [key for key, _ in blobs]
if blobs
else [checkpoint_key]
)
ttl_minutes = self.ttl_config.get("default_ttl")
ttl_seconds = int(ttl_minutes * 60)

# Use a new pipeline for TTL operations
ttl_pipeline = self._redis.pipeline()
for key in all_keys:
ttl_pipeline.expire(key, ttl_seconds)
await ttl_pipeline.execute()

return next_config

except asyncio.CancelledError:
Expand Down Expand Up @@ -575,6 +630,7 @@ async def aput_writes(

# Determine if this is an upsert case
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
created_keys = []

# Add all write operations to the pipeline
for write_obj in writes_objects:
Expand All @@ -599,15 +655,28 @@ async def aput_writes(
else:
# Create new key
await pipeline.json().set(key, "$", write_obj)
created_keys.append(key)
else:
# For non-upsert case, only set if key doesn't exist
exists = await self._redis.exists(key)
if not exists:
await pipeline.json().set(key, "$", write_obj)
created_keys.append(key)

# Execute all operations atomically
await pipeline.execute()

# Apply TTL to newly created keys
if created_keys and self.ttl_config and "default_ttl" in self.ttl_config:
ttl_minutes = self.ttl_config.get("default_ttl")
ttl_seconds = int(ttl_minutes * 60)

# Use a new pipeline for TTL operations
ttl_pipeline = self._redis.pipeline()
for key in created_keys:
ttl_pipeline.expire(key, ttl_seconds)
await ttl_pipeline.execute()

except asyncio.CancelledError:
# Handle cancellation/interruption
# Pipeline will be automatically discarded
Expand Down
49 changes: 49 additions & 0 deletions langgraph/checkpoint/redis/ashallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@ def __init__(
*,
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
ttl: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
)
self.loop = asyncio.get_running_loop()

Expand Down Expand Up @@ -149,12 +151,14 @@ async def from_conn_string(
*,
redis_client: Optional[AsyncRedis] = None,
connection_args: Optional[dict[str, Any]] = None,
ttl: Optional[dict[str, Any]] = None,
) -> AsyncIterator[AsyncShallowRedisSaver]:
"""Create a new AsyncShallowRedisSaver instance."""
async with cls(
redis_url=redis_url,
redis_client=redis_client,
connection_args=connection_args,
ttl=ttl,
) as saver:
yield saver

Expand Down Expand Up @@ -279,6 +283,22 @@ async def aput(
# Execute all operations atomically
await pipeline.execute()

# Apply TTL to checkpoint and blob keys if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
# Prepare the list of keys to apply TTL
ttl_keys = [checkpoint_key]
if blobs:
ttl_keys.extend([key for key, _ in blobs])

# Apply TTL to all keys
ttl_minutes = self.ttl_config.get("default_ttl")
ttl_seconds = int(ttl_minutes * 60)

ttl_pipeline = self._redis.pipeline()
for key in ttl_keys:
ttl_pipeline.expire(key, ttl_seconds)
await ttl_pipeline.execute()

return next_config

except asyncio.CancelledError:
Expand Down Expand Up @@ -389,6 +409,35 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:

doc = results.docs[0]

# If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys
if self.ttl_config and self.ttl_config.get("refresh_on_read"):
thread_id = getattr(doc, "thread_id", "")
checkpoint_ns = getattr(doc, "checkpoint_ns", "")

# Get the checkpoint key
checkpoint_key = AsyncShallowRedisSaver._make_shallow_redis_checkpoint_key(
thread_id, checkpoint_ns
)

# Get all blob keys related to this checkpoint
blob_key_pattern = (
AsyncShallowRedisSaver._make_shallow_redis_checkpoint_blob_key_pattern(
thread_id, checkpoint_ns
)
)
blob_keys = await self._redis.keys(blob_key_pattern)
blob_keys = [key.decode() for key in blob_keys]

# Apply TTL
ttl_minutes = self.ttl_config.get("default_ttl")
if ttl_minutes is not None:
ttl_seconds = int(ttl_minutes * 60)
pipeline = self._redis.pipeline()
pipeline.expire(checkpoint_key, ttl_seconds)
for key in blob_keys:
pipeline.expire(key, ttl_seconds)
await pipeline.execute()

checkpoint = json.loads(doc["$.checkpoint"])

# Fetch channel_values
Expand Down
Loading