Initial commit
This commit is contained in:
commit
0cc501eec0
14 changed files with 535 additions and 0 deletions
48
.gitignore
vendored
Normal file
48
.gitignore
vendored
Normal file
|
@ -0,0 +1,48 @@
|
|||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
**/logs/
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# IDEs
|
||||
.idea/
|
||||
.vscode/
|
||||
*.iml
|
||||
|
||||
# Build & dependencies (Java)
|
||||
target/
|
||||
.mvn/
|
||||
mvnw
|
||||
mvnw.cmd
|
||||
|
||||
# Java cache
|
||||
.classpath
|
||||
.project
|
||||
.settings/
|
||||
.gradle/
|
||||
build/
|
||||
|
||||
# Angular
|
||||
node_modules/
|
||||
dist/
|
||||
.angular/
|
||||
.output/
|
||||
.env
|
||||
.env.*
|
||||
|
||||
# Angular cache
|
||||
.turbo/
|
||||
.npm/
|
||||
.cache/
|
||||
|
||||
# Docker
|
||||
*.pid
|
||||
docker-compose.override.yml
|
||||
|
||||
# Misc
|
||||
*.bak
|
||||
*.swp
|
||||
*.tmp
|
61
pom.xml
Normal file
61
pom.xml
Normal file
|
@ -0,0 +1,61 @@
|
|||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-parent</artifactId>
|
||||
<version>3.2.5</version>
|
||||
<relativePath/> <!-- lookup parent from repository -->
|
||||
</parent>
|
||||
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>ai-chatbot</artifactId>
|
||||
<version>0.0.1-SNAPSHOT</version>
|
||||
|
||||
<name>AI Chatbot</name>
|
||||
<description>Simple Spring Boot AI chatbot integration</description>
|
||||
|
||||
<properties>
|
||||
<java.version>17</java.version>
|
||||
<spring-boot.version>3.2.5</spring-boot.version>
|
||||
</properties>
|
||||
<dependencies>
|
||||
<!-- Spring Boot Starter Web for the REST API -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- OpenAI Java service -->
|
||||
<dependency>
|
||||
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||
<artifactId>service</artifactId>
|
||||
<version>0.12.0</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Spring Boot Starter Test for JUnit and other test dependencies -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- AssertJ dependency for fluent assertions -->
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-maven-plugin</artifactId>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
14
src/main/embedder/dockerfile
Normal file
14
src/main/embedder/dockerfile
Normal file
|
@ -0,0 +1,14 @@
|
|||
# Dockerfile
|
||||
FROM python:3.10-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements first to leverage Docker layer caching
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install dependencies including weaviate-client
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY server.py .
|
||||
|
||||
CMD ["python", "server.py"]
|
4
src/main/embedder/requirements.txt
Normal file
4
src/main/embedder/requirements.txt
Normal file
|
@ -0,0 +1,4 @@
|
|||
flask
|
||||
sentence-transformers
|
||||
weaviate-client
|
||||
openai
|
86
src/main/embedder/server.py
Normal file
86
src/main/embedder/server.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
from flask import Flask, request, jsonify
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import weaviate
|
||||
import openai
|
||||
import os
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Load local model
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
# Configure Weaviate client (assumes running on localhost)
|
||||
weaviate_client = weaviate.Client("http://localhost:8083")
|
||||
|
||||
# OpenAI setup
|
||||
openai.api_key = os.getenv("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")
|
||||
|
||||
# Create Weaviate schema (if not exists)
|
||||
CLASS_NAME = "Document"
|
||||
|
||||
if not weaviate_client.schema.contains({"class": CLASS_NAME}):
|
||||
weaviate_client.schema.create_class({
|
||||
"class": CLASS_NAME,
|
||||
"vectorizer": "none", # Because we use external embeddings
|
||||
"properties": [
|
||||
{
|
||||
"name": "content",
|
||||
"dataType": ["text"]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
@app.route("/ingest", methods=["POST"])
|
||||
def ingest():
|
||||
content = request.json.get("content")
|
||||
vector = model.encode([content])[0].tolist()
|
||||
|
||||
weaviate_client.data_object.create(
|
||||
data_object={"content": content},
|
||||
class_name=CLASS_NAME,
|
||||
vector=vector
|
||||
)
|
||||
return jsonify({"status": "saved"})
|
||||
|
||||
|
||||
@app.route("/query", methods=["POST"])
|
||||
def query():
|
||||
question = request.json.get("question")
|
||||
q_vector = model.encode([question])[0].tolist()
|
||||
|
||||
response = weaviate_client.query.get(CLASS_NAME, ["content"])\
|
||||
.with_near_vector({"vector": q_vector})\
|
||||
.with_limit(3).do()
|
||||
|
||||
relevant_chunks = [d["content"] for d in response["data"]["Get"][CLASS_NAME]]
|
||||
context = "\n".join(relevant_chunks)
|
||||
|
||||
prompt = f"""
|
||||
Answer the following question based on the context below.
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
"""
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
answer = completion['choices'][0]['message']['content'].strip()
|
||||
return jsonify({"answer": answer})
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def root():
|
||||
return "Embedding/QA API is running."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5000)
|
15
src/main/embedder/server.py.v1
Normal file
15
src/main/embedder/server.py.v1
Normal file
|
@ -0,0 +1,15 @@
|
|||
from flask import Flask, request, jsonify
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
app = Flask(__name__)
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
@app.route('/embed', methods=['POST'])
|
||||
def embed():
|
||||
data = request.json
|
||||
texts = data.get("texts", [])
|
||||
embeddings = model.encode(texts).tolist()
|
||||
return jsonify({"embeddings": embeddings})
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5000)
|
51
src/main/java/com/example/chatbot/ChatController.java
Normal file
51
src/main/java/com/example/chatbot/ChatController.java
Normal file
|
@ -0,0 +1,51 @@
|
|||
package com.example.chatbot;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import com.example.chatbot.service.EmbedderClient;
|
||||
import com.example.chatbot.service.WeaviateService;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/chat")
|
||||
public class ChatController {
|
||||
|
||||
private final EmbedderClient embedder;
|
||||
private final WeaviateService weaviate;
|
||||
private final OpenAiServiceWrapper openai;
|
||||
|
||||
public ChatController(EmbedderClient embedder, WeaviateService weaviate, OpenAiServiceWrapper openai) {
|
||||
this.embedder = embedder;
|
||||
this.weaviate = weaviate;
|
||||
this.openai = openai;
|
||||
}
|
||||
|
||||
@PostMapping("/ask")
|
||||
public ResponseEntity<String> ask(@RequestBody Map<String, String> payload) throws JsonProcessingException {
|
||||
String question = payload.get("question");
|
||||
List<Double> vector = embedder.embedText(question);
|
||||
List<String> context = weaviate.searchRelevant(question, vector);
|
||||
String answer = openai.processQuestion(question, context);
|
||||
return ResponseEntity.ok(answer);
|
||||
}
|
||||
|
||||
@PostMapping("/ingest")
|
||||
public ResponseEntity<String> ingest(@RequestBody Map<String, String> payload) {
|
||||
String content = payload.get("content");
|
||||
List<Double> vector = embedder.embedText(content);
|
||||
weaviate.ingest(content, vector);
|
||||
return ResponseEntity.ok("Saved");
|
||||
}
|
||||
|
||||
@PostMapping("/askSimple")
|
||||
public String ask(@RequestBody String question) {
|
||||
return openai.processQuestion2(question);
|
||||
}
|
||||
}
|
11
src/main/java/com/example/chatbot/ChatbotApplication.java
Normal file
11
src/main/java/com/example/chatbot/ChatbotApplication.java
Normal file
|
@ -0,0 +1,11 @@
|
|||
package com.example.chatbot;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
|
||||
@SpringBootApplication
|
||||
public class ChatbotApplication {
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(ChatbotApplication.class, args);
|
||||
}
|
||||
}
|
48
src/main/java/com/example/chatbot/OpenAiServiceWrapper.java
Normal file
48
src/main/java/com/example/chatbot/OpenAiServiceWrapper.java
Normal file
|
@ -0,0 +1,48 @@
|
|||
package com.example.chatbot;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import com.theokanning.openai.completion.CompletionRequest;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import com.theokanning.openai.service.OpenAiService;
|
||||
|
||||
@Service
|
||||
public class OpenAiServiceWrapper {
|
||||
|
||||
private final OpenAiService service;
|
||||
|
||||
@Value("${openai.model}")
|
||||
private String openAiModel;
|
||||
|
||||
@Value("${weaviate.url}")
|
||||
private String weaviateUrl;
|
||||
|
||||
public OpenAiServiceWrapper(@Value("${openai.api-key}") String openAiApiKey) {
|
||||
this.service = new OpenAiService(openAiApiKey);
|
||||
}
|
||||
|
||||
public String processQuestion2(String question) {
|
||||
CompletionRequest request = CompletionRequest.builder().model(this.openAiModel).prompt(
|
||||
"You are a helpful assistant for interpreting database questions.\nQuestion: " + question).maxTokens(
|
||||
150).temperature(0.7).build();
|
||||
|
||||
return service.createCompletion(request).getChoices().get(0).getText().trim();
|
||||
}
|
||||
|
||||
public String processQuestion(String question, List<String> context) {
|
||||
String fullPrompt = "Answer the question based on the context below.\n\n" + "Context:\n"
|
||||
+ String.join("\n", context) + "\n\n" + "Question: " + question;
|
||||
|
||||
ChatMessage system = new ChatMessage("system", "You are a helpful assistant.");
|
||||
ChatMessage user = new ChatMessage("user", fullPrompt);
|
||||
|
||||
ChatCompletionRequest request = ChatCompletionRequest.builder().model(openAiModel).messages(
|
||||
List.of(system, user)).temperature(0.7).build();
|
||||
|
||||
return service.createChatCompletion(request).getChoices().get(0).getMessage().getContent().trim();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package com.example.chatbot.service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Service
|
||||
public class EmbedderClient {
|
||||
|
||||
@Value("${embedder.url}")
|
||||
private String embedderUrl;
|
||||
|
||||
private final RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
public List<Double> embedText(String text) {
|
||||
Map<String, Object> payload = Map.of("texts", List.of(text));
|
||||
Map<String, List<List<Double>>> response = restTemplate.postForObject(embedderUrl, payload, Map.class);
|
||||
return response.get("embeddings").get(0);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package com.example.chatbot.service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
@Service
|
||||
public class WeaviateService {
|
||||
|
||||
@Value("${weaviate.url}")
|
||||
private String weaviateUrl;
|
||||
|
||||
private final RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
public void ingest(String content, List<Double> vector) {
|
||||
Map<String, Object> data = Map.of("content", content);
|
||||
Map<String, Object> request = Map.of("class", "Document", "properties", data, "vector", vector);
|
||||
|
||||
restTemplate.postForObject(weaviateUrl + "/v1/objects", request, String.class);
|
||||
}
|
||||
|
||||
public List<String> searchRelevant(String question, List<Double> qVector) throws JsonProcessingException {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
Map<String, Object> nearVector = Map.of("vector", qVector);
|
||||
String gql = """
|
||||
{
|
||||
Get {
|
||||
Document(nearVector: %s, limit: 3) {
|
||||
content
|
||||
}
|
||||
}
|
||||
}
|
||||
""".formatted(new ObjectMapper().writeValueAsString(nearVector));
|
||||
|
||||
Map<String, String> query = Map.of("query", gql);
|
||||
HttpEntity<Map<String, String>> entity = new HttpEntity<>(query, headers);
|
||||
|
||||
ResponseEntity<JsonNode> resp = restTemplate.postForEntity(weaviateUrl + "/v1/graphql", entity, JsonNode.class);
|
||||
|
||||
List<String> results = new ArrayList<>();
|
||||
resp.getBody().at("/data/Get/Document").forEach(item -> results.add(item.get("content").asText()));
|
||||
return results;
|
||||
}
|
||||
}
|
12
src/main/resources/application.yml
Normal file
12
src/main/resources/application.yml
Normal file
|
@ -0,0 +1,12 @@
|
|||
server:
|
||||
port: 8082
|
||||
|
||||
weaviate:
|
||||
url: http://localhost:8083
|
||||
|
||||
embedding:
|
||||
url: http://localhost:8000/embed
|
||||
|
||||
openai:
|
||||
api-key: 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
|
||||
model: gpt-3.5-turbo-instruct
|
22
src/main/weaviate/docker-compose.yml
Normal file
22
src/main/weaviate/docker-compose.yml
Normal file
|
@ -0,0 +1,22 @@
|
|||
version: "3.8"
|
||||
services:
|
||||
weaviate:
|
||||
image: semitechnologies/weaviate:latest
|
||||
environment:
|
||||
- QUERY_DEFAULTS=100
|
||||
- AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true
|
||||
- CLUSTER_HOSTNAME=node1
|
||||
- CLUSTER_MODE=OFF
|
||||
- PERSISTENCE_DATA_PATH=/var/lib/weaviate
|
||||
- WEAVIATE_ENABLE_MODULES=transformers,text2vec-openai
|
||||
- OPENAI_API_KEY=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
|
||||
ports:
|
||||
- "8083:8083"
|
||||
volumes:
|
||||
- weaviate_data:/var/lib/weaviate
|
||||
networks:
|
||||
- weaviate-net
|
||||
volumes:
|
||||
weaviate_data:
|
||||
networks:
|
||||
weaviate-net:
|
83
src/test/java/com/example/chatbot/ChatControllerTest.java
Normal file
83
src/test/java/com/example/chatbot/ChatControllerTest.java
Normal file
|
@ -0,0 +1,83 @@
|
|||
package com.example.chatbot;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.yaml.snakeyaml.Yaml;
|
||||
|
||||
public class ChatControllerTest {
|
||||
|
||||
private static final String API_URL = "http://localhost:8080/chat/askSimple"; // Change this if needed to match your server's address
|
||||
|
||||
private static String embeddingUrl;
|
||||
|
||||
@BeforeAll
|
||||
public static void loadYaml() {
|
||||
// This loads application.yml config. Can't use SpringBoot (magic) infraestruture because I can't setup the TestCase as @SpringBootTest
|
||||
Yaml yaml = new Yaml();
|
||||
try (InputStream in = ChatControllerTest.class.getClassLoader().getResourceAsStream("application.yml")) {
|
||||
Map<String, Object> obj = yaml.load(in);
|
||||
|
||||
// Property embedding.url
|
||||
Map<String, Object> embedding = (Map<String, Object>) obj.get("embedding");
|
||||
embeddingUrl = (String) embedding.get("url");
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to load YAML config", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmbeddingEndpoint() {
|
||||
String textToEmbedd = "ABC spent 5000 on gasoline in 2023";
|
||||
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
|
||||
Map<String, Object> request = Map.of("texts", List.of(textToEmbedd));
|
||||
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
|
||||
|
||||
Map response = restTemplate.postForObject(embeddingUrl, entity, Map.class);
|
||||
List<Double> embedding = (List<Double>) ((List<?>) response.get("embeddings")).get(0);
|
||||
|
||||
System.out.println(embedding);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChatEndpoint() {
|
||||
String question = "How much did company ABC spend on Gasoline in 2023?";
|
||||
|
||||
// Set the headers for the request
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.TEXT_PLAIN);
|
||||
|
||||
// Create the HTTP request entity with the question body and headers
|
||||
HttpEntity<String> request = new HttpEntity<>(question, headers);
|
||||
|
||||
// Use RestTemplate to make the HTTP POST request
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
ResponseEntity<String> response = restTemplate.exchange(API_URL, HttpMethod.POST, request, String.class);
|
||||
|
||||
// Assert that the response status code is OK
|
||||
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
|
||||
|
||||
// Assert that the response body is not blank (contains a valid AI response)
|
||||
assertThat(response.getBody()).isNotBlank();
|
||||
System.out.println("AI Response: " + response.getBody());
|
||||
}
|
||||
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue