Skip to content

Commit df1645c

Browse files
author
Pedro Perez de Ayala
committed
fix: Fixes schema option not working
2 parents db3afad + 562a97c commit df1645c

File tree

4 files changed

+58
-29
lines changed

4 files changed

+58
-29
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ dependencies = [
3131
"async-timeout>=4.0.3",
3232
"simpleeval>=1.0.0",
3333
"jsonschema>=4.23.0",
34-
"duckduckgo-search>=7.2.1"
34+
"duckduckgo-search>=7.2.1",
35+
"pydantic>=2.10.2",
3536
]
3637

3738
readme = "README.md"

scrapegraphai/nodes/generate_answer_node.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
GenerateAnswerNode Module
33
"""
44

5-
import time
65
import json
6+
import time
77
from typing import List, Optional
88

99
from langchain.prompts import PromptTemplate
@@ -105,10 +105,7 @@ def process(self, state: dict) -> dict:
105105
raise ValueError("No user prompt found in state")
106106

107107
# Create the chain input with both content and question keys
108-
chain_input = {
109-
"content": content,
110-
"question": user_prompt
111-
}
108+
chain_input = {"content": content, "question": user_prompt}
112109

113110
try:
114111
response = self.invoke_with_timeout(self.chain, chain_input, self.timeout)
@@ -167,25 +164,13 @@ def execute(self, state: dict) -> dict:
167164
and not self.script_creator
168165
or self.is_md_scraper
169166
):
170-
template_no_chunks_prompt = (
171-
TEMPLATE_NO_CHUNKS_MD + "\n\nIMPORTANT: " + format_instructions
172-
)
173-
template_chunks_prompt = (
174-
TEMPLATE_CHUNKS_MD + "\n\nIMPORTANT: " + format_instructions
175-
)
176-
template_merge_prompt = (
177-
TEMPLATE_MERGE_MD + "\n\nIMPORTANT: " + format_instructions
178-
)
167+
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
168+
template_chunks_prompt = TEMPLATE_CHUNKS_MD
169+
template_merge_prompt = TEMPLATE_MERGE_MD
179170
else:
180-
template_no_chunks_prompt = (
181-
TEMPLATE_NO_CHUNKS + "\n\nIMPORTANT: " + format_instructions
182-
)
183-
template_chunks_prompt = (
184-
TEMPLATE_CHUNKS + "\n\nIMPORTANT: " + format_instructions
185-
)
186-
template_merge_prompt = (
187-
TEMPLATE_MERGE + "\n\nIMPORTANT: " + format_instructions
188-
)
171+
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
172+
template_chunks_prompt = TEMPLATE_CHUNKS
173+
template_merge_prompt = TEMPLATE_MERGE
189174

190175
if self.additional_info is not None:
191176
template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt
@@ -210,8 +195,14 @@ def execute(self, state: dict) -> dict:
210195
chain, {"question": user_prompt}, self.timeout
211196
)
212197
except (Timeout, json.JSONDecodeError) as e:
213-
error_msg = "Response timeout exceeded" if isinstance(e, Timeout) else "Invalid JSON response format"
214-
state.update({self.output[0]: {"error": error_msg, "raw_response": str(e)}})
198+
error_msg = (
199+
"Response timeout exceeded"
200+
if isinstance(e, Timeout)
201+
else "Invalid JSON response format"
202+
)
203+
state.update(
204+
{self.output[0]: {"error": error_msg, "raw_response": str(e)}}
205+
)
215206
return state
216207

217208
state.update({self.output[0]: answer})
@@ -241,7 +232,11 @@ def execute(self, state: dict) -> dict:
241232
async_runner, {"question": user_prompt}, self.timeout
242233
)
243234
except (Timeout, json.JSONDecodeError) as e:
244-
error_msg = "Response timeout exceeded during chunk processing" if isinstance(e, Timeout) else "Invalid JSON response format in chunk processing"
235+
error_msg = (
236+
"Response timeout exceeded during chunk processing"
237+
if isinstance(e, Timeout)
238+
else "Invalid JSON response format in chunk processing"
239+
)
245240
state.update({self.output[0]: {"error": error_msg, "raw_response": str(e)}})
246241
return state
247242

@@ -261,7 +256,11 @@ def execute(self, state: dict) -> dict:
261256
self.timeout,
262257
)
263258
except (Timeout, json.JSONDecodeError) as e:
264-
error_msg = "Response timeout exceeded during merge" if isinstance(e, Timeout) else "Invalid JSON response format during merge"
259+
error_msg = (
260+
"Response timeout exceeded during merge"
261+
if isinstance(e, Timeout)
262+
else "Invalid JSON response format during merge"
263+
)
265264
state.update({self.output[0]: {"error": error_msg, "raw_response": str(e)}})
266265
return state
267266

tests/graphs/smart_scraper_openai_test.py

+25
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
from dotenv import load_dotenv
9+
from pydantic import BaseModel
910

1011
from scrapegraphai.graphs import SmartScraperGraph
1112

@@ -53,3 +54,27 @@ def test_get_execution_info(graph_config):
5354
graph_exec_info = smart_scraper_graph.get_execution_info()
5455

5556
assert graph_exec_info is not None
57+
58+
59+
def test_get_execution_info_with_schema(graph_config):
60+
"""Get the execution info with schema"""
61+
62+
class ProjectSchema(BaseModel):
63+
title: str
64+
description: str
65+
66+
class ProjectListSchema(BaseModel):
67+
projects: list[ProjectSchema]
68+
69+
smart_scraper_graph = SmartScraperGraph(
70+
prompt="List me all the projects with their description.",
71+
source="https://perinim.github.io/projects/",
72+
config=graph_config,
73+
schema=ProjectListSchema,
74+
)
75+
76+
smart_scraper_graph.run()
77+
78+
graph_exec_info = smart_scraper_graph.get_execution_info()
79+
80+
assert graph_exec_info is not None

uv.lock

+5-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)