Commit 0c70924d authored by linpeiqin's avatar linpeiqin

修改断点目录

parent 574942c1
......@@ -18,6 +18,10 @@ public class PythonConfig {
* 数据集基础目录
*/
private String datasetFileBaseDir;
/**
* 模型训练基础目录
*/
private String modelTuningFileBaseDir;
/**
* 数据集配置文件
*/
......
......@@ -21,7 +21,6 @@ import com.yice.webadmin.app.service.KnowledgeManageService;
import com.yice.webadmin.app.service.ProxyPythonService;
import com.yice.webadmin.app.vo.KnowledgeManageVo;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiParam;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
......@@ -64,11 +63,7 @@ public class KnowledgeManageController {
return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, errorMessage);
}
KnowledgeManage knowledgeManage = MyModelUtil.copyTo(knowledgeManageDto, KnowledgeManage.class);
String requestBody = "{\n" +
" \"knowledge_base_name\": \"" + knowledgeManage.getKnowledgeName() + "\",\n" +
" \"vector_store_type\": \"faiss\",\n" +
" \"embed_model\": \"m3e-base\"\n" +
"}";
String requestBody = "{\n" + " \"knowledge_base_name\": \"" + knowledgeManage.getKnowledgeName() + "\",\n" + " \"vector_store_type\": \"faiss\",\n" + " \"embed_model\": \"m3e-base\"\n" + "}";
try {
String result = proxyPythonService.predictPost(knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getCreate(), requestBody);
JSONObject jo = (JSONObject) JSON.parse(result);
......@@ -156,7 +151,7 @@ public class KnowledgeManageController {
}
String data = null;
try {
data = this.proxyPythonService.predictPost(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getSearchDocs(),requestBody);
data = this.proxyPythonService.predictPost(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getSearchDocs(), requestBody);
} catch (IOException e) {
throw new RuntimeException(e);
}
......@@ -176,7 +171,7 @@ public class KnowledgeManageController {
}
String result = null;
try {
result = this.proxyPythonService.predictPost(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getDeleteDocs(),requestBody);
result = this.proxyPythonService.predictPost(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getDeleteDocs(), requestBody);
JSONObject jo = (JSONObject) JSON.parse(result);
Integer code = jo.getIntValue("code");
String msg = jo.getString("msg");
......@@ -190,6 +185,7 @@ public class KnowledgeManageController {
throw new RuntimeException(e);
}
}
/**
* 更新现有文件到知识库。
*
......@@ -203,7 +199,7 @@ public class KnowledgeManageController {
}
String result = null;
try {
result = this.proxyPythonService.predictPost(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getUpdateDocs(),requestBody);
result = this.proxyPythonService.predictPost(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getUpdateDocs(), requestBody);
JSONObject jo = (JSONObject) JSON.parse(result);
Integer code = jo.getIntValue("code");
String msg = jo.getString("msg");
......@@ -217,6 +213,7 @@ public class KnowledgeManageController {
throw new RuntimeException(e);
}
}
/**
* 上传文件到知识库,并进行量化。
*
......@@ -230,13 +227,13 @@ public class KnowledgeManageController {
}
String result = null;
try {
result = this.proxyPythonService.predictPostForFile(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getUploadDocs(),requestBody);
result = this.proxyPythonService.predictPostForFile(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getUploadDocs(), requestBody);
JSONObject jo = (JSONObject) JSON.parse(result);
Integer code = jo.getIntValue("code");
String msg = jo.getString("msg");
String data = jo.getString("data");
if (code != null && code == 200) {
return ResponseResult.create(ErrorCodeEnum.NO_ERROR,msg,data);
return ResponseResult.create(ErrorCodeEnum.NO_ERROR, msg, data);
} else {
return ResponseResult.create(ErrorCodeEnum.SERVER_INTERNAL_ERROR, msg, data);
}
......@@ -244,6 +241,7 @@ public class KnowledgeManageController {
throw new RuntimeException(e);
}
}
/**
* 根据content中文档重建向量库,流式输出处理进度。
*
......@@ -256,7 +254,7 @@ public class KnowledgeManageController {
return ResponseResult.error(ErrorCodeEnum.ARGUMENT_NULL_EXIST);
}
try {
String result = this.proxyPythonService.predictPost(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getRecreate(),requestBody);
String result = this.proxyPythonService.predictPost(this.knowledgeConfig.getKnowledgeInterface() + knowledgeConfig.getRecreate(), requestBody);
return ResponseResult.success(result);
} catch (IOException e) {
throw new RuntimeException(e);
......@@ -290,10 +288,7 @@ public class KnowledgeManageController {
* @return 应答结果对象,包含查询结果集。
*/
@PostMapping("/list")
public ResponseResult<MyPageData<KnowledgeManageVo>> list(
@MyRequestBody KnowledgeManageDto knowledgeManageDtoFilter,
@MyRequestBody MyOrderParam orderParam,
@MyRequestBody MyPageParam pageParam) {
public ResponseResult<MyPageData<KnowledgeManageVo>> list(@MyRequestBody KnowledgeManageDto knowledgeManageDtoFilter, @MyRequestBody MyOrderParam orderParam, @MyRequestBody MyPageParam pageParam) {
if (pageParam != null) {
PageMethod.startPage(pageParam.getPageNum(), pageParam.getPageSize());
}
......@@ -321,6 +316,37 @@ public class KnowledgeManageController {
return ResponseResult.success(MyPageUtil.makeResponseData(reKnowledgeManageList, KnowledgeManage.INSTANCE));
}
/**
* 列出符合过滤条件的知识库管理列表。
*
* @param knowledgeManageDtoFilter 过滤对象。
* @return 应答结果对象,包含查询结果集。
*/
@PostMapping("/listForTree")
public ResponseResult<List<KnowledgeManageVo>> listForTree(@MyRequestBody KnowledgeManageDto knowledgeManageDtoFilter) {
KnowledgeManage knowledgeManageFilter = MyModelUtil.copyTo(knowledgeManageDtoFilter, KnowledgeManage.class);
List<KnowledgeManage> knowledgeManageList = knowledgeManageService.getKnowledgeManageListWithRelation(knowledgeManageFilter, "");
List<KnowledgeManage> reKnowledgeManageList = new ArrayList<>();
try {
String result = this.proxyPythonService.predictGet(this.knowledgeConfig.getKnowledgeInterface() + this.knowledgeConfig.getList(), "");
JSONObject jo = (JSONObject) JSON.parse(result);
Integer code = jo.getIntValue("code");
JSONArray jsonArray = jo.getJSONArray("data");
if (code != null && code == 200) {
for (KnowledgeManage knowledgeManage : knowledgeManageList) {
for (Object jsonObject : jsonArray) {
if (knowledgeManage.getKnowledgeName().equals(String.valueOf(jsonObject))) {
reKnowledgeManageList.add(knowledgeManage);
}
}
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return ResponseResult.success(KnowledgeManage.INSTANCE.fromModelList(reKnowledgeManageList));
}
/**
* 查看指定知识库管理对象详情。
*
......
......@@ -9,14 +9,6 @@ import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.service.ProxyPythonService;
import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpEntity;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.PostMapping;
......@@ -31,15 +23,15 @@ import java.io.IOException;
@Api(tags = "python代理接口")
@Slf4j
@RestController
@RequestMapping("/admin/app/python")
public class ProxyController {
@RequestMapping("/admin/app/proxy")
public class ProxyPythonController {
@Autowired
private PythonConfig pythonConfig;
@Autowired
private ProxyPythonService proxyPythonService;
private final WebClient webClient;
public ProxyController(WebClient webClient) {
public ProxyPythonController(WebClient webClient) {
this.webClient = webClient;
}
......
package com.yice.webadmin.app.controller;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yice.common.core.object.ResponseResult;
import com.yice.common.log.annotation.OperationLog;
import com.yice.common.log.model.constant.SysOperationLogType;
import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.service.ProxyPythonService;
import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import java.io.IOException;
@Api(tags = "接收python消息接口")
@Slf4j
@RestController
@RequestMapping("/admin/app/receive")
public class ReceivePythonController {
@Autowired
private PythonConfig pythonConfig;
@Autowired
private ProxyPythonService proxyPythonService;
private final WebClient webClient;
public ReceivePythonController(WebClient webClient) {
this.webClient = webClient;
}
@OperationLog(type = SysOperationLogType.OTHER)
@PostMapping("/originalPredict")
public Mono<String> originalPredict(@RequestBody String requestBody) throws JsonProcessingException {
return webClient.post()
.uri(this.pythonConfig.getFactoryInterface())
.contentType(MediaType.APPLICATION_JSON)
.bodyValue(new ObjectMapper().readTree(requestBody))
.retrieve()
.bodyToMono(String.class);
}
@OperationLog(type = SysOperationLogType.OTHER)
@PostMapping("/predict")
public ResponseResult<String> predict(@RequestBody String requestBody) throws IOException {
return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(),requestBody));
}
}
......@@ -3,6 +3,8 @@ package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.github.pagehelper.page.PageMethod;
import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport;
import com.yice.common.core.annotation.MyRequestBody;
......@@ -22,12 +24,23 @@ import com.yice.webadmin.app.vo.TuningRunVo;
import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.Resource;
import org.springframework.core.io.UrlResource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.List;
import java.util.Optional;
/**
* 精调任务运行操作控制器类。
......@@ -247,6 +260,52 @@ public class TuningRunController {
return ResponseResult.success(tuningRunVo);
}
/**
* 获取运行的日志。
*
* @param runId 指定对象主键Id。
* @return 应答结果对象,包含对象详情。
*/
@GetMapping("/getLog")
public ResponseEntity<Resource> getLog(@RequestParam Long runId) throws IOException {
TuningRun tuningRun = this.tuningRunService.getById(runId);
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
String url = this.pythonConfig.getModelTuningFileBaseDir() + modelVersion.getVersionName() + File.separator + tuningRun.getTrainMethod() + File.separator +"train_" + tuningRun.getRunId() + File.separator + "run_train.log";
return this.getFileByUrl(url);
}
/**
* 获取运行状态。
*
* @param runId 指定对象主键Id。
* @return 应答结果对象,包含对象详情。
*/
@GetMapping("/getStatus")
public ResponseResult<ArrayNode> getStatus(@RequestParam Long runId) throws IOException {
TuningRun tuningRun = this.tuningRunService.getById(runId);
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
String url = this.pythonConfig.getModelTuningFileBaseDir() + modelVersion.getVersionName() + File.separator + tuningRun.getTrainMethod() + File.separator +"train_" + tuningRun.getRunId() + File.separator + "trainer_state.json";
File file = new File(url); // 指定文件路径
ObjectMapper objectMapper = new ObjectMapper();
ArrayNode arrayNode = (ArrayNode) objectMapper.readTree(file);
return ResponseResult.success(arrayNode);
}
private ResponseEntity<Resource> getFileByUrl(String url) throws IOException {
Path file = Paths.get(url);
Resource resource = new UrlResource(file.toUri()); // 使用UrlResource从文件路径构建一个Resource对象
Optional<Resource> optionalResource = Optional.ofNullable(resource);
if (optionalResource.isPresent() && optionalResource.get().exists()) { // 检查文件是否存在
HttpHeaders headers = new HttpHeaders(); // 构建HTTP响应头
headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + resource.getFilename() + "\""); // 设置内容类型为附件,并指定文件名
headers.setContentType(MediaType.TEXT_PLAIN); // 设置内容类型为JSON
return ResponseEntity.ok() // 返回200 OK状态码
.headers(headers) // 设置响应头
.contentLength(resource.contentLength()) // 设置响应内容长度(可选)
.body(resource); // 将Resource对象作为响应的主体返回
} else {
return ResponseEntity.notFound().build(); // 如果文件不存在,返回404 Not Found状态码
}
}
/**
* 查看指定精调任务运行新版本对象详情。
*
......
......@@ -63,6 +63,8 @@ application:
python:
#数据集文件基础路径
datasetFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/lmp_data/
#模型训练文件基础路径
modelTuningFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
#数据集配置信息
datasetInfo: dataset_info.json
#数据集配置目录
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment