|
1 | 1 | """
|
2 | 2 | AbstractGraph Module
|
3 | 3 | """
|
4 |
| - |
5 | 4 | from abc import ABC, abstractmethod
|
6 | 5 | from typing import Optional
|
7 |
| - |
8 |
| -from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
9 |
| -from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
10 | 6 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
11 |
| - |
| 7 | +from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
12 | 8 | from ..helpers import models_tokens
|
13 | 9 | from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
|
| 10 | +from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
| 11 | +from langchain_google_genai import GoogleGenerativeAIEmbeddings |
14 | 12 |
|
15 | 13 |
|
16 | 14 | class AbstractGraph(ABC):
|
@@ -69,7 +67,7 @@ def _set_model_token(self, llm):
|
69 | 67 | self.model_token = models_tokens["azure"][llm.model_name]
|
70 | 68 | except KeyError:
|
71 | 69 | raise KeyError("Model not supported")
|
72 |
| - |
| 70 | + |
73 | 71 | elif 'HuggingFaceEndpoint' in str(type(llm)):
|
74 | 72 | if 'mistral' in llm.repo_id:
|
75 | 73 | try:
|
@@ -229,29 +227,30 @@ def _create_embedder(self, embedder_config: dict) -> object:
|
229 | 227 |
|
230 | 228 | if 'model_instance' in embedder_config:
|
231 | 229 | return embedder_config['model_instance']
|
232 |
| - |
233 | 230 | # Instantiate the embedding model based on the model name
|
234 | 231 | if "openai" in embedder_config["model"]:
|
235 | 232 | return OpenAIEmbeddings(api_key=embedder_config["api_key"])
|
236 |
| - |
237 | 233 | elif "azure" in embedder_config["model"]:
|
238 | 234 | return AzureOpenAIEmbeddings()
|
239 |
| - |
240 | 235 | elif "ollama" in embedder_config["model"]:
|
241 | 236 | embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
242 | 237 | try:
|
243 | 238 | models_tokens["ollama"][embedder_config["model"]]
|
244 | 239 | except KeyError as exc:
|
245 | 240 | raise KeyError("Model not supported") from exc
|
246 | 241 | return OllamaEmbeddings(**embedder_config)
|
247 |
| - |
248 | 242 | elif "hugging_face" in embedder_config["model"]:
|
249 | 243 | try:
|
250 | 244 | models_tokens["hugging_face"][embedder_config["model"]]
|
251 | 245 | except KeyError as exc:
|
252 | 246 | raise KeyError("Model not supported")from exc
|
253 | 247 | return HuggingFaceHubEmbeddings(model=embedder_config["model"])
|
254 |
| - |
| 248 | + elif "gemini" in embedder_config["model"]: |
| 249 | + try: |
| 250 | + models_tokens["gemini"][embedder_config["model"]] |
| 251 | + except KeyError as exc: |
| 252 | + raise KeyError("Model not supported")from exc |
| 253 | + return GoogleGenerativeAIEmbeddings(model=embedder_config["model"]) |
255 | 254 | elif "bedrock" in embedder_config["model"]:
|
256 | 255 | embedder_config["model"] = embedder_config["model"].split("/")[-1]
|
257 | 256 | try:
|
|
0 commit comments