Commit 830cbcf6 authored by linpeiqin's avatar linpeiqin

修改一些小BUG

parent b834a96a
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
<result column="is_compress" jdbcType="TINYINT" property="isCompress"/> <result column="is_compress" jdbcType="TINYINT" property="isCompress"/>
<result column="task_id" jdbcType="BIGINT" property="taskId"/> <result column="task_id" jdbcType="BIGINT" property="taskId"/>
<result column="run_id" jdbcType="BIGINT" property="runId"/> <result column="run_id" jdbcType="BIGINT" property="runId"/>
<result column="base_prompt_template" jdbcType="VARCHAR" property="basePromptTemplate"/>
</resultMap> </resultMap>
<insert id="insertList"> <insert id="insertList">
...@@ -55,7 +56,8 @@ ...@@ -55,7 +56,8 @@
model_url, model_url,
is_compress, is_compress,
task_id, task_id,
run_id run_id,
base_prompt_template
) )
VALUES VALUES
<foreach collection="list" index="index" item="item" separator=","> <foreach collection="list" index="index" item="item" separator=",">
...@@ -83,7 +85,9 @@ ...@@ -83,7 +85,9 @@
#{item.modelUrl}, #{item.modelUrl},
#{item.isCompress}, #{item.isCompress},
#{item.taskId}, #{item.taskId},
#{item.runId}) #{item.runId},
#{item.basePromptTemplate}
)
</foreach> </foreach>
</insert> </insert>
......
...@@ -144,4 +144,10 @@ public class ModelVersionDto { ...@@ -144,4 +144,10 @@ public class ModelVersionDto {
*/ */
@ApiModelProperty(value = "训练任务ID") @ApiModelProperty(value = "训练任务ID")
private Long taskId; private Long taskId;
/**
* 默认的提示词模板。
*/
@ApiModelProperty(value = "默认的提示词模板")
private String basePromptTemplate;
} }
...@@ -129,6 +129,12 @@ public class ModelVersion extends BaseModel { ...@@ -129,6 +129,12 @@ public class ModelVersion extends BaseModel {
*/ */
private Integer isCompress; private Integer isCompress;
/**
* 默认的提示词模板。
*/
private String basePromptTemplate;
@RelationOneToOne( @RelationOneToOne(
masterIdField = "modelId", masterIdField = "modelId",
slaveModelClass = ModelManage.class, slaveModelClass = ModelManage.class,
......
...@@ -60,6 +60,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -60,6 +60,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
private IdGeneratorWrapper idGenerator; private IdGeneratorWrapper idGenerator;
@Autowired @Autowired
private PythonConfig pythonConfig; private PythonConfig pythonConfig;
private String modelName;
/** /**
* 返回当前Service的主表Mapper对象。 * 返回当前Service的主表Mapper对象。
...@@ -179,6 +180,23 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -179,6 +180,23 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
TuningRun tuningRun = this.getById(runPublishDto.getRunId()); TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId()); ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelManage modelManage = this.modelManageService.getById(modelVersion.getModelId()); ModelManage modelManage = this.modelManageService.getById(modelVersion.getModelId());
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
if (runPublishDto.getPublishWay() == 0) {
ModelManage modelManageS = new ModelManage();
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
this.modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
modelName = runPublishDto.getModelName() + "_V1";
} else {
modelVersionS.setModelId(runPublishDto.getModelId());
ModelManage modelManageS = this.modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
this.modelManageService.updateById(modelManageS);
modelName = this.modelVersionService.saveNew(modelVersionS).getVersionName();
}
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) { new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
@Override @Override
public void onOpen(ServerHandshake serverHandshake) { public void onOpen(ServerHandshake serverHandshake) {
...@@ -210,7 +228,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -210,7 +228,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(jsonObject.get("promptTemplate")); array.add(jsonObject.get("promptTemplate"));
array.add(2); array.add(2);
//路径需要修改,暂时不能确定,后面一条线解决 //路径需要修改,暂时不能确定,后面一条线解决
String newModelUrl = pythonConfig.getModelOutputFileBaseDir() + (runPublishDto.getPublishWay() == 0 ? runPublishDto.getModelName() : modelManage.getModelName()); String newModelUrl = pythonConfig.getModelOutputFileBaseDir() + modelName;
array.add(newModelUrl); array.add(newModelUrl);
array.add("none"); array.add("none");
sendJson.put("data",array); sendJson.put("data",array);
...@@ -234,24 +252,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -234,24 +252,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
log.error("报错了:::" + e.getMessage()); log.error("报错了:::" + e.getMessage());
} }
}.connect(); }.connect();
if (runPublishDto.getPublishWay() == 0) {
ModelManage modelManageS = new ModelManage();
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
this.modelManageService.saveAndCreateVersion(modelManageS, new ModelVersion());
} else {
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setModelId(runPublishDto.getModelId());
List<ModelVersion> modelVersionList = this.modelVersionService.getModelVersionList(modelVersionS, "model_version");
int lastModelVersion = modelVersionList.get(modelVersionList.size() - 1).getModelVersion();
modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionS.setModelVersion(lastModelVersion + 1);
ModelManage modelManageS = this.modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
this.modelManageService.updateById(modelManageS);
this.modelVersionService.saveNew(modelVersionS);
}
tuningRun.setPublishStatus(1); tuningRun.setPublishStatus(1);
return this.updateById(tuningRun); return this.updateById(tuningRun);
} }
......
...@@ -151,4 +151,10 @@ public class ModelVersionVo extends BaseVo { ...@@ -151,4 +151,10 @@ public class ModelVersionVo extends BaseVo {
*/ */
@ApiModelProperty(value = "训练任务ID") @ApiModelProperty(value = "训练任务ID")
private Long taskId; private Long taskId;
/**
* 默认的提示词模板。
*/
@ApiModelProperty(value = "默认的提示词模板")
private String basePromptTemplate;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment