Skip to content

Commit

Permalink
Adds loading check to CSV and GSheets so we don't load twice the same…
Browse files Browse the repository at this point in the history
… content
  • Loading branch information
vmesel committed Jul 11, 2024
1 parent 1a1b775 commit 5b07ec6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
12 changes: 10 additions & 2 deletions dialog_lib/loaders/csv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from dialog_lib.db import get_session
from dialog_lib.db.models import CompanyContent
from dialog_lib.embeddings.generate import generate_embedding
Expand All @@ -6,6 +7,9 @@
from langchain_community.document_loaders.csv_loader import CSVLoader


logger = logging.getLogger(__name__)


def load_csv(
file_path, dbsession=get_session, embeddings_model_instance=None,
embedding_llm_model=None, embedding_llm_api_key=None, company_id=None
Expand All @@ -28,6 +32,10 @@ def load_csv(
values = line.split(": ")
content[values[0]] = values[1]


if not dbsession.query(CompanyContent).filter(
CompanyContent.question == content["question"], CompanyContent.content == content["content"]
).first():
company_content = CompanyContent(
category="csv",
subcategory="csv-content",
Expand All @@ -37,5 +45,5 @@ def load_csv(
embedding=generate_embedding(csv_content.page_content, embeddings_model_instance)
)
session.add(company_content)

return company_content
else:
logger.warning(f"Question: {content['question']} already exists in the database. Skipping.")
29 changes: 18 additions & 11 deletions dialog_lib/loaders/gsheets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gspread
import logging

from dialog_lib.db.models import CompanyContent
from dialog_lib.embeddings.generate import generate_embedding
Expand All @@ -11,6 +12,8 @@
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union


logger = logging.getLogger(__name__)

class GoogleSheetsLoader(BaseLoader):
def __init__(self, credentials_path: Union[str, Path], spreadsheet_url: str, sheet_name: str):
self.sheet_name = sheet_name
Expand Down Expand Up @@ -59,15 +62,19 @@ def load_google_sheets(
values = line.split(": ")
content[values[0]] = values[1]

company_content = CompanyContent(
category="csv",
subcategory="csv-content",
question=content["question"],
content=content["content"],
dataset=company_id,
embedding=generate_embedding(csv_content.page_content, embeddings_model_instance)
)
dbsession.add(company_content)
if not dbsession.query(CompanyContent).filter(
CompanyContent.question == content["question"], CompanyContent.content == content["content"]
).first():
company_content = CompanyContent(
category="csv",
subcategory="csv-content",
question=content["question"],
content=content["content"],
dataset=company_id,
embedding=generate_embedding(csv_content.page_content, embeddings_model_instance)
)
dbsession.add(company_content)
else:
logger.warning(f"Question: {content['question']} already exists in the database. Skipping.")

dbsession.commit()
return company_content
dbsession.commit()

0 comments on commit 5b07ec6

Please sign in to comment.