@@ -71,7 +71,7 @@ class RawEpisode(BaseModel):
71
71
72
72
73
73
async def retrieve_previous_episodes_bulk (
74
- driver : AsyncDriver , episodes : list [EpisodicNode ]
74
+ driver : AsyncDriver , episodes : list [EpisodicNode ]
75
75
) -> list [tuple [EpisodicNode , list [EpisodicNode ]]]:
76
76
previous_episodes_list = await semaphore_gather (
77
77
* [
@@ -89,11 +89,11 @@ async def retrieve_previous_episodes_bulk(
89
89
90
90
91
91
async def add_nodes_and_edges_bulk (
92
- driver : AsyncDriver ,
93
- episodic_nodes : list [EpisodicNode ],
94
- episodic_edges : list [EpisodicEdge ],
95
- entity_nodes : list [EntityNode ],
96
- entity_edges : list [EntityEdge ],
92
+ driver : AsyncDriver ,
93
+ episodic_nodes : list [EpisodicNode ],
94
+ episodic_edges : list [EpisodicEdge ],
95
+ entity_nodes : list [EntityNode ],
96
+ entity_edges : list [EntityEdge ],
97
97
):
98
98
async with driver .session (database = DEFAULT_DATABASE ) as session :
99
99
await session .execute_write (
@@ -102,11 +102,11 @@ async def add_nodes_and_edges_bulk(
102
102
103
103
104
104
async def add_nodes_and_edges_bulk_tx (
105
- tx : AsyncManagedTransaction ,
106
- episodic_nodes : list [EpisodicNode ],
107
- episodic_edges : list [EpisodicEdge ],
108
- entity_nodes : list [EntityNode ],
109
- entity_edges : list [EntityEdge ],
105
+ tx : AsyncManagedTransaction ,
106
+ episodic_nodes : list [EpisodicNode ],
107
+ episodic_edges : list [EpisodicEdge ],
108
+ entity_nodes : list [EntityNode ],
109
+ entity_edges : list [EntityEdge ],
110
110
):
111
111
episodes = [dict (episode ) for episode in episodic_nodes ]
112
112
for episode in episodes :
@@ -128,12 +128,14 @@ async def add_nodes_and_edges_bulk_tx(
128
128
129
129
await tx .run (EPISODIC_NODE_SAVE_BULK , episodes = episodes )
130
130
await tx .run (ENTITY_NODE_SAVE_BULK , nodes = nodes )
131
- await tx .run (EPISODIC_EDGE_SAVE_BULK , episodic_edges = [dict (edge ) for edge in episodic_edges ])
132
- await tx .run (ENTITY_EDGE_SAVE_BULK , entity_edges = [dict (edge ) for edge in entity_edges ])
131
+ await tx .run (
132
+ EPISODIC_EDGE_SAVE_BULK , episodic_edges = [edge .model_dump () for edge in episodic_edges ]
133
+ )
134
+ await tx .run (ENTITY_EDGE_SAVE_BULK , entity_edges = [edge .model_dump () for edge in entity_edges ])
133
135
134
136
135
137
async def extract_nodes_and_edges_bulk (
136
- llm_client : LLMClient , episode_tuples : list [tuple [EpisodicNode , list [EpisodicNode ]]]
138
+ llm_client : LLMClient , episode_tuples : list [tuple [EpisodicNode , list [EpisodicNode ]]]
137
139
) -> tuple [list [EntityNode ], list [EntityEdge ], list [EpisodicEdge ]]:
138
140
extracted_nodes_bulk = await semaphore_gather (
139
141
* [
@@ -176,16 +178,16 @@ async def extract_nodes_and_edges_bulk(
176
178
177
179
178
180
async def dedupe_nodes_bulk (
179
- driver : AsyncDriver ,
180
- llm_client : LLMClient ,
181
- extracted_nodes : list [EntityNode ],
181
+ driver : AsyncDriver ,
182
+ llm_client : LLMClient ,
183
+ extracted_nodes : list [EntityNode ],
182
184
) -> tuple [list [EntityNode ], dict [str , str ]]:
183
185
# Compress nodes
184
186
nodes , uuid_map = node_name_match (extracted_nodes )
185
187
186
188
compressed_nodes , compressed_map = await compress_nodes (llm_client , nodes , uuid_map )
187
189
188
- node_chunks = [nodes [i : i + CHUNK_SIZE ] for i in range (0 , len (nodes ), CHUNK_SIZE )]
190
+ node_chunks = [nodes [i : i + CHUNK_SIZE ] for i in range (0 , len (nodes ), CHUNK_SIZE )]
189
191
190
192
existing_nodes_chunks : list [list [EntityNode ]] = list (
191
193
await semaphore_gather (
@@ -212,13 +214,13 @@ async def dedupe_nodes_bulk(
212
214
213
215
214
216
async def dedupe_edges_bulk (
215
- driver : AsyncDriver , llm_client : LLMClient , extracted_edges : list [EntityEdge ]
217
+ driver : AsyncDriver , llm_client : LLMClient , extracted_edges : list [EntityEdge ]
216
218
) -> list [EntityEdge ]:
217
219
# First compress edges
218
220
compressed_edges = await compress_edges (llm_client , extracted_edges )
219
221
220
222
edge_chunks = [
221
- compressed_edges [i : i + CHUNK_SIZE ] for i in range (0 , len (compressed_edges ), CHUNK_SIZE )
223
+ compressed_edges [i : i + CHUNK_SIZE ] for i in range (0 , len (compressed_edges ), CHUNK_SIZE )
222
224
]
223
225
224
226
relevant_edges_chunks : list [list [EntityEdge ]] = list (
@@ -254,7 +256,7 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
254
256
255
257
256
258
async def compress_nodes (
257
- llm_client : LLMClient , nodes : list [EntityNode ], uuid_map : dict [str , str ]
259
+ llm_client : LLMClient , nodes : list [EntityNode ], uuid_map : dict [str , str ]
258
260
) -> tuple [list [EntityNode ], dict [str , str ]]:
259
261
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
260
262
if len (nodes ) == 0 :
@@ -375,9 +377,9 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
375
377
376
378
377
379
async def extract_edge_dates_bulk (
378
- llm_client : LLMClient ,
379
- extracted_edges : list [EntityEdge ],
380
- episode_pairs : list [tuple [EpisodicNode , list [EpisodicNode ]]],
380
+ llm_client : LLMClient ,
381
+ extracted_edges : list [EntityEdge ],
382
+ episode_pairs : list [tuple [EpisodicNode , list [EpisodicNode ]]],
381
383
) -> list [EntityEdge ]:
382
384
edges : list [EntityEdge ] = []
383
385
# confirm that all of our edges have at least one episode
0 commit comments