Commit 39cf978d authored by linpeiqin's avatar linpeiqin

调整预览命令,增加日志,状态明细,状态,同步状态接口灯

parent 830cbcf6
...@@ -26,6 +26,10 @@ public class PythonConfig { ...@@ -26,6 +26,10 @@ public class PythonConfig {
* 模型训练基础目录 * 模型训练基础目录
*/ */
private String modelOutputFileBaseDir; private String modelOutputFileBaseDir;
/**
* 模型训练基础目录
*/
private String modelEstimateFileBaseDir;
/** /**
* 数据集配置文件 * 数据集配置文件
*/ */
......
...@@ -3,9 +3,11 @@ package com.yice.webadmin.app.controller; ...@@ -3,9 +3,11 @@ package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.pagehelper.page.PageMethod; import com.github.pagehelper.page.PageMethod;
import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport; import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport;
import com.yice.common.core.annotation.MyRequestBody; import com.yice.common.core.annotation.MyRequestBody;
import com.yice.common.core.annotation.NoAuthInterface;
import com.yice.common.core.constant.ErrorCodeEnum; import com.yice.common.core.constant.ErrorCodeEnum;
import com.yice.common.core.object.*; import com.yice.common.core.object.*;
import com.yice.common.core.util.MyCommonUtil; import com.yice.common.core.util.MyCommonUtil;
...@@ -18,12 +20,17 @@ import com.yice.webadmin.app.dto.ModelEstimateDto; ...@@ -18,12 +20,17 @@ import com.yice.webadmin.app.dto.ModelEstimateDto;
import com.yice.webadmin.app.dto.ModelTaskDto; import com.yice.webadmin.app.dto.ModelTaskDto;
import com.yice.webadmin.app.model.*; import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.*; import com.yice.webadmin.app.service.*;
import com.yice.webadmin.app.util.FileUtil;
import com.yice.webadmin.app.vo.ModelEstimateVo; import com.yice.webadmin.app.vo.ModelEstimateVo;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.io.File;
import java.io.IOException;
import java.util.List; import java.util.List;
/** /**
...@@ -47,9 +54,7 @@ public class ModelEstimateController { ...@@ -47,9 +54,7 @@ public class ModelEstimateController {
@Autowired @Autowired
private DatasetVersionService datasetVersionService; private DatasetVersionService datasetVersionService;
@Autowired @Autowired
private DatasetManageService datasetManageService; private TuningRunService tuningRunService;
@Autowired
private ModelManageService modelManageService;
@Autowired @Autowired
private PythonConfig pythonConfig; private PythonConfig pythonConfig;
...@@ -95,6 +100,7 @@ public class ModelEstimateController { ...@@ -95,6 +100,7 @@ public class ModelEstimateController {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId); ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId()); ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(modelEstimate.getDatasetVersionId()); DatasetVersion datasetVersion = this.datasetVersionService.getById(modelEstimate.getDatasetVersionId());
TuningRun tuningRun = this.tuningRunService.getById(modelVersion.getRunId());
JSONObject jsonObject = (JSONObject) JSON.parse(modelEstimate.getConfiguration()); JSONObject jsonObject = (JSONObject) JSON.parse(modelEstimate.getConfiguration());
JSONArray datasetVersionNames = new JSONArray(); JSONArray datasetVersionNames = new JSONArray();
datasetVersionNames.add(datasetVersion.getVersionName()); datasetVersionNames.add(datasetVersion.getVersionName());
...@@ -102,7 +108,7 @@ public class ModelEstimateController { ...@@ -102,7 +108,7 @@ public class ModelEstimateController {
array.add("zh"); array.add("zh");
array.add(modelVersion.getVersionName()); array.add(modelVersion.getVersionName());
array.add(modelVersion.getModelUrl()); array.add(modelVersion.getModelUrl());
array.add(""); array.add(tuningRun.getTrainMethod());
array.add(new JSONArray()); array.add(new JSONArray());
array.add(jsonObject.get("quantizationLevel")); array.add(jsonObject.get("quantizationLevel"));
array.add(jsonObject.get("promptTemplate")); array.add(jsonObject.get("promptTemplate"));
...@@ -119,9 +125,81 @@ public class ModelEstimateController { ...@@ -119,9 +125,81 @@ public class ModelEstimateController {
array.add(jsonObject.get("maximumGeneratingLength")); array.add(jsonObject.get("maximumGeneratingLength"));
array.add(jsonObject.get("ToppSamplingValue")); array.add(jsonObject.get("ToppSamplingValue"));
array.add(jsonObject.get("temperatureCoefficient")); array.add(jsonObject.get("temperatureCoefficient"));
array.add(this.pythonConfig.getModelEstimateFileBaseDir() + File.separator + modelVersion.getModelUrl() + File.separator + "evl_" + taskId);
System.out.println(array.toJSONString()); System.out.println(array.toJSONString());
return ResponseResult.success(array.toJSONString()); return ResponseResult.success(array.toJSONString());
} }
/**
* 提交任务状态信息。
*
* @return 应答结果对象,包含查询结果集。
*/
@NoAuthInterface
@OperationLog(type = SysOperationLogType.UPLOAD, saveResponse = false)
@PostMapping("/postTaskStatus")
public ResponseResult<Void> postTaskStatus(@MyRequestBody Long taskId,
@MyRequestBody Long runTime,
@MyRequestBody Integer taskStatus) {
String errorMessage;
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
if (modelEstimate == null) {
errorMessage = "数据验证失败,当前评估任务并不存在,请联系管理员!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
modelEstimate.setTaskStatus(taskStatus);
if (!modelEstimateService.updateById(modelEstimate)) {
errorMessage = "任务状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
return ResponseResult.success();
}
/**
* 获取评估任务的日志。
*
* @param taskId 指定对象主键Id。
* @return 应答结果对象,包含对象详情。
*/
@GetMapping("/getLog")
public ResponseEntity<Resource> getLog(@RequestParam Long taskId) throws IOException {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
String url = this.pythonConfig.getModelEstimateFileBaseDir() + File.separator + modelVersion.getModelUrl() + File.separator + "evl_" + taskId + File.separator + "run_eval.log";
return FileUtil.getFileByUrl(url);
}
/**
* 获取任务状态。
*
* @param taskId 指定对象主键Id。
* @return 应答结果对象,包含对象详情。
*/
@GetMapping("/getStatus")
public ResponseResult<String> getStatus(@RequestParam Long taskId) throws IOException {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
String url = this.pythonConfig.getModelEstimateFileBaseDir() + File.separator + modelVersion.getModelUrl() + File.separator + "evl_" + taskId + File.separator + "all_results.json";
File file = new File(url); // 指定文件路径
ObjectMapper objectMapper = new ObjectMapper();
return ResponseResult.success(objectMapper.readTree(file).toString());
}
/**
* 获取任务状态明细。
*
* @param taskId 指定对象主键Id。
* @return 应答结果对象,包含对象详情。
*/
@GetMapping("/getStatusDetail")
public ResponseResult<String> getStatusDetail(@RequestParam Long taskId) throws IOException {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
String url = this.pythonConfig.getModelEstimateFileBaseDir() + File.separator + modelVersion.getModelUrl() + File.separator + "evl_" + taskId + File.separator + "generated_predictions.jsonl";
File file = new File(url); // 指定文件路径
ObjectMapper objectMapper = new ObjectMapper();
return ResponseResult.success(objectMapper.readTree(file).toString());
}
/** /**
* 修改模型评估数据,及其关联的从表数据。 * 修改模型评估数据,及其关联的从表数据。
* *
......
...@@ -20,6 +20,7 @@ import com.yice.webadmin.app.dto.RunPublishDto; ...@@ -20,6 +20,7 @@ import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.dto.TuningRunDto; import com.yice.webadmin.app.dto.TuningRunDto;
import com.yice.webadmin.app.model.*; import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.*; import com.yice.webadmin.app.service.*;
import com.yice.webadmin.app.util.FileUtil;
import com.yice.webadmin.app.vo.TuningRunVo; import com.yice.webadmin.app.vo.TuningRunVo;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
...@@ -82,13 +83,6 @@ public class TuningRunController { ...@@ -82,13 +83,6 @@ public class TuningRunController {
return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, "请填写任务ID"); return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, "请填写任务ID");
} }
TuningRun tuningRun = MyModelUtil.copyTo(tuningRunDto, TuningRun.class); TuningRun tuningRun = MyModelUtil.copyTo(tuningRunDto, TuningRun.class);
TuningRun tuningRunFilter = new TuningRun();
tuningRunFilter.setTaskId(tuningRun.getTaskId());
List<TuningRun> reTuningRunList = this.tuningRunService.getTuningRunList(tuningRunFilter, "run_version");
TuningTask tuningTask = this.tuningTaskService.getById(tuningRun.getTaskId());
Integer lastRunVersion = reTuningRunList.get(reTuningRunList.size() - 1).getRunVersion();
tuningRun.setRunVersion(lastRunVersion + 1);
tuningRun.setRunName(tuningTask.getTaskName() + " V" + (lastRunVersion + 1));
tuningRun = tuningRunService.saveNew(tuningRun); tuningRun = tuningRunService.saveNew(tuningRun);
return ResponseResult.success(tuningRun.getRunId()); return ResponseResult.success(tuningRun.getRunId());
} }
...@@ -298,7 +292,7 @@ public class TuningRunController { ...@@ -298,7 +292,7 @@ public class TuningRunController {
TuningRun tuningRun = this.tuningRunService.getById(runId); TuningRun tuningRun = this.tuningRunService.getById(runId);
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId()); ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
String url = this.pythonConfig.getModelTuningFileBaseDir() + modelVersion.getVersionName() + File.separator + tuningRun.getTrainMethod() + File.separator + "train_" + tuningRun.getRunId() + File.separator + "run_train.log"; String url = this.pythonConfig.getModelTuningFileBaseDir() + modelVersion.getVersionName() + File.separator + tuningRun.getTrainMethod() + File.separator + "train_" + tuningRun.getRunId() + File.separator + "run_train.log";
return this.getFileByUrl(url); return FileUtil.getFileByUrl(url);
} }
/** /**
...@@ -317,21 +311,6 @@ public class TuningRunController { ...@@ -317,21 +311,6 @@ public class TuningRunController {
return ResponseResult.success(objectMapper.readTree(file).toString()); return ResponseResult.success(objectMapper.readTree(file).toString());
} }
private ResponseEntity<Resource> getFileByUrl(String url) throws IOException {
Resource resource = new UrlResource(Paths.get(url).toUri()); // 使用UrlResource从文件路径构建一个Resource对象
Optional<Resource> optionalResource = Optional.ofNullable(resource);
if (optionalResource.isPresent() && optionalResource.get().exists()) { // 检查文件是否存在
HttpHeaders headers = new HttpHeaders(); // 构建HTTP响应头
headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + resource.getFilename() + "\""); // 设置内容类型为附件,并指定文件名
headers.setContentType(MediaType.TEXT_PLAIN); // 设置内容类型为JSON
return ResponseEntity.ok() // 返回200 OK状态码
.headers(headers) // 设置响应头
.contentLength(resource.contentLength()) // 设置响应内容长度(可选)
.body(resource); // 将Resource对象作为响应的主体返回
} else {
return ResponseEntity.notFound().build(); // 如果文件不存在,返回404 Not Found状态码
}
}
/** /**
* 查看指定精调任务运行新版本对象详情。 * 查看指定精调任务运行新版本对象详情。
......
...@@ -19,6 +19,7 @@ import com.yice.webadmin.app.dto.RunPublishDto; ...@@ -19,6 +19,7 @@ import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.model.ModelManage; import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelVersion; import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.model.TuningRun; import com.yice.webadmin.app.model.TuningRun;
import com.yice.webadmin.app.model.TuningTask;
import com.yice.webadmin.app.service.ModelManageService; import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelVersionService; import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.service.TuningRunService; import com.yice.webadmin.app.service.TuningRunService;
...@@ -29,12 +30,10 @@ import org.java_websocket.client.WebSocketClient; ...@@ -29,12 +30,10 @@ import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455; import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ServerHandshake; import org.java_websocket.handshake.ServerHandshake;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException;
import java.util.List; import java.util.List;
/** /**
...@@ -81,6 +80,16 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -81,6 +80,16 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
@Override @Override
public TuningRun saveNew(TuningRun tuningRun) { public TuningRun saveNew(TuningRun tuningRun) {
TuningRun tuningRunFilter = new TuningRun();
tuningRunFilter.setTaskId(tuningRun.getTaskId());
List<TuningRun> reTuningRunList = this.getTuningRunList(tuningRunFilter, "run_version");
TuningTask tuningTask = this.tuningTaskService.getById(tuningRun.getTaskId());
Integer version = 1;
if (reTuningRunList != null && reTuningRunList.size() != 0) {
version = reTuningRunList.get(reTuningRunList.size() - 1).getRunVersion() + 1;
}
tuningRun.setRunVersion(version);
tuningRun.setRunName(tuningTask.getTaskName() + "_V" + tuningRun.getRunVersion());
tuningRunMapper.insert(this.buildDefaultValue(tuningRun)); tuningRunMapper.insert(this.buildDefaultValue(tuningRun));
return tuningRun; return tuningRun;
} }
...@@ -202,6 +211,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -202,6 +211,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
public void onOpen(ServerHandshake serverHandshake) { public void onOpen(ServerHandshake serverHandshake) {
log.info("-------------与大模型建立连接-------------"); log.info("-------------与大模型建立连接-------------");
} }
@Override @Override
public void onMessage(String message) { public void onMessage(String message) {
log.info("收到来自服务端的消息:" + message); log.info("收到来自服务端的消息:" + message);
...@@ -227,12 +237,11 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -227,12 +237,11 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(tuningRun.getTrainMethod()); array.add(tuningRun.getTrainMethod());
array.add(jsonObject.get("promptTemplate")); array.add(jsonObject.get("promptTemplate"));
array.add(2); array.add(2);
//路径需要修改,暂时不能确定,后面一条线解决
String newModelUrl = pythonConfig.getModelOutputFileBaseDir() + modelName; String newModelUrl = pythonConfig.getModelOutputFileBaseDir() + modelName;
array.add(newModelUrl); array.add(newModelUrl);
array.add("none"); array.add("none");
sendJson.put("data",array); sendJson.put("data", array);
sendJson.put("event_data","null"); sendJson.put("event_data", "null");
sendJson.put("fn_index", 44); sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString()); sendJson.put("session_hash", runPublishDto.getRunId().toString());
System.out.println(array.toJSONString()); System.out.println(array.toJSONString());
...@@ -243,10 +252,12 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -243,10 +252,12 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
this.close(); this.close();
} }
} }
@Override @Override
public void onClose(int i, String s, boolean b) { public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b); log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
} }
@Override @Override
public void onError(Exception e) { public void onError(Exception e) {
log.error("报错了:::" + e.getMessage()); log.error("报错了:::" + e.getMessage());
......
...@@ -145,8 +145,6 @@ public class TuningTaskServiceImpl extends BaseService<TuningTask, Long> impleme ...@@ -145,8 +145,6 @@ public class TuningTaskServiceImpl extends BaseService<TuningTask, Long> impleme
public TuningTask saveAndCreateVersion(TuningTask tuningTask, TuningRun tuningRun) { public TuningTask saveAndCreateVersion(TuningTask tuningTask, TuningRun tuningRun) {
TuningTask reTuningTask = this.saveNew(tuningTask); TuningTask reTuningTask = this.saveNew(tuningTask);
tuningRun.setTaskId(reTuningTask.getTaskId()); tuningRun.setTaskId(reTuningTask.getTaskId());
tuningRun.setRunVersion(1);
tuningRun.setRunName(tuningTask.getTaskName() + " V" + tuningRun.getRunVersion());
this.tuningRunService.saveNew(tuningRun); this.tuningRunService.saveNew(tuningRun);
return reTuningTask; return reTuningTask;
} }
......
package com.yice.webadmin.app.util;
import org.springframework.core.io.Resource;
import org.springframework.core.io.UrlResource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Optional;
public class FileUtil {
public static ResponseEntity<Resource> getFileByUrl(String url) throws IOException {
Resource resource = new UrlResource(Paths.get(url).toUri()); // 使用UrlResource从文件路径构建一个Resource对象
Optional<Resource> optionalResource = Optional.ofNullable(resource);
if (optionalResource.isPresent() && optionalResource.get().exists()) { // 检查文件是否存在
HttpHeaders headers = new HttpHeaders(); // 构建HTTP响应头
headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + resource.getFilename() + "\""); // 设置内容类型为附件,并指定文件名
headers.setContentType(MediaType.TEXT_PLAIN); // 设置内容类型为JSON
return ResponseEntity.ok() // 返回200 OK状态码
.headers(headers) // 设置响应头
.contentLength(resource.contentLength()) // 设置响应内容长度(可选)
.body(resource); // 将Resource对象作为响应的主体返回
} else {
return ResponseEntity.notFound().build(); // 如果文件不存在,返回404 Not Found状态码
}
}
}
...@@ -67,6 +67,8 @@ python: ...@@ -67,6 +67,8 @@ python:
modelTuningFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/ modelTuningFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
#模型训练文件合并后路径 #模型训练文件合并后路径
modelOutputFileBaseDir: /home/linking/llms/models/ modelOutputFileBaseDir: /home/linking/llms/models/
#模型评估文件基础路径
modelEstimateFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
#数据集配置信息 #数据集配置信息
datasetInfo: dataset_info.json datasetInfo: dataset_info.json
#数据集配置目录 #数据集配置目录
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment