Skip to content

Commit 79daa4c

Browse files
committed
feat: add gemini embeddings
1 parent da8c72c commit 79daa4c

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
"""
22
AbstractGraph Module
33
"""
4-
54
from abc import ABC, abstractmethod
65
from typing import Optional
7-
8-
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
9-
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
106
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
11-
7+
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
128
from ..helpers import models_tokens
139
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
10+
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
11+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
1412

1513

1614
class AbstractGraph(ABC):
@@ -69,7 +67,7 @@ def _set_model_token(self, llm):
6967
self.model_token = models_tokens["azure"][llm.model_name]
7068
except KeyError:
7169
raise KeyError("Model not supported")
72-
70+
7371
elif 'HuggingFaceEndpoint' in str(type(llm)):
7472
if 'mistral' in llm.repo_id:
7573
try:
@@ -229,29 +227,30 @@ def _create_embedder(self, embedder_config: dict) -> object:
229227

230228
if 'model_instance' in embedder_config:
231229
return embedder_config['model_instance']
232-
233230
# Instantiate the embedding model based on the model name
234231
if "openai" in embedder_config["model"]:
235232
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
236-
237233
elif "azure" in embedder_config["model"]:
238234
return AzureOpenAIEmbeddings()
239-
240235
elif "ollama" in embedder_config["model"]:
241236
embedder_config["model"] = embedder_config["model"].split("/")[-1]
242237
try:
243238
models_tokens["ollama"][embedder_config["model"]]
244239
except KeyError as exc:
245240
raise KeyError("Model not supported") from exc
246241
return OllamaEmbeddings(**embedder_config)
247-
248242
elif "hugging_face" in embedder_config["model"]:
249243
try:
250244
models_tokens["hugging_face"][embedder_config["model"]]
251245
except KeyError as exc:
252246
raise KeyError("Model not supported")from exc
253247
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
254-
248+
elif "gemini" in embedder_config["model"]:
249+
try:
250+
models_tokens["gemini"][embedder_config["model"]]
251+
except KeyError as exc:
252+
raise KeyError("Model not supported")from exc
253+
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
255254
elif "bedrock" in embedder_config["model"]:
256255
embedder_config["model"] = embedder_config["model"].split("/")[-1]
257256
try:

0 commit comments

Comments
 (0)