|
1 | 1 | """
|
2 | 2 | AbstractGraph Module
|
3 | 3 | """
|
4 |
| - |
5 | 4 | from abc import ABC, abstractmethod
|
6 | 5 | from typing import Optional
|
7 |
| - |
8 |
| -from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
9 |
| -from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
10 | 6 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
11 |
| - |
| 7 | +from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
12 | 8 | from ..helpers import models_tokens
|
13 | 9 | from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
|
| 10 | +from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
| 11 | +from langchain_google_genai import GoogleGenerativeAIEmbeddings |
14 | 12 |
|
15 | 13 |
|
16 | 14 | class AbstractGraph(ABC):
|
@@ -86,7 +84,7 @@ def _set_model_token(self, llm):
|
86 | 84 | self.model_token = models_tokens["azure"][llm.model_name]
|
87 | 85 | except KeyError:
|
88 | 86 | raise KeyError("Model not supported")
|
89 |
| - |
| 87 | + |
90 | 88 | elif 'HuggingFaceEndpoint' in str(type(llm)):
|
91 | 89 | if 'mistral' in llm.repo_id:
|
92 | 90 | try:
|
@@ -246,29 +244,30 @@ def _create_embedder(self, embedder_config: dict) -> object:
|
246 | 244 |
|
247 | 245 | if 'model_instance' in embedder_config:
|
248 | 246 | return embedder_config['model_instance']
|
249 |
| - |
250 | 247 | # Instantiate the embedding model based on the model name
|
251 | 248 | if "openai" in embedder_config["model"]:
|
252 | 249 | return OpenAIEmbeddings(api_key=embedder_config["api_key"])
|
253 |
| - |
254 | 250 | elif "azure" in embedder_config["model"]:
|
255 | 251 | return AzureOpenAIEmbeddings()
|
256 |
| - |
257 | 252 | elif "ollama" in embedder_config["model"]:
|
258 | 253 | embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
259 | 254 | try:
|
260 | 255 | models_tokens["ollama"][embedder_config["model"]]
|
261 | 256 | except KeyError as exc:
|
262 | 257 | raise KeyError("Model not supported") from exc
|
263 | 258 | return OllamaEmbeddings(**embedder_config)
|
264 |
| - |
265 | 259 | elif "hugging_face" in embedder_config["model"]:
|
266 | 260 | try:
|
267 | 261 | models_tokens["hugging_face"][embedder_config["model"]]
|
268 | 262 | except KeyError as exc:
|
269 | 263 | raise KeyError("Model not supported")from exc
|
270 | 264 | return HuggingFaceHubEmbeddings(model=embedder_config["model"])
|
271 |
| - |
| 265 | + elif "gemini" in embedder_config["model"]: |
| 266 | + try: |
| 267 | + models_tokens["gemini"][embedder_config["model"]] |
| 268 | + except KeyError as exc: |
| 269 | + raise KeyError("Model not supported")from exc |
| 270 | + return GoogleGenerativeAIEmbeddings(model=embedder_config["model"]) |
272 | 271 | elif "bedrock" in embedder_config["model"]:
|
273 | 272 | embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
274 | 273 | try:
|
|
0 commit comments