Commit 35842abb authored by linpeiqin's avatar linpeiqin

组装参数

parent 30ded748
package com.yice.webadmin.app.controller; package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.page.PageMethod; import com.github.pagehelper.page.PageMethod;
import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport; import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport;
import com.yice.common.core.annotation.MyRequestBody; import com.yice.common.core.annotation.MyRequestBody;
...@@ -12,16 +15,16 @@ import com.yice.common.log.annotation.OperationLog; ...@@ -12,16 +15,16 @@ import com.yice.common.log.annotation.OperationLog;
import com.yice.common.log.model.constant.SysOperationLogType; import com.yice.common.log.model.constant.SysOperationLogType;
import com.yice.webadmin.app.dto.RunPublishDto; import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.dto.TuningRunDto; import com.yice.webadmin.app.dto.TuningRunDto;
import com.yice.webadmin.app.model.TuningRun; import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.model.TuningTask; import com.yice.webadmin.app.service.*;
import com.yice.webadmin.app.service.TuningRunService;
import com.yice.webadmin.app.service.TuningTaskService;
import com.yice.webadmin.app.vo.TuningRunVo; import com.yice.webadmin.app.vo.TuningRunVo;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.List; import java.util.List;
/** /**
...@@ -40,6 +43,14 @@ public class TuningRunController { ...@@ -40,6 +43,14 @@ public class TuningRunController {
private TuningRunService tuningRunService; private TuningRunService tuningRunService;
@Autowired @Autowired
private TuningTaskService tuningTaskService; private TuningTaskService tuningTaskService;
@Autowired
private ModelVersionService modelVersionService;
@Autowired
private ModelManageService modelManageService;
@Autowired
private DatasetVersionService datasetVersionService;
@Autowired
private DatasetManageService datasetManageService;
/** /**
* 新增精调任务运行数据。 * 新增精调任务运行数据。
...@@ -93,6 +104,63 @@ public class TuningRunController { ...@@ -93,6 +104,63 @@ public class TuningRunController {
return ResponseResult.success(); return ResponseResult.success();
} }
@GetMapping("/getPreviewCommand")
public ResponseResult<String> getPreviewCommand(@RequestParam Long runId) {
TuningRun tuningRun = this.tuningRunService.getById(runId);
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(tuningRun.getDatasetVersionId());
DatasetManage datasetManage = this.datasetManageService.getById(datasetVersion.getDatasetId());
ModelManage modelManage = this.modelManageService.getById(modelVersion.getModelId());
JSONObject jsonObject = (JSONObject) JSON.parse(tuningRun.getConfiguration());
JSONArray datasetVersionNames = new JSONArray();
datasetVersionNames.add(datasetManage.getDatasetName() + "_V" + datasetVersion.getDatasetVersion());
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd-hh-mm-ss");
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelManage.getModelName());
array.add(modelVersion.getModelUrl());
array.add(tuningRun.getTrainMethod());
array.add(new JSONArray());
array.add(jsonObject.get("quantizationLevel"));
array.add(jsonObject.get("promptTemplate"));
array.add("");
array.add(false);
array.add(false);
array.add("none");
array.add("Supervised Fine-Tuning");
array.add("lmp_data");
array.add(datasetVersionNames);
array.add(jsonObject.get("truncationLength"));
array.add(jsonObject.get("learningRate"));
array.add(jsonObject.get("iterationRound"));
array.add(jsonObject.get("maximumSampleSize"));
array.add(jsonObject.get("CalculationType"));
array.add(jsonObject.get("batchSize"));
array.add(jsonObject.get("gradientAccumulation"));
array.add(jsonObject.get("LearningRateRegulator"));
array.add(jsonObject.get("maximumGradientNorm"));
array.add(jsonObject.get("verificationRatio"));
array.add(jsonObject.get("LogInterval"));
array.add(jsonObject.get("saveInterval"));
array.add(jsonObject.get("warmUpSteps"));
array.add(jsonObject.get("nEFTuneNoiseParameter"));
array.add(false);
array.add(false);
array.add(jsonObject.get("loRaRank"));
array.add(jsonObject.get("loRaRandomDiscard"));
array.add("");
array.add("");
array.add(true);
array.add(jsonObject.get("dpoBetaData"));
array.add(new JSONArray());
array.add(sdf.format(new Date()));
JSONObject jsonObject1 = new JSONObject();
jsonObject1.put("data", array);
System.out.println(jsonObject1.toJSONString());
// String test2 = "[\"zh\",\"ChatGLM3-6B-Chat\",\"/home/linking/llms/models/chatglm3-6b\",\"lora\",[\"2023-11-30-17-10-39\"],\"none\",\"chatglm3\",\"11111\",false,false,\"none\",\"Supervised Fine-Tuning\",\"data\",[\"alpaca_zh\"],1024,\"5e-5\",\"3.0\",\"100000\",\"fp16\",4,4,\"cosine\",\"1.0\",0,5,100,0,0,false,false,8,0.1,\"\",\"\",true,0.1,[],\"2023-11-30-18-46-42\"]";
return ResponseResult.success(array.toJSONString());
}
/** /**
* 发布运行任务。 * 发布运行任务。
* *
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment