Commit 7297ac91 authored by linpeiqin's avatar linpeiqin

模型压缩逻辑初步提交

parent f477a48f
......@@ -5,6 +5,7 @@ import com.yice.common.core.base.service.IBaseService;
import com.yice.webadmin.app.model.ModelCompress;
import com.yice.webadmin.app.model.ModelTask;
import java.net.URISyntaxException;
import java.util.List;
/**
......@@ -37,7 +38,7 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long>
* @param relationData 全部关联从表数据。
* @return 返回新增主表对象。
*/
ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData);
ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData) throws URISyntaxException;
/**
* 更新数据对象。
......@@ -87,4 +88,5 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long>
* @return 查询结果集。
*/
List<ModelCompress> getModelCompressListWithRelation(ModelCompress filter, ModelTask modelTaskFilter, String orderBy);
}
package com.yice.webadmin.app.service.impl;
import cn.hutool.core.collection.CollUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.Page;
......@@ -9,20 +11,23 @@ import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.dao.ModelCompressMapper;
import com.yice.webadmin.app.model.ModelCompress;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.ModelCompressService;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ServerHandshake;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
/**
......@@ -45,6 +50,10 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
private ModelVersionService modelVersionService;
@Autowired
private ModelManageService modelManageService;
@Autowired
private PythonConfig pythonConfig;
private String modelVersionURl;
/**
* 返回当前Service的主表Mapper对象。
......@@ -83,29 +92,89 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
}
}
@Transactional(rollbackFor = Exception.class)
@Override
public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData) {
public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData) throws URISyntaxException {
this.saveNew(modelCompress);
this.saveOrUpdateOneToOneRelationData(modelCompress, relationData);
ModelVersion modelVersion = this.modelVersionService.getById(modelCompress.getSourceVersionId());
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setTaskId(modelCompress.getTaskId());
modelVersionS.setIsCompress(1);
if (modelCompress.getCreateMethod() == 0) {
ModelVersion modelVersion = new ModelVersion();
modelVersion.setModelId(modelCompress.getTargetModelId());
List<ModelVersion> modelVersionList = this.modelVersionService.getModelVersionList(modelVersion, "model_version");
int lastModelVersion = modelVersionList.get(modelVersionList.size() - 1).getModelVersion();
modelVersion.setModelId(modelCompress.getTargetModelId());
modelVersion.setIsCompress(1);
modelVersion.setModelVersion(lastModelVersion + 1);
this.modelVersionService.saveNew(modelVersion);
modelVersionS.setModelId(modelCompress.getTargetModelId());
ModelManage modelManageS = this.modelManageService.getById(modelCompress.getTargetModelId());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
this.modelManageService.updateById(modelManageS);
ModelVersion modelVersionR = this.modelVersionService.saveNew(modelVersionS);
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelVersionR.getVersionName();
modelVersionR.setModelUrl(modelVersionURl);
this.modelVersionService.updateById(modelVersionR);
} else {
ModelManage modelManage = new ModelManage();
ModelVersion modelVersion = new ModelVersion();
modelManage.setModelName(modelCompress.getTargetModelName());
modelManage.setModelType(0);
modelVersion.setIsCompress(1);
modelManage.setModelDescribe("模型压缩");
this.modelManageService.saveAndCreateVersion(modelManage, modelVersion);
ModelManage modelManageS = new ModelManage();
modelManageS.setModelName(modelCompress.getTargetModelName());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelCompress.getTargetModelName() + "_V1";
modelVersionS.setModelUrl(modelVersionURl);
this.modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
}
this.saveNew(modelCompress);
this.saveOrUpdateOneToOneRelationData(modelCompress, relationData);
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
@Override
public void onOpen(ServerHandshake serverHandshake) {
log.info("-------------与大模型建立连接-------------");
}
@Override
public void onMessage(String message) {
log.info("收到来自服务端的消息:" + message);
JSONObject receiveJson = (JSONObject) JSON.parse(message);
String receiveMsg = receiveJson.getString("msg");
JSONObject sendJson = new JSONObject();
if (receiveMsg.equals("send_hash")) {
sendJson.put("fn_index", 44);
sendJson.put("session_hash", modelCompress.getTaskId().toString());
log.info("发送服务端的消息:" + message);
System.out.println(sendJson.toJSONString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("send_data")) {
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelVersion.getVersionName());
array.add(modelVersion.getModelUrl());
array.add(new JSONArray());
array.add("");
array.add("");
array.add(2);
array.add(modelVersionURl);
array.add("8");
sendJson.put("data", array);
sendJson.put("event_data", "null");
sendJson.put("fn_index", 44);
sendJson.put("session_hash", modelCompress.getTaskId().toString());
System.out.println(array.toJSONString());
log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("process_completed")) {
this.close();
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
}
}.connect();
return modelCompress;
}
......@@ -198,6 +267,7 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
return resultList;
}
private ModelCompress buildDefaultValue(ModelCompress modelCompress) {
if (modelCompress.getTaskId() == null) {
modelCompress.setTaskId(idGenerator.nextLongId());
......
......@@ -189,7 +189,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
public boolean publishToModelVersion(RunPublishDto runPublishDto) {
TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelManage modelManage = this.modelManageService.getById(modelVersion.getModelId());
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
......@@ -243,8 +242,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(tuningRun.getTrainMethod());
array.add(jsonObject.get("promptTemplate"));
array.add(2);
String newModelUrl = modelVersionURl;
array.add(newModelUrl);
array.add(modelVersionURl);
array.add("none");
sendJson.put("data", array);
sendJson.put("event_data", "null");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment