feat: rag的api

This commit is contained in:
2026-01-17 23:54:47 +08:00
parent 64ba3b5767
commit e042e548f9
6 changed files with 1469 additions and 12 deletions

View File

@@ -0,0 +1,34 @@
package com.storm.dev.api;
import com.storm.dev.api.response.Response;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.multipart.MultipartFile;
import reactor.core.publisher.Flux;
import java.util.List;
/**
* @author: lyd
* @date: 2026/1/14 23:41
*/
public interface IRAGService {
/**
* 获取标签列表
*
* @return
*/
Response<List<String>> queryRagTagList();
/**
* 上传知识库
*
* @param ragTag
* @param files
* @return
*/
Response<String> uploadFile(String ragTag, List<MultipartFile> files);
ChatResponse generateStreamRag(String model, String ragTag, String message);
}

View File

@@ -0,0 +1,20 @@
package com.storm.dev.api.response;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class Response<T> implements Serializable {
private String code;
private String info;
private T data;
}

View File

@@ -21,6 +21,7 @@ import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringRunner;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
@@ -61,7 +62,7 @@ public class RAGApiTest {
@Test
public void chat() {
// 构建提问
String message = "李永德,哪年出生的";
String message = "拆装出库的操作流程是什么?";
// 构建推理模板
String SYSTEM_PROMPT = """
@@ -72,7 +73,7 @@ public class RAGApiTest {
{documents}
""";
// 读取向量库信息
SearchRequest request = SearchRequest.query(message).withTopK(5).withFilterExpression("knowledge == '德德'");
SearchRequest request = SearchRequest.query(message).withTopK(5).withFilterExpression("knowledge == '富士迈泰国项目软件方案'");
// 相似性搜索
List<Document> documents = pgVectorStore.similaritySearch(request);
String documentsCollectors = documents.stream().map(Document::getContent).collect(Collectors.joining());
@@ -84,7 +85,8 @@ public class RAGApiTest {
messages.add(ragMessage);
// 提问
ChatResponse chatResponse = ollamaChatClient.call(new Prompt(messages, OllamaOptions.create().withModel("deepseek-r1:7b")));
log.info("测试结果:{}", JSON.toJSONString(chatResponse));
// ChatResponse chatResponse = ollamaChatClient.call(new Prompt(messages, OllamaOptions.create().withModel("deepseek-r1:7b")));
Flux<ChatResponse> stream = ollamaChatClient.stream(new Prompt(messages, OllamaOptions.create().withModel("deepseek-r1:7b")));
log.info("测试结果:{}", JSON.toJSONString(stream));
}
}

View File

@@ -27,14 +27,14 @@
<!-- <groupId>org.springframework.ai</groupId>-->
<!-- <artifactId>spring-ai-openai-spring-boot-starter</artifactId>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>org.springframework.ai</groupId>-->
<!-- <artifactId>spring-ai-tika-document-reader</artifactId>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>org.springframework.ai</groupId>-->
<!-- <artifactId>spring-ai-pgvector-store</artifactId>-->
<!-- </dependency>-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-tika-document-reader</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-pgvector-store</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama</artifactId>

View File

@@ -0,0 +1,116 @@
package com.storm.dev.trigger.http;
import com.alibaba.fastjson.JSON;
import com.storm.dev.api.IRAGService;
import com.storm.dev.api.response.Response;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.redisson.api.RList;
import org.redisson.api.RedissonClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* @author: lyd
* @date: 2026/1/14 23:43
*/
@Slf4j
@RestController()
@CrossOrigin("*")
@RequestMapping("/api/v1/rag/")
public class RAGController implements IRAGService {
@Resource
private RedissonClient redissonClient;
@Resource
private OllamaChatClient ollamaChatClient;
@Resource
private TokenTextSplitter tokenTextSplitter;
@Resource
private SimpleVectorStore simpleVectorStore;
@Resource
private PgVectorStore pgVectorStore;
@Override
@RequestMapping(value = "query_rag_tag_list", method = RequestMethod.GET)
public Response<List<String>> queryRagTagList() {
RList<String> ragTag = redissonClient.getList("ragTag");
return Response.<List<String>>builder()
.code("0000")
.info("调用成功")
.data(ragTag)
.build();
}
@Override
@RequestMapping(value = "file/upload", method = RequestMethod.POST, headers = "content-type=multipart/form-data")
public Response<String> uploadFile(@RequestParam String ragTag, @RequestParam("file") List<MultipartFile> files) {
log.info("上传知识库开始 {}", ragTag);
for (MultipartFile file : files) {
// 上传
TikaDocumentReader reader = new TikaDocumentReader(file.getResource());
List<Document> documents = reader.get();
List<Document> documentSplitterList = tokenTextSplitter.apply(documents);
// 打标
documents.forEach(document -> document.getMetadata().put("knowledge", ragTag));
documentSplitterList.forEach(document -> document.getMetadata().put("knowledge", ragTag));
pgVectorStore.accept(documentSplitterList);
// 可以用MySQL存储
RList<String> elements = redissonClient.getList("ragTag");
if (!elements.contains(ragTag)){
elements.add(ragTag);
}
log.info("上传完成!");
}
return Response.<String>builder().code("0000").info("调用成功").build();
}
@Override
@RequestMapping(value = "generate_stream_rag", method = RequestMethod.GET)
public ChatResponse generateStreamRag(@RequestParam String model, @RequestParam String ragTag, @RequestParam String message) {
log.info("用户选择模型:{},知识库:{},提问问题:{}", model, ragTag, message);
// 构建推理模板
String SYSTEM_PROMPT = """
Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately.
If unsure, simply state that you don't know.
Another thing you need to note is that your reply must be in Chinese!
DOCUMENTS:
{documents}
""";
// 读取向量库信息
SearchRequest request = SearchRequest.query(message).withTopK(5).withFilterExpression("knowledge == '" + ragTag + "'");
// 相似性搜索
List<Document> documents = pgVectorStore.similaritySearch(request);
String documentsCollectors = documents.stream().map(Document::getContent).collect(Collectors.joining());
// 推理RAG
Message ragMessage = new SystemPromptTemplate(SYSTEM_PROMPT).createMessage(Map.of("documents", documentsCollectors));
ArrayList<Message> messages = new ArrayList<>();
messages.add(new UserMessage(message));
messages.add(ragMessage);
// 提问
// Flux<ChatResponse> chatResponse = ollamaChatClient.stream(new Prompt(messages, OllamaOptions.create().withModel(model)));
ChatResponse call = ollamaChatClient.call(new Prompt(messages, OllamaOptions.create().withModel(model)));
log.info("测试结果:{}", call);
return call;
}
}

1285
docs/nginx/html/rag-ai.html Normal file

File diff suppressed because it is too large Load Diff