Skip to content

fix: handle both bytes and string Redis keys when decode_responses=True #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
EMPTY_ID_SENTINEL,
from_storage_safe_id,
from_storage_safe_str,
safely_decode,
to_storage_safe_id,
to_storage_safe_str,
)
Expand Down Expand Up @@ -212,12 +213,14 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
# Get the blob keys
blob_key_pattern = f"{CHECKPOINT_BLOB_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:*"
blob_keys = await self._redis.keys(blob_key_pattern)
blob_keys = [key.decode() for key in blob_keys]
# Use safely_decode to handle both string and bytes responses
blob_keys = [safely_decode(key) for key in blob_keys]

# Also get checkpoint write keys that should have the same TTL
write_key_pattern = f"{CHECKPOINT_WRITE_PREFIX}:{to_storage_safe_id(doc_thread_id)}:{to_storage_safe_str(doc_checkpoint_ns)}:{to_storage_safe_id(doc_checkpoint_id)}:*"
write_keys = await self._redis.keys(write_key_pattern)
write_keys = [key.decode() for key in write_keys]
# Use safely_decode to handle both string and bytes responses
write_keys = [safely_decode(key) for key in write_keys]

# Apply TTL to checkpoint, blob keys, and write keys
ttl_minutes = self.ttl_config.get("default_ttl")
Expand Down Expand Up @@ -895,9 +898,11 @@ async def _aload_pending_writes(
None,
)
matching_keys = await self._redis.keys(pattern=writes_key)
# Use safely_decode to handle both string and bytes responses
decoded_keys = [safely_decode(key) for key in matching_keys]
parsed_keys = [
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
for key in matching_keys
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
for key in decoded_keys
]
pending_writes = BaseRedisSaver._load_writes(
self.serde,
Expand Down
18 changes: 13 additions & 5 deletions langgraph/checkpoint/redis/ashallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
REDIS_KEY_SEPARATOR,
BaseRedisSaver,
)
from langgraph.checkpoint.redis.util import safely_decode

SCHEMAS = [
{
Expand Down Expand Up @@ -252,7 +253,9 @@ async def aput(
# Process each existing blob key to determine if it should be kept or deleted
if existing_blob_keys:
for blob_key in existing_blob_keys:
key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR)
# Use safely_decode to handle both string and bytes responses
decoded_key = safely_decode(blob_key)
key_parts = decoded_key.split(REDIS_KEY_SEPARATOR)
# The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
if len(key_parts) >= 5:
channel = key_parts[3]
Expand Down Expand Up @@ -428,7 +431,8 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
)
)
blob_keys = await self._redis.keys(blob_key_pattern)
blob_keys = [key.decode() for key in blob_keys]
# Use safely_decode to handle both string and bytes responses
blob_keys = [safely_decode(key) for key in blob_keys]

