Initial commit
This commit is contained in:
commit
0cc501eec0
14 changed files with 535 additions and 0 deletions
86
src/main/embedder/server.py
Normal file
86
src/main/embedder/server.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
from flask import Flask, request, jsonify
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import weaviate
|
||||
import openai
|
||||
import os
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Load local model
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
# Configure Weaviate client (assumes running on localhost)
|
||||
weaviate_client = weaviate.Client("http://localhost:8083")
|
||||
|
||||
# OpenAI setup
|
||||
openai.api_key = os.getenv("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")
|
||||
|
||||
# Create Weaviate schema (if not exists)
|
||||
CLASS_NAME = "Document"
|
||||
|
||||
if not weaviate_client.schema.contains({"class": CLASS_NAME}):
|
||||
weaviate_client.schema.create_class({
|
||||
"class": CLASS_NAME,
|
||||
"vectorizer": "none", # Because we use external embeddings
|
||||
"properties": [
|
||||
{
|
||||
"name": "content",
|
||||
"dataType": ["text"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
@app.route("/ingest", methods=["POST"])
|
||||
def ingest():
|
||||
content = request.json.get("content")
|
||||
vector = model.encode([content])[0].tolist()
|
||||
|
||||
weaviate_client.data_object.create(
|
||||
data_object={"content": content},
|
||||
class_name=CLASS_NAME,
|
||||
vector=vector
|
||||
)
|
||||
return jsonify({"status": "saved"})
|
||||
|
||||
|
||||
@app.route("/query", methods=["POST"])
|
||||
def query():
|
||||
question = request.json.get("question")
|
||||
q_vector = model.encode([question])[0].tolist()
|
||||
|
||||
response = weaviate_client.query.get(CLASS_NAME, ["content"])\
|
||||
.with_near_vector({"vector": q_vector})\
|
||||
.with_limit(3).do()
|
||||
|
||||
relevant_chunks = [d["content"] for d in response["data"]["Get"][CLASS_NAME]]
|
||||
context = "\n".join(relevant_chunks)
|
||||
|
||||
prompt = f"""
|
||||
Answer the following question based on the context below.
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
answer = completion['choices'][0]['message']['content'].strip()
|
||||
return jsonify({"answer": answer})
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def root():
|
||||
return "Embedding/QA API is running."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5000)
|
Loading…
Add table
Add a link
Reference in a new issue