Skip to content

Commit a86e7d6

Browse files
authored
enhancement: add support for Playwright's storage_state parameter (#832)
* add support for Playwright `storage_state` * add storage_state param to node_config * add sleep for testing * add sleep in _with_js_support for testing * remove asyncio.sleep() from tests * fix typo in existing example filename; add auth example * add example `authenticated_playwright` * update source link in example to /feed * add `storage_state` to missing graphs
1 parent fbb4252 commit a86e7d6

16 files changed

+585
-329
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Example leveraging a state file containing session cookies which
3+
might be leveraged to authenticate to a website and scrape protected
4+
content.
5+
"""
6+
7+
import os
8+
import random
9+
from dotenv import load_dotenv
10+
11+
# import playwright so we can use it to create the state file
12+
from playwright.async_api import async_playwright
13+
14+
from scrapegraphai.graphs import OmniScraperGraph
15+
from scrapegraphai.utils import prettify_exec_info
16+
17+
load_dotenv()
18+
19+
# ************************************************
20+
# Leveraging Playwright external to the invocation of the graph to
21+
# login and create the state file
22+
# ************************************************
23+
24+
25+
# note this is just an example and probably won't actually work on
26+
# LinkedIn, the implementation of the login is highly dependent on the website
27+
async def do_login():
28+
async with async_playwright() as playwright:
29+
browser = await playwright.chromium.launch(
30+
timeout=30000,
31+
headless=False,
32+
slow_mo=random.uniform(500, 1500),
33+
)
34+
page = await browser.new_page()
35+
36+
# very basic implementation of a login, in reality it may be trickier
37+
await page.goto("https://www.linkedin.com/login")
38+
await page.get_by_label("Email or phone").fill("some_bloke@some_domain.com")
39+
await page.get_by_label("Password").fill("test1234")
40+
await page.get_by_role("button", name="Sign in").click()
41+
await page.wait_for_timeout(3000)
42+
43+
# assuming a successful login, we save the cookies to a file
44+
await page.context.storage_state(path="./state.json")
45+
46+
47+
async def main():
48+
await do_login()
49+
50+
# ************************************************
51+
# Define the configuration for the graph
52+
# ************************************************
53+
54+
openai_api_key = os.getenv("OPENAI_APIKEY")
55+
56+
graph_config = {
57+
"llm": {
58+
"api_key": openai_api_key,
59+
"model": "openai/gpt-4o",
60+
},
61+
"max_images": 10,
62+
"headless": False,
63+
# provide the path to the state file
64+
"storage_state": "./state.json",
65+
}
66+
67+
# ************************************************
68+
# Create the OmniScraperGraph instance and run it
69+
# ************************************************
70+
71+
omni_scraper_graph = OmniScraperGraph(
72+
prompt="List me all the projects with their description.",
73+
source="https://www.linkedin.com/feed/",
74+
config=graph_config,
75+
)
76+
77+
# the storage_state is used to load the cookies from the state file
78+
# so we are authenticated and able to scrape protected content
79+
result = omni_scraper_graph.run()
80+
print(result)
81+
82+
# ************************************************
83+
# Get graph execution info
84+
# ************************************************
85+
86+
graph_exec_info = omni_scraper_graph.get_execution_info()
87+
print(prettify_exec_info(graph_exec_info))
88+
89+
90+
if __name__ == "__main__":
91+
import asyncio
92+
93+
asyncio.run(main())

scrapegraphai/docloaders/chromium.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
logger = get_logger("web-loader")
1010

11+
1112
class ChromiumLoader(BaseLoader):
1213
"""Scrapes HTML pages from URLs using a (headless) instance of the
1314
Chromium web driver with proxy protection.
@@ -33,6 +34,7 @@ def __init__(
3334
proxy: Optional[Proxy] = None,
3435
load_state: str = "domcontentloaded",
3536
requires_js_support: bool = False,
37+
storage_state: Optional[str] = None,
3638
**kwargs: Any,
3739
):
3840
"""Initialize the loader with a list of URL paths.
@@ -62,6 +64,7 @@ def __init__(
6264
self.urls = urls
6365
self.load_state = load_state
6466
self.requires_js_support = requires_js_support
67+
self.storage_state = storage_state
6568

6669
async def ascrape_undetected_chromedriver(self, url: str) -> str:
6770
"""
@@ -91,7 +94,9 @@ async def ascrape_undetected_chromedriver(self, url: str) -> str:
9194
attempt += 1
9295
logger.error(f"Attempt {attempt} failed: {e}")
9396
if attempt == self.RETRY_LIMIT:
94-
results = f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}"
97+
results = (
98+
f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}"
99+
)
95100
finally:
96101
driver.quit()
97102

