特征推荐初次提交
This commit is contained in:
43
pom.xml
43
pom.xml
@@ -422,6 +422,49 @@
|
|||||||
<artifactId>bcpkix-jdk18on</artifactId>
|
<artifactId>bcpkix-jdk18on</artifactId>
|
||||||
<version>1.78.1</version>
|
<version>1.78.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<!-- Milvus Java SDK -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>io.milvus</groupId>
|
||||||
|
<artifactId>milvus-sdk-java</artifactId>
|
||||||
|
<version>2.3.4</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!-- DJL (Deep Java Library) for Deep Learning -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>ai.djl</groupId>
|
||||||
|
<artifactId>api</artifactId>
|
||||||
|
<version>0.24.0</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ai.djl.pytorch</groupId>
|
||||||
|
<artifactId>pytorch-engine</artifactId>
|
||||||
|
<version>0.24.0</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ai.djl.pytorch</groupId>
|
||||||
|
<artifactId>pytorch-native-cpu</artifactId>
|
||||||
|
<version>2.0.1</version>
|
||||||
|
<classifier>linux-x86_64</classifier>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ai.djl.pytorch</groupId>
|
||||||
|
<artifactId>pytorch-native-cpu</artifactId>
|
||||||
|
<version>2.0.1</version>
|
||||||
|
<classifier>win-x86_64</classifier>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ai.djl.pytorch</groupId>
|
||||||
|
<artifactId>pytorch-native-cpu</artifactId>
|
||||||
|
<version>2.0.1</version>
|
||||||
|
<classifier>osx-x86_64</classifier>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ai.djl.pytorch</groupId>
|
||||||
|
<artifactId>pytorch-native-cpu</artifactId>
|
||||||
|
<version>2.0.1</version>
|
||||||
|
<classifier>osx-aarch64</classifier>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<build>
|
<build>
|
||||||
|
|||||||
45
src/main/java/com/ai/da/common/config/MilvusConfig.java
Normal file
45
src/main/java/com/ai/da/common/config/MilvusConfig.java
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package com.ai.da.common.config;
|
||||||
|
|
||||||
|
import io.milvus.client.MilvusServiceClient;
|
||||||
|
import io.milvus.param.ConnectParam;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Milvus 向量数据库配置
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Configuration
|
||||||
|
public class MilvusConfig {
|
||||||
|
|
||||||
|
@Value("${milvus.host:localhost}")
|
||||||
|
private String host;
|
||||||
|
|
||||||
|
@Value("${milvus.port:19530}")
|
||||||
|
private Integer port;
|
||||||
|
|
||||||
|
@Value("${milvus.username:}")
|
||||||
|
private String username;
|
||||||
|
|
||||||
|
@Value("${milvus.password:}")
|
||||||
|
private String password;
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public MilvusServiceClient milvusClient() {
|
||||||
|
ConnectParam.Builder builder = ConnectParam.newBuilder()
|
||||||
|
.withHost(host)
|
||||||
|
.withPort(port);
|
||||||
|
|
||||||
|
if (StringUtils.isNotBlank(username) && StringUtils.isNotBlank(password)) {
|
||||||
|
builder.withAuthorization(username, password);
|
||||||
|
}
|
||||||
|
|
||||||
|
MilvusServiceClient client = new MilvusServiceClient(builder.build());
|
||||||
|
log.info("Milvus client initialized: {}:{}", host, port);
|
||||||
|
return client;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -72,6 +72,12 @@ public class SysFile implements Serializable {
|
|||||||
*/
|
*/
|
||||||
private String style;
|
private String style;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否废弃 0-否 1-是
|
||||||
|
*/
|
||||||
|
private Integer deprecated;
|
||||||
|
|
||||||
|
|
||||||
public SysFile() {
|
public SysFile() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
46
src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java
Normal file
46
src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package com.ai.da.model.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 推荐请求 DTO
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class RecommendRequestDTO {
|
||||||
|
/**
|
||||||
|
* 用户ID
|
||||||
|
*/
|
||||||
|
private Long userId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 类别(如 female_skirt, male_tops)- 可选,用于辅助筛选
|
||||||
|
*/
|
||||||
|
private String category;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 风格样式(可选)- 用于辅助筛选
|
||||||
|
*/
|
||||||
|
private String style;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 返回数量
|
||||||
|
*/
|
||||||
|
private Integer topK = 10;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否只返回未淘汰的图片
|
||||||
|
*/
|
||||||
|
private Boolean onlyActive = true;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 向量搜索的候选数量(第一阶段)
|
||||||
|
* 建议值:50-100,用于后续筛选和多样性优化
|
||||||
|
*/
|
||||||
|
private Integer candidateSize = 50;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用多样性优化(避免推荐过于相似的结果)
|
||||||
|
*/
|
||||||
|
private Boolean enableDiversity = true;
|
||||||
|
}
|
||||||
|
|
||||||
@@ -20,6 +20,7 @@ import com.ai.da.model.vo.*;
|
|||||||
import com.ai.da.python.vo.*;
|
import com.ai.da.python.vo.*;
|
||||||
import com.ai.da.service.DesignHistoryService;
|
import com.ai.da.service.DesignHistoryService;
|
||||||
import com.ai.da.service.PythonTAllInfoService;
|
import com.ai.da.service.PythonTAllInfoService;
|
||||||
|
import com.ai.da.service.RecommendationService;
|
||||||
import com.ai.da.service.SysFileService;
|
import com.ai.da.service.SysFileService;
|
||||||
import com.alibaba.fastjson.*;
|
import com.alibaba.fastjson.*;
|
||||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||||
@@ -722,6 +723,9 @@ public class PythonService {
|
|||||||
@Resource
|
@Resource
|
||||||
private AttributeRetrievalMapper attributeRetrievalMapper;
|
private AttributeRetrievalMapper attributeRetrievalMapper;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private RecommendationService recommendationService;
|
||||||
|
|
||||||
private List<CollectionElement> getSystemSketchPool(JSONObject attributeRecognition, String styleCategory, String modelSex, int poolNum, String style) {
|
private List<CollectionElement> getSystemSketchPool(JSONObject attributeRecognition, String styleCategory, String modelSex, int poolNum, String style) {
|
||||||
/**
|
/**
|
||||||
* female trousers->female_pants
|
* female trousers->female_pants
|
||||||
@@ -3955,6 +3959,27 @@ public class PythonService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public List<String> getSystemSketchByCategory(String category, Long brandId, Double brandScale,String style) {
|
public List<String> getSystemSketchByCategory(String category, Long brandId, Double brandScale,String style) {
|
||||||
|
AuthPrincipalVo userHolder = UserContext.getUserHolder();
|
||||||
|
|
||||||
|
// 优先使用基于 Milvus 的推荐系统
|
||||||
|
try {
|
||||||
|
com.ai.da.model.dto.RecommendRequestDTO request = new com.ai.da.model.dto.RecommendRequestDTO();
|
||||||
|
request.setUserId(userHolder.getId());
|
||||||
|
request.setCategory(category);
|
||||||
|
request.setStyle(style);
|
||||||
|
request.setTopK(1);
|
||||||
|
request.setOnlyActive(true);
|
||||||
|
|
||||||
|
List<String> recommendedUrls = recommendationService.recommend(request);
|
||||||
|
if (!CollectionUtils.isEmpty(recommendedUrls)) {
|
||||||
|
log.info("使用 Milvus 推荐系统返回 {} 个结果", recommendedUrls.size());
|
||||||
|
return recommendedUrls;
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Milvus 推荐失败,降级到备用方案: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 降级方案1: 使用 attribute_retrieval_style 表
|
||||||
//******3.1.2版本临时使用java推荐方案去解决style未使用的问题**********
|
//******3.1.2版本临时使用java推荐方案去解决style未使用的问题**********
|
||||||
try {
|
try {
|
||||||
//使用新库attribute_retrieval_style,表命名修改为elementVO.getModelSex().toLowerCase() + "_" + styleCategory.toLowerCase()比如female_skirt,与传入的category保持一致
|
//使用新库attribute_retrieval_style,表命名修改为elementVO.getModelSex().toLowerCase() + "_" + styleCategory.toLowerCase()比如female_skirt,与传入的category保持一致
|
||||||
@@ -3974,11 +3999,11 @@ public class PythonService {
|
|||||||
return Arrays.asList(path);
|
return Arrays.asList(path);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.info("推荐失败:{}",e.getMessage());
|
log.info("attribute_retrieval_style 推荐失败:{}",e.getMessage());
|
||||||
throw new BusinessException("system.error");
|
|
||||||
}
|
}
|
||||||
//**********************end***********************************
|
//**********************end***********************************
|
||||||
AuthPrincipalVo userHolder = UserContext.getUserHolder();
|
|
||||||
|
// 降级方案2: 调用 Python 服务(原有方案)
|
||||||
|
|
||||||
OkHttpClient client = new OkHttpClient().newBuilder()
|
OkHttpClient client = new OkHttpClient().newBuilder()
|
||||||
.connectTimeout(30, TimeUnit.SECONDS)
|
.connectTimeout(30, TimeUnit.SECONDS)
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package com.ai.da.service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 特征提取服务接口
|
||||||
|
* 用于从图片 URL 提取 2048 维特征向量
|
||||||
|
*/
|
||||||
|
public interface FeatureExtractionService {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从图片 URL 提取特征向量
|
||||||
|
*
|
||||||
|
* @param imageUrl 图片 URL(MinIO 路径)
|
||||||
|
* @return 2048 维特征向量,如果提取失败返回 null
|
||||||
|
*/
|
||||||
|
List<Float> extractFeature(String imageUrl);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批量提取特征向量
|
||||||
|
*
|
||||||
|
* @param imageUrls 图片 URL 列表
|
||||||
|
* @return 特征向量列表,与输入顺序对应
|
||||||
|
*/
|
||||||
|
List<List<Float>> extractFeatures(List<String> imageUrls);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 计算用户偏好向量(基于用户点赞的图片向量)
|
||||||
|
* 使用平均向量作为用户偏好
|
||||||
|
*
|
||||||
|
* @param userLikedVectors 用户点赞的图片向量列表
|
||||||
|
* @return 用户偏好向量(2048维)
|
||||||
|
*/
|
||||||
|
List<Float> calculateUserPreferenceVector(List<List<Float>> userLikedVectors);
|
||||||
|
}
|
||||||
|
|
||||||
34
src/main/java/com/ai/da/service/RecommendationService.java
Normal file
34
src/main/java/com/ai/da/service/RecommendationService.java
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package com.ai.da.service;
|
||||||
|
|
||||||
|
import com.ai.da.model.dto.RecommendRequestDTO;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 推荐服务接口
|
||||||
|
*/
|
||||||
|
public interface RecommendationService {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据用户偏好推荐系统 sketch
|
||||||
|
*
|
||||||
|
* @param request 推荐请求
|
||||||
|
* @return 推荐的 URL 列表
|
||||||
|
*/
|
||||||
|
List<String> recommend(RecommendRequestDTO request);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 同步 t_sys_file 数据到 Milvus
|
||||||
|
* 从 t_sys_file 表读取所有系统 sketch,提取特征向量并存储到 Milvus
|
||||||
|
*/
|
||||||
|
void syncSysFileToMilvus();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 更新单个文件的向量(当文件更新时调用)
|
||||||
|
*
|
||||||
|
* @param sysFileId 系统文件ID
|
||||||
|
* @param url 文件URL
|
||||||
|
*/
|
||||||
|
void updateVector(Long sysFileId, String url);
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,220 @@
|
|||||||
|
package com.ai.da.service.impl;
|
||||||
|
|
||||||
|
import ai.djl.Application;
|
||||||
|
import ai.djl.inference.Predictor;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.repository.zoo.Criteria;
|
||||||
|
import ai.djl.repository.zoo.ModelZoo;
|
||||||
|
import ai.djl.repository.zoo.ZooModel;
|
||||||
|
import ai.djl.training.util.ProgressBar;
|
||||||
|
import ai.djl.translate.Translator;
|
||||||
|
import com.ai.da.common.utils.MinioUtil;
|
||||||
|
import com.ai.da.service.FeatureExtractionService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import jakarta.annotation.PostConstruct;
|
||||||
|
import jakarta.annotation.PreDestroy;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 特征提取服务实现类
|
||||||
|
* 使用 DJL (Deep Java Library) 框架加载 ResNet50 模型提取特征向量
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class FeatureExtractionServiceImpl implements FeatureExtractionService {
|
||||||
|
|
||||||
|
private static final int DIMENSION = 2048; // ResNet50 特征向量维度
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private MinioUtil minioUtil;
|
||||||
|
|
||||||
|
private ZooModel<Image, float[]> model;
|
||||||
|
private Predictor<Image, float[]> predictor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 初始化 ResNet50 模型
|
||||||
|
* 加载预训练模型并移除最后的全连接层,以获取2048维特征向量
|
||||||
|
*/
|
||||||
|
@PostConstruct
|
||||||
|
public void initModel() {
|
||||||
|
try {
|
||||||
|
log.info("开始加载 ResNet50 模型(特征提取模式,2048维)...");
|
||||||
|
|
||||||
|
// 1. 加载完整的预训练 ResNet50 模型
|
||||||
|
Criteria<Image, float[]> criteria = Criteria.builder()
|
||||||
|
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
|
||||||
|
.setTypes(Image.class, float[].class)
|
||||||
|
.optModelName("resnet50")
|
||||||
|
.optEngine("PyTorch")
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// 2. 创建特征提取 Translator
|
||||||
|
Translator<Image, float[]> translator = new ResNetFeatureTranslator();
|
||||||
|
|
||||||
|
// 3. 加载模型
|
||||||
|
model = ModelZoo.loadModel(criteria);
|
||||||
|
|
||||||
|
// 4. 创建 Predictor
|
||||||
|
predictor = model.newPredictor(translator);
|
||||||
|
|
||||||
|
log.info("ResNet50 模型加载完成(特征提取模式,2048维)");
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("加载 ResNet50 模型失败", e);
|
||||||
|
throw new RuntimeException("模型加载失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理资源
|
||||||
|
*/
|
||||||
|
@PreDestroy
|
||||||
|
public void cleanup() {
|
||||||
|
if (predictor != null) {
|
||||||
|
predictor.close();
|
||||||
|
}
|
||||||
|
if (model != null) {
|
||||||
|
model.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从图片 URL 提取特征向量
|
||||||
|
*
|
||||||
|
* @param imageUrl MinIO 路径,格式为 "bucket_name/path/to/image.jpg"
|
||||||
|
* 例如: "aida-sys-image/images/female/blouse/0825001474.jpg"
|
||||||
|
* @return 2048 维特征向量,如果提取失败返回 null
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<Float> extractFeature(String imageUrl) {
|
||||||
|
if (predictor == null) {
|
||||||
|
log.error("模型未初始化,无法提取特征");
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 1. 从 MinIO 获取图片
|
||||||
|
Image image = getImageFromMinio(imageUrl);
|
||||||
|
if (image == null) {
|
||||||
|
log.warn("无法从 MinIO 获取图片: {}", imageUrl);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 使用模型提取特征
|
||||||
|
float[] featureArray = predictor.predict(image);
|
||||||
|
|
||||||
|
// 3. 转换为 List<Float>
|
||||||
|
List<Float> featureVector = new ArrayList<>(DIMENSION);
|
||||||
|
for (float value : featureArray) {
|
||||||
|
featureVector.add(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug("成功提取特征向量: {}, 维度: {}", imageUrl, featureVector.size());
|
||||||
|
return featureVector;
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("提取特征失败: {}", imageUrl, e);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从 MinIO 获取图片并转换为 DJL Image 对象
|
||||||
|
*/
|
||||||
|
private Image getImageFromMinio(String imageUrl) {
|
||||||
|
InputStream imageStream = null;
|
||||||
|
try {
|
||||||
|
// 解析路径:bucket_name/path/to/image.jpg
|
||||||
|
int index = imageUrl.indexOf("/");
|
||||||
|
if (index == -1 || index == imageUrl.length() - 1) {
|
||||||
|
log.warn("无效的图片路径格式: {}", imageUrl);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
String bucketName = imageUrl.substring(0, index);
|
||||||
|
String objectName = imageUrl.substring(index + 1);
|
||||||
|
|
||||||
|
// 从 MinIO 获取图片流
|
||||||
|
// 使用 MinioUtil 的 download 方法获取 InputStream
|
||||||
|
// download 方法可能抛出 MinioException 和 IOException
|
||||||
|
imageStream = minioUtil.download(bucketName, objectName);
|
||||||
|
if (imageStream == null) {
|
||||||
|
log.warn("无法从 MinIO 获取图片流: bucket={}, object={}", bucketName, objectName);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为 DJL Image 对象
|
||||||
|
Image image = ImageFactory.getInstance().fromInputStream(imageStream);
|
||||||
|
|
||||||
|
return image;
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("从 MinIO 获取图片失败: {}", imageUrl, e);
|
||||||
|
return null;
|
||||||
|
} finally {
|
||||||
|
if (imageStream != null) {
|
||||||
|
try {
|
||||||
|
imageStream.close();
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("关闭图片流失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<List<Float>> extractFeatures(List<String> imageUrls) {
|
||||||
|
if (CollectionUtils.isEmpty(imageUrls)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
return imageUrls.stream()
|
||||||
|
.map(this::extractFeature)
|
||||||
|
.filter(java.util.Objects::nonNull)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Float> calculateUserPreferenceVector(List<List<Float>> userLikedVectors) {
|
||||||
|
if (CollectionUtils.isEmpty(userLikedVectors)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
int dimension = userLikedVectors.get(0).size();
|
||||||
|
List<Float> preferenceVector = new ArrayList<>(dimension);
|
||||||
|
|
||||||
|
// 初始化为零向量
|
||||||
|
for (int i = 0; i < dimension; i++) {
|
||||||
|
preferenceVector.add(0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算平均向量
|
||||||
|
for (List<Float> vector : userLikedVectors) {
|
||||||
|
if (vector.size() != dimension) {
|
||||||
|
log.warn("向量维度不匹配: 期望 {}, 实际 {}", dimension, vector.size());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < dimension; i++) {
|
||||||
|
preferenceVector.set(i, preferenceVector.get(i) + vector.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 求平均值
|
||||||
|
int count = userLikedVectors.size();
|
||||||
|
for (int i = 0; i < dimension; i++) {
|
||||||
|
preferenceVector.set(i, preferenceVector.get(i) / count);
|
||||||
|
}
|
||||||
|
|
||||||
|
return preferenceVector;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,679 @@
|
|||||||
|
package com.ai.da.service.impl;
|
||||||
|
|
||||||
|
import com.ai.da.common.config.MilvusConfig;
|
||||||
|
import com.ai.da.mapper.primary.SysFileMapper;
|
||||||
|
import com.ai.da.mapper.primary.UserPreferenceLogMapper;
|
||||||
|
import com.ai.da.mapper.primary.entity.SysFile;
|
||||||
|
import com.ai.da.mapper.primary.entity.UserPreferenceLogTest;
|
||||||
|
import com.ai.da.model.dto.RecommendRequestDTO;
|
||||||
|
import com.ai.da.service.FeatureExtractionService;
|
||||||
|
import com.ai.da.service.RecommendationService;
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||||
|
import io.milvus.client.MilvusServiceClient;
|
||||||
|
import io.milvus.grpc.DataType;
|
||||||
|
import io.milvus.grpc.SearchResults;
|
||||||
|
import io.milvus.param.IndexType;
|
||||||
|
import io.milvus.param.MetricType;
|
||||||
|
import io.milvus.param.R;
|
||||||
|
import io.milvus.param.RpcStatus;
|
||||||
|
import io.milvus.param.collection.*;
|
||||||
|
import io.milvus.param.dml.InsertParam;
|
||||||
|
import io.milvus.param.dml.QueryParam;
|
||||||
|
import io.milvus.param.dml.SearchParam;
|
||||||
|
import io.milvus.param.index.CreateIndexParam;
|
||||||
|
import io.milvus.grpc.MutationResult;
|
||||||
|
import io.milvus.grpc.QueryResults;
|
||||||
|
import io.milvus.response.FieldDataWrapper;
|
||||||
|
import io.milvus.response.QueryResultsWrapper;
|
||||||
|
import io.milvus.response.SearchResultsWrapper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import jakarta.annotation.PostConstruct;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 推荐服务实现类
|
||||||
|
* 基于 Milvus 向量数据库的推荐系统
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class RecommendationServiceImpl implements RecommendationService {
|
||||||
|
|
||||||
|
private static final String COLLECTION_NAME = "sketch_recommendation";
|
||||||
|
private static final int DIMENSION = 2048; // ResNet50 特征向量维度
|
||||||
|
private static final String VECTOR_FIELD = "feature_vector";
|
||||||
|
private static final String URL_FIELD = "url";
|
||||||
|
private static final String STYLE_FIELD = "style";
|
||||||
|
private static final String CATEGORY_FIELD = "category";
|
||||||
|
private static final String DEPRECATED_FIELD = "deprecated";
|
||||||
|
private static final String SYS_FILE_ID_FIELD = "sys_file_id";
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private MilvusServiceClient milvusClient;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private SysFileMapper sysFileMapper;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private UserPreferenceLogMapper userPreferenceLogMapper;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private FeatureExtractionService featureExtractionService;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 初始化 Collection(如果不存在则创建)
|
||||||
|
*/
|
||||||
|
@PostConstruct
|
||||||
|
public void initCollection() {
|
||||||
|
try {
|
||||||
|
// 检查 Collection 是否存在
|
||||||
|
R<Boolean> hasCollectionR = milvusClient.hasCollection(
|
||||||
|
HasCollectionParam.newBuilder()
|
||||||
|
.withCollectionName(COLLECTION_NAME)
|
||||||
|
.build()
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!hasCollectionR.getData()) {
|
||||||
|
log.info("Collection {} 不存在,开始创建...", COLLECTION_NAME);
|
||||||
|
createCollection();
|
||||||
|
createIndex();
|
||||||
|
log.info("Collection {} 创建完成", COLLECTION_NAME);
|
||||||
|
} else {
|
||||||
|
log.info("Collection {} 已存在", COLLECTION_NAME);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("初始化 Collection 失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Collection
|
||||||
|
*/
|
||||||
|
private void createCollection() {
|
||||||
|
// 定义字段
|
||||||
|
List<FieldType> fields = Arrays.asList(
|
||||||
|
FieldType.newBuilder()
|
||||||
|
.withName(SYS_FILE_ID_FIELD)
|
||||||
|
.withDataType(DataType.Int64)
|
||||||
|
.withPrimaryKey(true)
|
||||||
|
.withAutoID(false)
|
||||||
|
.build(),
|
||||||
|
FieldType.newBuilder()
|
||||||
|
.withName(URL_FIELD)
|
||||||
|
.withDataType(DataType.VarChar)
|
||||||
|
.withMaxLength(500)
|
||||||
|
.build(),
|
||||||
|
FieldType.newBuilder()
|
||||||
|
.withName(STYLE_FIELD)
|
||||||
|
.withDataType(DataType.VarChar)
|
||||||
|
.withMaxLength(100)
|
||||||
|
.build(),
|
||||||
|
FieldType.newBuilder()
|
||||||
|
.withName(CATEGORY_FIELD)
|
||||||
|
.withDataType(DataType.VarChar)
|
||||||
|
.withMaxLength(100)
|
||||||
|
.build(),
|
||||||
|
FieldType.newBuilder()
|
||||||
|
.withName(DEPRECATED_FIELD)
|
||||||
|
.withDataType(DataType.Int8)
|
||||||
|
.build(),
|
||||||
|
FieldType.newBuilder()
|
||||||
|
.withName(VECTOR_FIELD)
|
||||||
|
.withDataType(DataType.FloatVector)
|
||||||
|
.withDimension(DIMENSION)
|
||||||
|
.build()
|
||||||
|
);
|
||||||
|
|
||||||
|
// 创建 Collection
|
||||||
|
CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
|
||||||
|
.withCollectionName(COLLECTION_NAME)
|
||||||
|
.withDescription("系统 sketch 推荐向量库")
|
||||||
|
.withShardsNum(2)
|
||||||
|
.withFieldTypes(fields)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
R<RpcStatus> createR = milvusClient.createCollection(createParam);
|
||||||
|
if (createR.getStatus() != R.Status.Success.getCode()) {
|
||||||
|
throw new RuntimeException("创建 Collection 失败: " + createR.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建索引
|
||||||
|
*/
|
||||||
|
private void createIndex() {
|
||||||
|
CreateIndexParam indexParam = CreateIndexParam.newBuilder()
|
||||||
|
.withCollectionName(COLLECTION_NAME)
|
||||||
|
.withFieldName(VECTOR_FIELD)
|
||||||
|
.withIndexType(IndexType.IVF_FLAT)
|
||||||
|
.withMetricType(MetricType.L2)
|
||||||
|
.withExtraParam("{\"nlist\":1024}")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
R<RpcStatus> indexR = milvusClient.createIndex(indexParam);
|
||||||
|
if (indexR.getStatus() != R.Status.Success.getCode()) {
|
||||||
|
throw new RuntimeException("创建索引失败: " + indexR.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> recommend(RecommendRequestDTO request) {
|
||||||
|
try {
|
||||||
|
// 1. 从 user_preference_log_test 获取用户点赞的 path 列表
|
||||||
|
QueryWrapper<UserPreferenceLogTest> queryWrapper = new QueryWrapper<>();
|
||||||
|
queryWrapper.lambda().eq(UserPreferenceLogTest::getAccountId, request.getUserId());
|
||||||
|
List<UserPreferenceLogTest> userLikes = userPreferenceLogMapper.selectList(queryWrapper);
|
||||||
|
|
||||||
|
// 新用户处理:如果没有点赞记录,使用降级推荐策略
|
||||||
|
if (CollectionUtils.isEmpty(userLikes)) {
|
||||||
|
log.info("用户 {} 没有点赞记录,使用新用户推荐策略(基于 style 和 category)", request.getUserId());
|
||||||
|
return recommendForNewUser(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 提取用户点赞图片的特征向量
|
||||||
|
List<String> likedPaths = userLikes.stream()
|
||||||
|
.map(UserPreferenceLogTest::getPath)
|
||||||
|
.filter(StringUtils::isNotBlank)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
List<List<Float>> likedVectors = featureExtractionService.extractFeatures(likedPaths);
|
||||||
|
likedVectors = likedVectors.stream()
|
||||||
|
.filter(Objects::nonNull)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(likedVectors)) {
|
||||||
|
log.warn("用户 {} 无法提取特征向量", request.getUserId());
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 计算用户偏好向量(平均向量)
|
||||||
|
List<Float> userPreferenceVector = featureExtractionService.calculateUserPreferenceVector(likedVectors);
|
||||||
|
|
||||||
|
// 4. 构建搜索参数(混合推荐策略)
|
||||||
|
// 策略:向量相似度搜索为主 + 分类筛选为辅
|
||||||
|
|
||||||
|
List<String> searchOutputFields = Arrays.asList(URL_FIELD, STYLE_FIELD, CATEGORY_FIELD);
|
||||||
|
List<List<Float>> searchVectors = Collections.singletonList(userPreferenceVector);
|
||||||
|
|
||||||
|
// 构建过滤表达式(基础筛选)
|
||||||
|
StringBuilder exprBuilder = new StringBuilder();
|
||||||
|
List<String> conditions = new ArrayList<>();
|
||||||
|
|
||||||
|
// 基础筛选:只返回未淘汰的图片
|
||||||
|
if (request.getOnlyActive() != null && request.getOnlyActive()) {
|
||||||
|
conditions.add(DEPRECATED_FIELD + " == 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 可选筛选:风格(如果指定)
|
||||||
|
if (StringUtils.isNotBlank(request.getStyle())) {
|
||||||
|
conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\"");
|
||||||
|
}
|
||||||
|
if (StringUtils.isNotBlank(request.getCategory())) {
|
||||||
|
conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\"");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!conditions.isEmpty()) {
|
||||||
|
exprBuilder.append(String.join(" && ", conditions));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一阶段:向量搜索获取大量候选(用于后续筛选和多样性优化)
|
||||||
|
int candidateSize = request.getCandidateSize() != null ?
|
||||||
|
Math.max(request.getCandidateSize(), request.getTopK() * 3) :
|
||||||
|
Math.max(50, request.getTopK() * 5);
|
||||||
|
|
||||||
|
SearchParam searchParam = SearchParam.newBuilder()
|
||||||
|
.withCollectionName(COLLECTION_NAME)
|
||||||
|
.withMetricType(MetricType.L2) // L2距离,越小越相似
|
||||||
|
.withOutFields(searchOutputFields)
|
||||||
|
.withTopK(candidateSize) // 获取更多候选用于后续处理
|
||||||
|
.withVectors(searchVectors)
|
||||||
|
.withVectorFieldName(VECTOR_FIELD)
|
||||||
|
.withExpr(exprBuilder.length() > 0 ? exprBuilder.toString() : null)
|
||||||
|
.withParams("{\"nprobe\":10}")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// 5. 执行搜索
|
||||||
|
R<SearchResults> searchR = milvusClient.search(searchParam);
|
||||||
|
if (searchR.getStatus() != R.Status.Success.getCode()) {
|
||||||
|
log.error("搜索失败: {}", searchR.getMessage());
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. 解析结果并进行后处理
|
||||||
|
SearchResultsWrapper wrapper = new SearchResultsWrapper(searchR.getData().getResults());
|
||||||
|
|
||||||
|
// 6.1 提取候选结果(包含相似度分数)
|
||||||
|
List<RecommendationCandidate> candidates = new ArrayList<>();
|
||||||
|
FieldDataWrapper urlFieldWrapper = wrapper.getFieldWrapper(URL_FIELD);
|
||||||
|
FieldDataWrapper styleFieldWrapper = wrapper.getFieldWrapper(STYLE_FIELD);
|
||||||
|
List<SearchResultsWrapper.IDScore> idScores = wrapper.getIDScore(0);
|
||||||
|
|
||||||
|
for (int i = 0; i < idScores.size(); i++) {
|
||||||
|
String url = (String) urlFieldWrapper.getFieldData().get(i);
|
||||||
|
String style = (String) styleFieldWrapper.getFieldData().get(i);
|
||||||
|
Float score = idScores.get(i).getScore(); // L2距离,越小越相似
|
||||||
|
|
||||||
|
if (StringUtils.isNotBlank(url)) {
|
||||||
|
candidates.add(new RecommendationCandidate(url, style, score));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// // 6.2 分类筛选(如果指定了 category)
|
||||||
|
// if (StringUtils.isNotBlank(request.getCategory())) {
|
||||||
|
// candidates = filterByCategory(candidates, request.getCategory());
|
||||||
|
// }
|
||||||
|
|
||||||
|
// 6.3 多样性优化(避免推荐过于相似的结果)
|
||||||
|
if (request.getEnableDiversity() != null && request.getEnableDiversity()) {
|
||||||
|
candidates = diversifyResults(candidates, request.getTopK() != null ? request.getTopK() : 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6.4 取 Top-K
|
||||||
|
int topK = request.getTopK() != null ? request.getTopK() : 1;
|
||||||
|
List<String> recommendedUrls = candidates.stream()
|
||||||
|
.limit(topK)
|
||||||
|
.map(RecommendationCandidate::getUrl)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
log.info("为用户 {} 推荐了 {} 个结果(从 {} 个候选中筛选)",
|
||||||
|
request.getUserId(), recommendedUrls.size(), idScores.size());
|
||||||
|
return recommendedUrls;
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("推荐失败", e);
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void syncSysFileToMilvus() {
|
||||||
|
log.info("开始同步 t_sys_file 数据到 Milvus");
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 1. 加载 Collection
|
||||||
|
LoadCollectionParam loadParam = LoadCollectionParam.newBuilder()
|
||||||
|
.withCollectionName(COLLECTION_NAME)
|
||||||
|
.build();
|
||||||
|
milvusClient.loadCollection(loadParam);
|
||||||
|
|
||||||
|
// 2. 查询所有系统 sketch
|
||||||
|
QueryWrapper<SysFile> queryWrapper = new QueryWrapper<>();
|
||||||
|
queryWrapper.lambda()
|
||||||
|
.eq(SysFile::getLevel1Type, "Images").eq(SysFile::getDeprecated, 0)
|
||||||
|
.isNotNull(SysFile::getUrl)
|
||||||
|
.isNotNull(SysFile::getStyle)
|
||||||
|
.ne(SysFile::getUrl, "");
|
||||||
|
List<SysFile> sysFiles = sysFileMapper.selectList(queryWrapper);
|
||||||
|
|
||||||
|
log.info("找到 {} 个系统文件需要同步", sysFiles.size());
|
||||||
|
|
||||||
|
int successCount = 0;
|
||||||
|
int failCount = 0;
|
||||||
|
int batchSize = 100;
|
||||||
|
|
||||||
|
// 3. 批量处理
|
||||||
|
for (int i = 0; i < sysFiles.size(); i += batchSize) {
|
||||||
|
int end = Math.min(i + batchSize, sysFiles.size());
|
||||||
|
List<SysFile> batch = sysFiles.subList(i, end);
|
||||||
|
|
||||||
|
List<Long> ids = new ArrayList<>();
|
||||||
|
List<String> urls = new ArrayList<>();
|
||||||
|
List<String> styles = new ArrayList<>();
|
||||||
|
List<String> categories = new ArrayList<>();
|
||||||
|
List<Byte> deprecatedList = new ArrayList<>();
|
||||||
|
List<List<Float>> vectors = new ArrayList<>();
|
||||||
|
|
||||||
|
for (SysFile sysFile : batch) {
|
||||||
|
try {
|
||||||
|
// 提取特征向量
|
||||||
|
List<Float> vector = featureExtractionService.extractFeature(sysFile.getUrl());
|
||||||
|
if (vector == null || vector.size() != DIMENSION) {
|
||||||
|
log.warn("文件 {} 特征提取失败,跳过", sysFile.getUrl());
|
||||||
|
failCount++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ids.add(sysFile.getId());
|
||||||
|
urls.add(sysFile.getUrl());
|
||||||
|
styles.add(sysFile.getStyle() != null ? sysFile.getStyle() : "");
|
||||||
|
categories.add(resolveCategory(sysFile));
|
||||||
|
deprecatedList.add((byte) (sysFile.getDeprecated() != null ? sysFile.getDeprecated() : 0));
|
||||||
|
vectors.add(vector);
|
||||||
|
successCount++;
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("处理文件 {} 失败", sysFile.getUrl(), e);
|
||||||
|
failCount++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量插入
|
||||||
|
if (!ids.isEmpty()) {
|
||||||
|
List<InsertParam.Field> fields = Arrays.asList(
|
||||||
|
new InsertParam.Field(SYS_FILE_ID_FIELD, ids),
|
||||||
|
new InsertParam.Field(URL_FIELD, urls),
|
||||||
|
new InsertParam.Field(STYLE_FIELD, styles),
|
||||||
|
new InsertParam.Field(DEPRECATED_FIELD, deprecatedList),
|
||||||
|
new InsertParam.Field(CATEGORY_FIELD, categories),
|
||||||
|
new InsertParam.Field(VECTOR_FIELD, vectors)
|
||||||
|
);
|
||||||
|
|
||||||
|
InsertParam insertParam = InsertParam.newBuilder()
|
||||||
|
.withCollectionName(COLLECTION_NAME)
|
||||||
|
.withFields(fields)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
R<MutationResult> insertR = milvusClient.insert(insertParam);
|
||||||
|
if (insertR.getStatus() == R.Status.Success.getCode()) {
|
||||||
|
log.info("批量插入成功: {}/{}", end, sysFiles.size());
|
||||||
|
} else {
|
||||||
|
log.error("批量插入失败: {}", insertR.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 刷新数据
|
||||||
|
FlushParam flushParam = FlushParam.newBuilder()
|
||||||
|
.withCollectionNames(Collections.singletonList(COLLECTION_NAME))
|
||||||
|
.build();
|
||||||
|
milvusClient.flush(flushParam);
|
||||||
|
|
||||||
|
log.info("同步完成: 成功 {}, 失败 {}", successCount, failCount);
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("同步数据到 Milvus 失败", e);
|
||||||
|
throw new RuntimeException("同步失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateVector(Long sysFileId, String url) {
|
||||||
|
// TODO: 实现单个向量更新逻辑
|
||||||
|
// 需要先删除旧向量,再插入新向量
|
||||||
|
log.info("更新向量: sysFileId={}, url={}", sysFileId, url);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 新用户推荐策略:当用户没有点赞记录时,根据 style 和 category 随机推荐
|
||||||
|
*
|
||||||
|
* @param request 推荐请求
|
||||||
|
* @return 推荐的 URL 列表
|
||||||
|
*/
|
||||||
|
private List<String> recommendForNewUser(RecommendRequestDTO request) {
|
||||||
|
try {
|
||||||
|
int topK = request.getTopK() != null ? request.getTopK() : 10;
|
||||||
|
|
||||||
|
List<String> milvusResults = recommendFromMilvusForNewUser(request, topK);
|
||||||
|
if (CollectionUtils.isEmpty(milvusResults)) {
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
Collections.shuffle(milvusResults);
|
||||||
|
return milvusResults.stream().limit(topK).collect(Collectors.toList());
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("新用户推荐失败", e);
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从 Milvus 中为新用户推荐(根据 style 和 category 查询)
|
||||||
|
*/
|
||||||
|
private List<String> recommendFromMilvusForNewUser(RecommendRequestDTO request, int limit) {
|
||||||
|
try {
|
||||||
|
// 构建过滤表达式
|
||||||
|
StringBuilder exprBuilder = new StringBuilder();
|
||||||
|
List<String> conditions = new ArrayList<>();
|
||||||
|
|
||||||
|
// 基础筛选:只返回未淘汰的图片
|
||||||
|
if (request.getOnlyActive() != null && request.getOnlyActive()) {
|
||||||
|
conditions.add(DEPRECATED_FIELD + " == 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 风格筛选(如果指定)
|
||||||
|
if (StringUtils.isNotBlank(request.getStyle())) {
|
||||||
|
conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\"");
|
||||||
|
}
|
||||||
|
if (StringUtils.isNotBlank(request.getCategory())) {
|
||||||
|
conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\"");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (conditions.isEmpty()) {
|
||||||
|
// 如果没有筛选条件,查询所有未淘汰的
|
||||||
|
conditions.add(DEPRECATED_FIELD + " == 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
exprBuilder.append(String.join(" && ", conditions));
|
||||||
|
|
||||||
|
// 从 Milvus 查询符合条件的记录
|
||||||
|
QueryParam queryParam = QueryParam.newBuilder()
|
||||||
|
.withCollectionName(COLLECTION_NAME)
|
||||||
|
.withExpr(exprBuilder.toString())
|
||||||
|
.withOutFields(Arrays.asList(URL_FIELD))
|
||||||
|
.withLimit((long) limit)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
R<QueryResults> queryR = milvusClient.query(queryParam);
|
||||||
|
if (queryR.getStatus() != R.Status.Success.getCode()) {
|
||||||
|
log.warn("从 Milvus 查询失败: {}", queryR.getMessage());
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析结果
|
||||||
|
List<String> urls = new ArrayList<>();
|
||||||
|
QueryResultsWrapper wrapper = new QueryResultsWrapper(queryR.getData());
|
||||||
|
for (int i = 0; i < wrapper.getRowRecords().size(); i++) {
|
||||||
|
String url = (String) wrapper.getFieldWrapper(URL_FIELD).getFieldData().get(i);
|
||||||
|
if (StringUtils.isNotBlank(url)) {
|
||||||
|
urls.add(url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// // 如果指定了 category,进行筛选
|
||||||
|
// if (StringUtils.isNotBlank(request.getCategory()) && !urls.isEmpty()) {
|
||||||
|
// urls = urls.stream()
|
||||||
|
// .filter(url -> {
|
||||||
|
// // 从 URL 中提取类别信息
|
||||||
|
// // URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
|
||||||
|
// String[] parts = url.split("/");
|
||||||
|
// if (parts.length >= 4) {
|
||||||
|
// String sex = parts[parts.length - 3];
|
||||||
|
// String cat = parts[parts.length - 2];
|
||||||
|
// String expectedCategory = sex + "_" + cat;
|
||||||
|
// return expectedCategory.equals(request.getCategory());
|
||||||
|
// }
|
||||||
|
// return false;
|
||||||
|
// })
|
||||||
|
// .collect(Collectors.toList());
|
||||||
|
// }
|
||||||
|
|
||||||
|
log.info("从 Milvus 为新用户查询到 {} 条记录", urls.size());
|
||||||
|
return urls;
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("从 Milvus 查询新用户推荐失败", e);
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从 MySQL 中为新用户推荐(根据 style 和 category 查询)
|
||||||
|
*/
|
||||||
|
private List<String> recommendFromMySQLForNewUser(RecommendRequestDTO request, int topK) {
|
||||||
|
try {
|
||||||
|
QueryWrapper<SysFile> queryWrapper = new QueryWrapper<>();
|
||||||
|
queryWrapper.lambda()
|
||||||
|
.eq(SysFile::getLevel1Type, "Images")
|
||||||
|
.isNotNull(SysFile::getUrl)
|
||||||
|
.ne(SysFile::getUrl, "");
|
||||||
|
|
||||||
|
// 风格筛选(如果指定)
|
||||||
|
if (StringUtils.isNotBlank(request.getStyle())) {
|
||||||
|
queryWrapper.lambda().eq(SysFile::getStyle, request.getStyle());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 类别筛选(如果指定)
|
||||||
|
if (StringUtils.isNotBlank(request.getCategory())) {
|
||||||
|
// category 格式: female_skirt, male_tops 等
|
||||||
|
String[] parts = request.getCategory().split("_");
|
||||||
|
if (parts.length == 2) {
|
||||||
|
String sex = parts[0]; // female 或 male
|
||||||
|
String category = parts[1]; // skirt, tops 等
|
||||||
|
|
||||||
|
// 从 URL 中筛选类别
|
||||||
|
// URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
|
||||||
|
queryWrapper.lambda().like(SysFile::getUrl, "/" + sex + "/" + category + "/");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询所有符合条件的记录
|
||||||
|
List<SysFile> sysFiles = sysFileMapper.selectList(queryWrapper);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(sysFiles)) {
|
||||||
|
log.warn("MySQL 中未找到符合条件的记录");
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 随机选择
|
||||||
|
Collections.shuffle(sysFiles);
|
||||||
|
|
||||||
|
List<String> urls = sysFiles.stream()
|
||||||
|
.map(SysFile::getUrl)
|
||||||
|
.filter(StringUtils::isNotBlank)
|
||||||
|
.limit(topK)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
log.info("从 MySQL 为新用户查询到 {} 条记录", urls.size());
|
||||||
|
return urls;
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("从 MySQL 查询新用户推荐失败", e);
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据类别筛选候选结果
|
||||||
|
*/
|
||||||
|
private List<RecommendationCandidate> filterByCategory(
|
||||||
|
List<RecommendationCandidate> candidates, String category) {
|
||||||
|
// 从 URL 中提取类别信息
|
||||||
|
// URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
|
||||||
|
return candidates.stream()
|
||||||
|
.filter(candidate -> {
|
||||||
|
String url = candidate.getUrl();
|
||||||
|
if (StringUtils.isBlank(url) || !url.contains("/")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// 提取类别部分
|
||||||
|
String[] parts = url.split("/");
|
||||||
|
if (parts.length < 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
String sex = parts[parts.length - 3]; // female 或 male
|
||||||
|
String cat = parts[parts.length - 2]; // category
|
||||||
|
String expectedCategory = sex + "_" + cat;
|
||||||
|
return expectedCategory.equals(category);
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 多样性优化:避免推荐过于相似的结果
|
||||||
|
* 使用聚类或相似度阈值来增加多样性
|
||||||
|
*/
|
||||||
|
private List<RecommendationCandidate> diversifyResults(
|
||||||
|
List<RecommendationCandidate> candidates, int targetSize) {
|
||||||
|
if (candidates.size() <= targetSize) {
|
||||||
|
return candidates;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 简单策略:按相似度分数排序,然后每隔几个取一个,增加多样性
|
||||||
|
// 更复杂的策略可以使用聚类算法
|
||||||
|
|
||||||
|
List<RecommendationCandidate> diversified = new ArrayList<>();
|
||||||
|
int step = Math.max(1, candidates.size() / targetSize);
|
||||||
|
|
||||||
|
for (int i = 0; i < candidates.size() && diversified.size() < targetSize; i += step) {
|
||||||
|
diversified.add(candidates.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果还不够,补充剩余的结果
|
||||||
|
if (diversified.size() < targetSize) {
|
||||||
|
for (RecommendationCandidate candidate : candidates) {
|
||||||
|
if (diversified.size() >= targetSize) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!diversified.contains(candidate)) {
|
||||||
|
diversified.add(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return diversified;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 推荐候选结果内部类
|
||||||
|
*/
|
||||||
|
private static class RecommendationCandidate {
|
||||||
|
private final String url;
|
||||||
|
private final String style;
|
||||||
|
private final Float similarityScore; // L2距离,越小越相似
|
||||||
|
|
||||||
|
public RecommendationCandidate(String url, String style, Float similarityScore) {
|
||||||
|
this.url = url;
|
||||||
|
this.style = style;
|
||||||
|
this.similarityScore = similarityScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getUrl() {
|
||||||
|
return url;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getStyle() {
|
||||||
|
return style;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Float getSimilarityScore() {
|
||||||
|
return similarityScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
RecommendationCandidate that = (RecommendationCandidate) o;
|
||||||
|
return Objects.equals(url, that.url);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resolveCategory(SysFile sysFile) {
|
||||||
|
String level3 = sysFile.getLevel3Type();
|
||||||
|
String level2 = sysFile.getLevel2Type();
|
||||||
|
if (StringUtils.isNotBlank(level3) && StringUtils.isNotBlank(level2)) {
|
||||||
|
return (level3 + "_" + level2).toLowerCase();
|
||||||
|
}
|
||||||
|
String url = sysFile.getUrl();
|
||||||
|
if (StringUtils.isNotBlank(url)) {
|
||||||
|
String[] parts = url.split("/");
|
||||||
|
if (parts.length >= 4) {
|
||||||
|
String sex = parts[parts.length - 3];
|
||||||
|
String category = parts[parts.length - 2];
|
||||||
|
return (sex + "_" + category).toLowerCase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package com.ai.da.service.impl;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.transform.Normalize;
|
||||||
|
import ai.djl.modality.cv.transform.Resize;
|
||||||
|
import ai.djl.modality.cv.transform.ToTensor;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.ndarray.NDManager;
|
||||||
|
import ai.djl.translate.Batchifier;
|
||||||
|
import ai.djl.translate.Pipeline;
|
||||||
|
import ai.djl.translate.Translator;
|
||||||
|
import ai.djl.translate.TranslatorContext;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ResNet 特征提取 Translator
|
||||||
|
* 移除最后的全连接层,只保留特征提取部分(2048维)
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class ResNetFeatureTranslator implements Translator<Image, float[]> {
|
||||||
|
|
||||||
|
private static final int IMAGE_SIZE = 224;
|
||||||
|
private static final Pipeline PIPELINE = new Pipeline()
|
||||||
|
.add(new Resize(IMAGE_SIZE, IMAGE_SIZE))
|
||||||
|
.add(new ToTensor())
|
||||||
|
.add(new Normalize(
|
||||||
|
new float[]{0.485f, 0.456f, 0.406f},
|
||||||
|
new float[]{0.229f, 0.224f, 0.225f}
|
||||||
|
));
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||||
|
NDManager manager = ctx.getNDManager();
|
||||||
|
NDArray array = input.toNDArray(manager);
|
||||||
|
NDList ndList = new NDList(array);
|
||||||
|
return PIPELINE.transform(ndList);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] processOutput(TranslatorContext ctx, NDList list) {
|
||||||
|
NDArray output = list.singletonOrThrow();
|
||||||
|
|
||||||
|
// ResNet50 特征层输出形状应该是:
|
||||||
|
// - [1, 2048, 1, 1] - 全局平均池化后的特征(需要 flatten)
|
||||||
|
// - [1, 2048] - 已 flatten 的特征层输出
|
||||||
|
|
||||||
|
long[] shape = output.getShape().getShape();
|
||||||
|
|
||||||
|
// 如果是 4 维,说明是 [1, 2048, 1, 1],需要 flatten
|
||||||
|
if (shape.length == 4) {
|
||||||
|
output = output.flatten();
|
||||||
|
}
|
||||||
|
// 如果是 2 维,检查维度是否正确
|
||||||
|
else if (shape.length == 2) {
|
||||||
|
if (shape[1] == 1000) {
|
||||||
|
// 如果仍然是分类层输出(1000维),说明模型加载有问题
|
||||||
|
log.error("模型输出仍然是分类层(1000维),模型加载可能失败");
|
||||||
|
throw new IllegalStateException("模型输出维度错误:期望 2048 维,实际 1000 维。请检查模型是否正确移除了全连接层。");
|
||||||
|
}
|
||||||
|
// 如果已经是 [1, 2048],直接使用
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为 float 数组
|
||||||
|
float[] result = output.toFloatArray();
|
||||||
|
|
||||||
|
// 验证维度必须是 2048
|
||||||
|
if (result.length != 2048) {
|
||||||
|
log.error("特征向量维度错误: 期望 2048, 实际 {}", result.length);
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format("特征向量维度错误: 期望 2048, 实际 %d", result.length)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Batchifier getBatchifier() {
|
||||||
|
return Batchifier.STACK;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Reference in New Issue
Block a user