Initial commit
This commit is contained in:
commit
0cc501eec0
14 changed files with 535 additions and 0 deletions
48
.gitignore
vendored
Normal file
48
.gitignore
vendored
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
**/logs/
|
||||||
|
|
||||||
|
# OS files
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# IDEs
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.iml
|
||||||
|
|
||||||
|
# Build & dependencies (Java)
|
||||||
|
target/
|
||||||
|
.mvn/
|
||||||
|
mvnw
|
||||||
|
mvnw.cmd
|
||||||
|
|
||||||
|
# Java cache
|
||||||
|
.classpath
|
||||||
|
.project
|
||||||
|
.settings/
|
||||||
|
.gradle/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Angular
|
||||||
|
node_modules/
|
||||||
|
dist/
|
||||||
|
.angular/
|
||||||
|
.output/
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
|
||||||
|
# Angular cache
|
||||||
|
.turbo/
|
||||||
|
.npm/
|
||||||
|
.cache/
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
*.pid
|
||||||
|
docker-compose.override.yml
|
||||||
|
|
||||||
|
# Misc
|
||||||
|
*.bak
|
||||||
|
*.swp
|
||||||
|
*.tmp
|
61
pom.xml
Normal file
61
pom.xml
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<parent>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-parent</artifactId>
|
||||||
|
<version>3.2.5</version>
|
||||||
|
<relativePath/> <!-- lookup parent from repository -->
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
<groupId>com.example</groupId>
|
||||||
|
<artifactId>ai-chatbot</artifactId>
|
||||||
|
<version>0.0.1-SNAPSHOT</version>
|
||||||
|
|
||||||
|
<name>AI Chatbot</name>
|
||||||
|
<description>Simple Spring Boot AI chatbot integration</description>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<java.version>17</java.version>
|
||||||
|
<spring-boot.version>3.2.5</spring-boot.version>
|
||||||
|
</properties>
|
||||||
|
<dependencies>
|
||||||
|
<!-- Spring Boot Starter Web for the REST API -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!-- OpenAI Java service -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||||
|
<artifactId>service</artifactId>
|
||||||
|
<version>0.12.0</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!-- Spring Boot Starter Test for JUnit and other test dependencies -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-test</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!-- AssertJ dependency for fluent assertions -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.assertj</groupId>
|
||||||
|
<artifactId>assertj-core</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
</dependencies>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-maven-plugin</artifactId>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
</project>
|
14
src/main/embedder/dockerfile
Normal file
14
src/main/embedder/dockerfile
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# Dockerfile
|
||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy requirements first to leverage Docker layer caching
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# Install dependencies including weaviate-client
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY server.py .
|
||||||
|
|
||||||
|
CMD ["python", "server.py"]
|
4
src/main/embedder/requirements.txt
Normal file
4
src/main/embedder/requirements.txt
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
flask
|
||||||
|
sentence-transformers
|
||||||
|
weaviate-client
|
||||||
|
openai
|
86
src/main/embedder/server.py
Normal file
86
src/main/embedder/server.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import weaviate
|
||||||
|
import openai
|
||||||
|
import os
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# Load local model
|
||||||
|
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
|
||||||
|
# Configure Weaviate client (assumes running on localhost)
|
||||||
|
weaviate_client = weaviate.Client("http://localhost:8083")
|
||||||
|
|
||||||
|
# OpenAI setup
|
||||||
|
openai.api_key = os.getenv("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")
|
||||||
|
|
||||||
|
# Create Weaviate schema (if not exists)
|
||||||
|
CLASS_NAME = "Document"
|
||||||
|
|
||||||
|
if not weaviate_client.schema.contains({"class": CLASS_NAME}):
|
||||||
|
weaviate_client.schema.create_class({
|
||||||
|
"class": CLASS_NAME,
|
||||||
|
"vectorizer": "none", # Because we use external embeddings
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"name": "content",
|
||||||
|
"dataType": ["text"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/ingest", methods=["POST"])
|
||||||
|
def ingest():
|
||||||
|
content = request.json.get("content")
|
||||||
|
vector = model.encode([content])[0].tolist()
|
||||||
|
|
||||||
|
weaviate_client.data_object.create(
|
||||||
|
data_object={"content": content},
|
||||||
|
class_name=CLASS_NAME,
|
||||||
|
vector=vector
|
||||||
|
)
|
||||||
|
return jsonify({"status": "saved"})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/query", methods=["POST"])
|
||||||
|
def query():
|
||||||
|
question = request.json.get("question")
|
||||||
|
q_vector = model.encode([question])[0].tolist()
|
||||||
|
|
||||||
|
response = weaviate_client.query.get(CLASS_NAME, ["content"])\
|
||||||
|
.with_near_vector({"vector": q_vector})\
|
||||||
|
.with_limit(3).do()
|
||||||
|
|
||||||
|
relevant_chunks = [d["content"] for d in response["data"]["Get"][CLASS_NAME]]
|
||||||
|
context = "\n".join(relevant_chunks)
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
Answer the following question based on the context below.
|
||||||
|
|
||||||
|
Context:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
"""
|
||||||
|
|
||||||
|
completion = openai.ChatCompletion.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = completion['choices'][0]['message']['content'].strip()
|
||||||
|
return jsonify({"answer": answer})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def root():
|
||||||
|
return "Embedding/QA API is running."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host="0.0.0.0", port=5000)
|
15
src/main/embedder/server.py.v1
Normal file
15
src/main/embedder/server.py.v1
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
|
||||||
|
@app.route('/embed', methods=['POST'])
|
||||||
|
def embed():
|
||||||
|
data = request.json
|
||||||
|
texts = data.get("texts", [])
|
||||||
|
embeddings = model.encode(texts).tolist()
|
||||||
|
return jsonify({"embeddings": embeddings})
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host="0.0.0.0", port=5000)
|
51
src/main/java/com/example/chatbot/ChatController.java
Normal file
51
src/main/java/com/example/chatbot/ChatController.java
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package com.example.chatbot;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
|
import com.example.chatbot.service.EmbedderClient;
|
||||||
|
import com.example.chatbot.service.WeaviateService;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("/chat")
|
||||||
|
public class ChatController {
|
||||||
|
|
||||||
|
private final EmbedderClient embedder;
|
||||||
|
private final WeaviateService weaviate;
|
||||||
|
private final OpenAiServiceWrapper openai;
|
||||||
|
|
||||||
|
public ChatController(EmbedderClient embedder, WeaviateService weaviate, OpenAiServiceWrapper openai) {
|
||||||
|
this.embedder = embedder;
|
||||||
|
this.weaviate = weaviate;
|
||||||
|
this.openai = openai;
|
||||||
|
}
|
||||||
|
|
||||||
|
@PostMapping("/ask")
|
||||||
|
public ResponseEntity<String> ask(@RequestBody Map<String, String> payload) throws JsonProcessingException {
|
||||||
|
String question = payload.get("question");
|
||||||
|
List<Double> vector = embedder.embedText(question);
|
||||||
|
List<String> context = weaviate.searchRelevant(question, vector);
|
||||||
|
String answer = openai.processQuestion(question, context);
|
||||||
|
return ResponseEntity.ok(answer);
|
||||||
|
}
|
||||||
|
|
||||||
|
@PostMapping("/ingest")
|
||||||
|
public ResponseEntity<String> ingest(@RequestBody Map<String, String> payload) {
|
||||||
|
String content = payload.get("content");
|
||||||
|
List<Double> vector = embedder.embedText(content);
|
||||||
|
weaviate.ingest(content, vector);
|
||||||
|
return ResponseEntity.ok("Saved");
|
||||||
|
}
|
||||||
|
|
||||||
|
@PostMapping("/askSimple")
|
||||||
|
public String ask(@RequestBody String question) {
|
||||||
|
return openai.processQuestion2(question);
|
||||||
|
}
|
||||||
|
}
|
11
src/main/java/com/example/chatbot/ChatbotApplication.java
Normal file
11
src/main/java/com/example/chatbot/ChatbotApplication.java
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
package com.example.chatbot;
|
||||||
|
|
||||||
|
import org.springframework.boot.SpringApplication;
|
||||||
|
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||||
|
|
||||||
|
@SpringBootApplication
|
||||||
|
public class ChatbotApplication {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
SpringApplication.run(ChatbotApplication.class, args);
|
||||||
|
}
|
||||||
|
}
|
48
src/main/java/com/example/chatbot/OpenAiServiceWrapper.java
Normal file
48
src/main/java/com/example/chatbot/OpenAiServiceWrapper.java
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package com.example.chatbot;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import com.theokanning.openai.completion.CompletionRequest;
|
||||||
|
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||||
|
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||||
|
import com.theokanning.openai.service.OpenAiService;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
public class OpenAiServiceWrapper {
|
||||||
|
|
||||||
|
private final OpenAiService service;
|
||||||
|
|
||||||
|
@Value("${openai.model}")
|
||||||
|
private String openAiModel;
|
||||||
|
|
||||||
|
@Value("${weaviate.url}")
|
||||||
|
private String weaviateUrl;
|
||||||
|
|
||||||
|
public OpenAiServiceWrapper(@Value("${openai.api-key}") String openAiApiKey) {
|
||||||
|
this.service = new OpenAiService(openAiApiKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String processQuestion2(String question) {
|
||||||
|
CompletionRequest request = CompletionRequest.builder().model(this.openAiModel).prompt(
|
||||||
|
"You are a helpful assistant for interpreting database questions.\nQuestion: " + question).maxTokens(
|
||||||
|
150).temperature(0.7).build();
|
||||||
|
|
||||||
|
return service.createCompletion(request).getChoices().get(0).getText().trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String processQuestion(String question, List<String> context) {
|
||||||
|
String fullPrompt = "Answer the question based on the context below.\n\n" + "Context:\n"
|
||||||
|
+ String.join("\n", context) + "\n\n" + "Question: " + question;
|
||||||
|
|
||||||
|
ChatMessage system = new ChatMessage("system", "You are a helpful assistant.");
|
||||||
|
ChatMessage user = new ChatMessage("user", fullPrompt);
|
||||||
|
|
||||||
|
ChatCompletionRequest request = ChatCompletionRequest.builder().model(openAiModel).messages(
|
||||||
|
List.of(system, user)).temperature(0.7).build();
|
||||||
|
|
||||||
|
return service.createChatCompletion(request).getChoices().get(0).getMessage().getContent().trim();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package com.example.chatbot.service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
public class EmbedderClient {
|
||||||
|
|
||||||
|
@Value("${embedder.url}")
|
||||||
|
private String embedderUrl;
|
||||||
|
|
||||||
|
private final RestTemplate restTemplate = new RestTemplate();
|
||||||
|
|
||||||
|
public List<Double> embedText(String text) {
|
||||||
|
Map<String, Object> payload = Map.of("texts", List.of(text));
|
||||||
|
Map<String, List<List<Double>>> response = restTemplate.postForObject(embedderUrl, payload, Map.class);
|
||||||
|
return response.get("embeddings").get(0);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
package com.example.chatbot.service;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.http.HttpEntity;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
public class WeaviateService {
|
||||||
|
|
||||||
|
@Value("${weaviate.url}")
|
||||||
|
private String weaviateUrl;
|
||||||
|
|
||||||
|
private final RestTemplate restTemplate = new RestTemplate();
|
||||||
|
|
||||||
|
public void ingest(String content, List<Double> vector) {
|
||||||
|
Map<String, Object> data = Map.of("content", content);
|
||||||
|
Map<String, Object> request = Map.of("class", "Document", "properties", data, "vector", vector);
|
||||||
|
|
||||||
|
restTemplate.postForObject(weaviateUrl + "/v1/objects", request, String.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> searchRelevant(String question, List<Double> qVector) throws JsonProcessingException {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
Map<String, Object> nearVector = Map.of("vector", qVector);
|
||||||
|
String gql = """
|
||||||
|
{
|
||||||
|
Get {
|
||||||
|
Document(nearVector: %s, limit: 3) {
|
||||||
|
content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""".formatted(new ObjectMapper().writeValueAsString(nearVector));
|
||||||
|
|
||||||
|
Map<String, String> query = Map.of("query", gql);
|
||||||
|
HttpEntity<Map<String, String>> entity = new HttpEntity<>(query, headers);
|
||||||
|
|
||||||
|
ResponseEntity<JsonNode> resp = restTemplate.postForEntity(weaviateUrl + "/v1/graphql", entity, JsonNode.class);
|
||||||
|
|
||||||
|
List<String> results = new ArrayList<>();
|
||||||
|
resp.getBody().at("/data/Get/Document").forEach(item -> results.add(item.get("content").asText()));
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
}
|
12
src/main/resources/application.yml
Normal file
12
src/main/resources/application.yml
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
server:
|
||||||
|
port: 8082
|
||||||
|
|
||||||
|
weaviate:
|
||||||
|
url: http://localhost:8083
|
||||||
|
|
||||||
|
embedding:
|
||||||
|
url: http://localhost:8000/embed
|
||||||
|
|
||||||
|
openai:
|
||||||
|
api-key: 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
|
||||||
|
model: gpt-3.5-turbo-instruct
|
22
src/main/weaviate/docker-compose.yml
Normal file
22
src/main/weaviate/docker-compose.yml
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
version: "3.8"
|
||||||
|
services:
|
||||||
|
weaviate:
|
||||||
|
image: semitechnologies/weaviate:latest
|
||||||
|
environment:
|
||||||
|
- QUERY_DEFAULTS=100
|
||||||
|
- AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true
|
||||||
|
- CLUSTER_HOSTNAME=node1
|
||||||
|
- CLUSTER_MODE=OFF
|
||||||
|
- PERSISTENCE_DATA_PATH=/var/lib/weaviate
|
||||||
|
- WEAVIATE_ENABLE_MODULES=transformers,text2vec-openai
|
||||||
|
- OPENAI_API_KEY=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
|
||||||
|
ports:
|
||||||
|
- "8083:8083"
|
||||||
|
volumes:
|
||||||
|
- weaviate_data:/var/lib/weaviate
|
||||||
|
networks:
|
||||||
|
- weaviate-net
|
||||||
|
volumes:
|
||||||
|
weaviate_data:
|
||||||
|
networks:
|
||||||
|
weaviate-net:
|
83
src/test/java/com/example/chatbot/ChatControllerTest.java
Normal file
83
src/test/java/com/example/chatbot/ChatControllerTest.java
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
package com.example.chatbot;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.springframework.http.HttpEntity;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.HttpMethod;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
import org.yaml.snakeyaml.Yaml;
|
||||||
|
|
||||||
|
public class ChatControllerTest {
|
||||||
|
|
||||||
|
private static final String API_URL = "http://localhost:8080/chat/askSimple"; // Change this if needed to match your server's address
|
||||||
|
|
||||||
|
private static String embeddingUrl;
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
public static void loadYaml() {
|
||||||
|
// This loads application.yml config. Can't use SpringBoot (magic) infraestruture because I can't setup the TestCase as @SpringBootTest
|
||||||
|
Yaml yaml = new Yaml();
|
||||||
|
try (InputStream in = ChatControllerTest.class.getClassLoader().getResourceAsStream("application.yml")) {
|
||||||
|
Map<String, Object> obj = yaml.load(in);
|
||||||
|
|
||||||
|
// Property embedding.url
|
||||||
|
Map<String, Object> embedding = (Map<String, Object>) obj.get("embedding");
|
||||||
|
embeddingUrl = (String) embedding.get("url");
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Failed to load YAML config", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEmbeddingEndpoint() {
|
||||||
|
String textToEmbedd = "ABC spent 5000 on gasoline in 2023";
|
||||||
|
|
||||||
|
RestTemplate restTemplate = new RestTemplate();
|
||||||
|
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
|
||||||
|
Map<String, Object> request = Map.of("texts", List.of(textToEmbedd));
|
||||||
|
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(request, headers);
|
||||||
|
|
||||||
|
Map response = restTemplate.postForObject(embeddingUrl, entity, Map.class);
|
||||||
|
List<Double> embedding = (List<Double>) ((List<?>) response.get("embeddings")).get(0);
|
||||||
|
|
||||||
|
System.out.println(embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testChatEndpoint() {
|
||||||
|
String question = "How much did company ABC spend on Gasoline in 2023?";
|
||||||
|
|
||||||
|
// Set the headers for the request
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.TEXT_PLAIN);
|
||||||
|
|
||||||
|
// Create the HTTP request entity with the question body and headers
|
||||||
|
HttpEntity<String> request = new HttpEntity<>(question, headers);
|
||||||
|
|
||||||
|
// Use RestTemplate to make the HTTP POST request
|
||||||
|
RestTemplate restTemplate = new RestTemplate();
|
||||||
|
ResponseEntity<String> response = restTemplate.exchange(API_URL, HttpMethod.POST, request, String.class);
|
||||||
|
|
||||||
|
// Assert that the response status code is OK
|
||||||
|
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
|
||||||
|
|
||||||
|
// Assert that the response body is not blank (contains a valid AI response)
|
||||||
|
assertThat(response.getBody()).isNotBlank();
|
||||||
|
System.out.println("AI Response: " + response.getBody());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue