Commit bad23ce0 authored by linpeiqin's avatar linpeiqin

增加多数据集训练及评估,简化配置文件

parent 51746da1
...@@ -2,6 +2,7 @@ package com.yice.webadmin.app.controller; ...@@ -2,6 +2,7 @@ package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ArrayNode;
...@@ -34,9 +35,8 @@ import java.io.BufferedReader; ...@@ -34,9 +35,8 @@ import java.io.BufferedReader;
import java.io.File; import java.io.File;
import java.io.FileReader; import java.io.FileReader;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.*;
import java.util.Date; import java.util.stream.Collectors;
import java.util.List;
/** /**
* 模型评估操作控制器类。 * 模型评估操作控制器类。
...@@ -94,11 +94,19 @@ public class ModelEstimateController { ...@@ -94,11 +94,19 @@ public class ModelEstimateController {
public ResponseResult<JSONArray> getPreviewCommand(@RequestParam Long taskId) { public ResponseResult<JSONArray> getPreviewCommand(@RequestParam Long taskId) {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId); ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId()); ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(modelEstimate.getDatasetVersionId());
TuningRun tuningRun = this.tuningRunService.getById(modelVersion.getRunId()); TuningRun tuningRun = this.tuningRunService.getById(modelVersion.getRunId());
JSONObject jsonObject = (JSONObject) JSON.parse(modelEstimate.getConfiguration()); JSONObject jsonObject = (JSONObject) JSON.parse(modelEstimate.getConfiguration());
JSONArray datasetVersionNames = new JSONArray(); JSONArray datasetVersionNames = new JSONArray();
datasetVersionNames.add(datasetVersion.getVersionName()); Set<Long> datasetVersionIds = Arrays.stream(modelEstimate.getDatasetVersionIds().split(",")).mapToLong(Long::parseLong).boxed().collect(Collectors.toSet());
datasetVersionService.getInList(datasetVersionIds).stream()
.map(DatasetVersion::getVersionName)
.forEach(versionName -> {
try {
datasetVersionNames.add(versionName);
} catch (JSONException e) {
e.printStackTrace();
}
});
JSONArray array = new JSONArray(); JSONArray array = new JSONArray();
array.add("zh"); array.add("zh");
array.add(modelVersion.getVersionName()); array.add(modelVersion.getVersionName());
...@@ -195,19 +203,21 @@ public class ModelEstimateController { ...@@ -195,19 +203,21 @@ public class ModelEstimateController {
} }
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId); ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId()); ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(modelEstimate.getDatasetVersionId()); Set<Long> datasetVersionIds = Arrays.stream(modelEstimate.getDatasetVersionIds().split(",")).mapToLong(Long::parseLong).boxed().collect(Collectors.toSet());
List<DatasetVersion> datasetVersions = datasetVersionService.getInList(datasetVersionIds);
List<String> dataList = new ArrayList<>();
int i = 0;
String url = this.pythonConfig.getModelEstimateFileBaseDir() + modelVersion.getVersionName() + File.separator + "evl_" + taskId + File.separator + "generated_predictions.jsonl"; String url = this.pythonConfig.getModelEstimateFileBaseDir() + modelVersion.getVersionName() + File.separator + "evl_" + taskId + File.separator + "generated_predictions.jsonl";
//获取评估输出详情 //获取评估输出详情
List<JSONObject> jsonObjects = this.getFileJsonArray(url); List<JSONObject> jsonObjects = this.getFileJsonArray(url);
File file = new File(datasetVersion.getFileUrl()); // 指定文件路径 for (DatasetVersion datasetVersion : datasetVersions) {
ObjectMapper objectMapper = new ObjectMapper(); ArrayNode arrayNode = (ArrayNode) new ObjectMapper().readTree(new File(datasetVersion.getFileUrl())); // 读取JSON文件内容并转换为ArrayNode对象
ArrayNode arrayNode = (ArrayNode) objectMapper.readTree(file); // 读取JSON文件内容并转换为ArrayNode对象 for (JSONObject jsonNode : jsonObjects) { // 遍历JSON数组并取出每个元素(ObjectNode)中的数据
List<String> dataList = new ArrayList<>(); jsonNode.put("datasetName",datasetVersion.getVersionName());
int i = 0; jsonNode.put("input", arrayNode.get(i).get("instruction").textValue() + arrayNode.get(i).get("input").textValue());
for (JSONObject jsonNode : jsonObjects) { // 遍历JSON数组并取出每个元素(ObjectNode)中的数据 dataList.add(jsonNode.toJSONString());
jsonNode.put("input", arrayNode.get(i).get("instruction").textValue() + arrayNode.get(i).get("input").textValue()); i++;
dataList.add(jsonNode.toJSONString()); }
i++;
} }
int total = dataList.size(); // 获取总数据量 int total = dataList.size(); // 获取总数据量
int page = pageParam.getPageNum(); int page = pageParam.getPageNum();
......
...@@ -2,6 +2,7 @@ package com.yice.webadmin.app.controller; ...@@ -2,6 +2,7 @@ package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.pagehelper.page.PageMethod; import com.github.pagehelper.page.PageMethod;
...@@ -33,8 +34,11 @@ import org.springframework.web.bind.annotation.*; ...@@ -33,8 +34,11 @@ import org.springframework.web.bind.annotation.*;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/** /**
* 精调任务运行操作控制器类。 * 精调任务运行操作控制器类。
...@@ -145,10 +149,18 @@ public class TuningRunController { ...@@ -145,10 +149,18 @@ public class TuningRunController {
public ResponseResult<JSONArray> getPreviewCommand(@RequestParam Long runId) { public ResponseResult<JSONArray> getPreviewCommand(@RequestParam Long runId) {
TuningRun tuningRun = this.tuningRunService.getById(runId); TuningRun tuningRun = this.tuningRunService.getById(runId);
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId()); ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(tuningRun.getDatasetVersionId());
JSONObject jsonObject = (JSONObject) JSON.parse(tuningRun.getConfiguration()); JSONObject jsonObject = (JSONObject) JSON.parse(tuningRun.getConfiguration());
JSONArray datasetVersionNames = new JSONArray(); JSONArray datasetVersionNames = new JSONArray();
datasetVersionNames.add(datasetVersion.getVersionName()); Set<Long> datasetVersionIds = Arrays.stream(tuningRun.getDatasetVersionIds().split(",")).mapToLong(Long::parseLong).boxed().collect(Collectors.toSet());
datasetVersionService.getInList(datasetVersionIds).stream()
.map(DatasetVersion::getVersionName)
.forEach(versionName -> {
try {
datasetVersionNames.add(versionName);
} catch (JSONException e) {
e.printStackTrace();
}
});
JSONArray array = new JSONArray(); JSONArray array = new JSONArray();
array.add("zh"); array.add("zh");
array.add(modelVersion.getVersionName()); array.add(modelVersion.getVersionName());
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
<result column="task_name" jdbcType="VARCHAR" property="taskName"/> <result column="task_name" jdbcType="VARCHAR" property="taskName"/>
<result column="task_describe" jdbcType="VARCHAR" property="taskDescribe"/> <result column="task_describe" jdbcType="VARCHAR" property="taskDescribe"/>
<result column="model_version_id" jdbcType="BIGINT" property="modelVersionId"/> <result column="model_version_id" jdbcType="BIGINT" property="modelVersionId"/>
<result column="dataset_version_id" jdbcType="BIGINT" property="datasetVersionId"/> <result column="dataset_version_ids" jdbcType="VARCHAR" property="datasetVersionIds"/>
<result column="scoring_mode" jdbcType="TINYINT" property="scoringMode"/> <result column="scoring_mode" jdbcType="TINYINT" property="scoringMode"/>
<result column="configuration" jdbcType="LONGVARCHAR" property="configuration"/> <result column="configuration" jdbcType="LONGVARCHAR" property="configuration"/>
<result column="task_status" jdbcType="TINYINT" property="taskStatus"/> <result column="task_status" jdbcType="TINYINT" property="taskStatus"/>
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
task_name, task_name,
task_describe, task_describe,
model_version_id, model_version_id,
dataset_version_id, dataset_version_ids,
scoring_mode, scoring_mode,
configuration, configuration,
task_status) task_status)
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#{item.taskName}, #{item.taskName},
#{item.taskDescribe}, #{item.taskDescribe},
#{item.modelVersionId}, #{item.modelVersionId},
#{item.datasetVersionId}, #{item.datasetVersionIds},
#{item.scoringMode}, #{item.scoringMode},
#{item.configuration}, #{item.configuration},
#{item.taskStatus}) #{item.taskStatus})
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
<result column="model_id" jdbcType="BIGINT" property="modelId"/> <result column="model_id" jdbcType="BIGINT" property="modelId"/>
<result column="model_version_id" jdbcType="BIGINT" property="modelVersionId"/> <result column="model_version_id" jdbcType="BIGINT" property="modelVersionId"/>
<result column="run_time" jdbcType="BIGINT" property="runTime"/> <result column="run_time" jdbcType="BIGINT" property="runTime"/>
<result column="dataset_version_id" jdbcType="BIGINT" property="datasetVersionId"/> <result column="dataset_version_ids" jdbcType="VARCHAR" property="datasetVersionIds"/>
<result column="split_ratio" jdbcType="TINYINT" property="splitRatio"/> <result column="split_ratio" jdbcType="TINYINT" property="splitRatio"/>
<result column="train_mode" jdbcType="VARCHAR" property="trainMode"/> <result column="train_mode" jdbcType="VARCHAR" property="trainMode"/>
<result column="train_method" jdbcType="VARCHAR" property="trainMethod"/> <result column="train_method" jdbcType="VARCHAR" property="trainMethod"/>
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
model_id, model_id,
model_version_id, model_version_id,
run_time, run_time,
dataset_version_id, dataset_version_ids,
split_ratio, split_ratio,
train_mode, train_mode,
train_method, train_method,
...@@ -52,7 +52,7 @@ ...@@ -52,7 +52,7 @@
#{item.modelId}, #{item.modelId},
#{item.modelVersionId}, #{item.modelVersionId},
#{item.runTime}, #{item.runTime},
#{item.datasetVersionId}, #{item.datasetVersionIds},
#{item.splitRatio}, #{item.splitRatio},
#{item.trainMode}, #{item.trainMode},
#{item.trainMethod}, #{item.trainMethod},
......
...@@ -43,10 +43,10 @@ public class ModelEstimateDto { ...@@ -43,10 +43,10 @@ public class ModelEstimateDto {
private Long modelVersionId; private Long modelVersionId;
/** /**
* 评估数据集版本ID。 * 评估数据集版本IDs
*/ */
@ApiModelProperty(value = "评估数据集版本ID") @ApiModelProperty(value = "评估数据集版本IDs")
private Long datasetVersionId; private String datasetVersionIds;
/** /**
* 打分模式。 * 打分模式。
......
...@@ -67,10 +67,10 @@ public class TuningRunDto { ...@@ -67,10 +67,10 @@ public class TuningRunDto {
private Long runTime; private Long runTime;
/** /**
* 数据集版本ID。 * 数据集版本IDs
*/ */
@ApiModelProperty(value = "数据集版本ID") @ApiModelProperty(value = "数据集版本IDs")
private Long datasetVersionId; private String datasetVersionIds;
/** /**
* 拆分比例。 * 拆分比例。
......
...@@ -61,20 +61,9 @@ public class ModelEstimate extends BaseModel { ...@@ -61,20 +61,9 @@ public class ModelEstimate extends BaseModel {
private Map<String, Object> modelVersionIdDictMap; private Map<String, Object> modelVersionIdDictMap;
/** /**
* 评估数据集版本ID。 * 评估数据集版本IDs
*/ */
private Long datasetVersionId; private String datasetVersionIds;
/**
* 评估数据集版本字典。
*/
@RelationDict(
masterIdField = "datasetVersionId",
slaveModelClass = DatasetVersion.class,
slaveIdField = "versionId",
slaveNameField = "versionName")
@TableField(exist = false)
private Map<String, Object> datasetVersionIdDictMap;
/** /**
* 打分模式。 * 打分模式。
......
...@@ -68,9 +68,9 @@ public class TuningRun extends BaseModel { ...@@ -68,9 +68,9 @@ public class TuningRun extends BaseModel {
private Long runTime; private Long runTime;
/** /**
* 数据集版本ID。 * 数据集版本IDs
*/ */
private Long datasetVersionId; private String datasetVersionIds;
/** /**
* 拆分比例。 * 拆分比例。
......
...@@ -47,7 +47,7 @@ public class ModelEstimateVo extends BaseVo { ...@@ -47,7 +47,7 @@ public class ModelEstimateVo extends BaseVo {
* 评估数据集版本ID。 * 评估数据集版本ID。
*/ */
@ApiModelProperty(value = "评估数据集版本ID") @ApiModelProperty(value = "评估数据集版本ID")
private Long datasetVersionId; private String datasetVersionIds;
/** /**
* 打分模式。 * 打分模式。
*/ */
......
...@@ -68,10 +68,10 @@ public class TuningRunVo extends BaseVo { ...@@ -68,10 +68,10 @@ public class TuningRunVo extends BaseVo {
private Long runTime; private Long runTime;
/** /**
* 数据集版本ID。 * 数据集版本IDs
*/ */
@ApiModelProperty(value = "数据集版本ID") @ApiModelProperty(value = "数据集版本IDs")
private Long datasetVersionId; private String datasetVersionIds;
/** /**
* 拆分比例。 * 拆分比例。
......
...@@ -51,7 +51,7 @@ application: ...@@ -51,7 +51,7 @@ application:
# 初始化密码。 # 初始化密码。
defaultUserPassword: 123456 defaultUserPassword: 123456
# 缺省的文件上传根目录。 # 缺省的文件上传根目录。
uploadFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/lmp_data/ uploadFileBaseDir: ${python.baseDir}/code/lmp_training_api/lmp_data/
# 跨域的IP(http://192.168.10.10:8086)白名单列表,多个IP之间逗号分隔(* 表示全部信任,空白表示禁用跨域信任)。 # 跨域的IP(http://192.168.10.10:8086)白名单列表,多个IP之间逗号分隔(* 表示全部信任,空白表示禁用跨域信任)。
credentialIpList: "*" credentialIpList: "*"
...@@ -61,29 +61,35 @@ application: ...@@ -61,29 +61,35 @@ application:
excludeLogin: false excludeLogin: false
python: python:
#基础路径
baseDir: z:/home/linking/llms
#训练服务地址
trainAddress: 192.168.0.36:7860
#推理服务地址
predictAddress: 192.168.0.36:7861
#数据集文件基础路径 #数据集文件基础路径
datasetFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/lmp_data/ datasetFileBaseDir: ${python.baseDir}/code/lmp_training_api/lmp_data/
#模型训练文件基础路径 #模型训练文件基础路径
modelTuningFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/ modelTuningFileBaseDir: ${python.baseDir}/code/lmp_training_api/saves/
#模型训练文件合并后路径 #模型训练文件合并后路径
modelOutputFileBaseDir: /home/linking/llms/models/ modelOutputFileBaseDir: ${python.baseDir}/models/
#模型评估文件基础路径 #模型评估文件基础路径
modelEstimateFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/ modelEstimateFileBaseDir: ${python.baseDir}/code/lmp_training_api/saves/
#数据集配置信息 #数据集配置信息
datasetInfo: dataset_info.json datasetInfo: dataset_info.json
#数据集配置目录 #数据集配置目录
datasetFileMenu: lmp_data datasetFileMenu: lmp_data
#python平台通用接口地址 #python平台通用接口地址
factoryInterface: http://192.168.0.36:7860/run/predict factoryInterface: http://${python.trainAddress}/run/predict
#python websocket 服务地址 #python websocket 服务地址
pythonWebsocketUri: ws://192.168.0.36:7860/queue/join pythonWebsocketUri: ws://${python.trainAddress}/queue/join
#输出控制地址 #输出控制地址
controllerAddress: http://127.0.0.1:20001 controllerAddress: http://127.0.0.1:20001
#对话基础路径 #对话基础路径
chatAddress: http://192.168.0.36:8000/ chatAddress: http://192.168.0.36:8000/
llm-model: llm-model:
#模型管理基础路径 #模型管理基础路径
llmModelInterface: http://192.168.0.36:7861/llm_model/ llmModelInterface: http://${python.predictAddress}/llm_model/
#模型停止 #模型停止
stop: stop stop: stop
#模型模型部署 #模型模型部署
...@@ -94,12 +100,12 @@ llm-model: ...@@ -94,12 +100,12 @@ llm-model:
listRunningModels: list_running_models listRunningModels: list_running_models
other: other:
#其他管理接口 #其他管理接口
otherInterface: http://192.168.0.36:7861/other/ otherInterface: http://${python.predictAddress}/other/
#获取gpu信息 #获取gpu信息
getGpuInfo: get_gpu_info getGpuInfo: get_gpu_info
knowledge: knowledge:
#知识库通用接口地址 #知识库通用接口地址
knowledgeInterface: http://192.168.0.36:7861/knowledge_base/ knowledgeInterface: http://${python.predictAddress}/knowledge_base/
#创建知识库 #创建知识库
create: create_knowledge_base create: create_knowledge_base
#获取知识库列表 #获取知识库列表
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment