瀏覽代碼

支持上传多份文件

lamphua 1 周之前
父節點
當前提交
f1fd168fc3

+ 2
- 1
llm-back/ruoyi-llm/src/main/java/com/ruoyi/web/llm/controller/CmcChatController.java 查看文件

@@ -74,7 +74,8 @@ public class CmcChatController extends BaseController
74 74
     @PostMapping
75 75
     public AjaxResult add(@RequestBody CmcChat cmcChat)
76 76
     {
77
-        cmcChat.setChatId(new SnowFlake().generateId());
77
+        if (cmcChat.getChatId() == null)
78
+            cmcChat.setChatId(new SnowFlake().generateId());
78 79
         return toAjax(cmcChatService.insertCmcChat(cmcChat));
79 80
     }
80 81
 

+ 35
- 1
llm-back/ruoyi-llm/src/main/java/com/ruoyi/web/llm/controller/CmcDocumentController.java 查看文件

@@ -1,8 +1,13 @@
1 1
 package com.ruoyi.web.llm.controller;
2 2
 
3
+import java.io.File;
4
+import java.io.IOException;
5
+import java.util.ArrayList;
3 6
 import java.util.List;
4 7
 import javax.servlet.http.HttpServletResponse;
5 8
 
9
+import com.alibaba.fastjson2.JSONObject;
10
+import com.ruoyi.common.config.RuoYiConfig;
6 11
 import com.ruoyi.common.utils.SnowFlake;
7 12
 import org.springframework.beans.factory.annotation.Autowired;
8 13
 import org.springframework.web.bind.annotation.GetMapping;
@@ -21,6 +26,7 @@ import com.ruoyi.llm.domain.CmcDocument;
21 26
 import com.ruoyi.llm.service.ICmcDocumentService;
22 27
 import com.ruoyi.common.utils.poi.ExcelUtil;
23 28
 import com.ruoyi.common.core.page.TableDataInfo;
29
+import org.springframework.web.multipart.MultipartFile;
24 30
 
25 31
 /**
26 32
  * cmc聊天附件Controller
@@ -67,6 +73,35 @@ public class CmcDocumentController extends BaseController
67 73
         return success(cmcDocumentService.selectCmcDocumentByDocumentId(documentId));
68 74
     }
69 75
 
76
+    /**
77
+     * 上传外部文件
78
+     * @return
79
+     */
80
+    @GetMapping("/upload")
81
+    public JSONObject upload(MultipartFile[] fileList) throws IOException {
82
+        File profilePath = new File( RuoYiConfig.getProfile() + "/upload/rag/document" );
83
+        if (!profilePath.exists())
84
+            profilePath.mkdirs();
85
+        String chatId = new SnowFlake().generateId();
86
+        JSONObject jsonObject = new JSONObject();
87
+        jsonObject.put("chatId", chatId);
88
+        List<String> filenames = new ArrayList<>();
89
+        for (MultipartFile file : fileList) {
90
+            File transferFile = new File(profilePath + File.separator + file.getOriginalFilename());
91
+            if (!transferFile.exists()) {
92
+                file.transferTo(transferFile);
93
+            }
94
+            CmcDocument cmcDocument = new CmcDocument();
95
+            cmcDocument.setDocumentId(new SnowFlake().generateId());
96
+            cmcDocument.setChatId(chatId);
97
+            cmcDocument.setPath(file.getOriginalFilename());
98
+            cmcDocumentService.insertCmcDocument(cmcDocument);
99
+            filenames.add(file.getOriginalFilename());
100
+        }
101
+        jsonObject.put("filenames", filenames);
102
+        return jsonObject;
103
+    }
104
+
70 105
     /**
71 106
      * 新增cmc聊天附件
72 107
      */
@@ -74,7 +109,6 @@ public class CmcDocumentController extends BaseController
74 109
     @PostMapping
75 110
     public AjaxResult add(@RequestBody CmcDocument cmcDocument)
76 111
     {
77
-        cmcDocument.setDocumentId(new SnowFlake().generateId());
78 112
         return toAjax(cmcDocumentService.insertCmcDocument(cmcDocument));
79 113
     }
80 114
 

+ 2
- 30
llm-back/ruoyi-llm/src/main/java/com/ruoyi/web/llm/controller/SessionController.java 查看文件

@@ -1,9 +1,6 @@
1 1
 package com.ruoyi.web.llm.controller;
2 2
 
3
-import com.ruoyi.common.config.RuoYiConfig;
4 3
 import com.ruoyi.common.core.controller.BaseController;
5
-import com.ruoyi.llm.domain.CmcDocument;
6
-import com.ruoyi.llm.service.ICmcDocumentService;
7 4
 import com.ruoyi.web.llm.service.ILangChainMilvusService;
8 5
 import dev.langchain4j.model.embedding.EmbeddingModel;
9 6
 import dev.langchain4j.model.embedding.onnx.bgesmallzhv15.BgeSmallZhV15EmbeddingModel;
@@ -12,10 +9,8 @@ import org.springframework.beans.factory.annotation.Autowired;
12 9
 import org.springframework.web.bind.annotation.GetMapping;
13 10
 import org.springframework.web.bind.annotation.RequestMapping;
14 11
 import org.springframework.web.bind.annotation.RestController;
15
-import org.springframework.web.multipart.MultipartFile;
16 12
 import reactor.core.publisher.Flux;
17 13
 
18
-import java.io.File;
19 14
 import java.io.IOException;
20 15
 
21 16
 /**
@@ -31,9 +26,6 @@ public class SessionController extends BaseController
31 26
     @Autowired
32 27
     private ILangChainMilvusService langChainMilvusService;
33 28
 
34
-    @Autowired
35
-    private ICmcDocumentService cmcDocumentService;
36
-
37 29
     private static final EmbeddingModel embeddingModel = new BgeSmallZhV15EmbeddingModel();
38 30
 
39 31
     /**
@@ -44,33 +36,13 @@ public class SessionController extends BaseController
44 36
         return langChainMilvusService.generateAnswer(topicId, question, "http://192.168.28.188:8000/v1/chat/completions");
45 37
     }
46 38
 
47
-    /**
48
-     * 上传外部文件
49
-     * @return
50
-     */
51
-    @GetMapping("/upload")
52
-    public String upload(MultipartFile file) throws IOException {
53
-        File profilePath = new File( RuoYiConfig.getProfile() + "/upload/rag/document" );
54
-        if (!profilePath.exists())
55
-            profilePath.mkdirs();
56
-        File transferFile = new File( profilePath + File.separator + file.getOriginalFilename());
57
-        if (!transferFile.exists()) {
58
-            file.transferTo(transferFile);
59
-        }
60
-        return file.getOriginalFilename();
61
-    }
62
-
63 39
     /**
64 40
      * 调用LLM+RAG(外部文件)生成回答
65 41
      */
66 42
     @GetMapping("/answerWithDocument")
67
-    public Flux<AssistantMessage> answerWithDocument(String filename, String chatId, String topicId, String question) throws IOException
43
+    public Flux<AssistantMessage> answerWithDocument(String chatId, String question) throws IOException
68 44
     {
69
-        CmcDocument cmcDocument = new CmcDocument();
70
-        cmcDocument.setChatId(chatId);
71
-        cmcDocument.setPath(filename);
72
-        cmcDocumentService.insertCmcDocument(cmcDocument);
73
-        return langChainMilvusService.generateAnswerWithDocument(embeddingModel, filename, topicId, question, "http://192.168.28.188:8000/v1/chat/completions");
45
+        return langChainMilvusService.generateAnswerWithDocument(embeddingModel, chatId, question, "http://192.168.28.188:8000/v1/chat/completions");
74 46
     }
75 47
 
76 48
 }

+ 1
- 1
llm-back/ruoyi-llm/src/main/java/com/ruoyi/web/llm/service/ILangChainMilvusService.java 查看文件

@@ -47,6 +47,6 @@ public interface ILangChainMilvusService {
47 47
      * 调用LLM+RAG(外部文件)生成回答
48 48
      * @return
49 49
      */
50
-    public Flux<AssistantMessage> generateAnswerWithDocument(EmbeddingModel embeddingModel, String filename, String topicId, String question, String llmServiceUrl) throws IOException;
50
+    public Flux<AssistantMessage> generateAnswerWithDocument(EmbeddingModel embeddingModel, String chatId, String question, String llmServiceUrl) throws IOException;
51 51
 
52 52
 }

+ 37
- 30
llm-back/ruoyi-llm/src/main/java/com/ruoyi/web/llm/service/impl/LangChainMilvusServiceImpl.java 查看文件

@@ -3,7 +3,9 @@ package com.ruoyi.web.llm.service.impl;
3 3
 import com.alibaba.fastjson2.JSONObject;
4 4
 import com.ruoyi.common.config.RuoYiConfig;
5 5
 import com.ruoyi.llm.domain.CmcChat;
6
+import com.ruoyi.llm.domain.CmcDocument;
6 7
 import com.ruoyi.llm.service.ICmcChatService;
8
+import com.ruoyi.llm.service.ICmcDocumentService;
7 9
 import com.ruoyi.web.llm.service.ILangChainMilvusService;
8 10
 import dev.langchain4j.data.document.Document;
