Commit 39cf978d authored by linpeiqin's avatar linpeiqin

调整预览命令,增加日志,状态明细,状态,同步状态接口灯

parent 830cbcf6
......@@ -26,6 +26,10 @@ public class PythonConfig {
* 模型训练基础目录
*/
private String modelOutputFileBaseDir;
/**
* 模型训练基础目录
*/
private String modelEstimateFileBaseDir;
/**
* 数据集配置文件
*/
......
......@@ -3,9 +3,11 @@ package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.pagehelper.page.PageMethod;
import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport;
import com.yice.common.core.annotation.MyRequestBody;
import com.yice.common.core.annotation.NoAuthInterface;
import com.yice.common.core.constant.ErrorCodeEnum;
import com.yice.common.core.object.*;
import com.yice.common.core.util.MyCommonUtil;
......@@ -18,12 +20,17 @@ import com.yice.webadmin.app.dto.ModelEstimateDto;
import com.yice.webadmin.app.dto.ModelTaskDto;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.*;
import com.yice.webadmin.app.util.FileUtil;
import com.yice.webadmin.app.vo.ModelEstimateVo;
import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.io.File;
import java.io.IOException;
import java.util.List;
/**
......@@ -47,9 +54,7 @@ public class ModelEstimateController {
@Autowired
private DatasetVersionService datasetVersionService;
@Autowired
private DatasetManageService datasetManageService;
@Autowired
private ModelManageService modelManageService;
private TuningRunService tuningRunService;
@Autowired
private PythonConfig pythonConfig;
......@@ -95,6 +100,7 @@ public class ModelEstimateController {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(modelEstimate.getDatasetVersionId());
TuningRun tuningRun = this.tuningRunService.getById(modelVersion.getRunId());
JSONObject jsonObject = (JSONObject) JSON.parse(modelEstimate.getConfiguration());
JSONArray datasetVersionNames = new JSONArray();
datasetVersionNames.add(datasetVersion.getVersionName());
......@@ -102,7 +108,7 @@ public class ModelEstimateController {
array.add("zh");
array.add(modelVersion.getVersionName());
array.add(modelVersion.getModelUrl());
array.add("");
array.add(tuningRun.getTrainMethod());
array.add(new JSONArray());
array.add(jsonObject.get("quantizationLevel"));
array.add(jsonObject.get("promptTemplate"));
......@@ -119,9 +125,81 @@ public class ModelEstimateController {
array.add(jsonObject.get("maximumGeneratingLength"));
array.add(jsonObject.get("ToppSamplingValue"));
array.add(jsonObject.get("temperatureCoefficient"));
array.add(this.pythonConfig.getModelEstimateFileBaseDir() + File.separator + modelVersion.getModelUrl() + File.separator + "evl_" + taskId);
System.out.println(array.toJSONString());
return ResponseResult.success(array.toJSONString());
}
/**
* 提交任务状态信息。
*
* @return 应答结果对象,包含查询结果集。
*/
@NoAuthInterface
@OperationLog(type = SysOperationLogType.UPLOAD, saveResponse = false)
@PostMapping("/postTaskStatus")
public ResponseResult<Void> postTaskStatus(@MyRequestBody Long taskId,
@MyRequestBody Long runTime,
@MyRequestBody Integer taskStatus) {
String errorMessage;
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
if (modelEstimate == null) {
errorMessage = "数据验证失败,当前评估任务并不存在,请联系管理员!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
modelEstimate.setTaskStatus(taskStatus);
if (!modelEstimateService.updateById(modelEstimate)) {
errorMessage = "任务状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
return ResponseResult.success();
}
/**
* 获取评估任务的日志。
*
* @param taskId 指定对象主键Id。
* @return 应答结果对象,包含对象详情。
*/
@GetMapping("/getLog")
public ResponseEntity<Resource> getLog(@RequestParam Long taskId) throws IOException {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
String url = this.pythonConfig.getModelEstimateFileBaseDir() + File.separator + modelVersion.getModelUrl() + File.separator + "evl_" + taskId + File.separator + "run_eval.log";
return FileUtil.getFileByUrl(url);
}
/**
* 获取任务状态。
*
* @param taskId 指定对象主键Id。
* @return 应答结果对象,包含对象详情。
*/
@GetMapping("/getStatus")
public ResponseResult<String> getStatus(@RequestParam Long taskId) throws IOException {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
String url = this.pythonConfig.getModelEstimateFileBaseDir() + File.separator + modelVersion.getModelUrl() + File.separator + "evl_" + taskId + File.separator + "all_results.json";
File file = new File(url); // 指定文件路径
ObjectMapper objectMapper = new ObjectMapper();
return ResponseResult.success(objectMapper.readTree(file).toString());
}
/**
* 获取任务状态明细。
*
* @param taskId 指定对象主键Id。
* @return 应答结果对象,包含对象详情。
*/
@GetMapping("/getStatusDetail")
public ResponseResult<String> getStatusDetail(@RequestParam Long taskId) throws IOException {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
String url = this.pythonConfig.getModelEstimateFileBaseDir() + File.separator + modelVersion.getModelUrl() + File.separator + "evl_" + taskId + File.separator + "generated_predictions.jsonl";
File file = new File(url); // 指定文件路径
ObjectMapper objectMapper = new ObjectMapper();
return ResponseResult.success(objectMapper.readTree(file).toString());
}
/**
* 修改模型评估数据,及其关联的从表数据。
*
......
......@@ -20,6 +20,7 @@ import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.dto.TuningRunDto;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.*;
import com.yice.webadmin.app.util.FileUtil;
import com.yice.webadmin.app.vo.TuningRunVo;
import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j;
......@@ -82,13 +83,6 @@ public class TuningRunController {
return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, "请填写任务ID");
}
TuningRun tuningRun = MyModelUtil.copyTo(tuningRunDto, TuningRun.class);
TuningRun tuningRunFilter = new TuningRun();
tuningRunFilter.setTaskId(tuningRun.getTaskId());
List<TuningRun> reTuningRunList = this.tuningRunService.getTuningRunList(tuningRunFilter, "run_version");
TuningTask tuningTask = this.tuningTaskService.getById(tuningRun.getTaskId());
Integer lastRunVersion = reTuningRunList.get(reTuningRunList.size() - 1).getRunVersion();
tuningRun.setRunVersion(lastRunVersion + 1);
tuningRun.setRunName(tuningTask.getTaskName() + " V" + (lastRunVersion + 1));
tuningRun = tuningRunService.saveNew(tuningRun);
return ResponseResult.success(tuningRun.getRunId());
}
......@@ -298,7 +292,7 @@ public class TuningRunController {
TuningRun tuningRun = this.tuningRunService.getById(runId);
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
String url = this.pythonConfig.getModelTuningFileBaseDir() + modelVersion.getVersionName() + File.separator + tuningRun.getTrainMethod() + File.separator + "train_" + tuningRun.getRunId() + File.separator + "run_train.log";
return this.getFileByUrl(url);
return FileUtil.getFileByUrl(url);
}
/**
......@@ -317,21 +311,6 @@ public class TuningRunController {
return ResponseResult.success(objectMapper.readTree(file).toString());
}
private ResponseEntity<Resource> getFileByUrl(String url) throws IOException {
Resource resource = new UrlResource(Paths.get(url).toUri()); // 使用UrlResource从文件路径构建一个Resource对象
Optional<Resource> optionalResource = Optional.ofNullable(resource);
if (optionalResource.isPresent() && optionalResource.get().exists()) { // 检查文件是否存在
HttpHeaders headers = new HttpHeaders(); // 构建HTTP响应头
headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + resource.getFilename() + "\""); // 设置内容类型为附件,并指定文件名
headers.setContentType(MediaType.TEXT_PLAIN); // 设置内容类型为JSON
return ResponseEntity.ok() // 返回200 OK状态码
.headers(headers) // 设置响应头
.contentLength(resource.contentLength()) // 设置响应内容长度(可选)
.body(resource); // 将Resource对象作为响应的主体返回
} else {
return ResponseEntity.notFound().build(); // 如果文件不存在,返回404 Not Found状态码
}
}
/**
* 查看指定精调任务运行新版本对象详情。
......
......@@ -19,6 +19,7 @@ import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.model.TuningRun;
import com.yice.webadmin.app.model.TuningTask;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.service.TuningRunService;
......@@ -29,12 +30,10 @@ import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ServerHandshake;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
/**
......@@ -81,6 +80,16 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Transactional(rollbackFor = Exception.class)
@Override
public TuningRun saveNew(TuningRun tuningRun) {
TuningRun tuningRunFilter = new TuningRun();
tuningRunFilter.setTaskId(tuningRun.getTaskId());
List<TuningRun> reTuningRunList = this.getTuningRunList(tuningRunFilter, "run_version");
TuningTask tuningTask = this.tuningTaskService.getById(tuningRun.getTaskId());
Integer version = 1;
if (reTuningRunList != null && reTuningRunList.size() != 0) {
version = reTuningRunList.get(reTuningRunList.size() - 1).getRunVersion() + 1;
}
tuningRun.setRunVersion(version);
tuningRun.setRunName(tuningTask.getTaskName() + "_V" + tuningRun.getRunVersion());
tuningRunMapper.insert(this.buildDefaultValue(tuningRun));
return tuningRun;
}
......@@ -202,6 +211,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
public void onOpen(ServerHandshake serverHandshake) {
log.info("-------------与大模型建立连接-------------");
}
@Override
public void onMessage(String message) {
log.info("收到来自服务端的消息:" + message);
......@@ -227,12 +237,11 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(tuningRun.getTrainMethod());
array.add(jsonObject.get("promptTemplate"));
array.add(2);
//路径需要修改,暂时不能确定,后面一条线解决
String newModelUrl = pythonConfig.getModelOutputFileBaseDir() + modelName;
array.add(newModelUrl);
array.add("none");
sendJson.put("data",array);
sendJson.put("event_data","null");
sendJson.put("data", array);
sendJson.put("event_data", "null");
sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString());
System.out.println(array.toJSONString());
......@@ -243,10 +252,12 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
this.close();
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
......
......@@ -145,8 +145,6 @@ public class TuningTaskServiceImpl extends BaseService<TuningTask, Long> impleme
public TuningTask saveAndCreateVersion(TuningTask tuningTask, TuningRun tuningRun) {
TuningTask reTuningTask = this.saveNew(tuningTask);
tuningRun.setTaskId(reTuningTask.getTaskId());
tuningRun.setRunVersion(1);
tuningRun.setRunName(tuningTask.getTaskName() + " V" + tuningRun.getRunVersion());
this.tuningRunService.saveNew(tuningRun);
return reTuningTask;
}
......
package com.yice.webadmin.app.util;
import org.springframework.core.io.Resource;
import org.springframework.core.io.UrlResource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Optional;
public class FileUtil {
public static ResponseEntity<Resource> getFileByUrl(String url) throws IOException {
Resource resource = new UrlResource(Paths.get(url).toUri()); // 使用UrlResource从文件路径构建一个Resource对象
Optional<Resource> optionalResource = Optional.ofNullable(resource);
if (optionalResource.isPresent() && optionalResource.get().exists()) { // 检查文件是否存在
HttpHeaders headers = new HttpHeaders(); // 构建HTTP响应头
headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + resource.getFilename() + "\""); // 设置内容类型为附件,并指定文件名
headers.setContentType(MediaType.TEXT_PLAIN); // 设置内容类型为JSON
return ResponseEntity.ok() // 返回200 OK状态码
.headers(headers) // 设置响应头
.contentLength(resource.contentLength()) // 设置响应内容长度(可选)
.body(resource); // 将Resource对象作为响应的主体返回
} else {
return ResponseEntity.notFound().build(); // 如果文件不存在,返回404 Not Found状态码
}
}
}
......@@ -67,6 +67,8 @@ python:
modelTuningFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
#模型训练文件合并后路径
modelOutputFileBaseDir: /home/linking/llms/models/
#模型评估文件基础路径
modelEstimateFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
#数据集配置信息
datasetInfo: dataset_info.json
#数据集配置目录
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment