Skip to content

Commit 3bef9bb

Browse files
authored
Merge pull request #154 from epage480/pass-common-params-graph
Pass common params to nodes in graph
2 parents da8c72c + cc27b21 commit 3bef9bb

20 files changed

+93
-113
lines changed

examples/openai/custom_graph_openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
robot_node = RobotsNode(
3535
input="url",
3636
output=["is_scrapable"],
37-
node_config={"llm": llm_model}
37+
node_config={"llm_model": llm_model}
3838
)
3939

4040
fetch_node = FetchNode(
@@ -50,12 +50,12 @@
5050
rag_node = RAGNode(
5151
input="user_prompt & (parsed_doc | doc)",
5252
output=["relevant_chunks"],
53-
node_config={"llm": llm_model},
53+
node_config={"llm_model": llm_model},
5454
)
5555
generate_answer_node = GenerateAnswerNode(
5656
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
5757
output=["answer"],
58-
node_config={"llm": llm_model},
58+
node_config={"llm_model": llm_model},
5959
)
6060

6161
# ************************************************

examples/single_node/robot_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
robots_node = RobotsNode(
2727
input="url",
2828
output=["is_scrapable"],
29-
node_config={"llm": llm_model,
29+
node_config={"llm_model": llm_model,
3030
"headless": False
3131
}
3232
)

scrapegraphai/graphs/abstract_graph.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,32 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
5252
) if "embeddings" not in config else self._create_embedder(
5353
config["embeddings"])
5454

55+
# Create the graph
56+
self.graph = self._create_graph()
57+
self.final_state = None
58+
self.execution_info = None
59+
5560
# Set common configuration parameters
5661
self.verbose = True if config is None else config.get("verbose", False)
5762
self.headless = True if config is None else config.get(
5863
"headless", True)
64+
common_params = {"headless": self.headless,
65+
"verbose": self.verbose,
66+
"llm_model": self.llm_model,
67+
"embedder_model": self.embedder_model}
68+
self.set_common_params(common_params, overwrite=False)
5969

60-
# Create the graph
61-
self.graph = self._create_graph()
62-
self.final_state = None
63-
self.execution_info = None
70+
71+
def set_common_params(self, params: dict, overwrite=False):
72+
"""
73+
Pass parameters to every node in the graph unless otherwise defined in the graph.
74+
75+
Args:
76+
params (dict): Common parameters and their values.
77+
"""
78+
79+
for node in self.graph.nodes:
80+
node.update_config(params, overwrite)
6481

6582
def _set_model_token(self, llm):
6683

scrapegraphai/graphs/csv_scraper_graph.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,34 +32,27 @@ def _create_graph(self):
3232
fetch_node = FetchNode(
3333
input="csv_dir",
3434
output=["doc"],
35-
node_config={
36-
"headless": self.headless,
37-
"verbose": self.verbose
38-
}
3935
)
4036
parse_node = ParseNode(
4137
input="doc",
4238
output=["parsed_doc"],
4339
node_config={
4440
"chunk_size": self.model_token,
45-
"verbose": self.verbose
4641
}
4742
)
4843
rag_node = RAGNode(
4944
input="user_prompt & (parsed_doc | doc)",
5045
output=["relevant_chunks"],
5146
node_config={
52-
"llm": self.llm_model,
47+
"llm_model": self.llm_model,
5348
"embedder_model": self.embedder_model,
54-
"verbose": self.verbose
5549
}
5650
)
5751
generate_answer_node = GenerateAnswerCSVNode(
5852
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
5953
output=["answer"],
6054
node_config={
61-
"llm": self.llm_model,
62-
"verbose": self.verbose
55+
"llm_model": self.llm_model,
6356
}
6457
)
6558

@@ -85,4 +78,4 @@ def run(self) -> str:
8578
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
8679
self.final_state, self.execution_info = self.graph.execute(inputs)
8780

88-
return self.final_state.get("answer", "No answer found.")
81+
return self.final_state.get("answer", "No answer found.")

scrapegraphai/graphs/json_scraper_graph.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,34 +56,27 @@ def _create_graph(self) -> BaseGraph:
5656
fetch_node = FetchNode(
5757
input="json_dir",
5858
output=["doc"],
59-
node_config={
60-
"headless": self.headless,
61-
"verbose": self.verbose
62-
}
6359
)
6460
parse_node = ParseNode(
6561
input="doc",
6662
output=["parsed_doc"],
6763
node_config={
68-
"chunk_size": self.model_token,
69-
"verbose": self.verbose
64+
"chunk_size": self.model_token
7065
}
7166
)
7267
rag_node = RAGNode(
7368
input="user_prompt & (parsed_doc | doc)",
7469
output=["relevant_chunks"],
7570
node_config={
76-
"llm": self.llm_model,
77-
"embedder_model": self.embedder_model,
78-
"verbose": self.verbose
71+
"llm_model": self.llm_model,
72+
"embedder_model": self.embedder_model
7973
}
8074
)
8175
generate_answer_node = GenerateAnswerNode(
8276
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8377
output=["answer"],
8478
node_config={
85-
"llm": self.llm_model,
86-
"verbose": self.verbose
79+
"llm_model": self.llm_model
8780
}
8881
)
8982

@@ -113,4 +106,4 @@ def run(self) -> str:
113106
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
114107
self.final_state, self.execution_info = self.graph.execute(inputs)
115108

116-
return self.final_state.get("answer", "No answer found.")
109+
return self.final_state.get("answer", "No answer found.")

scrapegraphai/graphs/script_creator_graph.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,32 +61,25 @@ def _create_graph(self) -> BaseGraph:
6161
fetch_node = FetchNode(
6262
input="url | local_dir",
6363
output=["doc"],
64-
node_config={
65-
"headless": self.headless,
66-
"verbose": self.verbose
67-
}
6864
)
6965
parse_node = ParseNode(
7066
input="doc",
7167
output=["parsed_doc"],
7268
node_config={"chunk_size": self.model_token,
73-
"verbose": self.verbose
7469
}
7570
)
7671
rag_node = RAGNode(
7772
input="user_prompt & (parsed_doc | doc)",
7873
output=["relevant_chunks"],
7974
node_config={
80-
"llm": self.llm_model,
81-
"embedder_model": self.embedder_model,
82-
"verbose": self.verbose
75+
"llm_model": self.llm_model,
76+
"embedder_model": self.embedder_model
8377
}
8478
)
8579
generate_scraper_node = GenerateScraperNode(
8680
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8781
output=["answer"],
88-
node_config={"llm": self.llm_model,
89-
"verbose": self.verbose},
82+
node_config={"llm_model": self.llm_model},
9083
library=self.library,
9184
website=self.source
9285
)
@@ -117,4 +110,4 @@ def run(self) -> str:
117110
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
118111
self.final_state, self.execution_info = self.graph.execute(inputs)
119112

120-
return self.final_state.get("answer", "No answer found.")
113+
return self.final_state.get("answer", "No answer found.")

scrapegraphai/graphs/search_graph.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,41 +50,33 @@ def _create_graph(self) -> BaseGraph:
5050
input="user_prompt",
5151
output=["url"],
5252
node_config={
53-
"llm": self.llm_model,
54-
"verbose": self.verbose
53+
"llm_model": self.llm_model
5554
}
5655
)
5756
fetch_node = FetchNode(
5857
input="url | local_dir",
59-
output=["doc"],
60-
node_config={
61-
"headless": self.headless,
62-
"verbose": self.verbose
63-
}
58+
output=["doc"]
6459
)
6560
parse_node = ParseNode(
6661
input="doc",
6762
output=["parsed_doc"],
6863
node_config={
69-
"chunk_size": self.model_token,
70-
"verbose": self.verbose
64+
"chunk_size": self.model_token
7165
}
7266
)
7367
rag_node = RAGNode(
7468
input="user_prompt & (parsed_doc | doc)",
7569
output=["relevant_chunks"],
7670
node_config={
77-
"llm": self.llm_model,
78-
"embedder_model": self.embedder_model,
79-
"verbose": self.verbose
71+
"llm_model": self.llm_model,
72+
"embedder_model": self.embedder_model
8073
}
8174
)
8275
generate_answer_node = GenerateAnswerNode(
8376
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8477
output=["answer"],
8578
node_config={
86-
"llm": self.llm_model,
87-
"verbose": self.verbose
79+
"llm_model": self.llm_model
8880
}
8981
)
9082

