Skip to content

Commit 323a485

Browse files
committed
fix(store): honor limit parameter in Redis search operations (#30)
The RedisStore and AsyncRedisStore search methods were hardcoded to return at most 10 results regardless of the specified limit parameter. This fix properly respects the user-specified limit when searching, enabling retrieval of more than 10 results when needed. - Modified _get_batch_search_queries to return limit and offset values - Updated search query creation to use Query.paging() method with the user-specified limit - Added integration tests to verify fix in both sync and async implementations
1 parent 5d3a689 commit 323a485

File tree

5 files changed

+161
-20
lines changed

5 files changed

+161
-20
lines changed

langgraph/store/redis/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _batch_search_ops(
367367
query_vectors = dict(zip([idx for idx, _ in embedding_requests], vectors))
368368

369369
# Process each search operation
370-
for (idx, op), (query_str, params) in zip(search_ops, queries):
370+
for (idx, op), (query_str, params, limit, offset) in zip(search_ops, queries):
371371
if op.query and idx in query_vectors:
372372
# Vector similarity search
373373
vector = query_vectors[idx]
@@ -376,7 +376,7 @@ def _batch_search_ops(
376376
vector_field_name="embedding",
377377
filter_expression=f"@prefix:{_namespace_to_text(op.namespace_prefix)}*",
378378
return_fields=["prefix", "key", "vector_distance"],
379-
num_results=op.limit,
379+
num_results=limit, # Use the user-specified limit
380380
)
381381
vector_results = self.vector_index.query(vector_query)
382382

@@ -469,8 +469,10 @@ def _batch_search_ops(
469469
results[idx] = items
470470
else:
471471
# Regular search
472-
query = Query(query_str)
473-
# Get all potential matches for filtering
472+
# Create a query with LIMIT and OFFSET parameters
473+
query = Query(query_str).paging(offset, limit)
474+
475+
# Execute search with limit and offset applied by Redis
474476
res = self.store_index.search(query)
475477
items = []
476478
refresh_keys = [] # Track keys that need TTL refreshed
@@ -505,10 +507,7 @@ def _batch_search_ops(
505507

506508
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
507509

508-
# Apply pagination after filtering
509-
if params:
510-
limit, offset = params
511-
items = items[offset : offset + limit]
510+
# Note: Pagination is now handled by Redis, no need to slice items manually
512511

513512
# Refresh TTL if requested
514513
if op.refresh_ttl and refresh_keys and self.ttl_config:

langgraph/store/redis/aio.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ async def _batch_search_ops(
648648
query_vectors = dict(zip([idx for idx, _ in embedding_requests], vectors))
649649

650650
# Process each search operation
651-
for (idx, op), (query_str, params) in zip(search_ops, queries):
651+
for (idx, op), (query_str, params, limit, offset) in zip(search_ops, queries):
652652
if op.query and idx in query_vectors:
653653
# Vector similarity search
654654
vector = query_vectors[idx]
@@ -658,7 +658,7 @@ async def _batch_search_ops(
658658
vector_field_name="embedding",
659659
filter_expression=f"@prefix:{_namespace_to_text(op.namespace_prefix)}*",
660660
return_fields=["prefix", "key", "vector_distance"],
661-
num_results=op.limit,
661+
num_results=limit, # Use the user-specified limit
662662
)
663663
)
664664

@@ -722,8 +722,10 @@ async def _batch_search_ops(
722722
results[idx] = items
723723
else:
724724
# Regular search
725-
query = Query(query_str)
726-
# Get all potential matches for filtering
725+
# Create a query with LIMIT and OFFSET parameters
726+
query = Query(query_str).paging(offset, limit)
727+
728+
# Execute search with limit and offset applied by Redis
727729
res = await self.store_index.search(query)
728730
items = []
729731

@@ -746,12 +748,9 @@ async def _batch_search_ops(
746748
continue
747749
items.append(_row_to_search_item(_decode_ns(data["prefix"]), data))
748750

749-
# Apply pagination after filtering
750-
if params:
751-
limit, offset = params
752-
items = items[offset : offset + limit]
751+
# Note: Pagination is now handled by Redis, no need to slice items manually
753752

754-
results[idx] = items
753+
results[idx] = items
755754

756755
async def _batch_list_namespaces_ops(
757756
self,

langgraph/store/redis/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _prepare_batch_PUT_queries(
398398
def _get_batch_search_queries(
399399
self,
400400
search_ops: Sequence[tuple[int, SearchOp]],
401-
) -> tuple[list[tuple[str, list]], list[tuple[int, str]]]:
401+
) -> tuple[list[tuple[str, list, int, int]], list[tuple[int, str]]]:
402402
"""Convert search operations into Redis queries."""
403403
queries = []
404404
embedding_requests = []
@@ -413,8 +413,10 @@ def _get_batch_search_queries(
413413
embedding_requests.append((idx, op.query))
414414

415415
query = " ".join(filter_conditions) if filter_conditions else "*"
416-
params = [op.limit, op.offset] if op.limit or op.offset else []
417-
queries.append((query, params))
416+
limit = op.limit if op.limit is not None else 10
417+
offset = op.offset if op.offset is not None else 0
418+
params = [limit, offset]
419+
queries.append((query, params, limit, offset))
418420

419421
return queries, embedding_requests
420422

tests/test_async_search_limit.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Tests for AsyncRedisStore search limits."""
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
import pytest_asyncio
7+
8+
from langgraph.store.redis import AsyncRedisStore
9+
10+
11+
@pytest_asyncio.fixture(scope="function")
12+
async def async_store(redis_url) -> AsyncRedisStore:
13+
"""Fixture to create an AsyncRedisStore."""
14+
async with AsyncRedisStore(redis_url) as store:
15+
await store.setup() # Initialize indices
16+
yield store
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_async_search_with_larger_limit(async_store: AsyncRedisStore) -> None:
21+
"""Test async search with limit > 10."""
22+
# Create 15 test documents
23+
for i in range(15):
24+
await async_store.aput(
25+
("test_namespace",), f"key{i}", {"data": f"value{i}", "index": i}
26+
)
27+
28+
# Search with a limit of 15
29+
results = await async_store.asearch(("test_namespace",), limit=15)
30+
31+
# Should return all 15 results
32+
assert len(results) == 15, f"Expected 15 results, got {len(results)}"
33+
34+
# Verify we have all the items
35+
result_keys = {item.key for item in results}
36+
expected_keys = {f"key{i}" for i in range(15)}
37+
assert result_keys == expected_keys
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_async_vector_search_with_larger_limit(redis_url) -> None:
42+
"""Test async vector search with limit > 10."""
43+
from tests.embed_test_utils import CharacterEmbeddings
44+
45+
# Create vector store with embeddings
46+
embeddings = CharacterEmbeddings(dims=4)
47+
index_config = {
48+
"dims": embeddings.dims,
49+
"embed": embeddings,
50+
"distance_type": "cosine",
51+
"fields": ["text"],
52+
}
53+
54+
async with AsyncRedisStore(redis_url, index=index_config) as store:
55+
await store.setup()
56+
57+
# Create 15 test documents
58+
for i in range(15):
59+
# Create documents with slightly different texts
60+
await store.aput(
61+
("test_namespace",), f"key{i}", {"text": f"sample text {i}", "index": i}
62+
)
63+
64+
# Search with a limit of 15
65+
results = await store.asearch(("test_namespace",), query="sample", limit=15)
66+
67+
# Should return all 15 results
68+
assert len(results) == 15, f"Expected 15 results, got {len(results)}"
69+
70+
# Verify we have all the items
71+
result_keys = {item.key for item in results}
72+
expected_keys = {f"key{i}" for i in range(15)}
73+
assert result_keys == expected_keys

tests/test_search_limit.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Tests for RedisStore search limits."""
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
7+
from langgraph.store.redis import RedisStore
8+
9+
10+
@pytest.fixture(scope="function")
11+
def store(redis_url) -> RedisStore:
12+
"""Fixture to create a Redis store."""
13+
with RedisStore.from_conn_string(redis_url) as store:
14+
store.setup() # Initialize indices
15+
yield store
16+
17+
18+
def test_search_with_larger_limit(store: RedisStore) -> None:
19+
"""Test search with limit > 10."""
20+
# Create 15 test documents
21+
for i in range(15):
22+
store.put(("test_namespace",), f"key{i}", {"data": f"value{i}", "index": i})
23+
24+
# Search with a limit of 15
25+
results = store.search(("test_namespace",), limit=15)
26+
27+
# Should return all 15 results
28+
assert len(results) == 15, f"Expected 15 results, got {len(results)}"
29+
30+
# Verify we have all the items
31+
result_keys = {item.key for item in results}
32+
expected_keys = {f"key{i}" for i in range(15)}
33+
assert result_keys == expected_keys
34+
35+
36+
def test_vector_search_with_larger_limit(redis_url) -> None:
37+
"""Test vector search with limit > 10."""
38+
from tests.embed_test_utils import CharacterEmbeddings
39+
40+
# Create vector store with embeddings
41+
embeddings = CharacterEmbeddings(dims=4)
42+
index_config = {
43+
"dims": embeddings.dims,
44+
"embed": embeddings,
45+
"distance_type": "cosine",
46+
"fields": ["text"],
47+
}
48+
49+
with RedisStore.from_conn_string(redis_url, index=index_config) as store:
50+
store.setup()
51+
52+
# Create 15 test documents
53+
for i in range(15):
54+
# Create documents with slightly different texts
55+
store.put(
56+
("test_namespace",), f"key{i}", {"text": f"sample text {i}", "index": i}
57+
)
58+
59+
# Search with a limit of 15
60+
results = store.search(("test_namespace",), query="sample", limit=15)
61+
62+
# Should return all 15 results
63+
assert len(results) == 15, f"Expected 15 results, got {len(results)}"
64+
65+
# Verify we have all the items
66+
result_keys = {item.key for item in results}
67+
expected_keys = {f"key{i}" for i in range(15)}
68+
assert result_keys == expected_keys

0 commit comments

Comments
 (0)