Skip to content

Commit 51bf29f

Browse files
committed
working refactor
1 parent 3d36102 commit 51bf29f

File tree

3 files changed

+80
-46
lines changed

3 files changed

+80
-46
lines changed

graphiti_core/models/edges/edge_db_queries.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
EPISODIC_EDGE_SAVE_BULK = """
2525
UNWIND $episodic_edges AS edge
26-
MATCH (episode:Episodic {uuid: startNode(edge).uuid})
27-
MATCH (node:Entity {uuid: endNode(edge).uuid})
26+
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
27+
MATCH (node:Entity {uuid: edge.target_node_uuid})
2828
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
2929
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
3030
RETURN r.uuid AS uuid
@@ -47,7 +47,7 @@
4747
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
4848
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
4949
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
50-
RETURN r.uuid AS uuid
50+
RETURN edge.uuid AS uuid
5151
"""
5252

5353
COMMUNITY_EDGE_SAVE = """

graphiti_core/search/search_utils.py

+51-19
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,10 @@ async def node_fulltext_search(
341341

342342
query = (
343343
"""
344-
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
345-
YIELD node AS n, score
346-
WHERE n:Entity
347-
"""
344+
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
345+
YIELD node AS n, score
346+
WHERE n:Entity
347+
"""
348348
+ filter_query
349349
+ ENTITY_NODE_RETURN
350350
+ """
@@ -739,12 +739,25 @@ async def get_relevant_edges(
739739
"""
740740
+ filter_query
741741
+ """
742-
WITH n, m, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
743-
WHERE score > $min_score"""
744-
+ ENTITY_EDGE_RETURN
745-
+ """
746-
ORDER BY score DESC
747-
LIMIT $limit
742+
WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
743+
WHERE score > $min_score
744+
WITH edge, e, score
745+
ORDER BY score DESC
746+
RETURN edge.uuid AS search_edge_uuid,
747+
collect({
748+
uuid: e.uuid,
749+
source_node_uuid: startNode(e).uuid,
750+
target_node_uuid: endNode(e).uuid,
751+
created_at: e.created_at,
752+
name: e.name,
753+
group_id: e.group_id,
754+
fact: e.fact,
755+
fact_embedding: e.fact_embedding,
756+
episodes: e.episodes,
757+
expired_at: e.expired_at,
758+
valid_at: e.valid_at,
759+
invalid_at: e.invalid_at
760+
})[..$limit] AS matches
748761
"""
749762
)
750763

