-
Notifications
You must be signed in to change notification settings - Fork 3
/
db.py
105 lines (88 loc) · 3.2 KB
/
db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
from dotenv import load_dotenv
load_dotenv()
import psycopg2
from mistral_api import MistralAPI
class Database:
def __init__(self):
# Connect to the Postgres database
self.connection_string = os.getenv('DATABASE_URL')
self.conn = psycopg2.connect(self.connection_string)
self.cur = self.conn.cursor()
def closeConn(self):
self.conn.close()
self.cur.close()
def testConnection(self):
# Execute SQL commands to retrieve the current time and version from PostgreSQL
self.cur.execute('SELECT NOW();')
time = self.cur.fetchone()[0]
self.cur.execute('SELECT version();')
version = self.cur.fetchone()[0]
# Print the results
print('Current time:', time)
print('PostgreSQL version:', version)
def insertImage(self, caption, image_filename, video_filename, embedding):
# SQL query to insert or update data into the table 'lechaton'
query = """
INSERT INTO lechaton (video_filename, image_filename, caption, embedding)
VALUES (%s, %s, %s, %s)
"""
# Execute the query
self.cur.execute(query, (video_filename, image_filename, caption, embedding))
# Commit the changes to the database
self.conn.commit()
# maybe add readall here?
def search(self, embedding, k):
# Transform the embedding list to a string format for SQL query compatibility
embedding_str = '[' + ','.join(map(str, embedding)) + ']'
query = """
SELECT * FROM lechaton ORDER BY embedding <=> (%s) LIMIT (%s);
"""
# Execute the query
self.cur.execute(query, (embedding_str, k))
# Commit the changes to the database
self.conn.commit()
rows = self.cur.fetchall()
# Return the fetched rows
return rows
class VectorStore:
def __init__(self):
self.db = Database()
self.mistral_client = MistralAPI()
def insert(self, caption, image_filename, video_filename):
embedding = self.mistral_client.embed(caption)
self.db.insertImage(caption, image_filename, video_filename, embedding)
def search(self, text, k):
embedding = self.mistral_client.embed(text)
rows = self.db.search(embedding, k)
return rows
def close(self):
self.db.closeConn()
def test_add_to_vs():
vs = VectorStore()
vs = VectorStore()
mistral_folder = 'Mistral'
files = os.listdir(mistral_folder)[:5] # Get the first 5 files from the Mistral folder
captions = [
"A serene landscape",
"A bustling cityscape",
"A quiet moment",
"The thrill of adventure",
"A night under the stars"
]
for i, file in enumerate(files):
if i > 5:
break
if file.endswith('.jpg'):
image_filename = os.path.join(mistral_folder, file)
video_filename = 'demo'
caption = captions[i]
vs.insert(caption, image_filename, video_filename)
def test_search_vs():
vs = VectorStore()
rows = vs.search("A small group of humans", 2)
print(rows)
if __name__ == "__main__":
#test_add_to_vs()
test_search_vs()
pass