# Apply TTL
ttl_minutes = self.ttl_config.get("default_ttl")
Expand Down Expand Up @@ -554,7 +558,9 @@ async def aput_writes(
# Process each existing writes key to determine if it should be kept or deleted
if existing_writes_keys:
for write_key in existing_writes_keys:
key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR)
# Use safely_decode to handle both string and bytes responses
decoded_key = safely_decode(write_key)
key_parts = decoded_key.split(REDIS_KEY_SEPARATOR)
# The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
if len(key_parts) >= 5:
key_checkpoint_id = key_parts[3]
Expand Down Expand Up @@ -700,9 +706,11 @@ async def _aload_pending_writes(
thread_id, checkpoint_ns, checkpoint_id, "*", None
)
matching_keys = await self._redis.keys(pattern=writes_key)
# Use safely_decode to handle both string and bytes responses
decoded_keys = [safely_decode(key) for key in matching_keys]
parsed_keys = [
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
for key in matching_keys
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
for key in decoded_keys
]
pending_writes = BaseRedisSaver._load_writes(
self.serde,
Expand Down
11 changes: 9 additions & 2 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from langgraph.checkpoint.serde.types import ChannelProtocol

from langgraph.checkpoint.redis.util import (
safely_decode,
to_storage_safe_id,
to_storage_safe_str,
)
Expand Down Expand Up @@ -509,9 +510,12 @@ def _load_pending_writes(
# Cast the result to List[bytes] to help type checker
matching_keys: List[bytes] = self._redis.keys(pattern=writes_key) # type: ignore[assignment]

# Use safely_decode to handle both string and bytes responses
decoded_keys = [safely_decode(key) for key in matching_keys]

parsed_keys = [
BaseRedisSaver._parse_redis_checkpoint_writes_key(key.decode())
for key in matching_keys
BaseRedisSaver._parse_redis_checkpoint_writes_key(key)
for key in decoded_keys
]
pending_writes = BaseRedisSaver._load_writes(
self.serde,
Expand Down Expand Up @@ -541,6 +545,9 @@ def _load_writes(

@staticmethod
def _parse_redis_checkpoint_writes_key(redis_key: str) -> dict:
# Ensure redis_key is a string
redis_key = safely_decode(redis_key)

parts = redis_key.split(REDIS_KEY_SEPARATOR)
# Ensure we have at least 6 parts
if len(parts) < 6:
Expand Down
14 changes: 11 additions & 3 deletions langgraph/checkpoint/redis/shallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
REDIS_KEY_SEPARATOR,
BaseRedisSaver,
)
from langgraph.checkpoint.redis.util import safely_decode

SCHEMAS = [
{
Expand Down Expand Up @@ -179,7 +180,9 @@ def put(
# Process each existing blob key to determine if it should be kept or deleted
if existing_blob_keys:
for blob_key in existing_blob_keys:
key_parts = blob_key.decode().split(REDIS_KEY_SEPARATOR)
# Use safely_decode to handle both string and bytes responses
decoded_key = safely_decode(blob_key)
key_parts = decoded_key.split(REDIS_KEY_SEPARATOR)
# The key format is checkpoint_blob:thread_id:checkpoint_ns:channel:version
if len(key_parts) >= 5:
channel = key_parts[3]
Expand Down Expand Up @@ -387,7 +390,10 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
thread_id, checkpoint_ns
)
)
blob_keys = [key.decode() for key in self._redis.keys(blob_key_pattern)]
# Use safely_decode to handle both string and bytes responses
blob_keys = [
safely_decode(key) for key in self._redis.keys(blob_key_pattern)
]

# Apply TTL
self._apply_ttl_to_keys(checkpoint_key, blob_keys)
Expand Down Expand Up @@ -524,7 +530,9 @@ def put_writes(
# Process each existing writes key to determine if it should be kept or deleted
if existing_writes_keys:
for write_key in existing_writes_keys:
key_parts = write_key.decode().split(REDIS_KEY_SEPARATOR)
# Use safely_decode to handle both string and bytes responses
decoded_key = safely_decode(write_key)
key_parts = decoded_key.split(REDIS_KEY_SEPARATOR)
# The key format is checkpoint_write:thread_id:checkpoint_ns:checkpoint_id:task_id:idx
if len(key_parts) >= 5:
key_checkpoint_id = key_parts[3]
Expand Down
49 changes: 49 additions & 0 deletions langgraph/checkpoint/redis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
that is lexicographically sortable. Typically, checkpoints that need
sentinel values are from the first run of the graph, so this should
generally be correct.

This module also includes utility functions for safely handling Redis responses,
including handling bytes vs string responses depending on how the Redis client is
configured with decode_responses.
"""

from typing import Any, Dict, List, Optional, Set, Tuple, Union

EMPTY_STRING_SENTINEL = "__empty__"
EMPTY_ID_SENTINEL = "00000000-0000-0000-0000-000000000000"

Expand Down Expand Up @@ -81,3 +87,46 @@ def from_storage_safe_id(value: str) -> str:
return ""
else:
return value


def safely_decode(obj: Any) -> Any:
"""
Safely decode Redis responses, handling both string and bytes types.

This is especially useful when working with Redis clients configured with
different decode_responses settings. It recursively processes nested
data structures (dicts, lists, tuples, sets).

Based on RedisVL's convert_bytes function (redisvl.redis.utils.convert_bytes)
but implemented directly to avoid runtime import issues and ensure consistent
behavior with sets and other data structures. See PR #34 and referenced
implementation: https://github.com/redis/redis-vl-python/blob/9f22a9ad4c2166af6462b007833b456448714dd9/redisvl/redis/utils.py#L20

Args:
obj: The object to decode. Can be a string, bytes, or a nested structure
containing strings/bytes (dict, list, tuple, set).

Returns:
The decoded object with all bytes converted to strings.
"""
if obj is None:
return None
elif isinstance(obj, bytes):
try:
return obj.decode("utf-8")
except UnicodeDecodeError:
# If decoding fails, return the original bytes
return obj
elif isinstance(obj, str):
return obj
elif isinstance(obj, dict):
return {safely_decode(k): safely_decode(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [safely_decode(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(safely_decode(item) for item in obj)
elif isinstance(obj, set):
return {safely_decode(item) for item in obj}
else:
# For other types (int, float, bool, etc.), return as is
return obj
148 changes: 148 additions & 0 deletions tests/test_decode_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Tests for Redis key decoding functionality."""

import os
import time
import uuid
from typing import Any, Dict, Optional

import pytest
from redis import Redis

from langgraph.checkpoint.redis.util import safely_decode


def test_safely_decode_basic_types():
"""Test safely_decode function with basic type inputs."""
# Test with bytes
assert safely_decode(b"test") == "test"

# Test with string
assert safely_decode("test") == "test"

# Test with None
assert safely_decode(None) is None

# Test with other types
assert safely_decode(123) == 123
assert safely_decode(1.23) == 1.23
assert safely_decode(True) is True


def test_safely_decode_nested_structures():
"""Test safely_decode function with nested data structures."""
# Test with dictionary
assert safely_decode({b"key": b"value"}) == {"key": "value"}
assert safely_decode({b"key1": b"value1", "key2": 123}) == {
"key1": "value1",
"key2": 123,
}

# Test with nested dictionary
nested_dict = {b"outer": {b"inner": b"value"}}
assert safely_decode(nested_dict) == {"outer": {"inner": "value"}}

# Test with list
assert safely_decode([b"item1", b"item2"]) == ["item1", "item2"]

# Test with tuple
assert safely_decode((b"item1", b"item2")) == ("item1", "item2")

# Test with set
decoded_set = safely_decode({b"item1", b"item2"})
assert isinstance(decoded_set, set)
assert "item1" in decoded_set
assert "item2" in decoded_set

# Test with complex nested structure
complex_struct = {
b"key1": [b"list_item1", {b"nested_key": b"nested_value"}],
b"key2": (b"tuple_item", 123),
b"key3": {b"set_item1", b"set_item2"},
}
decoded = safely_decode(complex_struct)
assert decoded["key1"][0] == "list_item1"
assert decoded["key1"][1]["nested_key"] == "nested_value"
assert decoded["key2"][0] == "tuple_item"
assert decoded["key2"][1] == 123
assert isinstance(decoded["key3"], set)
assert "set_item1" in decoded["key3"]
assert "set_item2" in decoded["key3"]


@pytest.mark.parametrize("decode_responses", [True, False])
def test_safely_decode_with_redis(decode_responses: bool, redis_url):
"""Test safely_decode function with actual Redis responses using TestContainers."""
r = Redis.from_url(redis_url, decode_responses=decode_responses)

try:
# Clean up before test to ensure a clean state
r.delete("test:string")
r.delete("test:hash")
r.delete("test:list")
r.delete("test:set")

# Set up test data
r.set("test:string", "value")
r.hset("test:hash", mapping={"field1": "value1", "field2": "value2"})
r.rpush("test:list", "item1", "item2", "item3")
r.sadd("test:set", "member1", "member2")

# Test string value
string_val = r.get("test:string")
decoded_string = safely_decode(string_val)
assert decoded_string == "value"

# Test hash value
hash_val = r.hgetall("test:hash")
decoded_hash = safely_decode(hash_val)
assert decoded_hash == {"field1": "value1", "field2": "value2"}

# Test list value
list_val = r.lrange("test:list", 0, -1)
decoded_list = safely_decode(list_val)
assert decoded_list == ["item1", "item2", "item3"]

# Test set value
set_val = r.smembers("test:set")
decoded_set = safely_decode(set_val)
assert isinstance(decoded_set, set)
assert "member1" in decoded_set
assert "member2" in decoded_set

# Test key fetching
keys = r.keys("test:*")
decoded_keys = safely_decode(keys)
assert sorted(decoded_keys) == sorted(
["test:string", "test:hash", "test:list", "test:set"]
)

finally:
# Clean up after test
r.delete("test:string")
r.delete("test:hash")
r.delete("test:list")
r.delete("test:set")
r.close()


def test_safely_decode_unicode_error_handling():
"""Test safely_decode function with invalid UTF-8 bytes."""
# Create bytes that will cause UnicodeDecodeError
invalid_utf8 = b"\xff\xfe\xfd"

# Should return the original bytes if it can't be decoded
result = safely_decode(invalid_utf8)
assert result == invalid_utf8

# Test with mixed valid and invalid in a complex structure
mixed = {
b"valid": b"This is valid UTF-8",
b"invalid": invalid_utf8,
b"nested": [b"valid", invalid_utf8],
}

result = safely_decode(mixed)
assert result["valid"] == "This is valid UTF-8"
assert result["invalid"] == invalid_utf8
assert result["nested"][0] == "valid"
assert result["nested"][1] == invalid_utf8