@@ -792,15 +805,29 @@ async def get_edge_invalidation_candidates(
792805
"""
793806
+ filter_query
794807
+ """
795-
WITH e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
808+
WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
796809
WHERE score > $min_score
797-
798-
ORDER BY score DESC
799-
LIMIT $limit
810+
WITH edge, e, score
811+
ORDER BY score DESC
812+
RETURN edge.uuid AS search_edge_uuid,
813+
collect({
814+
uuid: e.uuid,
815+
source_node_uuid: startNode(e).uuid,
816+
target_node_uuid: endNode(e).uuid,
817+
created_at: e.created_at,
818+
name: e.name,
819+
group_id: e.group_id,
820+
fact: e.fact,
821+
fact_embedding: e.fact_embedding,
822+
episodes: e.episodes,
823+
expired_at: e.expired_at,
824+
valid_at: e.valid_at,
825+
invalid_at: e.invalid_at
826+
})[..$limit] AS matches
800827
"""
801828
)
802829

803-
records_list, _, _ = await driver.execute_query(
830+
results, _, _ = await driver.execute_query(
804831
query,
805832
query_params,
806833
edges=[edge.model_dump() for edge in edges],
@@ -809,11 +836,16 @@ async def get_edge_invalidation_candidates(
809836
database_=DEFAULT_DATABASE,
810837
routing_='r',
811838
)
812-
relevant_edges = [
813-
[get_entity_edge_from_record(record) for record in records] for records in records_list
814-
]
839+
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
840+
result['search_edge_uuid']: [
841+
get_entity_edge_from_record(record) for record in result['matches']
842+
]
843+
for result in results
844+
}
815845

816-
return relevant_edges
846+
invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]
847+
848+
return invalidation_edges
817849

818850

819851
# takes in a list of rankings of uuids

graphiti_core/utils/bulk_utils.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class RawEpisode(BaseModel):
7171

7272

7373
async def retrieve_previous_episodes_bulk(
74-
driver: AsyncDriver, episodes: list[EpisodicNode]
74+
driver: AsyncDriver, episodes: list[EpisodicNode]
7575
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
7676
previous_episodes_list = await semaphore_gather(
7777
*[
@@ -89,11 +89,11 @@ async def retrieve_previous_episodes_bulk(
8989

9090

9191
async def add_nodes_and_edges_bulk(
92-
driver: AsyncDriver,
93-
episodic_nodes: list[EpisodicNode],
94-
episodic_edges: list[EpisodicEdge],
95-
entity_nodes: list[EntityNode],
96-
entity_edges: list[EntityEdge],
92+
driver: AsyncDriver,
93+
episodic_nodes: list[EpisodicNode],
94+
episodic_edges: list[EpisodicEdge],
95+
entity_nodes: list[EntityNode],
96+
entity_edges: list[EntityEdge],
9797
):
9898
async with driver.session(database=DEFAULT_DATABASE) as session:
9999
await session.execute_write(
@@ -102,11 +102,11 @@ async def add_nodes_and_edges_bulk(
102102

103103

104104
async def add_nodes_and_edges_bulk_tx(
105-
tx: AsyncManagedTransaction,
106-
episodic_nodes: list[EpisodicNode],
107-
episodic_edges: list[EpisodicEdge],
108-
entity_nodes: list[EntityNode],
109-
entity_edges: list[EntityEdge],
105+
tx: AsyncManagedTransaction,
106+
episodic_nodes: list[EpisodicNode],
107+
episodic_edges: list[EpisodicEdge],
108+
entity_nodes: list[EntityNode],
109+
entity_edges: list[EntityEdge],
110110
):
111111
episodes = [dict(episode) for episode in episodic_nodes]
112112
for episode in episodes:
@@ -128,12 +128,14 @@ async def add_nodes_and_edges_bulk_tx(
128128

129129
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
130130
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
131-
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
132-
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
131+
await tx.run(
132+
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
133+
)
134+
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
133135

134136

135137
async def extract_nodes_and_edges_bulk(
136-
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
138+
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
137139
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
138140
extracted_nodes_bulk = await semaphore_gather(
139141
*[
@@ -176,16 +178,16 @@ async def extract_nodes_and_edges_bulk(
176178

177179

178180
async def dedupe_nodes_bulk(
179-
driver: AsyncDriver,
180-
llm_client: LLMClient,
181-
extracted_nodes: list[EntityNode],
181+
driver: AsyncDriver,
182+
llm_client: LLMClient,
183+
extracted_nodes: list[EntityNode],
182184
) -> tuple[list[EntityNode], dict[str, str]]:
183185
# Compress nodes
184186
nodes, uuid_map = node_name_match(extracted_nodes)
185187

186188
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
187189

188-
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
190+
node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
189191

190192
existing_nodes_chunks: list[list[EntityNode]] = list(
191193
await semaphore_gather(
@@ -212,13 +214,13 @@ async def dedupe_nodes_bulk(
212214

213215

214216
async def dedupe_edges_bulk(
215-
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
217+
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
216218
) -> list[EntityEdge]:
217219
# First compress edges
218220
compressed_edges = await compress_edges(llm_client, extracted_edges)
219221

220222
edge_chunks = [
221-
compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
223+
compressed_edges[i: i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
222224
]
223225

224226
relevant_edges_chunks: list[list[EntityEdge]] = list(
@@ -254,7 +256,7 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
254256

255257

256258
async def compress_nodes(
257-
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
259+
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
258260
) -> tuple[list[EntityNode], dict[str, str]]:
259261
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
260262
if len(nodes) == 0:
@@ -375,9 +377,9 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
375377

376378

377379
async def extract_edge_dates_bulk(
378-
llm_client: LLMClient,
379-
extracted_edges: list[EntityEdge],
380-
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
380+
llm_client: LLMClient,
381+
extracted_edges: list[EntityEdge],
382+
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
381383
) -> list[EntityEdge]:
382384
edges: list[EntityEdge] = []
383385
# confirm that all of our edges have at least one episode

0 commit comments

Comments
 (0)