1
1
"""
2
2
AbstractGraph Module
3
3
"""
4
+
4
5
from abc import ABC , abstractmethod
5
6
from typing import Optional
6
7
import uuid
9
10
from langchain .chat_models import init_chat_model
10
11
from langchain_core .rate_limiters import InMemoryRateLimiter
11
12
from ..helpers import models_tokens
12
- from ..models import (
13
- OneApi ,
14
- DeepSeek
15
- )
13
+ from ..models import OneApi , DeepSeek
16
14
from ..utils .logging import set_verbosity_warning , set_verbosity_info
17
15
16
+
18
17
class AbstractGraph (ABC ):
19
18
"""
20
19
Scaffolding class for creating a graph representation and executing it.
@@ -39,14 +38,18 @@ class AbstractGraph(ABC):
39
38
... # Implementation of graph creation here
40
39
... return graph
41
40
...
42
- >>> my_graph = MyGraph("Example Graph",
41
+ >>> my_graph = MyGraph("Example Graph",
43
42
{"llm": {"model": "gpt-3.5-turbo"}}, "example_source")
44
43
>>> result = my_graph.run()
45
44
"""
46
45
47
- def __init__ (self , prompt : str , config : dict ,
48
- source : Optional [str ] = None , schema : Optional [BaseModel ] = None ):
49
-
46
+ def __init__ (
47
+ self ,
48
+ prompt : str ,
49
+ config : dict ,
50
+ source : Optional [str ] = None ,
51
+ schema : Optional [BaseModel ] = None ,
52
+ ):
50
53
if config .get ("llm" ).get ("temperature" ) is None :
51
54
config ["llm" ]["temperature" ] = 0
52
55
@@ -55,14 +58,13 @@ def __init__(self, prompt: str, config: dict,
55
58
self .config = config
56
59
self .schema = schema
57
60
self .llm_model = self ._create_llm (config ["llm" ])
58
- self .verbose = False if config is None else config .get (
59
- "verbose" , False )
60
- self .headless = True if self .config is None else config .get (
61
- "headless" , True )
61
+ self .verbose = False if config is None else config .get ("verbose" , False )
62
+ self .headless = True if self .config is None else config .get ("headless" , True )
62
63
self .loader_kwargs = self .config .get ("loader_kwargs" , {})
63
64
self .cache_path = self .config .get ("cache_path" , False )
64
65
self .browser_base = self .config .get ("browser_base" )
65
66
self .scrape_do = self .config .get ("scrape_do" )
67
+ self .storage_state = self .config .get ("storage_state" )
66
68
67
69
self .graph = self ._create_graph ()
68
70
self .final_state = None
@@ -81,7 +83,7 @@ def __init__(self, prompt: str, config: dict,
81
83
"loader_kwargs" : self .loader_kwargs ,
82
84
"llm_model" : self .llm_model ,
83
85
"cache_path" : self .cache_path ,
84
- }
86
+ }
85
87
86
88
self .set_common_params (common_params , overwrite = True )
87
89
@@ -129,7 +131,8 @@ def _create_llm(self, llm_config: dict) -> object:
129
131
with warnings .catch_warnings ():
130
132
warnings .simplefilter ("ignore" )
131
133
llm_params ["rate_limiter" ] = InMemoryRateLimiter (
132
- requests_per_second = requests_per_second )
134
+ requests_per_second = requests_per_second
135
+ )
133
136
if max_retries is not None :
134
137
llm_params ["max_retries" ] = max_retries
135
138
@@ -140,30 +143,55 @@ def _create_llm(self, llm_config: dict) -> object:
140
143
raise KeyError ("model_tokens not specified" ) from exc
141
144
return llm_params ["model_instance" ]
142
145
143
- known_providers = {"openai" , "azure_openai" , "google_genai" , "google_vertexai" ,
144
- "ollama" , "oneapi" , "nvidia" , "groq" , "anthropic" , "bedrock" , "mistralai" ,
145
- "hugging_face" , "deepseek" , "ernie" , "fireworks" , "togetherai" }
146
-
147
- if '/' in llm_params ["model" ]:
148
- split_model_provider = llm_params ["model" ].split ("/" , 1 )
149
- llm_params ["model_provider" ] = split_model_provider [0 ]
150
- llm_params ["model" ] = split_model_provider [1 ]
146
+ known_providers = {
147
+ "openai" ,
148
+ "azure_openai" ,
149
+ "google_genai" ,
150
+ "google_vertexai" ,
151
+ "ollama" ,
152
+ "oneapi" ,
153
+ "nvidia" ,
154
+ "groq" ,
155
+ "anthropic" ,
156
+ "bedrock" ,
157
+ "mistralai" ,
158
+ "hugging_face" ,
159
+ "deepseek" ,
160
+ "ernie" ,
161
+ "fireworks" ,
162
+ "togetherai" ,
163
+ }
164
+
165
+ if "/" in llm_params ["model" ]:
166
+ split_model_provider = llm_params ["model" ].split ("/" , 1 )
167
+ llm_params ["model_provider" ] = split_model_provider [0 ]
168
+ llm_params ["model" ] = split_model_provider [1 ]
151
169
else :
152
- possible_providers = [provider for provider , models_d in models_tokens .items () if llm_params ["model" ] in models_d ]
170
+ possible_providers = [
171
+ provider
172
+ for provider , models_d in models_tokens .items ()
173
+ if llm_params ["model" ] in models_d
174
+ ]
153
175
if len (possible_providers ) <= 0 :
154
176
raise ValueError (f"""Provider { llm_params ['model_provider' ]} is not supported.
155
177
If possible, try to use a model instance instead.""" )
156
178
llm_params ["model_provider" ] = possible_providers [0 ]
157
- print ((f"Found providers { possible_providers } for model { llm_params ['model' ]} , using { llm_params ['model_provider' ]} .\n "
158
- "If it was not intended please specify the model provider in the graph configuration" ))
179
+ print (
180
+ (
181
+ f"Found providers { possible_providers } for model { llm_params ['model' ]} , using { llm_params ['model_provider' ]} .\n "
182
+ "If it was not intended please specify the model provider in the graph configuration"
183
+ )
184
+ )
159
185
160
186
if llm_params ["model_provider" ] not in known_providers :
161
187
raise ValueError (f"""Provider { llm_params ['model_provider' ]} is not supported.
162
188
If possible, try to use a model instance instead.""" )
163
189
164
190
if "model_tokens" not in llm_params :
165
191
try :
166
- self .model_token = models_tokens [llm_params ["model_provider" ]][llm_params ["model" ]]
192
+ self .model_token = models_tokens [llm_params ["model_provider" ]][
193
+ llm_params ["model" ]
194
+ ]
167
195
except KeyError :
168
196
print (f"""Model { llm_params ['model_provider' ]} /{ llm_params ['model' ]} not found,
169
197
using default token size (8192)""" )
@@ -172,10 +200,17 @@ def _create_llm(self, llm_config: dict) -> object:
172
200
self .model_token = llm_params ["model_tokens" ]
173
201
174
202
try :
175
- if llm_params ["model_provider" ] not in \
176
- {"oneapi" ,"nvidia" ,"ernie" ,"deepseek" ,"togetherai" }:
203
+ if llm_params ["model_provider" ] not in {
204
+ "oneapi" ,
205
+ "nvidia" ,
206
+ "ernie" ,
207
+ "deepseek" ,
208
+ "togetherai" ,
209
+ }:
177
210
if llm_params ["model_provider" ] == "bedrock" :
178
- llm_params ["model_kwargs" ] = { "temperature" : llm_params .pop ("temperature" ) }
211
+ llm_params ["model_kwargs" ] = {
212
+ "temperature" : llm_params .pop ("temperature" )
213
+ }
179
214
with warnings .catch_warnings ():
180
215
warnings .simplefilter ("ignore" )
181
216
return init_chat_model (** llm_params )
@@ -187,6 +222,7 @@ def _create_llm(self, llm_config: dict) -> object:
187
222
188
223
if model_provider == "ernie" :
189
224
from langchain_community .chat_models import ErnieBotChat
225
+
190
226
return ErnieBotChat (** llm_params )
191
227
192
228
elif model_provider == "oneapi" :
@@ -211,7 +247,6 @@ def _create_llm(self, llm_config: dict) -> object:
211
247
except Exception as e :
212
248
raise Exception (f"Error instancing model: { e } " )
213
249
214
-
215
250
def get_state (self , key = None ) -> dict :
216
251
""" ""
217
252
Get the final state of the graph.
0 commit comments