Skip to content

Commit 36a1522

Browse files
authored
Merge pull request #153 from VinciGit00/google_embeddings
feat: add gemini embeddings
2 parents 3bef9bb + 79daa4c commit 36a1522

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
"""
22
AbstractGraph Module
33
"""
4-
54
from abc import ABC, abstractmethod
65
from typing import Optional
7-
8-
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
9-
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
106
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
11-
7+
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
128
from ..helpers import models_tokens
139
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude
10+
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
11+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
1412

1513

1614
class AbstractGraph(ABC):
@@ -86,7 +84,7 @@ def _set_model_token(self, llm):
8684
self.model_token = models_tokens["azure"][llm.model_name]
8785
except KeyError:
8886
raise KeyError("Model not supported")
89-
87+
9088
elif 'HuggingFaceEndpoint' in str(type(llm)):
9189
if 'mistral' in llm.repo_id:
9290
try:
@@ -246,29 +244,30 @@ def _create_embedder(self, embedder_config: dict) -> object:
246244

247245
if 'model_instance' in embedder_config:
248246
return embedder_config['model_instance']
249-
250247
# Instantiate the embedding model based on the model name
251248
if "openai" in embedder_config["model"]:
252249
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
253-
254250
elif "azure" in embedder_config["model"]:
255251
return AzureOpenAIEmbeddings()
256-
257252
elif "ollama" in embedder_config["model"]:
258253
embedder_config["model"] = embedder_config["model"].split("/")[-1]
259254
try:
260255
models_tokens["ollama"][embedder_config["model"]]
261256
except KeyError as exc:
262257
raise KeyError("Model not supported") from exc
263258
return OllamaEmbeddings(**embedder_config)
264-
265259
elif "hugging_face" in embedder_config["model"]:
266260
try:
267261
models_tokens["hugging_face"][embedder_config["model"]]
268262
except KeyError as exc:
269263
raise KeyError("Model not supported")from exc
270264
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
271-
265+
elif "gemini" in embedder_config["model"]:
266+
try:
267+
models_tokens["gemini"][embedder_config["model"]]
268+
except KeyError as exc:
269+
raise KeyError("Model not supported")from exc
270+
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
272271
elif "bedrock" in embedder_config["model"]:
273272
embedder_config["model"] = embedder_config["model"].split("/")[-1]
274273
try:

0 commit comments

Comments
 (0)