大语言模型
Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

vllm_server.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from vllm import LLM, SamplingParams
  2. from fastapi import FastAPI, Request
  3. from fastapi.responses import StreamingResponse
  4. import uvicorn
  5. import time
  6. import json
  7. def create_vllm_server(
  8. model: str,
  9. served_model_name: str,
  10. host: str,
  11. port: int,
  12. tensor_parallel_size: int,
  13. top_p: float,
  14. temperature: float,
  15. max_tokens: int,
  16. gpu_memory_utilization: float,
  17. dtype: str,
  18. ) -> FastAPI:
  19. # 只初始化 LLM
  20. llm = LLM(
  21. model=model,
  22. tensor_parallel_size=tensor_parallel_size,
  23. gpu_memory_utilization=gpu_memory_utilization,
  24. dtype=dtype,
  25. )
  26. sampling_params = SamplingParams(
  27. temperature=temperature,
  28. top_p=top_p,
  29. max_tokens=max_tokens,
  30. )
  31. app = FastAPI()
  32. @app.post("/v1/chat/completions")
  33. async def chat_completions(request: Request):
  34. try:
  35. data = await request.json()
  36. messages = data["messages"]
  37. tools = data.get("tools") # 支持 tools 参数
  38. created_time = time.time()
  39. request_id = f"chatcmpl-{int(time.time())}"
  40. # 调用 llm.chat(),传入 tools
  41. outputs = llm.chat(
  42. messages=messages,
  43. sampling_params=sampling_params,
  44. tools=tools,
  45. )
  46. if data.get("stream"):
  47. def generate():
  48. full_text = ""
  49. for output in outputs:
  50. new_text = output.outputs[0].text[len(full_text):]
  51. full_text = output.outputs[0].text
  52. response_data = {
  53. "id": request_id,
  54. "model": served_model_name,
  55. "created": created_time,
  56. "choices": [{
  57. "index": 0,
  58. "delta": {"content": new_text},
  59. "finish_reason": output.outputs[0].finish_reason,
  60. }],
  61. }
  62. yield f"data: {json.dumps(response_data)}\n\n"
  63. yield "data: [DONE]\n\n"
  64. return StreamingResponse(generate(), media_type="text/event-stream")
  65. else:
  66. return {
  67. "id": request_id,
  68. "model": served_model_name,
  69. "created": created_time,
  70. "choices": [{
  71. "index": 0,
  72. "message": {
  73. "role": "assistant",
  74. "content": outputs[0].outputs[0].text,
  75. },
  76. "finish_reason": outputs[0].outputs[0].finish_reason,
  77. }],
  78. }
  79. except Exception as e:
  80. return {"error": str(e)}, 400
  81. return app
  82. if __name__ == "__main__":
  83. # 配置参数
  84. CONFIG = {
  85. "model": "/mnt/d/Qwen/Qwen2.5-1.5B-Instruct",
  86. "served_model_name": "Qwen2.5-1.5B-Instruct",
  87. # "model": "/mnt/d/Deepseek/DeepSeek-R1-Distill-Qwen-1.5B",
  88. # "served_model_name": "DeepSeek-R1-Distill-Qwen-1.5B",
  89. "host": "172.25.231.226",
  90. "port": 8000,
  91. "tensor_parallel_size": 1,
  92. "top_p": 0.9,
  93. "temperature": 0.7,
  94. "max_tokens": 8192,
  95. "gpu_memory_utilization": 0.9,
  96. "dtype": "float16"
  97. }
  98. # 创建应用
  99. app = create_vllm_server(**CONFIG)
  100. # 启动服务器
  101. uvicorn.run(
  102. app,
  103. host=CONFIG["host"],
  104. port=CONFIG["port"],
  105. workers= 1,
  106. )