Commit b52fd131 authored by linpeiqin's avatar linpeiqin

增加发布到输出的内容

parent f229451e
......@@ -22,6 +22,10 @@ public class PythonConfig {
* 模型训练基础目录
*/
private String modelTuningFileBaseDir;
/**
* 模型训练基础目录
*/
private String modelOutputFileBaseDir;
/**
* 数据集配置文件
*/
......@@ -34,5 +38,10 @@ public class PythonConfig {
* python平台通用接口地址
*/
private String factoryInterface;
/**
* python websocket地址
*/
private String pythonWebsocketUri;
}
......@@ -33,9 +33,7 @@ import org.springframework.web.bind.annotation.*;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.text.SimpleDateFormat;
import java.util.List;
import java.util.Optional;
......@@ -80,6 +78,9 @@ public class TuningRunController {
if (errorMessage != null) {
return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, errorMessage);
}
if (tuningRunDto.getTaskId() == null) {
return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, "请填写任务ID");
}
TuningRun tuningRun = MyModelUtil.copyTo(tuningRunDto, TuningRun.class);
TuningRun tuningRunFilter = new TuningRun();
tuningRunFilter.setTaskId(tuningRun.getTaskId());
......@@ -118,6 +119,12 @@ public class TuningRunController {
return ResponseResult.success();
}
/**
* 获取预览命令。
*
* @param runId 运行ID。
* @return 应答结果对象。
*/
@GetMapping("/getPreviewCommand")
public ResponseResult<String> getPreviewCommand(@RequestParam Long runId) {
TuningRun tuningRun = this.tuningRunService.getById(runId);
......@@ -128,7 +135,6 @@ public class TuningRunController {
JSONObject jsonObject = (JSONObject) JSON.parse(tuningRun.getConfiguration());
JSONArray datasetVersionNames = new JSONArray();
datasetVersionNames.add(datasetManage.getDatasetName() + "_V" + datasetVersion.getDatasetVersion());
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd-hh-mm-ss");
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelManage.getModelName());
......@@ -314,8 +320,7 @@ public class TuningRunController {
}
private ResponseEntity<Resource> getFileByUrl(String url) throws IOException {
Path file = Paths.get(url);
Resource resource = new UrlResource(file.toUri()); // 使用UrlResource从文件路径构建一个Resource对象
Resource resource = new UrlResource(Paths.get(url).toUri()); // 使用UrlResource从文件路径构建一个Resource对象
Optional<Resource> optionalResource = Optional.ofNullable(resource);
if (optionalResource.isPresent() && optionalResource.get().exists()) { // 检查文件是否存在
HttpHeaders headers = new HttpHeaders(); // 构建HTTP响应头
......@@ -351,10 +356,7 @@ public class TuningRunController {
private ResponseResult<Void> doDelete(Long runId) {
String errorMessage;
// 验证关联Id的数据合法性
TuningRun originalTuningRun = tuningRunService.getById(runId);
if (originalTuningRun == null) {
// NOTE: 修改下面方括号中的话述
if (tuningRunService.getById(runId) == null) {
errorMessage = "数据验证失败,当前 [对象] 并不存在,请刷新后重试!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
......
......@@ -188,8 +188,6 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple
if (modelVersion.getBusinessLabel() == null && modelManage.getBusinessLabel() != null) {
modelVersion.setBusinessLabel(modelManage.getBusinessLabel());
}
modelVersion.setModelVersion(1);
modelVersion.setIsCompress(0);
this.modelVersionService.saveNew(modelVersion);
return reModelManage;
}
......
......@@ -61,7 +61,19 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
@Transactional(rollbackFor = Exception.class)
@Override
public ModelVersion saveNew(ModelVersion modelVersion) {
String modelName = this.modelManageService.getById(modelVersion.getModelId()).getModelName();
ModelVersion modelVersionFilter = new ModelVersion();
modelVersionFilter.setModelId(modelVersion.getModelId());
List<ModelVersion> modelVersionList = this.getModelVersionList(modelVersionFilter,"model_Version");
Integer version = 1;
if (modelVersionList != null && modelVersionList.size() == 0) {
version = modelVersionList.get(modelVersionList.size() - 1).getModelVersion() + 1;
}
modelVersion.setModelVersion(version);
modelVersion.setIsCompress(0);
modelVersion.setVersionName(modelName + "_V" + modelVersion.getModelVersion());
modelVersionMapper.insert(this.buildDefaultValue(modelVersion));
//此处应该调用精调运行发布的方法生成模型任务,不应该直接生成!!!!!!!!!!!!!!!!!
ModelTask modelTask = new ModelTask();
modelTask.setModelVersion(modelVersion.getModelVersion());
modelTask.setModelId(modelVersion.getModelId());
......
package com.yice.webadmin.app.service.impl;
import cn.hutool.core.collection.CollUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.Page;
......@@ -10,6 +13,7 @@ import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.dao.TuningRunMapper;
import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.model.ModelManage;
......@@ -19,11 +23,18 @@ import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.service.TuningRunService;
import com.yice.webadmin.app.service.TuningTaskService;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ServerHandshake;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
/**
......@@ -47,6 +58,8 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Autowired
private IdGeneratorWrapper idGenerator;
@Autowired
private PythonConfig pythonConfig;
/**
* 返回当前Service的主表Mapper对象。
......@@ -159,28 +172,86 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
return resultList;
}
@SneakyThrows
@Transactional(rollbackFor = Exception.class)
@Override
public boolean publishToModelVersion(RunPublishDto runPublishDto) {
TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelManage modelManage = this.modelManageService.getById(modelVersion.getModelId());
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
@Override
public void onOpen(ServerHandshake serverHandshake) {
log.info("-------------与大模型建立连接-------------");
}
@Override
public void onMessage(String message) {
log.info("收到来自服务端的消息:" + message);
JSONObject receiveJson = (JSONObject) JSON.parse(message);
JSONObject jsonObject = (JSONObject) JSON.parse(tuningRun.getConfiguration());
String receiveMsg = receiveJson.getString("msg");
JSONObject sendJson = new JSONObject();
if (receiveMsg.equals("send_hash")) {
sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString());
log.info("发送服务端的消息:" + message);
System.out.println(sendJson.toJSONString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("send_data")) {
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelManage.getModelName());
array.add(modelVersion.getModelUrl());
JSONArray ddJson = new JSONArray();
ddJson.add("train_" + tuningRun.getRunId());
array.add(ddJson);
array.add(tuningRun.getTrainMethod());
array.add(jsonObject.get("promptTemplate"));
array.add(2);
//路径需要修改,暂时不能确定,后面一条线解决
String newModelUrl = pythonConfig.getModelOutputFileBaseDir() + (runPublishDto.getPublishWay() == 0 ? runPublishDto.getModelName() : modelManage.getModelName());
array.add(newModelUrl);
array.add("none");
sendJson.put("data",array);
sendJson.put("event_data","null");
sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString());
System.out.println(array.toJSONString());
log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("process_completed")) {
this.close();
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
}
}.connect();
if (runPublishDto.getPublishWay() == 0) {
ModelManage modelManage = new ModelManage();
modelManage.setModelDescribe(runPublishDto.getModelDescribe());
modelManage.setModelName(runPublishDto.getModelName());
modelManage.setModelType(runPublishDto.getModelType());
this.modelManageService.saveAndCreateVersion(modelManage, new ModelVersion());
ModelManage modelManageS = new ModelManage();
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
this.modelManageService.saveAndCreateVersion(modelManageS, new ModelVersion());
} else {
ModelVersion modelVersion = new ModelVersion();
modelVersion.setModelId(runPublishDto.getModelId());
List<ModelVersion> modelVersionList = this.modelVersionService.getModelVersionList(modelVersion, "model_version");
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setModelId(runPublishDto.getModelId());
List<ModelVersion> modelVersionList = this.modelVersionService.getModelVersionList(modelVersionS, "model_version");
int lastModelVersion = modelVersionList.get(modelVersionList.size() - 1).getModelVersion();
modelVersion.setModelId(runPublishDto.getModelId());
modelVersion.setModelVersion(lastModelVersion + 1);
ModelManage modelManage = this.modelManageService.getById(runPublishDto.getModelId());
modelManage.setModelDescribe(runPublishDto.getModelDescribe());
this.modelManageService.updateById(modelManage);
this.modelVersionService.saveNew(modelVersion);
modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionS.setModelVersion(lastModelVersion + 1);
ModelManage modelManageS = this.modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
this.modelManageService.updateById(modelManageS);
this.modelVersionService.saveNew(modelVersionS);
}
TuningRun tuningRun = this.getById(runPublishDto.getRunId());
tuningRun.setPublishStatus(1);
return this.updateById(tuningRun);
}
......
......@@ -65,12 +65,16 @@ python:
datasetFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/lmp_data/
#模型训练文件基础路径
modelTuningFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
#模型训练文件合并后路径
modelOutputFileBaseDir: /home/linking/llms/models/
#数据集配置信息
datasetInfo: dataset_info.json
#数据集配置目录
datasetFileMenu: lmp_data
#python平台通用接口地址
factoryInterface: http://192.168.0.36:7860/run/predict
#python websocket 服务地址
pythonWebsocketUri: ws://192.168.0.36:7860/queue/join
knowledge:
#知识库通用接口地址
......
......@@ -59,6 +59,15 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<dependency>
<groupId>org.java-websocket</groupId>
<artifactId>Java-WebSocket</artifactId>
<version>1.5.4</version>
</dependency>
<!-- freemarker 模板引擎模块 -->
<dependency>
<groupId>org.springframework.boot</groupId>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment