Skip to content

Commit 579a27f

Browse files
authored
Merge pull request #126 from VinciGit00/bedrock_support
feat: bedrock support
2 parents 2ccb608 + db41905 commit 579a27f

File tree

7 files changed

+69
-7
lines changed

7 files changed

+69
-7
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ minify-html = "0.15.0"
4141
free-proxy = "1.1.1"
4242
langchain-groq = "0.1.3"
4343
playwright = "^1.43.0"
44+
langchain-aws = "^0.1.2"
45+
4446

4547
[tool.poetry.dev-dependencies]
4648
pytest = "8.0.0"

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ google==3.0.0
1313
minify-html==0.15.0
1414
free-proxy==1.1.1
1515
langchain-groq==0.1.3
16-
playwright==1.43.0
16+
playwright==1.43.0
17+
langchain-aws==0.1.2

scrapegraphai/graphs/abstract_graph.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from abc import ABC, abstractmethod
66
from typing import Optional
7-
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq
7+
8+
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock
89
from ..helpers import models_tokens
910

1011

@@ -47,7 +48,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4748

4849
# Set common configuration parameters
4950
self.verbose = True if config is None else config.get("verbose", False)
50-
self.headless = True if config is None else config.get("headless", True)
51+
self.headless = True if config is None else config.get(
52+
"headless", True)
5153

5254
# Create the graph
5355
self.graph = self._create_graph()
@@ -140,12 +142,26 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
140142
return HuggingFace(llm_params)
141143
elif "groq" in llm_params["model"]:
142144
llm_params["model"] = llm_params["model"].split("/")[-1]
143-
145+
144146
try:
145147
self.model_token = models_tokens["groq"][llm_params["model"]]
146148
except KeyError:
147149
raise KeyError("Model not supported")
148150
return Groq(llm_params)
151+
elif "bedrock" in llm_params["model"]:
152+
llm_params["model"] = llm_params["model"].split("/")[-1]
153+
model_id = llm_params["model"]
154+
155+
try:
156+
self.model_token = models_tokens["bedrock"][llm_params["model"]]
157+
except KeyError:
158+
raise KeyError("Model not supported")
159+
return Bedrock({
160+
"model_id": model_id,
161+
"model_kwargs": {
162+
"temperature": llm_params["temperature"],
163+
}
164+
})
149165
else:
150166
raise ValueError(
151167
"Model provided by the configuration not supported")

scrapegraphai/helpers/models_tokens.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,23 @@
4646
"claude2": 9000,
4747
"claude2.1": 200000,
4848
"claude3": 200000
49+
},
50+
"bedrock": {
51+
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
52+
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
53+
"anthropic.claude-3-opus-20240229-v1:0": 200000,
54+
"anthropic.claude-v2:1": 200000,
55+
"anthropic.claude-v2": 100000,
56+
"anthropic.claude-instant-v1": 100000,
57+
"meta.llama3-8b-instruct-v1:0": 8192,
58+
"meta.llama3-70b-instruct-v1:0": 8192,
59+
"meta.llama2-13b-chat-v1": 4096,
60+
"meta.llama2-70b-chat-v1": 4096,
61+
"mistral.mistral-7b-instruct-v0:2": 32768,
62+
"mistral.mixtral-8x7b-instruct-v0:1": 32768,
63+
"mistral.mistral-large-2402-v1:0": 32768,
64+
"cohere.embed-english-v3": 512,
65+
"cohere.embed-multilingual-v3": 512
4966
}
5067
}
5168

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
from .ollama import Ollama
1111
from .hugging_face import HuggingFace
1212
from .groq import Groq
13+
from .bedrock import Bedrock

scrapegraphai/models/bedrock.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
bedrock configuration wrapper
3+
"""
4+
from langchain_aws import ChatBedrock
5+
6+
7+
class Bedrock(ChatBedrock):
8+
"""Class for wrapping bedrock module"""
9+
10+
def __init__(self, llm_config: dict):
11+
"""
12+
A wrapper for the ChatBedrock class that provides default configuration
13+
and could be extended with additional methods if needed.
14+
15+
Args:
16+
llm_config (dict): Configuration parameters for the language model.
17+
"""
18+
# Initialize the superclass (ChatBedrock) with provided config parameters
19+
super().__init__(**llm_config)

scrapegraphai/nodes/rag_node.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from langchain.docstore.document import Document
77
from langchain.retrievers import ContextualCompressionRetriever
88
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
9+
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
910
from langchain_community.document_transformers import EmbeddingsRedundantFilter
1011
from langchain_community.embeddings import HuggingFaceHubEmbeddings
1112
from langchain_community.vectorstores import FAISS
1213
from langchain_community.embeddings import OllamaEmbeddings
1314
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
14-
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace
15+
16+
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace, Bedrock
1517
from .base_node import BaseNode
1618

1719

@@ -39,7 +41,8 @@ def __init__(self, input: str, output: List[str], node_config: dict, node_name:
3941

4042
self.llm_model = node_config["llm"]
4143
self.embedder_model = node_config.get("embedder_model", None)
42-
self.verbose = True if node_config is None else node_config.get("verbose", False)
44+
self.verbose = True if node_config is None else node_config.get(
45+
"verbose", False)
4346

4447
def execute(self, state: dict) -> dict:
4548
"""
@@ -80,7 +83,7 @@ def execute(self, state: dict) -> dict:
8083
},
8184
)
8285
chunked_docs.append(doc)
83-
86+
8487
if self.verbose:
8588
print("--- (updated chunks metadata) ---")
8689

@@ -104,6 +107,9 @@ def execute(self, state: dict) -> dict:
104107
embeddings = OllamaEmbeddings(**params)
105108
elif isinstance(embedding_model, HuggingFace):
106109
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
110+
elif isinstance(embedding_model, Bedrock):
111+
embeddings = BedrockEmbeddings(
112+
client=None, model_id=embedding_model.model_id)
107113
else:
108114
raise ValueError("Embedding Model missing or not supported")
109115

0 commit comments

Comments
 (0)