-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfastest_api.py
78 lines (61 loc) · 1.94 KB
/
fastest_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
This module defines the FastAPI application for the RAG (Retrieval-Augmented Generation) system.
It provides endpoints for document upload, querying, database operations, and content retrieval.
"""
import os
import tempfile
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
from rag_system import RAGSystem
app = FastAPI()
# Initialize the RAG system with PostgreSQL connection details
rag_system = RAGSystem(postgres_host='postgres', postgres_port=5432)
class QueryModel(BaseModel):
text: str
@app.post("/upsert")
async def upsert_document(file: UploadFile = File(...)):
"""
Endpoint to upload and process a document.
Args:
file (UploadFile): The PDF file to be uploaded and processed.
Returns:
dict: A message indicating the success of the operation.
"""
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
try:
result = rag_system.upsert_document(tmp_path)
return result
finally:
os.unlink(tmp_path)
@app.post("/query")
async def query(query: QueryModel):
"""
Endpoint to query the RAG system.
Args:
query (QueryModel): The query text.
Returns:
dict: The answer to the query and related information.
"""
return rag_system.query(query.text)
@app.post("/clear_db")
async def clear_db():
"""
Endpoint to clear the database. Utility endpoint for testing.
Returns:
dict: A message indicating the success of the operation.
"""
return rag_system.clear_db()
@app.get("/print_db")
async def print_db():
"""
Endpoint to retrieve and print the contents of the database.
Utility endpoint for testing.
Returns:
dict: The contents of the database.
"""
return rag_system.print_db_contents()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)