@@ -113,7 +118,9 @@ async def ascrape_playwright(self, url: str) -> str:
113118
browser = await p.chromium.launch(
114119
headless=self.headless, proxy=self.proxy, **self.browser_config
115120
)
116-
context = await browser.new_context()
121+
context = await browser.new_context(
122+
storage_state=self.storage_state
123+
)
117124
await Malenia.apply_stealth(context)
118125
page = await context.new_page()
119126
await page.goto(url, wait_until="domcontentloaded")
@@ -125,9 +132,11 @@ async def ascrape_playwright(self, url: str) -> str:
125132
attempt += 1
126133
logger.error(f"Attempt {attempt} failed: {e}")
127134
if attempt == self.RETRY_LIMIT:
128-
raise RuntimeError(f"Failed to fetch {url} after {self.RETRY_LIMIT} attempts: {e}")
135+
raise RuntimeError(
136+
f"Failed to fetch {url} after {self.RETRY_LIMIT} attempts: {e}"
137+
)
129138
finally:
130-
if 'browser' in locals():
139+
if "browser" in locals():
131140
await browser.close()
132141

133142
async def ascrape_with_js_support(self, url: str) -> str:
@@ -138,7 +147,7 @@ async def ascrape_with_js_support(self, url: str) -> str:
138147
url (str): The URL to scrape.
139148
140149
Returns:
141-
str: The fully rendered HTML content after JavaScript execution,
150+
str: The fully rendered HTML content after JavaScript execution,
142151
or an error message if an exception occurs.
143152
"""
144153
from playwright.async_api import async_playwright
@@ -153,7 +162,9 @@ async def ascrape_with_js_support(self, url: str) -> str:
153162
browser = await p.chromium.launch(
154163
headless=self.headless, proxy=self.proxy, **self.browser_config
155164
)
156-
context = await browser.new_context()
165+
context = await browser.new_context(
166+
storage_state=self.storage_state
167+
)
157168
page = await context.new_page()
158169
await page.goto(url, wait_until="networkidle")
159170
results = await page.content()
@@ -163,7 +174,9 @@ async def ascrape_with_js_support(self, url: str) -> str:
163174
attempt += 1
164175
logger.error(f"Attempt {attempt} failed: {e}")
165176
if attempt == self.RETRY_LIMIT:
166-
results = f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}"
177+
results = (
178+
f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}"
179+
)
167180
finally:
168181
await browser.close()
169182

@@ -180,7 +193,9 @@ def lazy_load(self) -> Iterator[Document]:
180193
Document: The scraped content encapsulated within a Document object.
181194
"""
182195
scraping_fn = (
183-
self.ascrape_with_js_support if self.requires_js_support else getattr(self, f"ascrape_{self.backend}")
196+
self.ascrape_with_js_support
197+
if self.requires_js_support
198+
else getattr(self, f"ascrape_{self.backend}")
184199
)
185200

186201
for url in self.urls:
@@ -202,7 +217,9 @@ async def alazy_load(self) -> AsyncIterator[Document]:
202217
source URL as metadata.
203218
"""
204219
scraping_fn = (
205-
self.ascrape_with_js_support if self.requires_js_support else getattr(self, f"ascrape_{self.backend}")
220+
self.ascrape_with_js_support
221+
if self.requires_js_support
222+
else getattr(self, f"ascrape_{self.backend}")
206223
)
207224

208225
tasks = [scraping_fn(url) for url in self.urls]

scrapegraphai/graphs/abstract_graph.py

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
AbstractGraph Module
33
"""
4+
45
from abc import ABC, abstractmethod
56
from typing import Optional
67
import uuid
@@ -9,12 +10,10 @@
910
from langchain.chat_models import init_chat_model
1011
from langchain_core.rate_limiters import InMemoryRateLimiter
1112
from ..helpers import models_tokens
12-
from ..models import (
13-
OneApi,
14-
DeepSeek
15-
)
13+
from ..models import OneApi, DeepSeek
1614
from ..utils.logging import set_verbosity_warning, set_verbosity_info
1715

16+
1817
class AbstractGraph(ABC):
1918
"""
2019
Scaffolding class for creating a graph representation and executing it.
@@ -39,14 +38,18 @@ class AbstractGraph(ABC):
3938
... # Implementation of graph creation here
4039
... return graph
4140
...
42-
>>> my_graph = MyGraph("Example Graph",
41+
>>> my_graph = MyGraph("Example Graph",
4342
{"llm": {"model": "gpt-3.5-turbo"}}, "example_source")
4443
>>> result = my_graph.run()
4544
"""
4645

47-
def __init__(self, prompt: str, config: dict,
48-
source: Optional[str] = None, schema: Optional[BaseModel] = None):
49-
46+
def __init__(
47+
self,
48+
prompt: str,
49+
config: dict,
50+
source: Optional[str] = None,
51+
schema: Optional[BaseModel] = None,
52+
):
5053
if config.get("llm").get("temperature") is None:
5154
config["llm"]["temperature"] = 0
5255

