Skip to content

Commit 6fed6ad

Browse files
committed
feat(redis): add TTL support for Redis checkpoint storage (#27)
Implements Time-To-Live (TTL) functionality for Redis-based checkpoint storage in LangGraph. This allows automatic expiration of checkpoint data after a configurable time period, helping to manage Redis memory usage. Key features: - Add TTL configuration option to RedisSaver and AsyncRedisSaver constructors - Implement TTL refresh-on-read functionality to extend expiration when checkpoints are accessed - Apply TTL consistently to all related keys (checkpoints, blobs, writes) - Add comprehensive test suite for TTL functionality - Update type hints to use explicit format (Dict, List, Tuple) for better clarity
1 parent 43f2b6b commit 6fed6ad

File tree

7 files changed

+668
-32
lines changed

7 files changed

+668
-32
lines changed

README.md

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,25 @@ with ShallowRedisSaver.from_conn_string("redis://localhost:6379") as checkpointe
148148
# ... rest of the implementation follows similar pattern
149149
```
150150

151+
## Redis Checkpoint TTL Support
152+
153+
Both Redis checkpoint savers and stores support Time-To-Live (TTL) functionality for automatic key expiration:
154+
155+
```python
156+
# Configure TTL for checkpoint savers
157+
ttl_config = {
158+
"default_ttl": 60, # Default TTL in minutes
159+
"refresh_on_read": True, # Refresh TTL when checkpoint is read
160+
}
161+
162+
# Use with any checkpoint saver implementation
163+
with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as checkpointer:
164+
checkpointer.setup()
165+
# Use the checkpointer...
166+
```
167+
168+
This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up.
169+
151170
## Redis Stores
152171

153172
Redis Stores provide a persistent key-value store with optional vector search capabilities.
@@ -169,9 +188,19 @@ index_config = {
169188
"fields": ["text"], # Fields to index
170189
}
171190

172-
with RedisStore.from_conn_string("redis://localhost:6379", index=index_config) as store:
191+
# With TTL configuration
192+
ttl_config = {
193+
"default_ttl": 60, # Default TTL in minutes
194+
"refresh_on_read": True, # Refresh TTL when store entries are read
195+
}
196+
197+
with RedisStore.from_conn_string(
198+
"redis://localhost:6379",
199+
index=index_config,
200+
ttl=ttl_config
201+
) as store:
173202
store.setup()
174-
# Use the store with vector search capabilities...
203+
# Use the store with vector search and TTL capabilities...
175204
```
176205

177206
### Async Implementation
@@ -180,7 +209,16 @@ with RedisStore.from_conn_string("redis://localhost:6379", index=index_config) a
180209
from langgraph.store.redis.aio import AsyncRedisStore
181210

182211
async def main():
183-
async with AsyncRedisStore.from_conn_string("redis://localhost:6379") as store:
212+
# TTL also works with async implementations
213+
ttl_config = {
214+
"default_ttl": 60, # Default TTL in minutes
215+
"refresh_on_read": True, # Refresh TTL when store entries are read
216+
}
217+
218+
async with AsyncRedisStore.from_conn_string(
219+
"redis://localhost:6379",
220+
ttl=ttl_config
221+
) as store:
184222
await store.setup()
185223
# Use the store asynchronously...
186224

@@ -235,6 +273,16 @@ For Redis Stores with vector search:
235273
1. **Store Index**: Main key-value store
236274
2. **Vector Index**: Optional vector embeddings for similarity search
237275

276+
### TTL Implementation
277+
278+
Both Redis checkpoint savers and stores leverage Redis's native key expiration:
279+
280+
- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command
281+
- **Automatic Cleanup**: Redis automatically removes expired keys
282+
- **Configurable Default TTL**: Set a default TTL for all keys in minutes
283+
- **TTL Refresh on Read**: Optionally refresh TTL when keys are accessed
284+
- **Applied to All Related Keys**: TTL is applied to all related keys (checkpoint, blobs, writes)
285+
238286
## Contributing
239287

240288
We welcome contributions! Here's how you can help:

langgraph/checkpoint/redis/__init__.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,16 @@ def put(
253253
checkpoint_data["source"] = metadata["source"]
254254
checkpoint_data["step"] = metadata["step"] # type: ignore
255255

256+
# Create the checkpoint key
257+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
258+
storage_safe_thread_id,
259+
storage_safe_checkpoint_ns,
260+
storage_safe_checkpoint_id,
261+
)
262+
256263
self.checkpoints_index.load(
257264
[checkpoint_data],
258-
keys=[
259-
BaseRedisSaver._make_redis_checkpoint_key(
260-
storage_safe_thread_id,
261-
storage_safe_checkpoint_ns,
262-
storage_safe_checkpoint_id,
263-
)
264-
],
265+
keys=[checkpoint_key],
265266
)
266267

267268
# Store blob values.
@@ -272,10 +273,16 @@ def put(
272273
new_versions,
273274
)
274275

276+
blob_keys = []
275277
if blobs:
276278
# Unzip the list of tuples into separate lists for keys and data
277279
keys, data = zip(*blobs)
278-
self.checkpoint_blobs_index.load(list(data), keys=list(keys))
280+
blob_keys = list(keys)
281+
self.checkpoint_blobs_index.load(list(data), keys=blob_keys)
282+
283+
# Apply TTL to checkpoint and blob keys if configured
284+
if self.ttl_config and "default_ttl" in self.ttl_config:
285+
self._apply_ttl_to_keys(checkpoint_key, blob_keys)
279286

280287
return next_config
281288

@@ -332,6 +339,33 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
332339
doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
333340
doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])
334341

