Skip to content

Commit ad693b2

Browse files
committed
fix: ollama tokenizer limited to 1024 tokens + ollama structured output + fix browser backend
1 parent 1a01912 commit ad693b2

12 files changed

+55
-69
lines changed

examples/local_models/smart_scraper_ollama.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"temperature": 0,
1616
"format": "json", # Ollama needs the format to be specified explicitly
1717
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
18-
"model_tokens": 1024,
18+
"model_tokens": 4096,
1919
},
2020
"verbose": True,
2121
"headless": False,
@@ -25,7 +25,7 @@
2525
# Create the SmartScraperGraph instance and run it
2626
# ************************************************
2727
smart_scraper_graph = SmartScraperGraph(
28-
prompt="Find some information about what does the company do, the name and a contact email.",
28+
prompt="Find some information about what does the company do and the list of founders.",
2929
source="https://scrapegraphai.com/",
3030
config=graph_config,
3131
)
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
1-
"""
1+
"""
22
Basic example of scraping pipeline using SmartScraper with schema
33
"""
4+
45
import json
5-
from typing import List
6+
67
from pydantic import BaseModel, Field
8+
79
from scrapegraphai.graphs import SmartScraperGraph
810
from scrapegraphai.utils import prettify_exec_info
911

12+
1013
# ************************************************
1114
# Define the configuration for the graph
1215
# ************************************************
1316
class Project(BaseModel):
1417
title: str = Field(description="The title of the project")
1518
description: str = Field(description="The description of the project")
1619

20+
1721
class Projects(BaseModel):
18-
projects: List[Project]
22+
projects: list[Project]
23+
1924

2025
graph_config = {
21-
"llm": {
22-
"model": "ollama/llama3.1",
23-
"temperature": 0,
24-
"format": "json", # Ollama needs the format to be specified explicitly
25-
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
26-
},
26+
"llm": {"model": "ollama/llama3.2", "temperature": 0, "model_tokens": 4096},
2727
"verbose": True,
28-
"headless": False
28+
"headless": False,
2929
}
3030

3131
# ************************************************
@@ -36,8 +36,15 @@ class Projects(BaseModel):
3636
prompt="List me all the projects with their description",
3737
source="https://perinim.github.io/projects/",
3838
schema=Projects,
39-
config=graph_config
39+
config=graph_config,
4040
)
4141

4242
result = smart_scraper_graph.run()
4343
print(json.dumps(result, indent=4))
44+
45+
# ************************************************
46+
# Get graph execution info
47+
# ************************************************
48+
49+
graph_exec_info = smart_scraper_graph.get_execution_info()
50+
print(prettify_exec_info(graph_exec_info))

pyproject.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ dependencies = [
3030
"googlesearch-python>=1.2.5",
3131
"async-timeout>=4.0.3",
3232
"simpleeval>=1.0.0",
33-
"jsonschema>=4.23.0",
34-
"transformers>=4.46.3",
33+
"jsonschema>=4.23.0"
3534
]
3635

3736
readme = "README.md"

scrapegraphai/docloaders/chromium.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ def __init__(
6161

6262
dynamic_import(backend, message)
6363

64-
self.backend = backend
6564
self.browser_config = kwargs
6665
self.headless = headless
6766
self.proxy = parse_or_search_proxy(proxy) if proxy else None
6867
self.urls = urls
6968
self.load_state = load_state
7069
self.requires_js_support = requires_js_support
7170
self.storage_state = storage_state
72-
self.browser_name = browser_name
71+
self.backend = kwargs.get("backend", backend)
72+
self.browser_name = kwargs.get("browser_name", browser_name)
7373
self.retry_limit = kwargs.get("retry_limit", retry_limit)
7474
self.timeout = kwargs.get("timeout", timeout)
7575

scrapegraphai/graphs/abstract_graph.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,9 @@ def _create_llm(self, llm_config: dict) -> object:
203203
]
204204
except KeyError:
205205
print(
206-
f"""Model {llm_params['model_provider']}/{llm_params['model']} not found,
207-
using default token size (8192)"""
206+
f"""Max input tokens for model {llm_params['model_provider']}/{llm_params['model']} not found,
207+
please specify the model_tokens parameter in the llm section of the graph configuration.
208+
Using default token size: 8192"""
208209
)
209210
self.model_token = 8192
210211
else:

scrapegraphai/nodes/generate_answer_node.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from langchain_community.chat_models import ChatOllama
1111
from langchain_core.output_parsers import JsonOutputParser
1212
from langchain_core.runnables import RunnableParallel
13-
from langchain_openai import AzureChatOpenAI, ChatOpenAI
13+
from langchain_openai import ChatOpenAI
1414
from requests.exceptions import Timeout
1515
from tqdm import tqdm
1616

@@ -59,7 +59,10 @@ def __init__(
5959
self.llm_model = node_config["llm_model"]
6060

6161
if isinstance(node_config["llm_model"], ChatOllama):
62-
self.llm_model.format = "json"
62+
if node_config.get("schema", None) is None:
63+
self.llm_model.format = "json"
64+
else:
65+
self.llm_model.format = self.node_config["schema"].model_json_schema()
6366

6467
self.verbose = node_config.get("verbose", False)
6568
self.force = node_config.get("force", False)
@@ -123,8 +126,7 @@ def execute(self, state: dict) -> dict:
123126
format_instructions = ""
124127

125128
if (
126-
isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI))
127-
and not self.script_creator
129+
not self.script_creator
128130
or self.force
129131
and not self.script_creator
130132
or self.is_md_scraper

scrapegraphai/nodes/generate_answer_node_k_level.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
from langchain.prompts import PromptTemplate
88
from langchain_aws import ChatBedrock
9+
from langchain_community.chat_models import ChatOllama
910
from langchain_core.output_parsers import JsonOutputParser
1011
from langchain_core.runnables import RunnableParallel
1112
from langchain_mistralai import ChatMistralAI
12-
from langchain_openai import AzureChatOpenAI, ChatOpenAI
13+
from langchain_openai import ChatOpenAI
1314
from tqdm import tqdm
1415

1516
from ..prompts import (
@@ -55,6 +56,13 @@ def __init__(
5556
super().__init__(node_name, "node", input, output, 2, node_config)
5657

5758
self.llm_model = node_config["llm_model"]
59+
60+
if isinstance(node_config["llm_model"], ChatOllama):
61+
if node_config.get("schema", None) is None:
62+
self.llm_model.format = "json"
63+
else:
64+
self.llm_model.format = self.node_config["schema"].model_json_schema()
65+
5866
self.embedder_model = node_config.get("embedder_model", None)
5967
self.verbose = node_config.get("verbose", False)
6068
self.force = node_config.get("force", False)
@@ -92,8 +100,7 @@ def execute(self, state: dict) -> dict:
92100
format_instructions = ""
93101

94102
if (
95-
isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI))
96-
and not self.script_creator
103+
not self.script_creator
97104
or self.force
98105
and not self.script_creator
99106
or self.is_md_scraper

scrapegraphai/nodes/parse_node.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def execute(self, state: dict) -> dict:
9696
chunks = split_text_into_chunks(
9797
text=docs_transformed.page_content,
9898
chunk_size=self.chunk_size - 250,
99-
model=self.llm_model,
10099
)
101100
else:
102101
docs_transformed = docs_transformed[0]
@@ -115,11 +114,10 @@ def execute(self, state: dict) -> dict:
115114
chunks = split_text_into_chunks(
116115
text=docs_transformed.page_content,
117116
chunk_size=chunk_size,
118-
model=self.llm_model,
119117
)
120118
else:
121119
chunks = split_text_into_chunks(
122-
text=docs_transformed, chunk_size=chunk_size, model=self.llm_model
120+
text=docs_transformed, chunk_size=chunk_size
123121
)
124122

125123
state.update({self.output[0]: chunks})

scrapegraphai/utils/split_text_into_chunks.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,10 @@
44

55
from typing import List
66

7-
from langchain_core.language_models.chat_models import BaseChatModel
8-
97
from .tokenizer import num_tokens_calculus
108

119

12-
def split_text_into_chunks(
13-
text: str, chunk_size: int, model: BaseChatModel, use_semchunk=True
14-
) -> List[str]:
10+
def split_text_into_chunks(text: str, chunk_size: int, use_semchunk=True) -> List[str]:
1511
"""
1612
Splits the text into chunks based on the number of tokens.
1713
@@ -27,17 +23,17 @@ def split_text_into_chunks(
2723
from semchunk import chunk
2824

2925
def count_tokens(text):
30-
return num_tokens_calculus(text, model)
26+
return num_tokens_calculus(text)
3127

32-
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
28+
chunk_size = min(chunk_size, int(chunk_size * 0.9))
3329

3430
chunks = chunk(
3531
text=text, chunk_size=chunk_size, token_counter=count_tokens, memoize=False
3632
)
3733
return chunks
3834

3935
else:
40-
tokens = num_tokens_calculus(text, model)
36+
tokens = num_tokens_calculus(text)
4137

4238
if tokens <= chunk_size:
4339
return [text]
@@ -48,7 +44,7 @@ def count_tokens(text):
4844

4945
words = text.split()
5046
for word in words:
51-
word_tokens = num_tokens_calculus(word, model)
47+
word_tokens = num_tokens_calculus(word)
5248
if current_length + word_tokens > chunk_size:
5349
chunks.append(" ".join(current_chunk))
5450
current_chunk = [word]

scrapegraphai/utils/tokenizer.py

+4-24
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,15 @@
22
Module for counting tokens and splitting text into chunks
33
"""
44

5-
from langchain_core.language_models.chat_models import BaseChatModel
6-
from langchain_mistralai import ChatMistralAI
7-
from langchain_ollama import ChatOllama
8-
from langchain_openai import ChatOpenAI
5+
from .tokenizers.tokenizer_openai import num_tokens_openai
96

107

11-
def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
8+
def num_tokens_calculus(string: str) -> int:
129
"""
1310
Returns the number of tokens in a text string.
1411
"""
15-
if isinstance(llm_model, ChatOpenAI):
16-
from .tokenizers.tokenizer_openai import num_tokens_openai
1712

18-
num_tokens_fn = num_tokens_openai
13+
num_tokens_fn = num_tokens_openai
1914

20-
elif isinstance(llm_model, ChatMistralAI):
21-
from .tokenizers.tokenizer_mistral import num_tokens_mistral
22-
23-
num_tokens_fn = num_tokens_mistral
24-
25-
elif isinstance(llm_model, ChatOllama):
26-
from .tokenizers.tokenizer_ollama import num_tokens_ollama
27-
28-
num_tokens_fn = num_tokens_ollama
29-
30-
else:
31-
from .tokenizers.tokenizer_openai import num_tokens_openai
32-
33-
num_tokens_fn = num_tokens_openai
34-
35-
num_tokens = num_tokens_fn(string, llm_model)
15+
num_tokens = num_tokens_fn(string)
3616
return num_tokens

scrapegraphai/utils/tokenizers/tokenizer_openai.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,17 @@
33
"""
44

55
import tiktoken
6-
from langchain_core.language_models.chat_models import BaseChatModel
76

87
from ..logging import get_logger
98

109

11-
def num_tokens_openai(text: str, llm_model: BaseChatModel) -> int:
10+
def num_tokens_openai(text: str) -> int:
1211
"""
1312
Estimate the number of tokens in a given text using OpenAI's tokenization method,
1413
adjusted for different OpenAI models.
1514
1615
Args:
1716
text (str): The text to be tokenized and counted.
18-
llm_model (BaseChatModel): The specific OpenAI model to adjust tokenization.
1917
2018
Returns:
2119
int: The number of tokens in the text.
@@ -25,7 +23,7 @@ def num_tokens_openai(text: str, llm_model: BaseChatModel) -> int:
2523

2624
logger.debug(f"Counting tokens for text of {len(text)} characters")
2725

28-
encoding = tiktoken.encoding_for_model("gpt-4")
26+
encoding = tiktoken.encoding_for_model("gpt-4o")
2927

3028
num_tokens = len(encoding.encode(text))
3129
return num_tokens

uv.lock

+1-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)