@@ -116,4 +108,4 @@ def run(self) -> str:
116108
inputs = {"user_prompt": self.prompt}
117109
self.final_state, self.execution_info = self.graph.execute(inputs)
118110

119-
return self.final_state.get("answer", "No answer found.")
111+
return self.final_state.get("answer", "No answer found.")

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,35 +57,28 @@ def _create_graph(self) -> BaseGraph:
5757
"""
5858
fetch_node = FetchNode(
5959
input="url | local_dir",
60-
output=["doc"],
61-
node_config={
62-
"headless": self.headless,
63-
"verbose": self.verbose
64-
}
60+
output=["doc"]
6561
)
6662
parse_node = ParseNode(
6763
input="doc",
6864
output=["parsed_doc"],
6965
node_config={
70-
"chunk_size": self.model_token,
71-
"verbose": self.verbose
66+
"chunk_size": self.model_token
7267
}
7368
)
7469
rag_node = RAGNode(
7570
input="user_prompt & (parsed_doc | doc)",
7671
output=["relevant_chunks"],
7772
node_config={
78-
"llm": self.llm_model,
79-
"embedder_model": self.embedder_model,
80-
"verbose": self.verbose
73+
"llm_model": self.llm_model,
74+
"embedder_model": self.embedder_model
8175
}
8276
)
8377
generate_answer_node = GenerateAnswerNode(
8478
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8579
output=["answer"],
8680
node_config={
87-
"llm": self.llm_model,
88-
"verbose": self.verbose
81+
"llm_model": self.llm_model
8982
}
9083
)
9184

@@ -115,4 +108,4 @@ def run(self) -> str:
115108
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
116109
self.final_state, self.execution_info = self.graph.execute(inputs)
117110

118-
return self.final_state.get("answer", "No answer found.")
111+
return self.final_state.get("answer", "No answer found.")

scrapegraphai/graphs/speech_graph.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,43 +56,34 @@ def _create_graph(self) -> BaseGraph:
5656

5757
fetch_node = FetchNode(
5858
input="url | local_dir",
59-
output=["doc"],
60-
node_config={
61-
"headless": self.headless,
62-
"verbose": self.verbose
63-
}
59+
output=["doc"]
6460
)
6561
parse_node = ParseNode(
6662
input="doc",
6763
output=["parsed_doc"],
6864
node_config={
69-
"chunk_size": self.model_token,
70-
"verbose": self.verbose
65+
"chunk_size": self.model_token
7166
}
7267
)
7368
rag_node = RAGNode(
7469
input="user_prompt & (parsed_doc | doc)",
7570
output=["relevant_chunks"],
7671
node_config={
77-
"llm": self.llm_model,
78-
"embedder_model": self.embedder_model,
79-
"verbose": self.verbose
80-
}
72+
"llm_model": self.llm_model,
73+
"embedder_model": self.embedder_model }
8174
)
8275
generate_answer_node = GenerateAnswerNode(
8376
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8477
output=["answer"],
8578
node_config={
86-
"llm": self.llm_model,
87-
"verbose": self.verbose
79+
"llm_model": self.llm_model
8880
}
8981
)
9082
text_to_speech_node = TextToSpeechNode(
9183
input="answer",
9284
output=["audio"],
9385
node_config={
94-
"tts_model": OpenAITextToSpeech(self.config["tts_model"]),
95-
"verbose": self.verbose
86+
"tts_model": OpenAITextToSpeech(self.config["tts_model"])
9687
}
9788
)
9889

@@ -131,4 +122,4 @@ def run(self) -> str:
131122
"output_path", "output.mp3"))
132123
print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}")
133124

134-
return self.final_state.get("answer", "No answer found.")
125+
return self.final_state.get("answer", "No answer found.")

0 commit comments

Comments
 (0)