9 11
 import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
@@ -48,6 +50,9 @@ public class LangChainMilvusServiceImpl implements ILangChainMilvusService
48 50
     @Autowired
49 51
     private ICmcChatService cmcChatService;
50 52
 
53
+    @Autowired
54
+    private ICmcDocumentService cmcDocumentService;
55
+
51 56
     /**
52 57
      * 导入知识库文件
53 58
      */
@@ -146,12 +151,14 @@ public class LangChainMilvusServiceImpl implements ILangChainMilvusService
146 151
                 .apiKey("1")
147 152
                 .build();
148 153
 
149
-        CmcChat cmcChat = new CmcChat();
150
-        cmcChat.setTopicId(topicId);
151
-        List<CmcChat> cmcChatList = cmcChatService.selectCmcChatList(cmcChat);
152
-        for (CmcChat chat : cmcChatList) {
153
-            chatSession.addMessage(ChatMessage.ofUser(chat.getInput()));
154
-            chatSession.addMessage(ChatMessage.ofAssistant(chat.getOutput()));
154
+        if (topicId != null) {
155
+            CmcChat cmcChat = new CmcChat();
156
+            cmcChat.setTopicId(topicId);
157
+            List<CmcChat> cmcChatList = cmcChatService.selectCmcChatList(cmcChat);
158
+            for (CmcChat chat : cmcChatList) {
159
+                chatSession.addMessage(ChatMessage.ofUser(chat.getInput()));
160
+                chatSession.addMessage(ChatMessage.ofAssistant(chat.getOutput()));
161
+            }
155 162
         }
156 163
         chatSession.addMessage(ChatMessage.ofUser(prompt));
157 164
 
@@ -184,26 +191,30 @@ public class LangChainMilvusServiceImpl implements ILangChainMilvusService
184 191
      * 调用LLM生成回答
185 192
      */
186 193
     @Override
187
-    public Flux<AssistantMessage> generateAnswerWithDocument(EmbeddingModel embeddingModel, String filename, String topicId, String question, String llmServiceUrl) throws IOException {
188
-
189
-        File profilePath = new File( RuoYiConfig.getProfile() + "/upload/rag/document/" + filename);
190
-        List<TextSegment> segments = splitDocument(filename, profilePath);
191
-        List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
192
-        InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
193
-        embeddingStore.addAll(embeddings, segments);
194
-        Embedding queryEmbedding = embeddingModel.embed(question).content();
195
-        EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
196
-                .queryEmbedding(queryEmbedding)
197
-                .maxResults(1)
198
-                .build();
199
-        String contexts = embeddingStore.search(embeddingSearchRequest).matches().get(0).embedded().text();
200
-        String sb = "问题: " + question + "\n\n" +
201
-                "根据以下上下文回答问题:\n\n" +
202
-                "文件" + ": " +
203
-                filename + "\n\n" +
204
-                "上下文" + ": " +
205
-                contexts + "\n\n";
206
-        return generateAnswer(topicId, sb, llmServiceUrl);
194
+    public Flux<AssistantMessage> generateAnswerWithDocument(EmbeddingModel embeddingModel, String chatId, String question, String llmServiceUrl) throws IOException {
195
+        String topicId = cmcChatService.selectCmcChatByChatId(chatId).getTopicId();
196
+        CmcDocument cmcDocument = new CmcDocument();
197
+        cmcDocument.setChatId(chatId);
198
+        List<CmcDocument> documentList = cmcDocumentService.selectCmcDocumentList(cmcDocument);
199
+        StringBuilder sb = new StringBuilder("问题: " + question + "\n\n").append("根据以下上下文回答问题:\n\n");
200
+        for (CmcDocument document : documentList) {
201
+            File profilePath = new File(RuoYiConfig.getProfile() + "/upload/rag/document/" + document.getPath());
202
+            List<TextSegment> segments = splitDocument(document.getPath(), profilePath);
203
+            List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
204
+            InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
205
+            embeddingStore.addAll(embeddings, segments);
206
+            Embedding queryEmbedding = embeddingModel.embed(question).content();
207
+            EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
208
+                    .queryEmbedding(queryEmbedding)
209
+                    .maxResults(1)
210
+                    .build();
211
+            String contexts = embeddingStore.search(embeddingSearchRequest).matches().get(0).embedded().text();
212
+            sb.append("文件").append(": ")
213
+                    .append(document.getPath()).append("\n\n")
214
+                    .append("上下文").append(": ")
215
+                    .append(contexts).append("\n\n");
216
+        }
217
+        return generateAnswer(topicId, sb.toString(), llmServiceUrl);
207 218
     }
208 219
 
209 220
     /**
@@ -270,8 +281,4 @@ public class LangChainMilvusServiceImpl implements ILangChainMilvusService
270 281
         DocumentByParagraphSplitter splitter = new DocumentByParagraphSplitter(1000,200);
271 282
         return splitter.split(document);
272 283
     }
273
-    interface Assistant {
274
-
275
-        String chat(String message);
276
-    }
277 284
 }

+ 117
- 0
llm-back/vllm_server.py 查看文件

@@ -0,0 +1,117 @@
1
+from vllm import LLM, SamplingParams
2
+from fastapi import FastAPI, Request
3
+from fastapi.responses import StreamingResponse
4
+import uvicorn
5
+import time
6
+import json
7
+
8
+def create_vllm_server(
9
+    model: str,
10
+    served_model_name: str,
11
+    host: str,
12
+    port: int,
13
+    tensor_parallel_size: int,
14
+    top_p: float,
15
+    temperature: float,
16
+    max_tokens: int,
17
+    gpu_memory_utilization: float,
18
+    dtype: str,
19
+) -> FastAPI:
20
+    # 只初始化 LLM
21
+    llm = LLM(
22
+        model=model,
23
+        tensor_parallel_size=tensor_parallel_size,
24
+        gpu_memory_utilization=gpu_memory_utilization,
25
+        dtype=dtype,
26
+    )
27
+
28
+    sampling_params = SamplingParams(
29
+        temperature=temperature,
30
+        top_p=top_p,
31
+        max_tokens=max_tokens,
32
+    )
33
+
34
+    app = FastAPI()
35
+
36
+    @app.post("/v1/chat/completions")
37
+    async def chat_completions(request: Request):
38
+        try:
39
+            data = await request.json()
40
+            messages = data["messages"]
41
+            tools = data.get("tools")  # 支持 tools 参数
42
+            created_time = time.time()
43
+            request_id = f"chatcmpl-{int(time.time())}"
44
+
45
+            # 调用 llm.chat(),传入 tools
46
+            outputs = llm.chat(
47
+                messages=messages,
48
+                sampling_params=sampling_params,
49
+                tools=tools,
50
+            )
51
+            if data.get("stream"):
52
+                def generate():
53
+                    full_text = ""
54
+                    for output in outputs:
55
+                        new_text = output.outputs[0].text[len(full_text):]
56
+                        full_text = output.outputs[0].text
57
+                        response_data = {
58
+                            "id": request_id,
59
+                            "model": served_model_name,
60
+                            "created": created_time,
61
+                            "choices": [{
62
+                                "index": 0,
63
+                                "delta": {"content": new_text},
64
+                                "finish_reason": output.outputs[0].finish_reason,
65
+                            }],
66
+                        }
67
+                        yield f"data: {json.dumps(response_data)}\n\n"
68
+                    yield "data: [DONE]\n\n"
69
+
70
+                return StreamingResponse(generate(), media_type="text/event-stream")
71
+            else:
72
+                return {
73
+                    "id": request_id,
74
+                    "model": served_model_name,
75
+                    "created": created_time,
76
+                    "choices": [{
77
+                        "index": 0,
78
+                        "message": {
79
+                            "role": "assistant",
80
+                            "content": outputs[0].outputs[0].text,
81
+                        },
82
+                        "finish_reason": outputs[0].outputs[0].finish_reason,
83
+                    }],
84
+                }
85
+
86
+        except Exception as e:
87
+            return {"error": str(e)}, 400
88
+
89
+    return app
90
+
91
+if __name__ == "__main__":
92
+    # 配置参数
93
+    CONFIG = {
94
+        "model": "/mnt/d/Qwen/Qwen2.5-1.5B-Instruct",
95
+        "served_model_name": "Qwen2.5-1.5B-Instruct",
96
+        # "model": "/mnt/d/Deepseek/DeepSeek-R1-Distill-Qwen-1.5B",
97
+        # "served_model_name": "DeepSeek-R1-Distill-Qwen-1.5B",
98
+        "host": "172.25.231.226",
99
+        "port": 8000,
100
+        "tensor_parallel_size": 1,
101
+        "top_p": 0.9,
102
+        "temperature": 0.7,
103
+        "max_tokens": 8192,
104
+        "gpu_memory_utilization": 0.9,
105
+        "dtype": "float16"
106
+    }
107
+    
108
+    # 创建应用
109
+    app = create_vllm_server(**CONFIG)
110
+    
111
+    # 启动服务器
112
+    uvicorn.run(
113
+        app,
114
+        host=CONFIG["host"],
115
+        port=CONFIG["port"],
116
+        workers= 1,
117
+    )

Loading…
取消
儲存