-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpostgres_ops.py
200 lines (178 loc) · 6.88 KB
/
postgres_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
This module provides a class for interacting with a PostgreSQL database.
It includes methods for storing, retrieving, and searching documents and their embeddings.
"""
import json
from typing import Dict, List
import psycopg2
from pgvector.psycopg2 import register_vector
from psycopg2.extras import execute_values
class PostgresOperations:
def __init__(self, host='postgres', port=5432, dbname='ragdb', user='raguser', password='ragpass'):
"""
Initialize the PostgresOperations class.
Args:
host (str): The database host.
port (int): The database port.
dbname (str): The name of the database.
user (str): The database user.
password (str): The database password.
"""
self.conn = psycopg2.connect(
dbname=dbname,
user=user,
password=password,
host=host,
port=port
)
self.create_tables()
def create_tables(self):
"""Create the necessary tables if they don't exist."""
with self.conn.cursor() as cur:
register_vector(cur)
# Create documents table
cur.execute("""
CREATE TABLE IF NOT EXISTS documents (
id SERIAL PRIMARY KEY,
filename TEXT UNIQUE,
content TEXT,
chunks JSONB
)
""")
# Create embeddings table
cur.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id SERIAL PRIMARY KEY,
document_id INTEGER REFERENCES documents(id),
chunk_index INTEGER,
embedding vector(1536)
)
""")
self.conn.commit()
def store_document(self, filename: str, content: str, chunks: List[str], embeddings: List[List[float]]):
"""
Store a document and its embeddings in the database.
Args:
filename (str): The name of the document file.
content (str): The full content of the document.
chunks (List[str]): The document split into chunks.
embeddings (List[List[float]]): The embeddings for each chunk.
"""
with self.conn.cursor() as cur:
register_vector(cur)
# Store document and chunks
cur.execute(
"INSERT INTO documents (filename, content, chunks) VALUES (%s, %s, %s) RETURNING id",
(filename, content, json.dumps(chunks))
)
document_id = cur.fetchone()[0]
# Store embeddings
execute_values(cur,
"INSERT INTO embeddings (document_id, chunk_index, embedding) VALUES %s",
[(document_id, i, embedding) for i, embedding in enumerate(embeddings)],
template="(%s, %s, %s::vector)"
)
self.conn.commit()
def get_document(self, filename: str) -> Dict[str, any]:
"""
Retrieve a document and its embeddings from the database.
Args:
filename (str): The name of the document file.
Returns:
Dict[str, any]: The document data, or None if not found.
"""
with self.conn.cursor() as cur:
cur.execute("""
SELECT d.content, d.chunks, array_agg(e.embedding ORDER BY e.chunk_index) as embeddings
FROM documents d
JOIN embeddings e ON d.id = e.document_id
WHERE d.filename = %s
GROUP BY d.id
""", (filename,))
result = cur.fetchone()
if result:
content, chunks, embeddings = result
return {
"filename": filename,
"content": content,
"chunks": json.loads(chunks),
"embeddings": embeddings
}
return None
def get_all_documents(self) -> List[Dict[str, any]]:
"""
Retrieve all documents and their embeddings from the database.
Returns:
List[Dict[str, any]]: A list of all documents and their data.
"""
with self.conn.cursor() as cur:
cur.execute("""
SELECT d.filename, d.content, d.chunks, array_agg(e.embedding ORDER BY e.chunk_index) as embeddings
FROM documents d
JOIN embeddings e ON d.id = e.document_id
GROUP BY d.id
""")
results = cur.fetchall()
return [
{
"filename": filename,
"content": content,
"chunks": json.loads(chunks),
"embeddings": embeddings
}
for filename, content, chunks, embeddings in results
]
def search_similar_chunks(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, any]]:
"""
Search for chunks similar to the query embedding.
Args:
query_embedding (List[float]): The embedding of the query.
top_k (int): The number of results to return.
Returns:
List[Dict[str, any]]: The top_k most similar chunks.
"""
with self.conn.cursor() as cur:
register_vector(cur)
cur.execute("""
SELECT d.filename, d.chunks->e.chunk_index as chunk, e.embedding <-> %s::vector as distance
FROM embeddings e
JOIN documents d ON e.document_id = d.id
ORDER BY distance
LIMIT %s
""", (query_embedding, top_k))
results = cur.fetchall()
return [
{
"filename": filename,
"chunk": chunk,
"distance": distance
}
for filename, chunk, distance in results
]
def clear_db(self):
"""Clear all data from the database."""
with self.conn.cursor() as cur:
cur.execute("TRUNCATE TABLE embeddings")
cur.execute("TRUNCATE TABLE documents CASCADE")
self.conn.commit()
print("Database cleared.")
def print_db_contents(self) -> List[Dict[str, any]]:
"""
Retrieve a summary of all documents in the database.
Returns:
List[Dict[str, any]]: A list of document summaries.
"""
docs = self.get_all_documents()
return [
{
"filename": doc["filename"],
"content_preview": doc["content"][:1000] + "..." if len(doc["content"]) > 1000 else doc["content"],
"chunks_count": len(doc["chunks"]),
"embeddings_count": len(doc["embeddings"])
}
for doc in docs
]
def __del__(self):
"""Close the database connection when the object is destroyed."""
if self.conn:
self.conn.close()