1、接入超分功能

2、添加积分系统
3、新增订单查询,积分详细查询
This commit is contained in:
2024-03-15 15:38:56 +08:00
parent bf05f88c00
commit 305324fe1a
35 changed files with 798 additions and 55 deletions

View File

@@ -22,7 +22,7 @@ import java.util.Objects;
@Slf4j
@Component
public class MQConsumer {
public class GenerateConsumer {
@Resource
private GenerateService generateService;
@@ -30,13 +30,13 @@ public class MQConsumer {
@Resource
private RedisUtil redisUtil;
@Value("${redis.key.consumptionOrder}")
@Value("${redis.key.orderForGenerate}")
private String consumptionOrderKey;
@Value("${redis.key.cancelSet}")
@Value("${redis.key.generateCancelSet}")
private String cancelSetKey;
@Value("${redis.key.exceptionMap}")
@Value("${redis.key.generateExceptionMap}")
private String exceptionMapKey;
@Value("${redis.key.resultMap}")

View File

@@ -1,16 +1,8 @@
package com.ai.da.common.RabbitMQ;
import org.springframework.amqp.core.Binding;
import org.springframework.amqp.core.BindingBuilder;
import org.springframework.amqp.core.FanoutExchange;
import org.springframework.amqp.core.Queue;
import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.core.RabbitAdmin;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.beans.factory.annotation.Value;
@Configuration
public class MQConfig {
@@ -18,7 +10,10 @@ public class MQConfig {
public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange";
// public static final String GENERATE_QUEUE = "generate-queue-prod";
// public static final String GENERATE_QUEUE = "generate-queue-test";
public static final String GENERATE_QUEUE = "generate-queue-dev";
// public static final String GENERATE_QUEUE = "generate-queue-dev";
public static final String GENERATE_QUEUE = "generate-queue";
public static final String SR_QUEUE = "SR-queue-dev";
public MQConfig() {
}
@@ -32,10 +27,15 @@ public class MQConfig {
* 创建队列,使用工作模式,不用定义交换机
*/
@Bean
public Queue queueRasa() {
public Queue generateQueue() {
return new Queue(GENERATE_QUEUE);
}
@Bean
public Queue SRQueue() {
return new Queue(SR_QUEUE);
}
/**
* 将队列绑定到交换机上【队列订阅交换机】
*/

View File

@@ -18,7 +18,11 @@ public class MQPublisher {
public void sendGenerateMessage(String mm) {
log.info("send message:" + mm);
amqpTemplate.convertAndSend(MQConfig.GENERATE_QUEUE, mm);
}
public void sendSRMessage(String mm) {
log.info("send message:" + mm);
amqpTemplate.convertAndSend(MQConfig.SR_QUEUE, mm);
}
}

View File

@@ -0,0 +1,111 @@
package com.ai.da.common.RabbitMQ;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.utils.RedisUtil;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.dto.SuperResolutionDTO;
import com.ai.da.model.vo.GenerateCollectionVO;
import com.ai.da.service.SuperResolutionService;
import com.alibaba.fastjson.JSONObject;
import com.rabbitmq.client.Channel;
import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.core.Message;
import org.springframework.amqp.rabbit.annotation.RabbitHandler;
import org.springframework.amqp.rabbit.annotation.RabbitListener;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.io.IOException;
import java.util.HashMap;
@Slf4j
@Component
public class SRConsumer {
@Resource
private RedisUtil redisUtil;
@Value("${redis.key.orderForSR}")
private String consumptionOrderKey;
@Value("${redis.key.SRCancelSet}")
private String cancelSetKey;
@Value("${redis.key.SRExceptionMap}")
private String exceptionMapKey;
@Value("${redis.key.resultMap}")
private String resultMapKey;
@Resource
private SuperResolutionService superResolutionService;
public void superResolution(Message msg, Channel channel, String consumerName){
log.info("============start listening==========");
long start = System.currentTimeMillis();
SuperResolutionDTO superResolutionDTO = JSONObject.parseObject(msg.getBody(), SuperResolutionDTO.class);
String uniqueId = superResolutionDTO.getUniqueId();
log.info("From " + consumerName + " : " + uniqueId);
try {
// 2、判断当前消息是否在取消列表中
Boolean isMember = redisUtil.isElementExistsInSet(cancelSetKey, uniqueId);
if (isMember) {
try {
// 2.1 手动确认该消息
channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
} catch (IOException ex) {
log.error("手动确认,不返回队列重新消费");
}
} else {
/*try {
Thread.sleep(15000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}*/
String srOutput = superResolutionService.SR(superResolutionDTO);
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
if (!StringUtil.isNullOrEmpty(srOutput)) {
HashMap<String, String> generateResult = new HashMap<>();
generateResult.put(uniqueId, srOutput);
// 将结果存在redis中 ,为空时不要存
redisUtil.addToMap(resultMapKey, generateResult);
}
}
} catch (BusinessException e) {
log.error(e.getMsg());
// channel.basicNack() 为不确认deliveryTag对应的消息第二个参数是否应用于多消息第三个参数是否requeue
try {
// 第二个参数是否批量确认消息当传false时只确认当前 deliveryTag对应的消息;当传true时会确认当前及之前所有未确认的消息。
channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
} catch (IOException exception) {
log.error("手动确认,取消返回队列,不再重新消费");
}
// 将入参和错误信息存入数据库
String exceptionMessage = JSONObject.toJSONString(superResolutionDTO) +
" Exception message " + e.getMsg();
HashMap<String, String> exceptionInfo = new HashMap<>();
exceptionInfo.put(String.valueOf(uniqueId), exceptionMessage);
// 存redis
redisUtil.addToMap(exceptionMapKey, exceptionInfo);
}
long end = System.currentTimeMillis();
log.info(" task_id " + uniqueId + "----------" + consumerName + " 执行时长:" + (end - start) + "毫秒");
log.info("=============end listening===========");
}
@RabbitListener(queues = MQConfig.SR_QUEUE)
@RabbitHandler
public void SRConsumer1(Message msg, Channel channel) {
superResolution(msg, channel, "consumer 1");
}
}

View File

@@ -14,8 +14,8 @@ public enum CreditsEventsEnum {
DAILY_CHECKIN("Daily Check-In", "50"),
SOCIAL_MEDIA_SHARING("Social Media Sharing","50"),
BUY_CREDITS("Buy Credits","2000"),
// 6USD -> 1000 points ==> 10HKD -> 215 points ==> 2HKD -> 43points
BUY_CREDITS("Buy Credits","43"),
SUPER_RESOLUTION("Super Resolution","300"),

View File

@@ -2,6 +2,7 @@ package com.ai.da.common.utils;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.model.dto.GenerateToPythonDTO;
import com.ai.da.model.dto.SuperResolutionDTO;
import com.ai.da.python.PythonService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
@@ -14,6 +15,7 @@ import java.util.concurrent.*;
@Component
public class AsyncCallerUtil {
// 存放状态 表示当前任务是否需要继续等待,默认持续等待
public static Map<String, Boolean> waitingStatus = new HashMap<>();
private static PythonService pythonService;
@@ -70,4 +72,52 @@ public class AsyncCallerUtil {
}
}
public CompletableFuture<String> callSRAsync(SuperResolutionDTO superResolutionDTO) {
return CompletableFuture.supplyAsync(() -> pythonService.superResolution(superResolutionDTO));
}
public String SR(SuperResolutionDTO superResolutionDTO) {
ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(5);
String taskId = superResolutionDTO.getUniqueId();
ScheduledFuture<?> timeoutTask = null;
if (!waitingStatus.containsKey(taskId)) waitingStatus.put(taskId, true);
try {
CompletableFuture<String> generateResult = callSRAsync(superResolutionDTO);
// 5秒后第一次确认之后每隔10秒确认一次用户选择结果
timeoutTask = scheduledExecutorService.scheduleAtFixedRate(() -> {
// 调用另一个接口获取用户的选择
if (!waitingStatus.get(taskId)) {
// 如果用户选择取消则取消对generate的调用
generateResult.cancel(true);
waitingStatus.remove(taskId);
} else log.info("===============持续等待===============");
}, 5, 10, TimeUnit.SECONDS);
log.info("阻塞等待结果...");
// 阻塞,等待结果
String result = generateResult.get();
// 取消定时任务
timeoutTask.cancel(true);
waitingStatus.remove(taskId);
return result;
} catch (CancellationException e) {
// generateResult.cancel(true);通过抛出异常取消该任务
log.info("==========成功取消generate任务==========");
return null;
} catch (InterruptedException | ExecutionException | BusinessException e) {
// 处理异常
log.error("发生错误 " + e, e);
// 取消定时任务
assert timeoutTask != null;
timeoutTask.cancel(true);
throw new BusinessException(e.getMessage());
} finally {
// 关闭线程池
// executorService.shutdown();
// scheduledExecutorService.shutdown();
}
}
}

View File

@@ -66,7 +66,7 @@ public class AliPayController {
}
/**
* 申请退款
* 不在页面提供申请退款接口
* @param orderNo
* @param reason
* @return

View File

@@ -3,9 +3,9 @@ package com.ai.da.controller;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.response.Response;
import com.ai.da.mapper.DesignMapper;
import com.ai.da.mapper.TrialOrderMapper;
import com.ai.da.mapper.entity.TrialOrder;
import com.ai.da.mapper.primary.DesignMapper;
import com.ai.da.mapper.primary.TrialOrderMapper;
import com.ai.da.mapper.primary.entity.TrialOrder;
import com.ai.da.model.dto.UserDesignStatisticDTO;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;

View File

@@ -1,6 +1,9 @@
package com.ai.da.controller;
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.common.response.Response;
import com.ai.da.mapper.primary.entity.CreditsDetail;
import com.ai.da.model.dto.QueryIncomeOrExpenditureDTO;
import com.ai.da.service.CreditsService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
@@ -8,6 +11,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import javax.validation.Valid;
@CrossOrigin
@RestController
@@ -19,10 +23,19 @@ public class CreditsController {
@Resource
private CreditsService creditsService;
@ApiOperation("获取积分")
@ApiOperation("获取当前积分")
@GetMapping("/getCredits")
public Response<String> getCredits(){
String credits = creditsService.getCredits();
return Response.success(credits);
}
@ApiOperation("获取积分详细")
@PostMapping("/getCreditsDetail")
public Response<PageBaseResponse<CreditsDetail>> getCreditsDetail(@Valid @RequestBody QueryIncomeOrExpenditureDTO queryPageByTimeDTO){
PageBaseResponse<CreditsDetail> credits = creditsService.queryCreditsDetailsPage(queryPageByTimeDTO);
return Response.success(credits);
}
}

View File

@@ -1,16 +1,17 @@
package com.ai.da.controller;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.OrderStatusEnum;
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.common.response.Response;
import com.ai.da.mapper.primary.entity.OrderInfo;
import com.ai.da.model.dto.QueryPageByTimeDTO;
import com.ai.da.service.OrderInfoService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import java.util.List;
import javax.validation.Valid;
@CrossOrigin //开放前端的跨域访问
@Api(tags = "商品订单管理")
@@ -22,9 +23,9 @@ public class OrderInfoController {
private OrderInfoService orderInfoService;
@ApiOperation("订单列表")
@GetMapping("/list")
public Response<List<OrderInfo>> list(){
List<OrderInfo> orderByAccountId = orderInfoService.getOrderByAccountId(UserContext.getUserHolder().getId());
@PostMapping("/list")
public Response<PageBaseResponse<OrderInfo>> list(@Valid @RequestBody QueryPageByTimeDTO queryPageByTimeDTO){
PageBaseResponse<OrderInfo> orderByAccountId = orderInfoService.getOrderByPage(queryPageByTimeDTO);
// List<OrderInfo> list = orderInfoService.listOrderByCreateTimeDesc();
return Response.success(orderByAccountId);
}

View File

@@ -26,9 +26,6 @@ public class PayPalCheckoutController {
@Resource
private PayPalCheckoutService payPalCheckoutService;
@Resource
private CallBackService callBackService;
@ApiOperation(value = "创建订单")
@PostMapping(value = "/trade/{amount}")
public Response<HashMap<String, String>> createOrder(@PathVariable Integer amount, @RequestParam String returnUrl) throws SerializeException {
@@ -39,13 +36,12 @@ public class PayPalCheckoutController {
@ApiOperation(value = "ipn异步回调")
@PostMapping(value = "/ipn/back")
public Response<String> callback(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
Boolean result = callBackService.doGet(request, response);
Boolean result = payPalCheckoutService.doPost(request, response);
if (result){
return Response.success(200,"success");
}else {
return Response.fail(500,"failure");
}
// return payPalCheckoutService.callback(RequestToMapUtil.getParameterMap(request));
}
@ApiOperation(value = "查询指定订单")
@@ -55,6 +51,7 @@ public class PayPalCheckoutController {
return Response.success(s);
}
/** 不提供退款接口 */
@ApiOperation("申请退款")
@PostMapping("/trade/refund/{orderNo}/{reason}")
public Response<HttpResponse<Refund>> refund(@PathVariable String orderNo, @PathVariable String reason) throws IOException {

View File

@@ -7,12 +7,14 @@ import com.ai.da.mapper.primary.entity.Library;
import com.ai.da.model.dto.ChatFlushDTO;
import com.ai.da.model.dto.ChatRobotLibraryDTO;
import com.ai.da.model.dto.ChatSendDTO;
import com.ai.da.model.dto.SuperResolutionDTO;
import com.ai.da.model.vo.ChatRobotVO;
import com.ai.da.model.vo.PythonLibraryVo;
import com.ai.da.model.vo.SysFileVO;
import com.ai.da.python.PythonService;
import com.ai.da.service.ChatRobotService;
import com.ai.da.service.LibraryService;
import com.ai.da.service.SuperResolutionService;
import com.ai.da.service.SysFileService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
@@ -41,10 +43,12 @@ public class PythonController {
private SysFileService sysFileService;
@Resource
private LibraryService libraryService;
@Resource
private ChatRobotService chatRobotService;
@Resource
private SuperResolutionService superResolutionService;
@ApiOperation(value = "python服务保存图片到java服务")
@PostMapping("/saveGeneratePicture")
public Response<String> upload(@RequestParam("file") MultipartFile file,
@@ -109,4 +113,10 @@ public class PythonController {
return Response.success(chatRobotService.chatBufferFlush(chatFlushDTO));
}
@ApiOperation(value = "超分辨率")
@PostMapping("/prepareForSR")
public Response<String> superResolution(@RequestBody SuperResolutionDTO superResolutionDTO){
return Response.success(superResolutionService.prepareForSR(superResolutionDTO));
}
}

View File

@@ -0,0 +1,7 @@
package com.ai.da.mapper.primary;
import com.ai.da.mapper.primary.entity.CreditsDetail;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
public interface CreditsDetailMapper extends BaseMapper<CreditsDetail> {
}

View File

@@ -0,0 +1,7 @@
package com.ai.da.mapper.primary;
import com.ai.da.mapper.primary.entity.SuperResolution;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
public interface SuperResolutionMapper extends BaseMapper<SuperResolution> {
}

View File

@@ -6,6 +6,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import java.io.Serializable;
import java.time.LocalDateTime;
import java.util.Date;
@Data
@@ -23,12 +24,12 @@ public class BaseEntity implements Serializable {
/**
* 创建时间
*/
private Date createTime;
private LocalDateTime createTime;
/**
* 更新时间
*/
private Date updateTime;
private LocalDateTime updateTime;
/**
* 是否已删除

View File

@@ -0,0 +1,21 @@
package com.ai.da.mapper.primary.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.math.BigDecimal;
@EqualsAndHashCode(callSuper = true)
@Data
@TableName("t_credits_detail")
public class CreditsDetail extends BaseEntity {
/** 用户id */
private Long accountId;
/** 积分变更事件 */
private String changeEvent;
/** 变更积分 ( + 表示加,- 表示减) */
private String changedCredits;
/** 当前积分 */
private BigDecimal credits;
}

View File

@@ -0,0 +1,20 @@
package com.ai.da.mapper.primary.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.util.List;
@EqualsAndHashCode(callSuper = true)
@Data
@TableName("t_super_resolution")
public class SuperResolution extends BaseEntity{
private String input_url;
private Integer scale;
private String output_url;
}

View File

@@ -0,0 +1,13 @@
package com.ai.da.model.dto;
import io.swagger.annotations.ApiModel;
import lombok.Data;
import lombok.EqualsAndHashCode;
@EqualsAndHashCode(callSuper = true)
@Data
@ApiModel("查积分的收支详情")
public class QueryIncomeOrExpenditureDTO extends QueryPageByTimeDTO{
private Boolean isIncome;
}

View File

@@ -0,0 +1,19 @@
package com.ai.da.model.dto;
import com.ai.da.model.vo.PageQueryBaseVo;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
@EqualsAndHashCode(callSuper = true)
@Data
@ApiModel("分页查询,限制时间区间")
public class QueryPageByTimeDTO extends PageQueryBaseVo {
@ApiModelProperty("开始时间 yyyy-mm-dd hh:mm:ss 可以不要时分秒")
private String startTime;
@ApiModelProperty("结束时间 yyyy-mm-dd hh:mm:ss 可以不要时分秒")
private String endTime;
}

View File

@@ -0,0 +1,27 @@
package com.ai.da.model.dto;
import io.swagger.annotations.ApiModelProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import javax.validation.constraints.NotBlank;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class SuperResolutionDTO {
@NotBlank(message = "You have to select at least one image")
@ApiModelProperty("图片")
private String images;
@NotBlank(message = "You must choose the magnification")
@ApiModelProperty("放大倍数")
private Integer scale;
@ApiModelProperty("唯一id用于保持消息唯一性")
String uniqueId;
}

View File

@@ -3082,4 +3082,64 @@ public class PythonService {
return Boolean.TRUE;
}
public String superResolution(SuperResolutionDTO superResolutionDTO){
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
MediaType mediaType = MediaType.parse("application/json");
HashMap<String, String> content = new HashMap<>();
content.put("image_url", superResolutionDTO.getImages());
content.put("sr_xn", superResolutionDTO.getScale().toString());
content.put("task_id", superResolutionDTO.getUniqueId());
String jsonString = JSON.toJSONString(content, SerializerFeature.WriteNullStringAsEmpty);
RequestBody body = RequestBody.create(mediaType, jsonString);
Request request = new Request.Builder()
.url(accessPythonIp + ":" + 9991 + "/super-resolution/")
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
Response response = null;
String bodyString;
try {
log.info("superResolution请求入参content###{}", JSON.toJSONString(superResolutionDTO, SerializerFeature.WriteMapNullValue));
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("PythonService##superResolution异常###{}", ExceptionUtil.getThrowableList(ioException));
throw new BusinessException(ioException.getMessage());
}
// 判断是否生成失败
if (Objects.isNull(response.body())) {
log.error("PythonService##superResolution异常###{}", "response or body is empty!");
throw new BusinessException("PythonService##superResolution异常###: response or body is empty!");
} else if (response.code() != HttpURLConnection.HTTP_OK) {
log.error("PythonService##superResolution异常###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("PythonService##superResolution异常### Response error!Response code ## " + response.code() + " ##");
} else {
try {
bodyString = response.body().string();
} catch (IOException e) {
log.error(e.getMessage());
throw new BusinessException(e.getMessage());
}
}
JSONObject jsonObject = JSON.parseObject(bodyString);
Boolean result = JSON.parseObject(JSON.toJSONString(response)).getBoolean("successful");
// todo 返回数据的结构没对接好
if (result && jsonObject.get("code").equals(200)) {
log.info("superResolution##responseObject###{}", jsonObject);
return jsonObject.getJSONObject("data").get("image").toString();
}
log.info("superResolution失败###{}", jsonObject);
log.info("superResolution Exception! Code : " + jsonObject.get("code"));
//生成失败
throw new BusinessException("sr.interface.error");
}
}

View File

@@ -1,6 +1,12 @@
package com.ai.da.service;
public interface CreditsService {
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.mapper.primary.entity.CreditsDetail;
import com.ai.da.model.dto.QueryIncomeOrExpenditureDTO;
import com.baomidou.mybatisplus.extension.service.IService;
public interface CreditsService extends IService<CreditsDetail> {
void initCredits();
@@ -13,4 +19,8 @@ public interface CreditsService {
String getCredits();
void creditsRefund(Long accountId, Integer quantity);
void insertToCreditsDetail(Long accountId, String changeEvent, String credits, String changeType);
PageBaseResponse<CreditsDetail> queryCreditsDetailsPage(QueryIncomeOrExpenditureDTO queryPageByTimeDTO);
}

View File

@@ -2,7 +2,9 @@ package com.ai.da.service;
import com.ai.da.common.enums.OrderStatusEnum;
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.mapper.primary.entity.OrderInfo;
import com.ai.da.model.dto.QueryPageByTimeDTO;
import com.baomidou.mybatisplus.extension.service.IService;
import java.util.List;
@@ -23,7 +25,7 @@ public interface OrderInfoService extends IService<OrderInfo> {
OrderInfo getOrderByOrderNo(String orderNo);
List<OrderInfo> getOrderByAccountId(Long accountId);
PageBaseResponse<OrderInfo> getOrderByPage(QueryPageByTimeDTO queryPageByTimeDTO);
void updateOrderNoById(Long id, String orderNo);
}

View File

@@ -3,6 +3,9 @@ package com.ai.da.service;
import com.paypal.http.exceptions.SerializeException;
import com.paypal.orders.Order;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
@@ -16,6 +19,9 @@ public interface PayPalCheckoutService {
*/
String callback(@SuppressWarnings("rawtypes") Map map);
Boolean doPost(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException;
String queryOrder(String orderNo) throws SerializeException;
Order captureOrder(String orderId) throws IOException;

View File

@@ -5,7 +5,9 @@ import org.springframework.stereotype.Service;
@Service
public interface RabbitMQService {
void publishMessage(String message);
void publishMessageToGenerate(String message);
void publishMessageToSR(String message);
Integer getMessageCount(String queueUrl);
}

View File

@@ -0,0 +1,12 @@
package com.ai.da.service;
import com.ai.da.mapper.primary.entity.SuperResolution;
import com.ai.da.model.dto.SuperResolutionDTO;
import com.baomidou.mybatisplus.extension.service.IService;
public interface SuperResolutionService extends IService<SuperResolution> {
String prepareForSR(SuperResolutionDTO superResolutionDTO);
String SR(SuperResolutionDTO superResolutionDTO);
}

View File

@@ -210,6 +210,11 @@ public class AliPayServiceImpl implements AliPayService {
orderInfoService.updateStatusByOrderNo(orderNo, OrderStatusEnum.SUCCESS);
//记录支付日志
paymentInfoService.createPaymentInfoForAliPay(params);
// 添加积分变更记录
creditsService.insertToCreditsDetail(orderByOrderNo.getAccountId(),
CreditsEventsEnum.BUY_CREDITS.getName() + "--Alipay",
CreditsEventsEnum.BUY_CREDITS.getValue(),
"positive");
// 更新积分
creditsService.buyCredits(orderByOrderNo.getAccountId(),Integer.parseInt(totalAmount) / Integer.parseInt(CreditsEventsEnum.PRICE.getValue()));
} finally {
@@ -307,6 +312,11 @@ public class AliPayServiceImpl implements AliPayService {
orderInfoService.updateStatusByOrderNo(orderNo, OrderStatusEnum.SUCCESS);
//并记录支付日志
paymentInfoService.createPaymentInfoForAliPay(alipayTradeQueryResponse);
// 添加积分变更记录
creditsService.insertToCreditsDetail(orderByOrderNo.getAccountId(),
CreditsEventsEnum.BUY_CREDITS.getName() + "--Alipay",
CreditsEventsEnum.BUY_CREDITS.getValue(),
"positive");
// 更新积分
creditsService.buyCredits(orderByOrderNo.getAccountId(),orderByOrderNo.getTotalFee() / Integer.parseInt(CreditsEventsEnum.PRICE.getValue()));
}

View File

@@ -3,17 +3,29 @@ package com.ai.da.service.impl;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.CreditsEventsEnum;
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.mapper.primary.AccountMapper;
import com.ai.da.mapper.primary.CreditsDetailMapper;
import com.ai.da.mapper.primary.entity.Account;
import com.ai.da.mapper.primary.entity.CreditsDetail;
import com.ai.da.model.dto.QueryIncomeOrExpenditureDTO;
import com.ai.da.service.AccountService;
import com.ai.da.service.CreditsService;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import io.netty.util.internal.StringUtil;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import javax.annotation.Resource;
import java.math.BigDecimal;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Objects;
@Service
public class CreditsServiceImpl implements CreditsService {
public class CreditsServiceImpl extends ServiceImpl<CreditsDetailMapper, CreditsDetail> implements CreditsService {
@Resource
private AccountService accountService;
@@ -88,4 +100,66 @@ public class CreditsServiceImpl implements CreditsService {
BigDecimal subtracted = existingCredits.subtract(newCredits);
accountService.updateCredits(accountId, subtracted.toString());
}
/**
* 向积分变更详细表添加记录
* @param changeEvent 导致积分变更的事件
* @param credits 变更的积分
* @param changeType 变更类型 positive->增 negative->减
*/
@Override
public void insertToCreditsDetail(Long accountId, String changeEvent, String credits, String changeType){
CreditsDetail creditsDetail = new CreditsDetail();
Account account = accountMapper.selectById(accountId);
BigDecimal finalCredits;
String changeCredits;
if ("positive".equals(changeType)){
finalCredits = account.getCredits().add(new BigDecimal(credits));
changeCredits = "+" + credits;
}else {
finalCredits = account.getCredits().subtract(new BigDecimal(credits));
changeCredits = "-" + credits;
}
creditsDetail.setAccountId(accountId);
creditsDetail.setChangeEvent(changeEvent);
creditsDetail.setChangedCredits(changeCredits);
creditsDetail.setCredits(finalCredits);
creditsDetail.setCreateTime(LocalDateTime.now());
baseMapper.insert(creditsDetail);
}
@Override
public PageBaseResponse<CreditsDetail> queryCreditsDetailsPage(QueryIncomeOrExpenditureDTO queryPageByTimeDTO){
QueryWrapper<CreditsDetail> qw = new QueryWrapper<>();
qw.eq("account_id",UserContext.getUserHolder().getId());
String startTime = queryPageByTimeDTO.getStartTime();
String endTime = queryPageByTimeDTO.getEndTime();
if (StringUtil.isNullOrEmpty(startTime)){
startTime = "2024-03-01 00:00:00";
}
if (StringUtil.isNullOrEmpty(endTime)){
LocalDateTime now = LocalDateTime.now();
DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
endTime = now.format(dateTimeFormatter);
}
if (!Objects.isNull(queryPageByTimeDTO.getIsIncome())){
if (queryPageByTimeDTO.getIsIncome()){
qw.likeRight("changed_credits","+");
}else {
qw.likeRight("changed_credits","-");
}
}
qw.between("create_time", startTime, endTime);
qw.orderByDesc("create_time");
Page<CreditsDetail> pageInfo = new Page<>(queryPageByTimeDTO.getPage(), queryPageByTimeDTO.getSize());
Page<CreditsDetail> orderInfo = baseMapper.selectPage(pageInfo, qw);
if (CollectionUtils.isEmpty(orderInfo.getRecords())) {
return PageBaseResponse.success(new Page<>());
}
return PageBaseResponse.success(orderInfo);
}
}

View File

@@ -68,13 +68,13 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Resource
private GenerateCancelMapper generateCancelMapper;
@Value("${redis.key.consumptionOrder}")
@Value("${redis.key.orderForGenerate}")
private String consumptionOrderKey;
@Value("${redis.key.cancelSet}")
@Value("${redis.key.generateCancelSet}")
private String cancelSetKey;
@Value("${redis.key.exceptionMap}")
@Value("${redis.key.generateExceptionMap}")
private String exceptionMapKey;
@Value("${redis.key.resultMap}")
@@ -368,7 +368,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
redisUtil.addToZSet(consumptionOrderKey, uuid, maxScore);
// 4、将消息发布到MQ消息队列
rabbitMQService.publishMessage(jsonString);
rabbitMQService.publishMessageToGenerate(jsonString);
// 5、返回唯一id
return new PrepareForGenerateVO(uuid, 2 - trialsCount);

View File

@@ -4,20 +4,27 @@ package com.ai.da.service.impl;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.CreditsEventsEnum;
import com.ai.da.common.enums.OrderStatusEnum;
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.common.utils.OrderNoUtils;
import com.ai.da.mapper.primary.OrderInfoMapper;
import com.ai.da.mapper.primary.ProductMapper;
import com.ai.da.mapper.primary.entity.OrderInfo;
import com.ai.da.model.dto.QueryPageByTimeDTO;
import com.ai.da.model.vo.AuthPrincipalVo;
import com.ai.da.service.OrderInfoService;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import javax.annotation.Resource;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
@Service
@@ -172,12 +179,30 @@ public class OrderInfoServiceImpl extends ServiceImpl<OrderInfoMapper, OrderInfo
return orderInfo;
}
public List<OrderInfo> getOrderByAccountId(Long accountId){
@Override
public PageBaseResponse<OrderInfo> getOrderByPage(QueryPageByTimeDTO queryPageByTimeDTO){
QueryWrapper<OrderInfo> qw = new QueryWrapper<>();
qw.eq("account_id",accountId);
qw.orderByDesc("create_time");
qw.eq("account_id",UserContext.getUserHolder().getId());
return baseMapper.selectList(qw);
String startTime = queryPageByTimeDTO.getStartTime();
String endTime = queryPageByTimeDTO.getEndTime();
if (StringUtil.isNullOrEmpty(startTime)){
startTime = "2024-02-01 00:00:00";
}
if (StringUtil.isNullOrEmpty(endTime)){
LocalDateTime now = LocalDateTime.now();
DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
endTime = now.format(dateTimeFormatter);
}
qw.between("create_time", startTime, endTime);
qw.orderByDesc("create_time");
Page<OrderInfo> pageInfo = new Page<>(queryPageByTimeDTO.getPage(), queryPageByTimeDTO.getSize());
Page<OrderInfo> orderInfo = baseMapper.selectPage(pageInfo, qw);
if (CollectionUtils.isEmpty(orderInfo.getRecords())) {
return PageBaseResponse.success(new Page<>());
}
return PageBaseResponse.success(orderInfo);
}
public void updateOrderNoById(Long id, String orderNo){

View File

@@ -7,10 +7,17 @@ import com.ai.da.common.constant.PayPalCheckoutConstant;
import com.ai.da.common.enums.*;
import com.ai.da.common.utils.RedisUtil;
import com.ai.da.common.utils.paypalRequest.AuthenticationRequest;
import com.ai.da.common.utils.paypalRequest.WebhookVerifyRequest;
import com.ai.da.mapper.primary.entity.OrderInfo;
import com.ai.da.mapper.primary.entity.RefundInfo;
import com.ai.da.service.*;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.gson.Gson;
import com.paypal.api.payments.Event;
import com.paypal.base.Constants;
import com.paypal.base.SDKUtil;
import com.paypal.base.rest.APIContext;
import com.paypal.base.rest.PayPalRESTException;
import com.paypal.http.HttpResponse;
import com.paypal.http.exceptions.SerializeException;
import com.paypal.http.serializer.Json;
@@ -24,9 +31,19 @@ import org.apache.commons.lang3.StringUtils;
import org.json.JSONObject;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import javax.annotation.Resource;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SignatureException;
import java.util.*;
import static com.ai.da.common.constant.PayPalCheckoutConstant.*;
@@ -47,19 +64,14 @@ public class PayPalCheckoutServiceImpl implements PayPalCheckoutService {
@Resource
private PayPalClient payPalClient;
@Resource
private OrderInfoService orderInfoService;
@Resource
private PaymentInfoService paymentInfoService;
@Resource
private RefundInfoService refundsInfoService;
@Resource
private CreditsService creditsService;
@Resource
private RedisUtil redisUtil;
@@ -67,6 +79,7 @@ public class PayPalCheckoutServiceImpl implements PayPalCheckoutService {
* 创建订单的方法
*/
@Override
@Transactional(rollbackFor = Exception.class)
public HashMap<String, String> createOrder(Integer amount, String returnUrl) throws SerializeException {
// 生成订单
log.info("生成订单");
@@ -102,7 +115,6 @@ public class PayPalCheckoutServiceImpl implements PayPalCheckoutService {
String orderId = response.result().id();
orderInfoService.updateOrderNoById(orderInfo.getId(), orderId);
HashMap<String, String> returnData = new HashMap<>();
returnData.put("approve",approve);
returnData.put("orderNo",orderId);
@@ -111,6 +123,99 @@ public class PayPalCheckoutServiceImpl implements PayPalCheckoutService {
return returnData;
}
// ##Validate Webhook
public Boolean doPost(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException {
try {
String body = getBody(req);
Map webhookEvent = new ObjectMapper().readValue(body, Map.class);
HashMap<String, Object> webhookRequest = new HashMap<>();
webhookRequest.put("auth_algo", SDKUtil.validateAndGet(getHeadersInfo(req), "PAYPAL-AUTH-ALGO"));
webhookRequest.put("cert_url",SDKUtil.validateAndGet(getHeadersInfo(req), "PAYPAL-CERT-URL"));
webhookRequest.put("transmission_id",SDKUtil.validateAndGet(getHeadersInfo(req), "PAYPAL-TRANSMISSION-ID"));
webhookRequest.put("transmission_sig",SDKUtil.validateAndGet(getHeadersInfo(req), "PAYPAL-TRANSMISSION-SIG"));
webhookRequest.put("transmission_time",SDKUtil.validateAndGet(getHeadersInfo(req), "PAYPAL-TRANSMISSION-TIME"));
webhookRequest.put("webhook_id",PayPalCheckoutConstant.WEBHOOK_ID);
webhookRequest.put("webhook_event",webhookEvent);
WebhookVerifyRequest webhookVerifyRequest = new WebhookVerifyRequest();
webhookVerifyRequest.authorization(getOAuth());
webhookVerifyRequest.requestBody(webhookRequest);
// 验签
HttpResponse<HashMap> verified = payPalClient.client(MODE, clientId, clientSecret).execute(webhookVerifyRequest);
boolean verifyResult = verified.result().get("verification_status").toString().equals("SUCCESS");
if (verifyResult){
// ### Api Context
APIContext apiContext = new APIContext(clientId, clientSecret, PayPalCheckoutConstant.MODE);
// Set the webhookId that you received when you created this webhook.
apiContext.addConfiguration(Constants.PAYPAL_WEBHOOK_ID, PayPalCheckoutConstant.WEBHOOK_ID);
Boolean result = Event.validateReceivedEvent(apiContext, getHeadersInfo(
req), body);
log.info("Webhook Validated: " + result);
if (result){
// 处理订单数据
LinkedHashMap<String,LinkedHashMap<String,String>> webhookEventMap = (LinkedHashMap<String,LinkedHashMap<String,String>>) webhookEvent;
String orderId = webhookEventMap.get("resource").get("id");
processOrder(orderId);
return Boolean.TRUE;
}
}
} catch (PayPalRESTException | InvalidKeyException | NoSuchAlgorithmException | SignatureException e) {
log.error(e.getMessage());
}
return Boolean.FALSE;
}
// Simple helper method to help you extract the headers from HttpServletRequest object.
private static Map<String, String> getHeadersInfo(HttpServletRequest request) {
Map<String, String> map = new HashMap<String, String>();
@SuppressWarnings("rawtypes")
Enumeration headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String key = (String) headerNames.nextElement();
String value = request.getHeader(key);
map.put(key, value);
}
return map;
}
// Simple helper method to fetch request data as a string from HttpServletRequest object.
private static String getBody(HttpServletRequest request) throws IOException {
String body;
StringBuilder stringBuilder = new StringBuilder();
BufferedReader bufferedReader = null;
try {
InputStream inputStream = request.getInputStream();
if (inputStream != null) {
bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
char[] charBuffer = new char[128];
int bytesRead = -1;
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
stringBuilder.append(charBuffer, 0, bytesRead);
}
} else {
stringBuilder.append("");
}
} catch (IOException ex) {
throw ex;
} finally {
if (bufferedReader != null) {
try {
bufferedReader.close();
} catch (IOException ex) {
throw ex;
}
}
}
body = stringBuilder.toString();
log.info("回调参数 ===> {}", body);
return body;
}
/**
* 生成订单主体信息
*/
@@ -256,6 +361,7 @@ public class PayPalCheckoutServiceImpl implements PayPalCheckoutService {
/**
* 用户授权支付成功,进行扣款操作
*/
@Transactional(rollbackFor = Exception.class)
public Order captureOrder(String orderId) {
OrdersCaptureRequest request = new OrdersCaptureRequest(orderId);
request.requestBody(new OrderRequest());
@@ -330,6 +436,7 @@ public class PayPalCheckoutServiceImpl implements PayPalCheckoutService {
/**
* 申请退款
*/
@Transactional(rollbackFor = Exception.class)
public Boolean refundOrder(String orderId,String reason) throws IOException {
RefundInfo refundByOrderNo = refundsInfoService.createRefundByOrderNo(orderId, reason);
@@ -458,6 +565,7 @@ public class PayPalCheckoutServiceImpl implements PayPalCheckoutService {
}
// 处理当前订单
@Transactional(rollbackFor = Exception.class)
public void processOrder(String orderId){
// 1、确定当前订单是否已经被扣款
OrderInfo orderInfo = orderInfoService.getOrderByOrderNo(orderId);
@@ -473,6 +581,11 @@ public class PayPalCheckoutServiceImpl implements PayPalCheckoutService {
orderInfoService.updateStatusByOrderNo(orderId, OrderStatusEnum.SUCCESS);
//记录支付日志
paymentInfoService.createPaymentInfoForPayPal(capturedOrder);
// 添加积分变更记录
creditsService.insertToCreditsDetail(orderInfo.getAccountId(),
CreditsEventsEnum.BUY_CREDITS.getName() + "--PayPal",
CreditsEventsEnum.BUY_CREDITS.getValue(),
"positive");
// 更新积分
creditsService.buyCredits(orderInfo.getAccountId(), orderInfo.getTotalFee() / Integer.parseInt(CreditsEventsEnum.PRICE.getValue()));
}

View File

@@ -27,10 +27,14 @@ public class RabbitMQServiceImpl implements RabbitMQService {
private MQPublisher mqPublisher;
@Override
public void publishMessage(String message) {
public void publishMessageToGenerate(String message) {
mqPublisher.sendGenerateMessage(message);
}
@Override
public void publishMessageToSR(String message) {
mqPublisher.sendSRMessage(message);
}
@Override
public Integer getMessageCount(String queueUrl) {

View File

@@ -0,0 +1,122 @@
package com.ai.da.service.impl;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.CreditsEventsEnum;
import com.ai.da.common.utils.AsyncCallerUtil;
import com.ai.da.common.utils.RedisUtil;
import com.ai.da.mapper.primary.SuperResolutionMapper;
import com.ai.da.mapper.primary.entity.SuperResolution;
import com.ai.da.model.dto.SuperResolutionDTO;
import com.ai.da.python.PythonService;
import com.ai.da.service.CreditsService;
import com.ai.da.service.RabbitMQService;
import com.ai.da.service.SuperResolutionService;
import com.alibaba.fastjson.JSON;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import javax.annotation.Resource;
import java.math.BigDecimal;
import java.time.LocalDateTime;
import java.util.UUID;
@Service
public class SuperResolutionServiceImpl extends ServiceImpl<SuperResolutionMapper, SuperResolution> implements SuperResolutionService {
public static final Integer creditsConsumption = 100;
@Resource
private CreditsService creditsService;
@Resource
private RabbitMQService rabbitMQService;
@Resource
private AsyncCallerUtil asyncCallerUtil;
@Resource
private PythonService pythonService;
@Resource
private RedisUtil redisUtil;
@Value("${redis.key.orderForSR}")
private String orderForSR;
@Value("${redis.key.resultMap}")
private String resultMapKey;
@Override
public String prepareForSR(SuperResolutionDTO superResolutionDTO) {
// 异步处理
// 判断用户当前积分是否够本次超分消耗
String credits = creditsService.getCredits();
if (new BigDecimal(credits).subtract(new BigDecimal(creditsConsumption)).compareTo(BigDecimal.ZERO) < 0){
return "Not enough Credits";
}
// 2、生成唯一id 使用uuid
String uuid = UUID.randomUUID().toString();
int num = 1;
// 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id
while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) ||
redisUtil.isElementExistsInZSet(orderForSR, uuid))
&& num < 10) {
uuid = UUID.randomUUID().toString();
num++;
}
// 无依据确定的数字
if (num > 10) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
uuid = UUID.randomUUID().toString();
}
superResolutionDTO.setUniqueId(uuid);
String jsonString = JSON.toJSONString(superResolutionDTO);
// 3、加入redis排队便于获取实时排队信息
Double maxScore = redisUtil.getMaxScore(orderForSR);
redisUtil.addToZSet(orderForSR, uuid, maxScore);
// 4、将消息发布到MQ消息队列
rabbitMQService.publishMessageToSR(jsonString);
// 5、返回唯一id
return uuid;
}
@Transactional(rollbackFor = Exception.class)
@Override
public String SR(SuperResolutionDTO superResolutionDTO){
// 1、向模型发起请求
// String srResult = asyncCallerUtil.SR(superResolutionDTO);
String srResult = pythonService.superResolution(superResolutionDTO);
// 2、向数据库插入数据
SuperResolution superResolution = new SuperResolution();
superResolution.setInput_url(superResolutionDTO.getImages());
superResolution.setScale(superResolutionDTO.getScale());
superResolution.setOutput_url(srResult);
superResolution.setCreateTime(LocalDateTime.now());
baseMapper.insert(superResolution);
// 3、记录积分变更
creditsService.insertToCreditsDetail(UserContext.getUserHolder().getId(),
CreditsEventsEnum.SUPER_RESOLUTION.getName(),
CreditsEventsEnum.SUPER_RESOLUTION.getValue(),
"negative");
// 4、扣除积分
creditsService.creditsDecrease(UserContext.getUserHolder().getId(), CreditsEventsEnum.SUPER_RESOLUTION.getValue());
// 4、将数据存在数据库
return srResult;
}
}