diff --git a/ai-rag-api/src/main/java/com/storm/dev/api/IAiService.java b/ai-rag-api/src/main/java/com/storm/dev/api/IAiService.java index 021606b..dec691c 100644 --- a/ai-rag-api/src/main/java/com/storm/dev/api/IAiService.java +++ b/ai-rag-api/src/main/java/com/storm/dev/api/IAiService.java @@ -23,4 +23,6 @@ public interface IAiService { * @return */ Flux generateStream(String model, String message); + + Flux generateStreamRag(String model, String ragTag, String message); } diff --git a/ai-rag-app/pom.xml b/ai-rag-app/pom.xml index 7b22c7a..273eda9 100644 --- a/ai-rag-app/pom.xml +++ b/ai-rag-app/pom.xml @@ -24,10 +24,10 @@ test - - - - + + org.springframework.ai + spring-ai-openai-spring-boot-starter + org.springframework.ai diff --git a/ai-rag-app/src/main/java/com/storm/dev/config/OllamaConfig.java b/ai-rag-app/src/main/java/com/storm/dev/config/OllamaConfig.java index db581be..acdafe0 100644 --- a/ai-rag-app/src/main/java/com/storm/dev/config/OllamaConfig.java +++ b/ai-rag-app/src/main/java/com/storm/dev/config/OllamaConfig.java @@ -4,6 +4,8 @@ import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaEmbeddingClient; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.openai.OpenAiEmbeddingClient; +import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.transformer.splitter.TokenTextSplitter; import org.springframework.ai.vectorstore.PgVectorStore; import org.springframework.ai.vectorstore.SimpleVectorStore; @@ -25,6 +27,11 @@ public class OllamaConfig { return new OllamaApi(baseUrl); } + @Bean + public OpenAiApi openAiApi(@Value("${spring.ai.openai.base-url}") String baseUrl, @Value("${spring.ai.openai.api-key}") String apikey) { + return new OpenAiApi(baseUrl, apikey); + } + @Bean public OllamaChatClient ollamaChatClient(OllamaApi ollamaApi) { return new OllamaChatClient(ollamaApi); @@ -36,17 +43,27 @@ public class OllamaConfig { } @Bean - public SimpleVectorStore simpleVectorStore(OllamaApi ollamaApi) { - OllamaEmbeddingClient embeddingClient = new OllamaEmbeddingClient(ollamaApi); - embeddingClient.withDefaultOptions(OllamaOptions.create().withModel("nomic-embed-text")); - return new SimpleVectorStore(embeddingClient); + public SimpleVectorStore vectorStore(@Value("${spring.ai.rag.embed}") String model, OllamaApi ollamaApi, OpenAiApi openAiApi) { + if ("nomic-embed-text".equalsIgnoreCase(model)) { + OllamaEmbeddingClient embeddingClient = new OllamaEmbeddingClient(ollamaApi); + embeddingClient.withDefaultOptions(OllamaOptions.create().withModel("nomic-embed-text")); + return new SimpleVectorStore(embeddingClient); + } else { + OpenAiEmbeddingClient embeddingClient = new OpenAiEmbeddingClient(openAiApi); + return new SimpleVectorStore(embeddingClient); + } } @Bean - public PgVectorStore pgVectorStore(OllamaApi ollamaApi, JdbcTemplate jdbcTemplate) { - OllamaEmbeddingClient embeddingClient = new OllamaEmbeddingClient(ollamaApi); - embeddingClient.withDefaultOptions(OllamaOptions.create().withModel("nomic-embed-text")); - return new PgVectorStore(jdbcTemplate, embeddingClient); + public PgVectorStore pgVectorStore(@Value("${spring.ai.rag.embed}") String model, OllamaApi ollamaApi, OpenAiApi openAiApi, JdbcTemplate jdbcTemplate) { + if ("nomic-embed-text".equalsIgnoreCase(model)) { + OllamaEmbeddingClient embeddingClient = new OllamaEmbeddingClient(ollamaApi); + embeddingClient.withDefaultOptions(OllamaOptions.create().withModel("nomic-embed-text")); + return new PgVectorStore(jdbcTemplate, embeddingClient); + } else { + OpenAiEmbeddingClient embeddingClient = new OpenAiEmbeddingClient(openAiApi); + return new PgVectorStore(jdbcTemplate, embeddingClient); + } } diff --git a/ai-rag-app/src/main/resources/application-dev.yml b/ai-rag-app/src/main/resources/application-dev.yml index f674300..ea2418d 100644 --- a/ai-rag-app/src/main/resources/application-dev.yml +++ b/ai-rag-app/src/main/resources/application-dev.yml @@ -15,6 +15,12 @@ spring: options: num-batch: 512 model: nomic-embed-text + openai: + base-url: xxx + api-key: xxx + embedding-model: text-embedding-ada-002 + rag: + embed: nomic-embed-text #nomic-embed-text、text-embedding-ada-002 # Redis redis: sdk: diff --git a/ai-rag-trigger/pom.xml b/ai-rag-trigger/pom.xml index f7a9809..3abc8f7 100644 --- a/ai-rag-trigger/pom.xml +++ b/ai-rag-trigger/pom.xml @@ -23,10 +23,10 @@ spring-boot-starter-web - - - - + + org.springframework.ai + spring-ai-openai-spring-boot-starter + org.springframework.ai spring-ai-tika-document-reader diff --git a/ai-rag-trigger/src/main/java/com/storm/dev/trigger/http/OllamaController.java b/ai-rag-trigger/src/main/java/com/storm/dev/trigger/http/OllamaController.java index 25e13be..5f54be0 100644 --- a/ai-rag-trigger/src/main/java/com/storm/dev/trigger/http/OllamaController.java +++ b/ai-rag-trigger/src/main/java/com/storm/dev/trigger/http/OllamaController.java @@ -2,17 +2,30 @@ package com.storm.dev.trigger.http; import com.storm.dev.api.IAiService; import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.document.Document; import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.vectorstore.PgVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + /** * @author: lyd * @date: 2025/6/7 22:18 */ +@Slf4j @RestController() @CrossOrigin("*") @RequestMapping("/api/v1/ollama/") @@ -20,6 +33,8 @@ public class OllamaController implements IAiService { @Resource private OllamaChatClient chatClient; + @Resource + private PgVectorStore pgVectorStore; /** * http://localhost:8090/api/v1/ollama/generate?model=deepseek-r1:7b&message=1+1 @@ -38,4 +53,33 @@ public class OllamaController implements IAiService { public Flux generateStream(@RequestParam String model, @RequestParam String message) { return chatClient.stream(new Prompt(message, OllamaOptions.create().withModel(model))); } + + @Override + @RequestMapping(value = "generate_stream_rag", method = RequestMethod.GET) + public Flux generateStreamRag(@RequestParam String model, @RequestParam String ragTag, @RequestParam String message) { + log.info("用户选择模型:{},知识库:{},提问问题:{}", model, ragTag, message); + // 构建推理模板 + String SYSTEM_PROMPT = """ + Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately. + If unsure, simply state that you don't know. + Another thing you need to note is that your reply must be in Chinese! + DOCUMENTS: + {documents} + """; + // 读取向量库信息 + SearchRequest request = SearchRequest.query(message).withTopK(5).withFilterExpression("knowledge == '" + ragTag + "'"); + // 相似性搜索 + List documents = pgVectorStore.similaritySearch(request); + String documentsCollectors = documents.stream().map(Document::getContent).collect(Collectors.joining()); + + // 推理:RAG + Message ragMessage = new SystemPromptTemplate(SYSTEM_PROMPT).createMessage(Map.of("documents", documentsCollectors)); + ArrayList messages = new ArrayList<>(); + messages.add(new UserMessage(message)); + messages.add(ragMessage); + + // 提问 + Flux chatResponse = chatClient.stream(new Prompt(messages, OllamaOptions.create().withModel(model))); + return chatResponse; + } } diff --git a/ai-rag-trigger/src/main/java/com/storm/dev/trigger/http/OpenAiController.java b/ai-rag-trigger/src/main/java/com/storm/dev/trigger/http/OpenAiController.java new file mode 100644 index 0000000..4753ab1 --- /dev/null +++ b/ai-rag-trigger/src/main/java/com/storm/dev/trigger/http/OpenAiController.java @@ -0,0 +1,86 @@ +package com.storm.dev.trigger.http; + +import com.storm.dev.api.IAiService; +import jakarta.annotation.Resource; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.vectorstore.PgVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.web.bind.annotation.CrossOrigin; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * @author: lyd + * @date: 2026/1/18 17:08 + */ +@RestController() +@CrossOrigin("*") +@RequestMapping("/api/v1/openai/") +public class OpenAiController implements IAiService { + + @Resource + private OpenAiChatClient chatClient; + @Resource + private PgVectorStore pgVectorStore; + + @Override + public ChatResponse generate(String model, String message) { + return chatClient.call(new Prompt(message, OpenAiChatOptions.builder().withModel(model).build())); + } + + @Override + public Flux generateStream(String model, String message) { + return chatClient.stream(new Prompt( + message, + OpenAiChatOptions.builder() + .withModel(model) + .build() + )); + } + + @Override + public Flux generateStreamRag(String model, String ragTag, String message) { + String SYSTEM_PROMPT = """ + Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately. + If unsure, simply state that you don't know. + Another thing you need to note is that your reply must be in Chinese! + DOCUMENTS: + {documents} + """; + + // 指定文档搜索 + SearchRequest request = SearchRequest.query(message) + .withTopK(5) + .withFilterExpression("knowledge == '" + ragTag + "'"); + + List documents = pgVectorStore.similaritySearch(request); + String documentCollectors = documents.stream().map(Document::getContent).collect(Collectors.joining()); + Message ragMessage = new SystemPromptTemplate(SYSTEM_PROMPT).createMessage(Map.of("documents", documentCollectors)); + + List messages = new ArrayList<>(); + messages.add(new UserMessage(message)); + messages.add(ragMessage); + + return chatClient.stream(new Prompt( + messages, + OpenAiChatOptions.builder() + .withModel(model) + .build() + )); + } + + +}