Skip to content

Commit 16de49f

Browse files
VinciGit00redrusty2
andcommitted
add integration for bedrock
Co-Authored-By: redrusty2 <[email protected]>
1 parent 40b2a34 commit 16de49f

File tree

6 files changed

+67
-6
lines changed

6 files changed

+67
-6
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ minify-html = "0.15.0"
4040
free-proxy = "1.1.1"
4141
langchain-groq = "0.1.3"
4242
playwright = "^1.43.0"
43+
langchain-aws = "^0.1.2"
44+
4345

4446
[tool.poetry.dev-dependencies]
4547
pytest = "8.0.0"

scrapegraphai/graphs/abstract_graph.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
"""
44
from abc import ABC, abstractmethod
55
from typing import Optional
6-
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq
6+
7+
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock
78
from ..helpers import models_tokens
89

910

@@ -25,7 +26,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
2526

2627
# Set common configuration parameters
2728
self.verbose = True if config is None else config.get("verbose", False)
28-
self.headless = True if config is None else config.get("headless", True)
29+
self.headless = True if config is None else config.get(
30+
"headless", True)
2931

3032
# Create the graph
3133
self.graph = self._create_graph()
@@ -92,12 +94,26 @@ def _create_llm(self, llm_config: dict):
9294
return HuggingFace(llm_params)
9395
elif "groq" in llm_params["model"]:
9496
llm_params["model"] = llm_params["model"].split("/")[-1]
95-
97+
9698
try:
9799
self.model_token = models_tokens["groq"][llm_params["model"]]
98100
except KeyError:
99101
raise KeyError("Model not supported")
100102
return Groq(llm_params)
103+
elif "bedrock" in llm_params["model"]:
104+
llm_params["model"] = llm_params["model"].split("/")[-1]
105+
model_id = llm_params["model"]
106+
107+
try:
108+
self.model_token = models_tokens["bedrock"][llm_params["model"]]
109+
except KeyError:
110+
raise KeyError("Model not supported")
111+
return Bedrock({
112+
"model_id": model_id,
113+
"model_kwargs": {
114+
"temperature": llm_params["temperature"],
115+
}
116+
})
101117
else:
102118
raise ValueError(
103119
"Model provided by the configuration not supported")

scrapegraphai/helpers/models_tokens.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,22 @@
4343
"claude2": 9000,
4444
"claude2.1": 200000,
4545
"claude3": 200000
46+
},
47+
"bedrock": {
48+
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
49+
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
50+
"anthropic.claude-3-opus-20240229-v1:0": 200000,
51+
"anthropic.claude-v2:1": 200000,
52+
"anthropic.claude-v2": 100000,
53+
"anthropic.claude-instant-v1": 100000,
54+
"meta.llama3-8b-instruct-v1:0": 8192,
55+
"meta.llama3-70b-instruct-v1:0": 8192,
56+
"meta.llama2-13b-chat-v1": 4096,
57+
"meta.llama2-70b-chat-v1": 4096,
58+
"mistral.mistral-7b-instruct-v0:2": 32768,
59+
"mistral.mixtral-8x7b-instruct-v0:1": 32768,
60+
"mistral.mistral-large-2402-v1:0": 32768,
61+
"cohere.embed-english-v3": 512,
62+
"cohere.embed-multilingual-v3": 512
4663
}
4764
}

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
from .ollama import Ollama
1111
from .hugging_face import HuggingFace
1212
from .groq import Groq
13+
from .bedrock import Bedrock

scrapegraphai/models/bedrock.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
bedrock configuration wrapper
3+
"""
4+
from langchain_aws import ChatBedrock
5+
6+
7+
class Bedrock(ChatBedrock):
8+
"""Class for wrapping bedrock module"""
9+
10+
def __init__(self, llm_config: dict):
11+
"""
12+
A wrapper for the ChatBedrock class that provides default configuration
13+
and could be extended with additional methods if needed.
14+
15+
Args:
16+
llm_config (dict): Configuration parameters for the language model.
17+
"""
18+
# Initialize the superclass (ChatBedrock) with provided config parameters
19+
super().__init__(**llm_config)

scrapegraphai/nodes/rag_node.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from langchain.docstore.document import Document
77
from langchain.retrievers import ContextualCompressionRetriever
88
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
9+
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
910
from langchain_community.document_transformers import EmbeddingsRedundantFilter
1011
from langchain_community.embeddings import HuggingFaceHubEmbeddings
1112
from langchain_community.vectorstores import FAISS
1213
from langchain_community.embeddings import OllamaEmbeddings
1314
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
14-
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace
15+
16+
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace, Bedrock
1517
from .base_node import BaseNode
1618

1719

@@ -42,7 +44,8 @@ def __init__(self, input: str, output: List[str], node_config: dict, node_name:
4244
super().__init__(node_name, "node", input, output, 2, node_config)
4345
self.llm_model = node_config["llm"]
4446
self.embedder_model = node_config.get("embedder_model", None)
45-
self.verbose = True if node_config is None else node_config.get("verbose", False)
47+
self.verbose = True if node_config is None else node_config.get(
48+
"verbose", False)
4649

4750
def execute(self, state):
4851
"""
@@ -82,7 +85,7 @@ def execute(self, state):
8285
},
8386
)
8487
chunked_docs.append(doc)
85-
88+
8689
if self.verbose:
8790
print("--- (updated chunks metadata) ---")
8891

@@ -104,6 +107,9 @@ def execute(self, state):
104107
embeddings = OllamaEmbeddings(**params)
105108
elif isinstance(embedding_model, HuggingFace):
106109
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
110+
elif isinstance(embedding_model, Bedrock):
111+
embeddings = BedrockEmbeddings(
112+
client=None, model_id=embedding_model.model_id)
107113
else:
108114
raise ValueError("Embedding Model missing or not supported")
109115

0 commit comments

Comments
 (0)