commit 0cc501eec0b66f0519362fb0ed48f3c7f12ee2b5 Author: Orlando M Guerreiro Date: Tue May 20 14:59:21 2025 +0100 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c44e2a3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Logs +*.log +logs/ +**/logs/ + +# OS files +.DS_Store +Thumbs.db + +# IDEs +.idea/ +.vscode/ +*.iml + +# Build & dependencies (Java) +target/ +.mvn/ +mvnw +mvnw.cmd + +# Java cache +.classpath +.project +.settings/ +.gradle/ +build/ + +# Angular +node_modules/ +dist/ +.angular/ +.output/ +.env +.env.* + +# Angular cache +.turbo/ +.npm/ +.cache/ + +# Docker +*.pid +docker-compose.override.yml + +# Misc +*.bak +*.swp +*.tmp diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..fc683ef --- /dev/null +++ b/pom.xml @@ -0,0 +1,61 @@ + + 4.0.0 + + + org.springframework.boot + spring-boot-starter-parent + 3.2.5 + + + + com.example + ai-chatbot + 0.0.1-SNAPSHOT + + AI Chatbot + Simple Spring Boot AI chatbot integration + + + 17 + 3.2.5 + + + + + org.springframework.boot + spring-boot-starter-web + + + + + com.theokanning.openai-gpt3-java + service + 0.12.0 + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + + org.assertj + assertj-core + test + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + diff --git a/src/main/embedder/dockerfile b/src/main/embedder/dockerfile new file mode 100644 index 0000000..d0b6bec --- /dev/null +++ b/src/main/embedder/dockerfile @@ -0,0 +1,14 @@ +# Dockerfile +FROM python:3.10-slim + +WORKDIR /app + +# Copy requirements first to leverage Docker layer caching +COPY requirements.txt . + +# Install dependencies including weaviate-client +RUN pip install --no-cache-dir -r requirements.txt + +COPY server.py . + +CMD ["python", "server.py"] diff --git a/src/main/embedder/requirements.txt b/src/main/embedder/requirements.txt new file mode 100644 index 0000000..1bac6d6 --- /dev/null +++ b/src/main/embedder/requirements.txt @@ -0,0 +1,4 @@ +flask +sentence-transformers +weaviate-client +openai diff --git a/src/main/embedder/server.py b/src/main/embedder/server.py new file mode 100644 index 0000000..df3731c --- /dev/null +++ b/src/main/embedder/server.py @@ -0,0 +1,86 @@ +from flask import Flask, request, jsonify +from sentence_transformers import SentenceTransformer +import weaviate +import openai +import os + +app = Flask(__name__) + +# Load local model +model = SentenceTransformer('all-MiniLM-L6-v2') + +# Configure Weaviate client (assumes running on localhost) +weaviate_client = weaviate.Client("http://localhost:8083") + +# OpenAI setup +openai.api_key = os.getenv("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") + +# Create Weaviate schema (if not exists) +CLASS_NAME = "Document" + +if not weaviate_client.schema.contains({"class": CLASS_NAME}): + weaviate_client.schema.create_class({ + "class": CLASS_NAME, + "vectorizer": "none", # Because we use external embeddings + "properties": [ + { + "name": "content", + "dataType": ["text"] + } + ] + }) + + +@app.route("/ingest", methods=["POST"]) +def ingest(): + content = request.json.get("content") + vector = model.encode([content])[0].tolist() + + weaviate_client.data_object.create( + data_object={"content": content}, + class_name=CLASS_NAME, + vector=vector + ) + return jsonify({"status": "saved"}) + + +@app.route("/query", methods=["POST"]) +def query(): + question = request.json.get("question") + q_vector = model.encode([question])[0].tolist() + + response = weaviate_client.query.get(CLASS_NAME, ["content"])\ + .with_near_vector({"vector": q_vector})\ + .with_limit(3).do() + + relevant_chunks = [d["content"] for d in response["data"]["Get"][CLASS_NAME]] + context = "\n".join(relevant_chunks) + + prompt = f""" +Answer the following question based on the context below. + +Context: +{context} + +Question: {question} +""" + + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ] + ) + + answer = completion['choices'][0]['message']['content'].strip() + return jsonify({"answer": answer}) + + +@app.route("/") +def root(): + return "Embedding/QA API is running." + + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=5000) diff --git a/src/main/embedder/server.py.v1 b/src/main/embedder/server.py.v1 new file mode 100644 index 0000000..7559c72 --- /dev/null +++ b/src/main/embedder/server.py.v1 @@ -0,0 +1,15 @@ +from flask import Flask, request, jsonify +from sentence_transformers import SentenceTransformer + +app = Flask(__name__) +model = SentenceTransformer('all-MiniLM-L6-v2') + +@app.route('/embed', methods=['POST']) +def embed(): + data = request.json + texts = data.get("texts", []) + embeddings = model.encode(texts).tolist() + return jsonify({"embeddings": embeddings}) + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=5000) diff --git a/src/main/java/com/example/chatbot/ChatController.java b/src/main/java/com/example/chatbot/ChatController.java new file mode 100644 index 0000000..50e0a47 --- /dev/null +++ b/src/main/java/com/example/chatbot/ChatController.java @@ -0,0 +1,51 @@ +package com.example.chatbot; + +import java.util.List; +import java.util.Map; + +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import com.example.chatbot.service.EmbedderClient; +import com.example.chatbot.service.WeaviateService; +import com.fasterxml.jackson.core.JsonProcessingException; + +@RestController +@RequestMapping("/chat") +public class ChatController { + + private final EmbedderClient embedder; + private final WeaviateService weaviate; + private final OpenAiServiceWrapper openai; + + public ChatController(EmbedderClient embedder, WeaviateService weaviate, OpenAiServiceWrapper openai) { + this.embedder = embedder; + this.weaviate = weaviate; + this.openai = openai; + } + + @PostMapping("/ask") + public ResponseEntity ask(@RequestBody Map payload) throws JsonProcessingException { + String question = payload.get("question"); + List vector = embedder.embedText(question); + List context = weaviate.searchRelevant(question, vector); + String answer = openai.processQuestion(question, context); + return ResponseEntity.ok(answer); + } + + @PostMapping("/ingest") + public ResponseEntity ingest(@RequestBody Map payload) { + String content = payload.get("content"); + List vector = embedder.embedText(content); + weaviate.ingest(content, vector); + return ResponseEntity.ok("Saved"); + } + + @PostMapping("/askSimple") + public String ask(@RequestBody String question) { + return openai.processQuestion2(question); + } +} diff --git a/src/main/java/com/example/chatbot/ChatbotApplication.java b/src/main/java/com/example/chatbot/ChatbotApplication.java new file mode 100644 index 0000000..22fb87d --- /dev/null +++ b/src/main/java/com/example/chatbot/ChatbotApplication.java @@ -0,0 +1,11 @@ +package com.example.chatbot; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication +public class ChatbotApplication { + public static void main(String[] args) { + SpringApplication.run(ChatbotApplication.class, args); + } +} diff --git a/src/main/java/com/example/chatbot/OpenAiServiceWrapper.java b/src/main/java/com/example/chatbot/OpenAiServiceWrapper.java new file mode 100644 index 0000000..5a41786 --- /dev/null +++ b/src/main/java/com/example/chatbot/OpenAiServiceWrapper.java @@ -0,0 +1,48 @@ +package com.example.chatbot; + +import java.util.List; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +import com.theokanning.openai.completion.CompletionRequest; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.service.OpenAiService; + +@Service +public class OpenAiServiceWrapper { + + private final OpenAiService service; + + @Value("${openai.model}") + private String openAiModel; + + @Value("${weaviate.url}") + private String weaviateUrl; + + public OpenAiServiceWrapper(@Value("${openai.api-key}") String openAiApiKey) { + this.service = new OpenAiService(openAiApiKey); + } + + public String processQuestion2(String question) { + CompletionRequest request = CompletionRequest.builder().model(this.openAiModel).prompt( + "You are a helpful assistant for interpreting database questions.\nQuestion: " + question).maxTokens( + 150).temperature(0.7).build(); + + return service.createCompletion(request).getChoices().get(0).getText().trim(); + } + + public String processQuestion(String question, List context) { + String fullPrompt = "Answer the question based on the context below.\n\n" + "Context:\n" + + String.join("\n", context) + "\n\n" + "Question: " + question; + + ChatMessage system = new ChatMessage("system", "You are a helpful assistant."); + ChatMessage user = new ChatMessage("user", fullPrompt); + + ChatCompletionRequest request = ChatCompletionRequest.builder().model(openAiModel).messages( + List.of(system, user)).temperature(0.7).build(); + + return service.createChatCompletion(request).getChoices().get(0).getMessage().getContent().trim(); + } +} diff --git a/src/main/java/com/example/chatbot/service/EmbedderClient.java b/src/main/java/com/example/chatbot/service/EmbedderClient.java new file mode 100644 index 0000000..cb7a857 --- /dev/null +++ b/src/main/java/com/example/chatbot/service/EmbedderClient.java @@ -0,0 +1,23 @@ +package com.example.chatbot.service; + +import java.util.List; +import java.util.Map; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +@Service +public class EmbedderClient { + + @Value("${embedder.url}") + private String embedderUrl; + + private final RestTemplate restTemplate = new RestTemplate(); + + public List embedText(String text) { + Map payload = Map.of("texts", List.of(text)); + Map>> response = restTemplate.postForObject(embedderUrl, payload, Map.class); + return response.get("embeddings").get(0); + } +} diff --git a/src/main/java/com/example/chatbot/service/WeaviateService.java b/src/main/java/com/example/chatbot/service/WeaviateService.java new file mode 100644 index 0000000..3457f7a --- /dev/null +++ b/src/main/java/com/example/chatbot/service/WeaviateService.java @@ -0,0 +1,57 @@ +package com.example.chatbot.service; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +@Service +public class WeaviateService { + + @Value("${weaviate.url}") + private String weaviateUrl; + + private final RestTemplate restTemplate = new RestTemplate(); + + public void ingest(String content, List vector) { + Map data = Map.of("content", content); + Map request = Map.of("class", "Document", "properties", data, "vector", vector); + + restTemplate.postForObject(weaviateUrl + "/v1/objects", request, String.class); + } + + public List searchRelevant(String question, List qVector) throws JsonProcessingException { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + Map nearVector = Map.of("vector", qVector); + String gql = """ + { + Get { + Document(nearVector: %s, limit: 3) { + content + } + } + } + """.formatted(new ObjectMapper().writeValueAsString(nearVector)); + + Map query = Map.of("query", gql); + HttpEntity> entity = new HttpEntity<>(query, headers); + + ResponseEntity resp = restTemplate.postForEntity(weaviateUrl + "/v1/graphql", entity, JsonNode.class); + + List results = new ArrayList<>(); + resp.getBody().at("/data/Get/Document").forEach(item -> results.add(item.get("content").asText())); + return results; + } +} diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml new file mode 100644 index 0000000..495a2a3 --- /dev/null +++ b/src/main/resources/application.yml @@ -0,0 +1,12 @@ +server: + port: 8082 + +weaviate: + url: http://localhost:8083 + +embedding: + url: http://localhost:8000/embed + +openai: + api-key: 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 + model: gpt-3.5-turbo-instruct diff --git a/src/main/weaviate/docker-compose.yml b/src/main/weaviate/docker-compose.yml new file mode 100644 index 0000000..600f4e7 --- /dev/null +++ b/src/main/weaviate/docker-compose.yml @@ -0,0 +1,22 @@ +version: "3.8" +services: + weaviate: + image: semitechnologies/weaviate:latest + environment: + - QUERY_DEFAULTS=100 + - AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true + - CLUSTER_HOSTNAME=node1 + - CLUSTER_MODE=OFF + - PERSISTENCE_DATA_PATH=/var/lib/weaviate + - WEAVIATE_ENABLE_MODULES=transformers,text2vec-openai + - OPENAI_API_KEY=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 + ports: + - "8083:8083" + volumes: + - weaviate_data:/var/lib/weaviate + networks: + - weaviate-net +volumes: + weaviate_data: +networks: + weaviate-net: diff --git a/src/test/java/com/example/chatbot/ChatControllerTest.java b/src/test/java/com/example/chatbot/ChatControllerTest.java new file mode 100644 index 0000000..ef620af --- /dev/null +++ b/src/test/java/com/example/chatbot/ChatControllerTest.java @@ -0,0 +1,83 @@ +package com.example.chatbot; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.InputStream; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; +import org.yaml.snakeyaml.Yaml; + +public class ChatControllerTest { + + private static final String API_URL = "http://localhost:8080/chat/askSimple"; // Change this if needed to match your server's address + + private static String embeddingUrl; + + @BeforeAll + public static void loadYaml() { + // This loads application.yml config. Can't use SpringBoot (magic) infraestruture because I can't setup the TestCase as @SpringBootTest + Yaml yaml = new Yaml(); + try (InputStream in = ChatControllerTest.class.getClassLoader().getResourceAsStream("application.yml")) { + Map obj = yaml.load(in); + + // Property embedding.url + Map embedding = (Map) obj.get("embedding"); + embeddingUrl = (String) embedding.get("url"); + + } catch (Exception e) { + throw new RuntimeException("Failed to load YAML config", e); + } + } + + @Test + public void testEmbeddingEndpoint() { + String textToEmbedd = "ABC spent 5000 on gasoline in 2023"; + + RestTemplate restTemplate = new RestTemplate(); + + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + + Map request = Map.of("texts", List.of(textToEmbedd)); + HttpEntity> entity = new HttpEntity<>(request, headers); + + Map response = restTemplate.postForObject(embeddingUrl, entity, Map.class); + List embedding = (List) ((List) response.get("embeddings")).get(0); + + System.out.println(embedding); + } + + @Test + public void testChatEndpoint() { + String question = "How much did company ABC spend on Gasoline in 2023?"; + + // Set the headers for the request + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + // Create the HTTP request entity with the question body and headers + HttpEntity request = new HttpEntity<>(question, headers); + + // Use RestTemplate to make the HTTP POST request + RestTemplate restTemplate = new RestTemplate(); + ResponseEntity response = restTemplate.exchange(API_URL, HttpMethod.POST, request, String.class); + + // Assert that the response status code is OK + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + // Assert that the response body is not blank (contains a valid AI response) + assertThat(response.getBody()).isNotBlank(); + System.out.println("AI Response: " + response.getBody()); + } + +}