转产品图模型换为nanobanana

This commit is contained in:
litianxiang
2025-10-06 22:00:34 +08:00
parent 6772610129
commit 24d54ea9ea
3 changed files with 609 additions and 249 deletions

View File

@@ -98,4 +98,6 @@ public interface GenerateService extends IService<Generate> {
byte[] downloadVideoOrImage(String url);
String createGoogleAsyncTask(GenerateThroughImageTextDTO generateDTO, String useModel, String prompt);
String toProductAsyncTask(String imagePath, String useModel, String prompt);
}

View File

@@ -614,6 +614,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
}
String modelName = generateDTO.getModelName();
if (StringUtil.isNullOrEmpty(modelName)){
return handleStandardGeneration(generateDTO);
}
HashMap<String, String> modelAndPromptMap = chooseModelAndPrompt(generateDTO, modelName);
String useModel = modelAndPromptMap.get(ModelConstants.USE_MODEL);
@@ -666,8 +669,162 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// processCreditDeduction(generateDTO.getUserId(), taskId, CreditsEventsEnum.WX_TEXT2IMG);
return new PrepareForGenerateVO(Collections.singletonList(taskId), 200);
}
public String toProductAsyncTask(String imagePath, String useModel, String prompt) {
if (StringUtil.isNullOrEmpty(imagePath)||StringUtil.isNullOrEmpty(useModel)||StringUtil.isNullOrEmpty(prompt)){
throw new BusinessException("Parameter Exception");
}
AuthPrincipalVo userHolder = UserContext.getUserHolder();
Long userId = userHolder.getId();
public String createGoogleAsyncTask(GenerateThroughImageTextDTO generateDTO, String useModel, String prompt) {
String uuid = UUID.randomUUID().toString();
// 生成唯一的任务ID
String taskId = uuid + "-" + userId;
String finalImagePath = null;
try {
finalImagePath = addWhiteBackground(imagePath);
//去掉"data:image/png;base64,"
finalImagePath = finalImagePath.replace("data:image/png;base64,", "");
// 如果白色背景处理失败或不需要直接获取原图的base64编码
if (StringUtil.isNullOrEmpty(finalImagePath)) {
finalImagePath = minioUtil.getImageAsBase64(imagePath);
}
} catch (IOException e) {
log.error("Error getting image as base64 taskId: {} ", taskId, e);
throw new BusinessException("Parameter Exception");
}
// 初始化Redis中的任务状态为"Executing"
String key = generateResultKey + ":" + taskId;
// 异步处理token获取和API调用
String projectId = "aida-461108";
String location = "global";
String endpoint = String.format(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent",
projectId, location, ModelConstants.NANO_BANANA
);
JSONObject requestBody = new JSONObject();
// 使用 gemini-2.5-flash-image-preview 模型时的请求体格式
// 创建图片部分
JSONObject imagePart = new JSONObject();
JSONObject inlineData = new JSONObject();
inlineData.set("mimeType", "image/png");
inlineData.set("data", finalImagePath);
imagePart.set("inlineData", inlineData);
// 创建文本部分
JSONObject textPart = new JSONObject();
textPart.set("text", prompt);
// 创建内容对象
JSONObject content = new JSONObject();
content.set("role", "user");
content.set("parts", Arrays.asList(imagePart, textPart));
// 设置 contents 数组
requestBody.set("contents", Arrays.asList(content));
// 设置 generationConfig
JSONObject generationConfig = new JSONObject();
// generationConfig.set("temperature", 1);
generationConfig.set("maxOutputTokens", 8192);
generationConfig.set("responseModalities", Arrays.asList("TEXT", "IMAGE"));
// generationConfig.set("topP", 0.95);
JSONObject imageConfig = new JSONObject();
imageConfig.set("aspectRatio", "9:16");
generationConfig.set("imageConfig", imageConfig);
requestBody.set("generationConfig", generationConfig);
String jsonBody = requestBody.toString();
log.info("Google 请求入参:{}", jsonBody);
GenerateResultVO resultVO = new GenerateResultVO(taskId, null, null, "Pending");
redisUtil.addToString(key, new Gson().toJson(resultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
CompletableFuture.runAsync(() -> {
try {
// 异步获取token
String tokenValue = null;
try (InputStream inputStream = GenerateServiceImpl.class.getClassLoader()
.getResourceAsStream("aida-461108-b4afaabebb84.json")) {
GoogleCredentials credentials = GoogleCredentials
.fromStream(inputStream)
.createScoped(Collections.singletonList("https://www.googleapis.com/auth/cloud-platform"));
credentials.refreshIfExpired();
tokenValue = credentials.getAccessToken().getTokenValue();
}
if (tokenValue == null) {
throw new RuntimeException("google token error");
}
// 异步发送API请求
OkHttpClient client = new OkHttpClient.Builder()
.connectTimeout(30, TimeUnit.SECONDS) // 连接超时时间
.readTimeout(60, TimeUnit.SECONDS) // 读取超时时间
.writeTimeout(60, TimeUnit.SECONDS) // 写入超时时间
.build();
Request request = new Request.Builder()
.url(endpoint)
.addHeader("Authorization", "Bearer " + tokenValue)
.addHeader("Content-Type", "application/json")
.post(RequestBody.create(MediaType.parse("application/json"), jsonBody))
.build();
try (Response response = client.newCall(request).execute()) {
String result = response.body().string();
log.info("Google 响应结果:{}", result);
com.alibaba.fastjson.JSONObject jsonResponse = JSON.parseObject(result);
String base64Data = null;
//根据模型类别按照api取出结果
if (ModelConstants.NANO_BANANA.equals(useModel)) {
JSONArray candidates = jsonResponse.getJSONArray("candidates");
if (candidates != null && !candidates.isEmpty()) {
com.alibaba.fastjson.JSONObject candidate = candidates.getJSONObject(0);
com.alibaba.fastjson.JSONObject contentResult = candidate.getJSONObject("content");
JSONArray parts = contentResult.getJSONArray("parts");
// 遍历parts数组找到包含inlineData的对象
for (int i = 0; i < parts.size(); i++) {
com.alibaba.fastjson.JSONObject part = parts.getJSONObject(i);
if (part.containsKey("inlineData")) {
com.alibaba.fastjson.JSONObject inlineDataResult= part.getJSONObject("inlineData");
base64Data = inlineDataResult.getString("data");
break;
}
}
}
}
if (base64Data != null && !base64Data.isEmpty()) {
String resultPath = userId + "/product_image" + "/" + uuid;
String minioPath = minioUtil.base64UploadToPath("data:image/png;base64," + base64Data, userBucket, resultPath);
// 生成成功更新Redis状态和URL
GenerateResultVO successResultVO = new GenerateResultVO(taskId, null, minioPath, "Success");
redisUtil.addToString(key, new Gson().toJson(successResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
} else {
// 没有找到图像数据或数据为空,标记为失败
log.warn("Google generation response does not contain valid image data for taskId: {}", taskId);
GenerateResultVO failResultVO = new GenerateResultVO(taskId, null, null, "Fail");
redisUtil.addToString(key, new Gson().toJson(failResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
}
}
} catch (Exception e) {
log.error("Google generation failed for taskId: {}", taskId, e);
// 生成失败更新Redis状态
GenerateResultVO failResultVO = new GenerateResultVO(taskId, null, null, "Fail");
redisUtil.addToString(key, new Gson().toJson(failResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
}
}, asyncTaskExecutor);
return taskId;
}
public String createGoogleAsyncTask(GenerateThroughImageTextDTO generateDTO, String useModel, String prompt) {
// 从 resources 加载 JSON 文件
System.setProperty("https.proxyHost", "127.0.0.1");
System.setProperty("https.proxyPort", "10809");
@@ -975,7 +1132,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
&& StringUtil.isNullOrEmpty(generateDTO.getLevel2Type())) {
throw new BusinessException("level2Type.cannot.be.empty");
}
if (generateDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())) {
if (generateDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())&&generateDTO.getLevel2Type().equals(CreditsEventsEnum.PATTERN.getName())) {
int firstCommaIndex = generateDTO.getText().indexOf(",");
String style = generateDTO.getText().substring(0, firstCommaIndex).trim();
//如果style不等于painting styleillustration stylereal style中的一种
@@ -1057,7 +1214,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
throw new RuntimeException(e.getMessage());
}
String taskId = result.getOutput().getTaskId();
log.info("wx text2image 请求生成:{}, taskId{}", JsonUtils.toJson(result), taskId);
log.info("qwen text2image 请求生成:{}, taskId{}", JsonUtils.toJson(result), taskId);
Generate generate = new Generate(userId, taskId, level1Type, level2Type, prompt, "text(" + gender + ")", "qwen-image", new Date());
save(generate);
@@ -3482,59 +3639,85 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
/**
* 接入flux模型用于imageToSketch(sketch extract) || relighting || to product image
*
* @param func 功能枚举名
* @param prompt 用户输入
* @param imagePath 图片minio路径
* @param func 功能枚举名指定使用flux模型的具体功能类型
* @param prompt 用户输入的提示词,如果为空则使用默认提示词
* @param imagePath 图片minio路径作为输入图像的base64编码源
* @param childStyle 是否为儿童风格,影响提示词的构建
* @return 返回taskId用于异步获取结果
*/
public String flux(CreditsEventsEnum func, String prompt, String imagePath, boolean childStyle) {
// Flux API的请求地址
String fluxRequestUrl = "https://api.bfl.ai/v1/flux-kontext-pro";
// 如果用户没有提供提示词,根据功能类型设置默认提示词
if (StringUtil.isNullOrEmpty(prompt)) {
switch (func) {
case RELIGHT_FLUX:
// 重新打光功能的默认提示词
prompt = "a model standing on the beautiful beach, ultra high quality, 8k";
break;
case IMAGE_TO_SKETCH_FLUX:
// 图片转线稿功能的默认提示词
prompt = "generate the sketch of the image, simple line, ultra high quality";
break;
case TO_PRODUCT_IMAGE_ADVANCED:
// 转产品图功能的默认提示词
prompt = "change the image to real style, ultra high quality, 8k";
// 如果是儿童风格,添加儿童面部特征描述
if (childStyle) prompt = prompt + ", Children's face";
break;
}
} else {
// 如果用户提供了提示词,先进行提示词优化处理
prompt = modifyPrompt(prompt, null, func.getName(), null);
// 根据不同功能类型,为提示词添加特定前缀
switch (func) {
case PATTERN:
// 图案生成功能,添加图案前缀
prompt = "pattern image, " + prompt;
break;
case SKETCH_BOARD:
// 线稿板功能,添加线稿描述
prompt = "a single item of sketch of " + prompt + ", clean white background, simple lines";
break;
}
}
// 构建Flux API请求体
JSONObject requestBody = new JSONObject();
// 设置生成提示词
requestBody.set("prompt", prompt);
// 设置随机种子,确保结果的可重现性
requestBody.set("seed", 42);
// 根据功能类型设置图片宽高比
if (func.equals(PATTERN)) {
// 图案生成使用正方形比例
requestBody.set("aspect_ratio", "1:1");
} else {
// 其他功能使用竖屏比例
requestBody.set("aspect_ratio", "9:16");
}
// 设置输出格式为PNG
requestBody.set("output_format", "png");
log.info("flux 请求入参:{}", requestBody);
// 提示词不能为空的校验
if (prompt.isEmpty()) throw new BusinessException("test");
// 如果提供了输入图片路径需要将图片转换为base64格式
if (!StringUtil.isNullOrEmpty(imagePath)) {
try {
String imageAsBase64 = null;
// 对于转产品图功能,先添加白色背景处理
if (func.equals(TO_PRODUCT_IMAGE_ADVANCED)) {
imageAsBase64 = addWhiteBackground(imagePath);
}
// 如果白色背景处理失败或不需要直接获取原图的base64编码
if (StringUtil.isNullOrEmpty(imageAsBase64)) {
imageAsBase64 = minioUtil.getImageAsBase64(imagePath);
}
// 将base64编码的图片添加到请求体中
requestBody.set("input_image", imageAsBase64);
} catch (IOException e) {
log.error("获取图片的base64格式失败{}", String.valueOf(e));
@@ -3542,22 +3725,30 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
}
}
// 发送POST请求到Flux API
String resp = sendRequestUtil.sendFluxPost(fluxRequestUrl, requestBody.toString());
JSONObject respObj = JSONUtil.parseObj(resp);
log.info("flux 发起生成请求返回结果: {}", respObj);
// 从响应中提取任务ID
String taskId = respObj.getStr("id");
if (StringUtil.isNullOrEmpty(taskId)) {
// 任务创建失败,记录错误信息并抛出异常
requestBody.set("input_image", imagePath);
log.error("flux生成任务创建失败func {} requestBody:{}", func.getName(), requestBody);
throw new BusinessException("Failed to generate task. Please retry later.");
}
// 获取轮询URL用于后续查询任务状态
String pollingUrl = respObj.getStr("polling_url");
String key = RedisUtil.FLUX_POLLING_URL + taskId;
// 将轮询URL存储到Redis中设置过期时间
redisUtil.addToString(key, pollingUrl, CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
// 添加到api_generate表中以便之后对结果查询做补偿
// 添加到api_generate表中以便之后对结果查询做补偿机制
apiGenerateService.addAPIGenerateRecordAsync(UserContext.getUserHolder().getId(), taskId, func.getName(), "flux", "Pending");
// 返回任务ID用于异步查询结果
return taskId;
}