@@ -55,14 +58,13 @@ def __init__(self, prompt: str, config: dict,
5558
self.config = config
5659
self.schema = schema
5760
self.llm_model = self._create_llm(config["llm"])
58-
self.verbose = False if config is None else config.get(
59-
"verbose", False)
60-
self.headless = True if self.config is None else config.get(
61-
"headless", True)
61+
self.verbose = False if config is None else config.get("verbose", False)
62+
self.headless = True if self.config is None else config.get("headless", True)
6263
self.loader_kwargs = self.config.get("loader_kwargs", {})
6364
self.cache_path = self.config.get("cache_path", False)
6465
self.browser_base = self.config.get("browser_base")
6566
self.scrape_do = self.config.get("scrape_do")
67+
self.storage_state = self.config.get("storage_state")
6668

6769
self.graph = self._create_graph()
6870
self.final_state = None
@@ -81,7 +83,7 @@ def __init__(self, prompt: str, config: dict,
8183
"loader_kwargs": self.loader_kwargs,
8284
"llm_model": self.llm_model,
8385
"cache_path": self.cache_path,
84-
}
86+
}
8587

8688
self.set_common_params(common_params, overwrite=True)
8789

@@ -129,7 +131,8 @@ def _create_llm(self, llm_config: dict) -> object:
129131
with warnings.catch_warnings():
130132
warnings.simplefilter("ignore")
131133
llm_params["rate_limiter"] = InMemoryRateLimiter(
132-
requests_per_second=requests_per_second)
134+
requests_per_second=requests_per_second
135+
)
133136
if max_retries is not None:
134137
llm_params["max_retries"] = max_retries
135138

@@ -140,30 +143,55 @@ def _create_llm(self, llm_config: dict) -> object:
140143
raise KeyError("model_tokens not specified") from exc
141144
return llm_params["model_instance"]
142145

143-
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
144-
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
145-
"hugging_face", "deepseek", "ernie", "fireworks", "togetherai"}
146-
147-
if '/' in llm_params["model"]:
148-
split_model_provider = llm_params["model"].split("/", 1)
149-
llm_params["model_provider"] = split_model_provider[0]
150-
llm_params["model"] = split_model_provider[1]
146+
known_providers = {
147+
"openai",
148+
"azure_openai",
149+
"google_genai",
150+
"google_vertexai",
151+
"ollama",
152+
"oneapi",
153+
"nvidia",
154+
"groq",
155+
"anthropic",
156+
"bedrock",
157+
"mistralai",
158+
"hugging_face",
159+
"deepseek",
160+
"ernie",
161+
"fireworks",
162+
"togetherai",
163+
}
164+
165+
if "/" in llm_params["model"]:
166+
split_model_provider = llm_params["model"].split("/", 1)
167+
llm_params["model_provider"] = split_model_provider[0]
168+
llm_params["model"] = split_model_provider[1]
151169
else:
152-
possible_providers = [provider for provider, models_d in models_tokens.items() if llm_params["model"] in models_d]
170+
possible_providers = [
171+
provider
172+
for provider, models_d in models_tokens.items()
173+
if llm_params["model"] in models_d
174+
]
153175
if len(possible_providers) <= 0:
154176
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.
155177
If possible, try to use a model instance instead.""")
156178
llm_params["model_provider"] = possible_providers[0]
157-
print((f"Found providers {possible_providers} for model {llm_params['model']}, using {llm_params['model_provider']}.\n"
158-
"If it was not intended please specify the model provider in the graph configuration"))
179+
print(
180+
(
181+
f"Found providers {possible_providers} for model {llm_params['model']}, using {llm_params['model_provider']}.\n"
182+
"If it was not intended please specify the model provider in the graph configuration"
183+
)
184+
)
159185

160186
if llm_params["model_provider"] not in known_providers:
161187
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.
162188
If possible, try to use a model instance instead.""")
163189

164190
if "model_tokens" not in llm_params:
165191
try:
166-
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
192+
self.model_token = models_tokens[llm_params["model_provider"]][
193+
llm_params["model"]
194+
]
167195
except KeyError:
168196
print(f"""Model {llm_params['model_provider']}/{llm_params['model']} not found,
169197
using default token size (8192)""")
@@ -172,10 +200,17 @@ def _create_llm(self, llm_config: dict) -> object:
172200
self.model_token = llm_params["model_tokens"]
173201

174202
try:
175-
if llm_params["model_provider"] not in \
176-
{"oneapi","nvidia","ernie","deepseek","togetherai"}:
203+
if llm_params["model_provider"] not in {
204+
"oneapi",
205+
"nvidia",
206+
"ernie",
207+
"deepseek",
208+
"togetherai",
209+
}:
177210
if llm_params["model_provider"] == "bedrock":
178-
llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") }
211+
llm_params["model_kwargs"] = {
212+
"temperature": llm_params.pop("temperature")
213+
}
179214
with warnings.catch_warnings():
180215
warnings.simplefilter("ignore")
181216
return init_chat_model(**llm_params)
@@ -187,6 +222,7 @@ def _create_llm(self, llm_config: dict) -> object:
187222

188223
if model_provider == "ernie":
189224
from langchain_community.chat_models import ErnieBotChat
225+
190226
return ErnieBotChat(**llm_params)
191227

192228
elif model_provider == "oneapi":
@@ -211,7 +247,6 @@ def _create_llm(self, llm_config: dict) -> object:
211247
except Exception as e:
212248
raise Exception(f"Error instancing model: {e}")
213249

214-
215250
def get_state(self, key=None) -> dict:
216251
""" ""
217252
Get the final state of the graph.

0 commit comments

Comments
 (0)