feat: rag的api
This commit is contained in:
34
ai-rag-api/src/main/java/com/storm/dev/api/IRAGService.java
Normal file
34
ai-rag-api/src/main/java/com/storm/dev/api/IRAGService.java
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package com.storm.dev.api;
|
||||||
|
|
||||||
|
import com.storm.dev.api.response.Response;
|
||||||
|
import org.springframework.ai.chat.ChatResponse;
|
||||||
|
import org.springframework.web.bind.annotation.RequestParam;
|
||||||
|
import org.springframework.web.multipart.MultipartFile;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author: lyd
|
||||||
|
* @date: 2026/1/14 23:41
|
||||||
|
*/
|
||||||
|
public interface IRAGService {
|
||||||
|
/**
|
||||||
|
* 获取标签列表
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Response<List<String>> queryRagTagList();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 上传知识库
|
||||||
|
*
|
||||||
|
* @param ragTag
|
||||||
|
* @param files
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Response<String> uploadFile(String ragTag, List<MultipartFile> files);
|
||||||
|
|
||||||
|
ChatResponse generateStreamRag(String model, String ragTag, String message);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package com.storm.dev.api.response;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class Response<T> implements Serializable {
|
||||||
|
|
||||||
|
private String code;
|
||||||
|
private String info;
|
||||||
|
private T data;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -21,6 +21,7 @@ import org.springframework.ai.vectorstore.SearchRequest;
|
|||||||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||||
import org.springframework.boot.test.context.SpringBootTest;
|
import org.springframework.boot.test.context.SpringBootTest;
|
||||||
import org.springframework.test.context.junit4.SpringRunner;
|
import org.springframework.test.context.junit4.SpringRunner;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -61,7 +62,7 @@ public class RAGApiTest {
|
|||||||
@Test
|
@Test
|
||||||
public void chat() {
|
public void chat() {
|
||||||
// 构建提问
|
// 构建提问
|
||||||
String message = "李永德,哪年出生的";
|
String message = "拆装出库的操作流程是什么?";
|
||||||
|
|
||||||
// 构建推理模板
|
// 构建推理模板
|
||||||
String SYSTEM_PROMPT = """
|
String SYSTEM_PROMPT = """
|
||||||
@@ -72,7 +73,7 @@ public class RAGApiTest {
|
|||||||
{documents}
|
{documents}
|
||||||
""";
|
""";
|
||||||
// 读取向量库信息
|
// 读取向量库信息
|
||||||
SearchRequest request = SearchRequest.query(message).withTopK(5).withFilterExpression("knowledge == '德德'");
|
SearchRequest request = SearchRequest.query(message).withTopK(5).withFilterExpression("knowledge == '富士迈泰国项目软件方案'");
|
||||||
// 相似性搜索
|
// 相似性搜索
|
||||||
List<Document> documents = pgVectorStore.similaritySearch(request);
|
List<Document> documents = pgVectorStore.similaritySearch(request);
|
||||||
String documentsCollectors = documents.stream().map(Document::getContent).collect(Collectors.joining());
|
String documentsCollectors = documents.stream().map(Document::getContent).collect(Collectors.joining());
|
||||||
@@ -84,7 +85,8 @@ public class RAGApiTest {
|
|||||||
messages.add(ragMessage);
|
messages.add(ragMessage);
|
||||||
|
|
||||||
// 提问
|
// 提问
|
||||||
ChatResponse chatResponse = ollamaChatClient.call(new Prompt(messages, OllamaOptions.create().withModel("deepseek-r1:7b")));
|
// ChatResponse chatResponse = ollamaChatClient.call(new Prompt(messages, OllamaOptions.create().withModel("deepseek-r1:7b")));
|
||||||
log.info("测试结果:{}", JSON.toJSONString(chatResponse));
|
Flux<ChatResponse> stream = ollamaChatClient.stream(new Prompt(messages, OllamaOptions.create().withModel("deepseek-r1:7b")));
|
||||||
|
log.info("测试结果:{}", JSON.toJSONString(stream));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,14 +27,14 @@
|
|||||||
<!-- <groupId>org.springframework.ai</groupId>-->
|
<!-- <groupId>org.springframework.ai</groupId>-->
|
||||||
<!-- <artifactId>spring-ai-openai-spring-boot-starter</artifactId>-->
|
<!-- <artifactId>spring-ai-openai-spring-boot-starter</artifactId>-->
|
||||||
<!-- </dependency>-->
|
<!-- </dependency>-->
|
||||||
<!-- <dependency>-->
|
<dependency>
|
||||||
<!-- <groupId>org.springframework.ai</groupId>-->
|
<groupId>org.springframework.ai</groupId>
|
||||||
<!-- <artifactId>spring-ai-tika-document-reader</artifactId>-->
|
<artifactId>spring-ai-tika-document-reader</artifactId>
|
||||||
<!-- </dependency>-->
|
</dependency>
|
||||||
<!-- <dependency>-->
|
<dependency>
|
||||||
<!-- <groupId>org.springframework.ai</groupId>-->
|
<groupId>org.springframework.ai</groupId>
|
||||||
<!-- <artifactId>spring-ai-pgvector-store</artifactId>-->
|
<artifactId>spring-ai-pgvector-store</artifactId>
|
||||||
<!-- </dependency>-->
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework.ai</groupId>
|
<groupId>org.springframework.ai</groupId>
|
||||||
<artifactId>spring-ai-ollama</artifactId>
|
<artifactId>spring-ai-ollama</artifactId>
|
||||||
|
|||||||
@@ -0,0 +1,116 @@
|
|||||||
|
package com.storm.dev.trigger.http;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSON;
|
||||||
|
import com.storm.dev.api.IRAGService;
|
||||||
|
import com.storm.dev.api.response.Response;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.redisson.api.RList;
|
||||||
|
import org.redisson.api.RedissonClient;
|
||||||
|
import org.springframework.ai.chat.ChatResponse;
|
||||||
|
import org.springframework.ai.chat.messages.Message;
|
||||||
|
import org.springframework.ai.chat.messages.UserMessage;
|
||||||
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
|
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||||
|
import org.springframework.ai.document.Document;
|
||||||
|
import org.springframework.ai.ollama.OllamaChatClient;
|
||||||
|
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||||
|
import org.springframework.ai.reader.tika.TikaDocumentReader;
|
||||||
|
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||||||
|
import org.springframework.ai.vectorstore.PgVectorStore;
|
||||||
|
import org.springframework.ai.vectorstore.SearchRequest;
|
||||||
|
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||||
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
import org.springframework.web.multipart.MultipartFile;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author: lyd
|
||||||
|
* @date: 2026/1/14 23:43
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@RestController()
|
||||||
|
@CrossOrigin("*")
|
||||||
|
@RequestMapping("/api/v1/rag/")
|
||||||
|
public class RAGController implements IRAGService {
|
||||||
|
@Resource
|
||||||
|
private RedissonClient redissonClient;
|
||||||
|
@Resource
|
||||||
|
private OllamaChatClient ollamaChatClient;
|
||||||
|
@Resource
|
||||||
|
private TokenTextSplitter tokenTextSplitter;
|
||||||
|
@Resource
|
||||||
|
private SimpleVectorStore simpleVectorStore;
|
||||||
|
@Resource
|
||||||
|
private PgVectorStore pgVectorStore;
|
||||||
|
@Override
|
||||||
|
@RequestMapping(value = "query_rag_tag_list", method = RequestMethod.GET)
|
||||||
|
public Response<List<String>> queryRagTagList() {
|
||||||
|
RList<String> ragTag = redissonClient.getList("ragTag");
|
||||||
|
return Response.<List<String>>builder()
|
||||||
|
.code("0000")
|
||||||
|
.info("调用成功")
|
||||||
|
.data(ragTag)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@RequestMapping(value = "file/upload", method = RequestMethod.POST, headers = "content-type=multipart/form-data")
|
||||||
|
public Response<String> uploadFile(@RequestParam String ragTag, @RequestParam("file") List<MultipartFile> files) {
|
||||||
|
log.info("上传知识库开始 {}", ragTag);
|
||||||
|
for (MultipartFile file : files) {
|
||||||
|
// 上传
|
||||||
|
TikaDocumentReader reader = new TikaDocumentReader(file.getResource());
|
||||||
|
List<Document> documents = reader.get();
|
||||||
|
List<Document> documentSplitterList = tokenTextSplitter.apply(documents);
|
||||||
|
// 打标
|
||||||
|
documents.forEach(document -> document.getMetadata().put("knowledge", ragTag));
|
||||||
|
documentSplitterList.forEach(document -> document.getMetadata().put("knowledge", ragTag));
|
||||||
|
|
||||||
|
pgVectorStore.accept(documentSplitterList);
|
||||||
|
// 可以用MySQL存储
|
||||||
|
RList<String> elements = redissonClient.getList("ragTag");
|
||||||
|
if (!elements.contains(ragTag)){
|
||||||
|
elements.add(ragTag);
|
||||||
|
}
|
||||||
|
log.info("上传完成!");
|
||||||
|
}
|
||||||
|
return Response.<String>builder().code("0000").info("调用成功").build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@RequestMapping(value = "generate_stream_rag", method = RequestMethod.GET)
|
||||||
|
public ChatResponse generateStreamRag(@RequestParam String model, @RequestParam String ragTag, @RequestParam String message) {
|
||||||
|
log.info("用户选择模型:{},知识库:{},提问问题:{}", model, ragTag, message);
|
||||||
|
// 构建推理模板
|
||||||
|
String SYSTEM_PROMPT = """
|
||||||
|
Use the information from the DOCUMENTS section to provide accurate answers but act as if you knew this information innately.
|
||||||
|
If unsure, simply state that you don't know.
|
||||||
|
Another thing you need to note is that your reply must be in Chinese!
|
||||||
|
DOCUMENTS:
|
||||||
|
{documents}
|
||||||
|
""";
|
||||||
|
// 读取向量库信息
|
||||||
|
SearchRequest request = SearchRequest.query(message).withTopK(5).withFilterExpression("knowledge == '" + ragTag + "'");
|
||||||
|
// 相似性搜索
|
||||||
|
List<Document> documents = pgVectorStore.similaritySearch(request);
|
||||||
|
String documentsCollectors = documents.stream().map(Document::getContent).collect(Collectors.joining());
|
||||||
|
|
||||||
|
// 推理:RAG
|
||||||
|
Message ragMessage = new SystemPromptTemplate(SYSTEM_PROMPT).createMessage(Map.of("documents", documentsCollectors));
|
||||||
|
ArrayList<Message> messages = new ArrayList<>();
|
||||||
|
messages.add(new UserMessage(message));
|
||||||
|
messages.add(ragMessage);
|
||||||
|
|
||||||
|
// 提问
|
||||||
|
// Flux<ChatResponse> chatResponse = ollamaChatClient.stream(new Prompt(messages, OllamaOptions.create().withModel(model)));
|
||||||
|
ChatResponse call = ollamaChatClient.call(new Prompt(messages, OllamaOptions.create().withModel(model)));
|
||||||
|
log.info("测试结果:{}", call);
|
||||||
|
return call;
|
||||||
|
}
|
||||||
|
}
|
||||||
1285
docs/nginx/html/rag-ai.html
Normal file
1285
docs/nginx/html/rag-ai.html
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user