342+
# If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys
343+
if self.ttl_config and self.ttl_config.get("refresh_on_read"):
344+
# Get the checkpoint key
345+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
346+
to_storage_safe_id(doc_thread_id),
347+
to_storage_safe_str(doc_checkpoint_ns),
348+
to_storage_safe_id(doc_checkpoint_id),
349+
)
350+
351+
# Get all blob keys related to this checkpoint
352+
from langgraph.checkpoint.redis.base import (
353+
CHECKPOINT_BLOB_PREFIX,
354+
CHECKPOINT_WRITE_PREFIX,
355+
)
356+
357+
# Get the blob keys
358+
blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*"
359+
blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)]
360+
361+
# Also get checkpoint write keys that should have the same TTL
362+
write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*"
363+
write_keys = [key.decode() for key in self._redis.keys(write_key_pattern)]
364+
365+
# Apply TTL to checkpoint, blob keys, and write keys
366+
all_related_keys = blob_keys + write_keys
367+
self._apply_ttl_to_keys(checkpoint_key, all_related_keys)
368+
335369
# Fetch channel_values
336370
channel_values = self.get_channel_values(
337371
thread_id=doc_thread_id,

langgraph/checkpoint/redis/aio.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,45 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
192192
doc_checkpoint_id = from_storage_safe_id(doc["checkpoint_id"])
193193
doc_parent_checkpoint_id = from_storage_safe_id(doc["parent_checkpoint_id"])
194194

195+
# If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys
196+
if self.ttl_config and self.ttl_config.get("refresh_on_read"):
197+
# Get the checkpoint key
198+
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
199+
to_storage_safe_id(doc_thread_id),
200+
to_storage_safe_str(doc_checkpoint_ns),
201+
to_storage_safe_id(doc_checkpoint_id),
202+
)
203+
204+
# Get all blob keys related to this checkpoint
205+
from langgraph.checkpoint.redis.base import (
206+
CHECKPOINT_BLOB_PREFIX,
207+
CHECKPOINT_WRITE_PREFIX,
208+
)
209+
210+
# Get the blob keys
211+
blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*"
212+
blob_keys = await self._redis.keys(blob_key_pattern)
213+
blob_keys = [key.decode() for key in blob_keys]
214+
215+
# Also get checkpoint write keys that should have the same TTL
216+
write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*"
217+
write_keys = await self._redis.keys(write_key_pattern)
218+
write_keys = [key.decode() for key in write_keys]
219+
220+
# Apply TTL to checkpoint, blob keys, and write keys
221+
ttl_minutes = self.ttl_config.get("default_ttl")
222+
if ttl_minutes is not None:
223+
ttl_seconds = int(ttl_minutes * 60)
224+
pipeline = self._redis.pipeline()
225+
pipeline.expire(checkpoint_key, ttl_seconds)
226+
227+
# Combine blob keys and write keys for TTL refresh
228+
all_related_keys = blob_keys + write_keys
229+
for key in all_related_keys:
230+
pipeline.expire(key, ttl_seconds)
231+
232+
await pipeline.execute()
233+
195234
# Fetch channel_values
196235
channel_values = await self.aget_channel_values(
197236
thread_id=doc_thread_id,
@@ -476,6 +515,22 @@ async def aput(
476515
# Execute all operations atomically
477516
await pipeline.execute()
478517

518+
# Apply TTL to checkpoint and blob keys if configured
519+
if self.ttl_config and "default_ttl" in self.ttl_config:
520+
all_keys = (
521+
[checkpoint_key] + [key for key, _ in blobs]
522+
if blobs
523+
else [checkpoint_key]
524+
)
525+
ttl_minutes = self.ttl_config.get("default_ttl")
526+
ttl_seconds = int(ttl_minutes * 60)
527+
528+
# Use a new pipeline for TTL operations
529+
ttl_pipeline = self._redis.pipeline()
530+
for key in all_keys:
531+
ttl_pipeline.expire(key, ttl_seconds)
532+
await ttl_pipeline.execute()
533+
479534
return next_config
480535

481536
except asyncio.CancelledError:
@@ -575,6 +630,7 @@ async def aput_writes(
575630

576631
# Determine if this is an upsert case
577632
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
633+
created_keys = []
578634

579635
# Add all write operations to the pipeline
580636
for write_obj in writes_objects:
@@ -599,15 +655,28 @@ async def aput_writes(
599655
else:
600656
# Create new key
601657
await pipeline.json().set(key, "$", write_obj)
658+
created_keys.append(key)
602659
else:
603660
# For non-upsert case, only set if key doesn't exist
604661
exists = await self._redis.exists(key)
605662
if not exists:
606663
await pipeline.json().set(key, "$", write_obj)
664+
created_keys.append(key)
607665

608666
# Execute all operations atomically
609667
await pipeline.execute()
610668

669+
# Apply TTL to newly created keys
670+
if created_keys and self.ttl_config and "default_ttl" in self.ttl_config:
671+
ttl_minutes = self.ttl_config.get("default_ttl")
672+
ttl_seconds = int(ttl_minutes * 60)
673+
674+
# Use a new pipeline for TTL operations
675+
ttl_pipeline = self._redis.pipeline()
676+
for key in created_keys:
677+
ttl_pipeline.expire(key, ttl_seconds)
678+
await ttl_pipeline.execute()
679+
611680
except asyncio.CancelledError:
612681
# Handle cancellation/interruption
613682
# Pipeline will be automatically discarded

langgraph/checkpoint/redis/ashallow.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,13 @@ def __init__(
108108
*,
109109
redis_client: Optional[AsyncRedis] = None,
110110
connection_args: Optional[dict[str, Any]] = None,
111+
ttl: Optional[dict[str, Any]] = None,
111112
) -> None:
112113
super().__init__(
113114
redis_url=redis_url,
114115
redis_client=redis_client,
115116
connection_args=connection_args,
117+
ttl=ttl,
116118
)
117119
self.loop = asyncio.get_running_loop()
118120

@@ -149,12 +151,14 @@ async def from_conn_string(
149151
*,
150152
redis_client: Optional[AsyncRedis] = None,
151153
connection_args: Optional[dict[str, Any]] = None,
154+
ttl: Optional[dict[str, Any]] = None,
152155
) -> AsyncIterator[AsyncShallowRedisSaver]:
153156
"""Create a new AsyncShallowRedisSaver instance."""
154157
async with cls(
155158
redis_url=redis_url,
156159
redis_client=redis_client,
157160
connection_args=connection_args,
161+
ttl=ttl,
158162
) as saver:
159163
yield saver
160164

@@ -279,6 +283,22 @@ async def aput(
279283
# Execute all operations atomically
280284
await pipeline.execute()
281285

286+
# Apply TTL to checkpoint and blob keys if configured
287+
if self.ttl_config and "default_ttl" in self.ttl_config:
288+
# Prepare the list of keys to apply TTL
289+
ttl_keys = [checkpoint_key]
290+
if blobs:
291+
ttl_keys.extend([key for key, _ in blobs])
292+
293+
# Apply TTL to all keys
294+
ttl_minutes = self.ttl_config.get("default_ttl")
295+
ttl_seconds = int(ttl_minutes * 60)
296+
297+
ttl_pipeline = self._redis.pipeline()
298+
for key in ttl_keys:
299+
ttl_pipeline.expire(key, ttl_seconds)
300+
await ttl_pipeline.execute()
301+
282302
return next_config
283303

284304
except asyncio.CancelledError:
@@ -389,6 +409,35 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
389409

390410
doc = results.docs[0]
391411

412+
# If refresh_on_read is enabled, refresh TTL for checkpoint key and related keys
413+
if self.ttl_config and self.ttl_config.get("refresh_on_read"):
414+
thread_id = getattr(doc, "thread_id", "")
415+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
416+
417+
# Get the checkpoint key
418+
checkpoint_key = AsyncShallowRedisSaver._make_shallow_redis_checkpoint_key(
419+
thread_id, checkpoint_ns
420+
)
421+
422+
# Get all blob keys related to this checkpoint
423+
blob_key_pattern = (
424+
AsyncShallowRedisSaver._make_shallow_redis_checkpoint_blob_key_pattern(
425+
thread_id, checkpoint_ns
426+
)
427+
)
428+
blob_keys = await self._redis.keys(blob_key_pattern)
429+
blob_keys = [key.decode() for key in blob_keys]
430+
431+
# Apply TTL
432+
ttl_minutes = self.ttl_config.get("default_ttl")
433+
if ttl_minutes is not None:
434+
ttl_seconds = int(ttl_minutes * 60)
435+
pipeline = self._redis.pipeline()
436+
pipeline.expire(checkpoint_key, ttl_seconds)
437+
for key in blob_keys:
438+
pipeline.expire(key, ttl_seconds)
439+
await pipeline.execute()
440+
392441
checkpoint = json.loads(doc["$.checkpoint"])
393442

394443
# Fetch channel_values

0 commit comments

Comments
 (0)