特征推荐初次提交
This commit is contained in:
43
pom.xml
43
pom.xml
@@ -422,6 +422,49 @@
|
||||
<artifactId>bcpkix-jdk18on</artifactId>
|
||||
<version>1.78.1</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Milvus Java SDK -->
|
||||
<dependency>
|
||||
<groupId>io.milvus</groupId>
|
||||
<artifactId>milvus-sdk-java</artifactId>
|
||||
<version>2.3.4</version>
|
||||
</dependency>
|
||||
|
||||
<!-- DJL (Deep Java Library) for Deep Learning -->
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>api</artifactId>
|
||||
<version>0.24.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-engine</artifactId>
|
||||
<version>0.24.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-native-cpu</artifactId>
|
||||
<version>2.0.1</version>
|
||||
<classifier>linux-x86_64</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-native-cpu</artifactId>
|
||||
<version>2.0.1</version>
|
||||
<classifier>win-x86_64</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-native-cpu</artifactId>
|
||||
<version>2.0.1</version>
|
||||
<classifier>osx-x86_64</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-native-cpu</artifactId>
|
||||
<version>2.0.1</version>
|
||||
<classifier>osx-aarch64</classifier>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
||||
45
src/main/java/com/ai/da/common/config/MilvusConfig.java
Normal file
45
src/main/java/com/ai/da/common/config/MilvusConfig.java
Normal file
@@ -0,0 +1,45 @@
|
||||
package com.ai.da.common.config;
|
||||
|
||||
import io.milvus.client.MilvusServiceClient;
|
||||
import io.milvus.param.ConnectParam;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
/**
|
||||
* Milvus 向量数据库配置
|
||||
*/
|
||||
@Slf4j
|
||||
@Configuration
|
||||
public class MilvusConfig {
|
||||
|
||||
@Value("${milvus.host:localhost}")
|
||||
private String host;
|
||||
|
||||
@Value("${milvus.port:19530}")
|
||||
private Integer port;
|
||||
|
||||
@Value("${milvus.username:}")
|
||||
private String username;
|
||||
|
||||
@Value("${milvus.password:}")
|
||||
private String password;
|
||||
|
||||
@Bean
|
||||
public MilvusServiceClient milvusClient() {
|
||||
ConnectParam.Builder builder = ConnectParam.newBuilder()
|
||||
.withHost(host)
|
||||
.withPort(port);
|
||||
|
||||
if (StringUtils.isNotBlank(username) && StringUtils.isNotBlank(password)) {
|
||||
builder.withAuthorization(username, password);
|
||||
}
|
||||
|
||||
MilvusServiceClient client = new MilvusServiceClient(builder.build());
|
||||
log.info("Milvus client initialized: {}:{}", host, port);
|
||||
return client;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,12 @@ public class SysFile implements Serializable {
|
||||
*/
|
||||
private String style;
|
||||
|
||||
/**
|
||||
* 是否废弃 0-否 1-是
|
||||
*/
|
||||
private Integer deprecated;
|
||||
|
||||
|
||||
public SysFile() {
|
||||
}
|
||||
|
||||
|
||||
46
src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java
Normal file
46
src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java
Normal file
@@ -0,0 +1,46 @@
|
||||
package com.ai.da.model.dto;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 推荐请求 DTO
|
||||
*/
|
||||
@Data
|
||||
public class RecommendRequestDTO {
|
||||
/**
|
||||
* 用户ID
|
||||
*/
|
||||
private Long userId;
|
||||
|
||||
/**
|
||||
* 类别(如 female_skirt, male_tops)- 可选,用于辅助筛选
|
||||
*/
|
||||
private String category;
|
||||
|
||||
/**
|
||||
* 风格样式(可选)- 用于辅助筛选
|
||||
*/
|
||||
private String style;
|
||||
|
||||
/**
|
||||
* 返回数量
|
||||
*/
|
||||
private Integer topK = 10;
|
||||
|
||||
/**
|
||||
* 是否只返回未淘汰的图片
|
||||
*/
|
||||
private Boolean onlyActive = true;
|
||||
|
||||
/**
|
||||
* 向量搜索的候选数量(第一阶段)
|
||||
* 建议值:50-100,用于后续筛选和多样性优化
|
||||
*/
|
||||
private Integer candidateSize = 50;
|
||||
|
||||
/**
|
||||
* 是否启用多样性优化(避免推荐过于相似的结果)
|
||||
*/
|
||||
private Boolean enableDiversity = true;
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import com.ai.da.model.vo.*;
|
||||
import com.ai.da.python.vo.*;
|
||||
import com.ai.da.service.DesignHistoryService;
|
||||
import com.ai.da.service.PythonTAllInfoService;
|
||||
import com.ai.da.service.RecommendationService;
|
||||
import com.ai.da.service.SysFileService;
|
||||
import com.alibaba.fastjson.*;
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
@@ -722,6 +723,9 @@ public class PythonService {
|
||||
@Resource
|
||||
private AttributeRetrievalMapper attributeRetrievalMapper;
|
||||
|
||||
@Resource
|
||||
private RecommendationService recommendationService;
|
||||
|
||||
private List<CollectionElement> getSystemSketchPool(JSONObject attributeRecognition, String styleCategory, String modelSex, int poolNum, String style) {
|
||||
/**
|
||||
* female trousers->female_pants
|
||||
@@ -3955,6 +3959,27 @@ public class PythonService {
|
||||
}
|
||||
|
||||
public List<String> getSystemSketchByCategory(String category, Long brandId, Double brandScale,String style) {
|
||||
AuthPrincipalVo userHolder = UserContext.getUserHolder();
|
||||
|
||||
// 优先使用基于 Milvus 的推荐系统
|
||||
try {
|
||||
com.ai.da.model.dto.RecommendRequestDTO request = new com.ai.da.model.dto.RecommendRequestDTO();
|
||||
request.setUserId(userHolder.getId());
|
||||
request.setCategory(category);
|
||||
request.setStyle(style);
|
||||
request.setTopK(1);
|
||||
request.setOnlyActive(true);
|
||||
|
||||
List<String> recommendedUrls = recommendationService.recommend(request);
|
||||
if (!CollectionUtils.isEmpty(recommendedUrls)) {
|
||||
log.info("使用 Milvus 推荐系统返回 {} 个结果", recommendedUrls.size());
|
||||
return recommendedUrls;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("Milvus 推荐失败,降级到备用方案: {}", e.getMessage());
|
||||
}
|
||||
|
||||
// 降级方案1: 使用 attribute_retrieval_style 表
|
||||
//******3.1.2版本临时使用java推荐方案去解决style未使用的问题**********
|
||||
try {
|
||||
//使用新库attribute_retrieval_style,表命名修改为elementVO.getModelSex().toLowerCase() + "_" + styleCategory.toLowerCase()比如female_skirt,与传入的category保持一致
|
||||
@@ -3974,11 +3999,11 @@ public class PythonService {
|
||||
return Arrays.asList(path);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.info("推荐失败:{}",e.getMessage());
|
||||
throw new BusinessException("system.error");
|
||||
log.info("attribute_retrieval_style 推荐失败:{}",e.getMessage());
|
||||
}
|
||||
//**********************end***********************************
|
||||
AuthPrincipalVo userHolder = UserContext.getUserHolder();
|
||||
|
||||
// 降级方案2: 调用 Python 服务(原有方案)
|
||||
|
||||
OkHttpClient client = new OkHttpClient().newBuilder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.ai.da.service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 特征提取服务接口
|
||||
* 用于从图片 URL 提取 2048 维特征向量
|
||||
*/
|
||||
public interface FeatureExtractionService {
|
||||
|
||||
/**
|
||||
* 从图片 URL 提取特征向量
|
||||
*
|
||||
* @param imageUrl 图片 URL(MinIO 路径)
|
||||
* @return 2048 维特征向量,如果提取失败返回 null
|
||||
*/
|
||||
List<Float> extractFeature(String imageUrl);
|
||||
|
||||
/**
|
||||
* 批量提取特征向量
|
||||
*
|
||||
* @param imageUrls 图片 URL 列表
|
||||
* @return 特征向量列表,与输入顺序对应
|
||||
*/
|
||||
List<List<Float>> extractFeatures(List<String> imageUrls);
|
||||
|
||||
/**
|
||||
* 计算用户偏好向量(基于用户点赞的图片向量)
|
||||
* 使用平均向量作为用户偏好
|
||||
*
|
||||
* @param userLikedVectors 用户点赞的图片向量列表
|
||||
* @return 用户偏好向量(2048维)
|
||||
*/
|
||||
List<Float> calculateUserPreferenceVector(List<List<Float>> userLikedVectors);
|
||||
}
|
||||
|
||||
34
src/main/java/com/ai/da/service/RecommendationService.java
Normal file
34
src/main/java/com/ai/da/service/RecommendationService.java
Normal file
@@ -0,0 +1,34 @@
|
||||
package com.ai.da.service;
|
||||
|
||||
import com.ai.da.model.dto.RecommendRequestDTO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 推荐服务接口
|
||||
*/
|
||||
public interface RecommendationService {
|
||||
|
||||
/**
|
||||
* 根据用户偏好推荐系统 sketch
|
||||
*
|
||||
* @param request 推荐请求
|
||||
* @return 推荐的 URL 列表
|
||||
*/
|
||||
List<String> recommend(RecommendRequestDTO request);
|
||||
|
||||
/**
|
||||
* 同步 t_sys_file 数据到 Milvus
|
||||
* 从 t_sys_file 表读取所有系统 sketch,提取特征向量并存储到 Milvus
|
||||
*/
|
||||
void syncSysFileToMilvus();
|
||||
|
||||
/**
|
||||
* 更新单个文件的向量(当文件更新时调用)
|
||||
*
|
||||
* @param sysFileId 系统文件ID
|
||||
* @param url 文件URL
|
||||
*/
|
||||
void updateVector(Long sysFileId, String url);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
package com.ai.da.service.impl;
|
||||
|
||||
import ai.djl.Application;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelZoo;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.training.util.ProgressBar;
|
||||
import ai.djl.translate.Translator;
|
||||
import com.ai.da.common.utils.MinioUtil;
|
||||
import com.ai.da.service.FeatureExtractionService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import jakarta.annotation.PreDestroy;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 特征提取服务实现类
|
||||
* 使用 DJL (Deep Java Library) 框架加载 ResNet50 模型提取特征向量
|
||||
*/
|
||||
@Slf4j
|
||||
@Service
|
||||
public class FeatureExtractionServiceImpl implements FeatureExtractionService {
|
||||
|
||||
private static final int DIMENSION = 2048; // ResNet50 特征向量维度
|
||||
|
||||
@Autowired
|
||||
private MinioUtil minioUtil;
|
||||
|
||||
private ZooModel<Image, float[]> model;
|
||||
private Predictor<Image, float[]> predictor;
|
||||
|
||||
/**
|
||||
* 初始化 ResNet50 模型
|
||||
* 加载预训练模型并移除最后的全连接层,以获取2048维特征向量
|
||||
*/
|
||||
@PostConstruct
|
||||
public void initModel() {
|
||||
try {
|
||||
log.info("开始加载 ResNet50 模型(特征提取模式,2048维)...");
|
||||
|
||||
// 1. 加载完整的预训练 ResNet50 模型
|
||||
Criteria<Image, float[]> criteria = Criteria.builder()
|
||||
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
|
||||
.setTypes(Image.class, float[].class)
|
||||
.optModelName("resnet50")
|
||||
.optEngine("PyTorch")
|
||||
.optProgress(new ProgressBar())
|
||||
.build();
|
||||
|
||||
// 2. 创建特征提取 Translator
|
||||
Translator<Image, float[]> translator = new ResNetFeatureTranslator();
|
||||
|
||||
// 3. 加载模型
|
||||
model = ModelZoo.loadModel(criteria);
|
||||
|
||||
// 4. 创建 Predictor
|
||||
predictor = model.newPredictor(translator);
|
||||
|
||||
log.info("ResNet50 模型加载完成(特征提取模式,2048维)");
|
||||
} catch (Exception e) {
|
||||
log.error("加载 ResNet50 模型失败", e);
|
||||
throw new RuntimeException("模型加载失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理资源
|
||||
*/
|
||||
@PreDestroy
|
||||
public void cleanup() {
|
||||
if (predictor != null) {
|
||||
predictor.close();
|
||||
}
|
||||
if (model != null) {
|
||||
model.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从图片 URL 提取特征向量
|
||||
*
|
||||
* @param imageUrl MinIO 路径,格式为 "bucket_name/path/to/image.jpg"
|
||||
* 例如: "aida-sys-image/images/female/blouse/0825001474.jpg"
|
||||
* @return 2048 维特征向量,如果提取失败返回 null
|
||||
*/
|
||||
@Override
|
||||
public List<Float> extractFeature(String imageUrl) {
|
||||
if (predictor == null) {
|
||||
log.error("模型未初始化,无法提取特征");
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. 从 MinIO 获取图片
|
||||
Image image = getImageFromMinio(imageUrl);
|
||||
if (image == null) {
|
||||
log.warn("无法从 MinIO 获取图片: {}", imageUrl);
|
||||
return null;
|
||||
}
|
||||
|
||||
// 2. 使用模型提取特征
|
||||
float[] featureArray = predictor.predict(image);
|
||||
|
||||
// 3. 转换为 List<Float>
|
||||
List<Float> featureVector = new ArrayList<>(DIMENSION);
|
||||
for (float value : featureArray) {
|
||||
featureVector.add(value);
|
||||
}
|
||||
|
||||
log.debug("成功提取特征向量: {}, 维度: {}", imageUrl, featureVector.size());
|
||||
return featureVector;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("提取特征失败: {}", imageUrl, e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 MinIO 获取图片并转换为 DJL Image 对象
|
||||
*/
|
||||
private Image getImageFromMinio(String imageUrl) {
|
||||
InputStream imageStream = null;
|
||||
try {
|
||||
// 解析路径:bucket_name/path/to/image.jpg
|
||||
int index = imageUrl.indexOf("/");
|
||||
if (index == -1 || index == imageUrl.length() - 1) {
|
||||
log.warn("无效的图片路径格式: {}", imageUrl);
|
||||
return null;
|
||||
}
|
||||
|
||||
String bucketName = imageUrl.substring(0, index);
|
||||
String objectName = imageUrl.substring(index + 1);
|
||||
|
||||
// 从 MinIO 获取图片流
|
||||
// 使用 MinioUtil 的 download 方法获取 InputStream
|
||||
// download 方法可能抛出 MinioException 和 IOException
|
||||
imageStream = minioUtil.download(bucketName, objectName);
|
||||
if (imageStream == null) {
|
||||
log.warn("无法从 MinIO 获取图片流: bucket={}, object={}", bucketName, objectName);
|
||||
return null;
|
||||
}
|
||||
|
||||
// 转换为 DJL Image 对象
|
||||
Image image = ImageFactory.getInstance().fromInputStream(imageStream);
|
||||
|
||||
return image;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("从 MinIO 获取图片失败: {}", imageUrl, e);
|
||||
return null;
|
||||
} finally {
|
||||
if (imageStream != null) {
|
||||
try {
|
||||
imageStream.close();
|
||||
} catch (Exception e) {
|
||||
log.warn("关闭图片流失败", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<List<Float>> extractFeatures(List<String> imageUrls) {
|
||||
if (CollectionUtils.isEmpty(imageUrls)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
return imageUrls.stream()
|
||||
.map(this::extractFeature)
|
||||
.filter(java.util.Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> calculateUserPreferenceVector(List<List<Float>> userLikedVectors) {
|
||||
if (CollectionUtils.isEmpty(userLikedVectors)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
int dimension = userLikedVectors.get(0).size();
|
||||
List<Float> preferenceVector = new ArrayList<>(dimension);
|
||||
|
||||
// 初始化为零向量
|
||||
for (int i = 0; i < dimension; i++) {
|
||||
preferenceVector.add(0.0f);
|
||||
}
|
||||
|
||||
// 计算平均向量
|
||||
for (List<Float> vector : userLikedVectors) {
|
||||
if (vector.size() != dimension) {
|
||||
log.warn("向量维度不匹配: 期望 {}, 实际 {}", dimension, vector.size());
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = 0; i < dimension; i++) {
|
||||
preferenceVector.set(i, preferenceVector.get(i) + vector.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
// 求平均值
|
||||
int count = userLikedVectors.size();
|
||||
for (int i = 0; i < dimension; i++) {
|
||||
preferenceVector.set(i, preferenceVector.get(i) / count);
|
||||
}
|
||||
|
||||
return preferenceVector;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,679 @@
|
||||
package com.ai.da.service.impl;
|
||||
|
||||
import com.ai.da.common.config.MilvusConfig;
|
||||
import com.ai.da.mapper.primary.SysFileMapper;
|
||||
import com.ai.da.mapper.primary.UserPreferenceLogMapper;
|
||||
import com.ai.da.mapper.primary.entity.SysFile;
|
||||
import com.ai.da.mapper.primary.entity.UserPreferenceLogTest;
|
||||
import com.ai.da.model.dto.RecommendRequestDTO;
|
||||
import com.ai.da.service.FeatureExtractionService;
|
||||
import com.ai.da.service.RecommendationService;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import io.milvus.client.MilvusServiceClient;
|
||||
import io.milvus.grpc.DataType;
|
||||
import io.milvus.grpc.SearchResults;
|
||||
import io.milvus.param.IndexType;
|
||||
import io.milvus.param.MetricType;
|
||||
import io.milvus.param.R;
|
||||
import io.milvus.param.RpcStatus;
|
||||
import io.milvus.param.collection.*;
|
||||
import io.milvus.param.dml.InsertParam;
|
||||
import io.milvus.param.dml.QueryParam;
|
||||
import io.milvus.param.dml.SearchParam;
|
||||
import io.milvus.param.index.CreateIndexParam;
|
||||
import io.milvus.grpc.MutationResult;
|
||||
import io.milvus.grpc.QueryResults;
|
||||
import io.milvus.response.FieldDataWrapper;
|
||||
import io.milvus.response.QueryResultsWrapper;
|
||||
import io.milvus.response.SearchResultsWrapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import jakarta.annotation.Resource;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 推荐服务实现类
|
||||
* 基于 Milvus 向量数据库的推荐系统
|
||||
*/
|
||||
@Slf4j
|
||||
@Service
|
||||
public class RecommendationServiceImpl implements RecommendationService {
|
||||
|
||||
private static final String COLLECTION_NAME = "sketch_recommendation";
|
||||
private static final int DIMENSION = 2048; // ResNet50 特征向量维度
|
||||
private static final String VECTOR_FIELD = "feature_vector";
|
||||
private static final String URL_FIELD = "url";
|
||||
private static final String STYLE_FIELD = "style";
|
||||
private static final String CATEGORY_FIELD = "category";
|
||||
private static final String DEPRECATED_FIELD = "deprecated";
|
||||
private static final String SYS_FILE_ID_FIELD = "sys_file_id";
|
||||
|
||||
@Resource
|
||||
private MilvusServiceClient milvusClient;
|
||||
|
||||
@Resource
|
||||
private SysFileMapper sysFileMapper;
|
||||
|
||||
@Resource
|
||||
private UserPreferenceLogMapper userPreferenceLogMapper;
|
||||
|
||||
@Resource
|
||||
private FeatureExtractionService featureExtractionService;
|
||||
|
||||
/**
|
||||
* 初始化 Collection(如果不存在则创建)
|
||||
*/
|
||||
@PostConstruct
|
||||
public void initCollection() {
|
||||
try {
|
||||
// 检查 Collection 是否存在
|
||||
R<Boolean> hasCollectionR = milvusClient.hasCollection(
|
||||
HasCollectionParam.newBuilder()
|
||||
.withCollectionName(COLLECTION_NAME)
|
||||
.build()
|
||||
);
|
||||
|
||||
if (!hasCollectionR.getData()) {
|
||||
log.info("Collection {} 不存在,开始创建...", COLLECTION_NAME);
|
||||
createCollection();
|
||||
createIndex();
|
||||
log.info("Collection {} 创建完成", COLLECTION_NAME);
|
||||
} else {
|
||||
log.info("Collection {} 已存在", COLLECTION_NAME);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("初始化 Collection 失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 Collection
|
||||
*/
|
||||
private void createCollection() {
|
||||
// 定义字段
|
||||
List<FieldType> fields = Arrays.asList(
|
||||
FieldType.newBuilder()
|
||||
.withName(SYS_FILE_ID_FIELD)
|
||||
.withDataType(DataType.Int64)
|
||||
.withPrimaryKey(true)
|
||||
.withAutoID(false)
|
||||
.build(),
|
||||
FieldType.newBuilder()
|
||||
.withName(URL_FIELD)
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(500)
|
||||
.build(),
|
||||
FieldType.newBuilder()
|
||||
.withName(STYLE_FIELD)
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(100)
|
||||
.build(),
|
||||
FieldType.newBuilder()
|
||||
.withName(CATEGORY_FIELD)
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(100)
|
||||
.build(),
|
||||
FieldType.newBuilder()
|
||||
.withName(DEPRECATED_FIELD)
|
||||
.withDataType(DataType.Int8)
|
||||
.build(),
|
||||
FieldType.newBuilder()
|
||||
.withName(VECTOR_FIELD)
|
||||
.withDataType(DataType.FloatVector)
|
||||
.withDimension(DIMENSION)
|
||||
.build()
|
||||
);
|
||||
|
||||
// 创建 Collection
|
||||
CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
|
||||
.withCollectionName(COLLECTION_NAME)
|
||||
.withDescription("系统 sketch 推荐向量库")
|
||||
.withShardsNum(2)
|
||||
.withFieldTypes(fields)
|
||||
.build();
|
||||
|
||||
R<RpcStatus> createR = milvusClient.createCollection(createParam);
|
||||
if (createR.getStatus() != R.Status.Success.getCode()) {
|
||||
throw new RuntimeException("创建 Collection 失败: " + createR.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建索引
|
||||
*/
|
||||
private void createIndex() {
|
||||
CreateIndexParam indexParam = CreateIndexParam.newBuilder()
|
||||
.withCollectionName(COLLECTION_NAME)
|
||||
.withFieldName(VECTOR_FIELD)
|
||||
.withIndexType(IndexType.IVF_FLAT)
|
||||
.withMetricType(MetricType.L2)
|
||||
.withExtraParam("{\"nlist\":1024}")
|
||||
.build();
|
||||
|
||||
R<RpcStatus> indexR = milvusClient.createIndex(indexParam);
|
||||
if (indexR.getStatus() != R.Status.Success.getCode()) {
|
||||
throw new RuntimeException("创建索引失败: " + indexR.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> recommend(RecommendRequestDTO request) {
|
||||
try {
|
||||
// 1. 从 user_preference_log_test 获取用户点赞的 path 列表
|
||||
QueryWrapper<UserPreferenceLogTest> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda().eq(UserPreferenceLogTest::getAccountId, request.getUserId());
|
||||
List<UserPreferenceLogTest> userLikes = userPreferenceLogMapper.selectList(queryWrapper);
|
||||
|
||||
// 新用户处理:如果没有点赞记录,使用降级推荐策略
|
||||
if (CollectionUtils.isEmpty(userLikes)) {
|
||||
log.info("用户 {} 没有点赞记录,使用新用户推荐策略(基于 style 和 category)", request.getUserId());
|
||||
return recommendForNewUser(request);
|
||||
}
|
||||
|
||||
// 2. 提取用户点赞图片的特征向量
|
||||
List<String> likedPaths = userLikes.stream()
|
||||
.map(UserPreferenceLogTest::getPath)
|
||||
.filter(StringUtils::isNotBlank)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<List<Float>> likedVectors = featureExtractionService.extractFeatures(likedPaths);
|
||||
likedVectors = likedVectors.stream()
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isEmpty(likedVectors)) {
|
||||
log.warn("用户 {} 无法提取特征向量", request.getUserId());
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// 3. 计算用户偏好向量(平均向量)
|
||||
List<Float> userPreferenceVector = featureExtractionService.calculateUserPreferenceVector(likedVectors);
|
||||
|
||||
// 4. 构建搜索参数(混合推荐策略)
|
||||
// 策略:向量相似度搜索为主 + 分类筛选为辅
|
||||
|
||||
List<String> searchOutputFields = Arrays.asList(URL_FIELD, STYLE_FIELD, CATEGORY_FIELD);
|
||||
List<List<Float>> searchVectors = Collections.singletonList(userPreferenceVector);
|
||||
|
||||
// 构建过滤表达式(基础筛选)
|
||||
StringBuilder exprBuilder = new StringBuilder();
|
||||
List<String> conditions = new ArrayList<>();
|
||||
|
||||
// 基础筛选:只返回未淘汰的图片
|
||||
if (request.getOnlyActive() != null && request.getOnlyActive()) {
|
||||
conditions.add(DEPRECATED_FIELD + " == 0");
|
||||
}
|
||||
|
||||
// 可选筛选:风格(如果指定)
|
||||
if (StringUtils.isNotBlank(request.getStyle())) {
|
||||
conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\"");
|
||||
}
|
||||
if (StringUtils.isNotBlank(request.getCategory())) {
|
||||
conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\"");
|
||||
}
|
||||
|
||||
if (!conditions.isEmpty()) {
|
||||
exprBuilder.append(String.join(" && ", conditions));
|
||||
}
|
||||
|
||||
// 第一阶段:向量搜索获取大量候选(用于后续筛选和多样性优化)
|
||||
int candidateSize = request.getCandidateSize() != null ?
|
||||
Math.max(request.getCandidateSize(), request.getTopK() * 3) :
|
||||
Math.max(50, request.getTopK() * 5);
|
||||
|
||||
SearchParam searchParam = SearchParam.newBuilder()
|
||||
.withCollectionName(COLLECTION_NAME)
|
||||
.withMetricType(MetricType.L2) // L2距离,越小越相似
|
||||
.withOutFields(searchOutputFields)
|
||||
.withTopK(candidateSize) // 获取更多候选用于后续处理
|
||||
.withVectors(searchVectors)
|
||||
.withVectorFieldName(VECTOR_FIELD)
|
||||
.withExpr(exprBuilder.length() > 0 ? exprBuilder.toString() : null)
|
||||
.withParams("{\"nprobe\":10}")
|
||||
.build();
|
||||
|
||||
// 5. 执行搜索
|
||||
R<SearchResults> searchR = milvusClient.search(searchParam);
|
||||
if (searchR.getStatus() != R.Status.Success.getCode()) {
|
||||
log.error("搜索失败: {}", searchR.getMessage());
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// 6. 解析结果并进行后处理
|
||||
SearchResultsWrapper wrapper = new SearchResultsWrapper(searchR.getData().getResults());
|
||||
|
||||
// 6.1 提取候选结果(包含相似度分数)
|
||||
List<RecommendationCandidate> candidates = new ArrayList<>();
|
||||
FieldDataWrapper urlFieldWrapper = wrapper.getFieldWrapper(URL_FIELD);
|
||||
FieldDataWrapper styleFieldWrapper = wrapper.getFieldWrapper(STYLE_FIELD);
|
||||
List<SearchResultsWrapper.IDScore> idScores = wrapper.getIDScore(0);
|
||||
|
||||
for (int i = 0; i < idScores.size(); i++) {
|
||||
String url = (String) urlFieldWrapper.getFieldData().get(i);
|
||||
String style = (String) styleFieldWrapper.getFieldData().get(i);
|
||||
Float score = idScores.get(i).getScore(); // L2距离,越小越相似
|
||||
|
||||
if (StringUtils.isNotBlank(url)) {
|
||||
candidates.add(new RecommendationCandidate(url, style, score));
|
||||
}
|
||||
}
|
||||
|
||||
// // 6.2 分类筛选(如果指定了 category)
|
||||
// if (StringUtils.isNotBlank(request.getCategory())) {
|
||||
// candidates = filterByCategory(candidates, request.getCategory());
|
||||
// }
|
||||
|
||||
// 6.3 多样性优化(避免推荐过于相似的结果)
|
||||
if (request.getEnableDiversity() != null && request.getEnableDiversity()) {
|
||||
candidates = diversifyResults(candidates, request.getTopK() != null ? request.getTopK() : 10);
|
||||
}
|
||||
|
||||
// 6.4 取 Top-K
|
||||
int topK = request.getTopK() != null ? request.getTopK() : 1;
|
||||
List<String> recommendedUrls = candidates.stream()
|
||||
.limit(topK)
|
||||
.map(RecommendationCandidate::getUrl)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
log.info("为用户 {} 推荐了 {} 个结果(从 {} 个候选中筛选)",
|
||||
request.getUserId(), recommendedUrls.size(), idScores.size());
|
||||
return recommendedUrls;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("推荐失败", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void syncSysFileToMilvus() {
|
||||
log.info("开始同步 t_sys_file 数据到 Milvus");
|
||||
|
||||
try {
|
||||
// 1. 加载 Collection
|
||||
LoadCollectionParam loadParam = LoadCollectionParam.newBuilder()
|
||||
.withCollectionName(COLLECTION_NAME)
|
||||
.build();
|
||||
milvusClient.loadCollection(loadParam);
|
||||
|
||||
// 2. 查询所有系统 sketch
|
||||
QueryWrapper<SysFile> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda()
|
||||
.eq(SysFile::getLevel1Type, "Images").eq(SysFile::getDeprecated, 0)
|
||||
.isNotNull(SysFile::getUrl)
|
||||
.isNotNull(SysFile::getStyle)
|
||||
.ne(SysFile::getUrl, "");
|
||||
List<SysFile> sysFiles = sysFileMapper.selectList(queryWrapper);
|
||||
|
||||
log.info("找到 {} 个系统文件需要同步", sysFiles.size());
|
||||
|
||||
int successCount = 0;
|
||||
int failCount = 0;
|
||||
int batchSize = 100;
|
||||
|
||||
// 3. 批量处理
|
||||
for (int i = 0; i < sysFiles.size(); i += batchSize) {
|
||||
int end = Math.min(i + batchSize, sysFiles.size());
|
||||
List<SysFile> batch = sysFiles.subList(i, end);
|
||||
|
||||
List<Long> ids = new ArrayList<>();
|
||||
List<String> urls = new ArrayList<>();
|
||||
List<String> styles = new ArrayList<>();
|
||||
List<String> categories = new ArrayList<>();
|
||||
List<Byte> deprecatedList = new ArrayList<>();
|
||||
List<List<Float>> vectors = new ArrayList<>();
|
||||
|
||||
for (SysFile sysFile : batch) {
|
||||
try {
|
||||
// 提取特征向量
|
||||
List<Float> vector = featureExtractionService.extractFeature(sysFile.getUrl());
|
||||
if (vector == null || vector.size() != DIMENSION) {
|
||||
log.warn("文件 {} 特征提取失败,跳过", sysFile.getUrl());
|
||||
failCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
ids.add(sysFile.getId());
|
||||
urls.add(sysFile.getUrl());
|
||||
styles.add(sysFile.getStyle() != null ? sysFile.getStyle() : "");
|
||||
categories.add(resolveCategory(sysFile));
|
||||
deprecatedList.add((byte) (sysFile.getDeprecated() != null ? sysFile.getDeprecated() : 0));
|
||||
vectors.add(vector);
|
||||
successCount++;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("处理文件 {} 失败", sysFile.getUrl(), e);
|
||||
failCount++;
|
||||
}
|
||||
}
|
||||
|
||||
// 批量插入
|
||||
if (!ids.isEmpty()) {
|
||||
List<InsertParam.Field> fields = Arrays.asList(
|
||||
new InsertParam.Field(SYS_FILE_ID_FIELD, ids),
|
||||
new InsertParam.Field(URL_FIELD, urls),
|
||||
new InsertParam.Field(STYLE_FIELD, styles),
|
||||
new InsertParam.Field(DEPRECATED_FIELD, deprecatedList),
|
||||
new InsertParam.Field(CATEGORY_FIELD, categories),
|
||||
new InsertParam.Field(VECTOR_FIELD, vectors)
|
||||
);
|
||||
|
||||
InsertParam insertParam = InsertParam.newBuilder()
|
||||
.withCollectionName(COLLECTION_NAME)
|
||||
.withFields(fields)
|
||||
.build();
|
||||
|
||||
R<MutationResult> insertR = milvusClient.insert(insertParam);
|
||||
if (insertR.getStatus() == R.Status.Success.getCode()) {
|
||||
log.info("批量插入成功: {}/{}", end, sysFiles.size());
|
||||
} else {
|
||||
log.error("批量插入失败: {}", insertR.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 刷新数据
|
||||
FlushParam flushParam = FlushParam.newBuilder()
|
||||
.withCollectionNames(Collections.singletonList(COLLECTION_NAME))
|
||||
.build();
|
||||
milvusClient.flush(flushParam);
|
||||
|
||||
log.info("同步完成: 成功 {}, 失败 {}", successCount, failCount);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("同步数据到 Milvus 失败", e);
|
||||
throw new RuntimeException("同步失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateVector(Long sysFileId, String url) {
|
||||
// TODO: 实现单个向量更新逻辑
|
||||
// 需要先删除旧向量,再插入新向量
|
||||
log.info("更新向量: sysFileId={}, url={}", sysFileId, url);
|
||||
}
|
||||
|
||||
/**
|
||||
* 新用户推荐策略:当用户没有点赞记录时,根据 style 和 category 随机推荐
|
||||
*
|
||||
* @param request 推荐请求
|
||||
* @return 推荐的 URL 列表
|
||||
*/
|
||||
private List<String> recommendForNewUser(RecommendRequestDTO request) {
|
||||
try {
|
||||
int topK = request.getTopK() != null ? request.getTopK() : 10;
|
||||
|
||||
List<String> milvusResults = recommendFromMilvusForNewUser(request, topK);
|
||||
if (CollectionUtils.isEmpty(milvusResults)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
Collections.shuffle(milvusResults);
|
||||
return milvusResults.stream().limit(topK).collect(Collectors.toList());
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("新用户推荐失败", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 Milvus 中为新用户推荐(根据 style 和 category 查询)
|
||||
*/
|
||||
private List<String> recommendFromMilvusForNewUser(RecommendRequestDTO request, int limit) {
|
||||
try {
|
||||
// 构建过滤表达式
|
||||
StringBuilder exprBuilder = new StringBuilder();
|
||||
List<String> conditions = new ArrayList<>();
|
||||
|
||||
// 基础筛选:只返回未淘汰的图片
|
||||
if (request.getOnlyActive() != null && request.getOnlyActive()) {
|
||||
conditions.add(DEPRECATED_FIELD + " == 0");
|
||||
}
|
||||
|
||||
// 风格筛选(如果指定)
|
||||
if (StringUtils.isNotBlank(request.getStyle())) {
|
||||
conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\"");
|
||||
}
|
||||
if (StringUtils.isNotBlank(request.getCategory())) {
|
||||
conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\"");
|
||||
}
|
||||
|
||||
if (conditions.isEmpty()) {
|
||||
// 如果没有筛选条件,查询所有未淘汰的
|
||||
conditions.add(DEPRECATED_FIELD + " == 0");
|
||||
}
|
||||
|
||||
exprBuilder.append(String.join(" && ", conditions));
|
||||
|
||||
// 从 Milvus 查询符合条件的记录
|
||||
QueryParam queryParam = QueryParam.newBuilder()
|
||||
.withCollectionName(COLLECTION_NAME)
|
||||
.withExpr(exprBuilder.toString())
|
||||
.withOutFields(Arrays.asList(URL_FIELD))
|
||||
.withLimit((long) limit)
|
||||
.build();
|
||||
|
||||
R<QueryResults> queryR = milvusClient.query(queryParam);
|
||||
if (queryR.getStatus() != R.Status.Success.getCode()) {
|
||||
log.warn("从 Milvus 查询失败: {}", queryR.getMessage());
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// 解析结果
|
||||
List<String> urls = new ArrayList<>();
|
||||
QueryResultsWrapper wrapper = new QueryResultsWrapper(queryR.getData());
|
||||
for (int i = 0; i < wrapper.getRowRecords().size(); i++) {
|
||||
String url = (String) wrapper.getFieldWrapper(URL_FIELD).getFieldData().get(i);
|
||||
if (StringUtils.isNotBlank(url)) {
|
||||
urls.add(url);
|
||||
}
|
||||
}
|
||||
|
||||
// // 如果指定了 category,进行筛选
|
||||
// if (StringUtils.isNotBlank(request.getCategory()) && !urls.isEmpty()) {
|
||||
// urls = urls.stream()
|
||||
// .filter(url -> {
|
||||
// // 从 URL 中提取类别信息
|
||||
// // URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
|
||||
// String[] parts = url.split("/");
|
||||
// if (parts.length >= 4) {
|
||||
// String sex = parts[parts.length - 3];
|
||||
// String cat = parts[parts.length - 2];
|
||||
// String expectedCategory = sex + "_" + cat;
|
||||
// return expectedCategory.equals(request.getCategory());
|
||||
// }
|
||||
// return false;
|
||||
// })
|
||||
// .collect(Collectors.toList());
|
||||
// }
|
||||
|
||||
log.info("从 Milvus 为新用户查询到 {} 条记录", urls.size());
|
||||
return urls;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("从 Milvus 查询新用户推荐失败", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 MySQL 中为新用户推荐(根据 style 和 category 查询)
|
||||
*/
|
||||
private List<String> recommendFromMySQLForNewUser(RecommendRequestDTO request, int topK) {
|
||||
try {
|
||||
QueryWrapper<SysFile> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda()
|
||||
.eq(SysFile::getLevel1Type, "Images")
|
||||
.isNotNull(SysFile::getUrl)
|
||||
.ne(SysFile::getUrl, "");
|
||||
|
||||
// 风格筛选(如果指定)
|
||||
if (StringUtils.isNotBlank(request.getStyle())) {
|
||||
queryWrapper.lambda().eq(SysFile::getStyle, request.getStyle());
|
||||
}
|
||||
|
||||
// 类别筛选(如果指定)
|
||||
if (StringUtils.isNotBlank(request.getCategory())) {
|
||||
// category 格式: female_skirt, male_tops 等
|
||||
String[] parts = request.getCategory().split("_");
|
||||
if (parts.length == 2) {
|
||||
String sex = parts[0]; // female 或 male
|
||||
String category = parts[1]; // skirt, tops 等
|
||||
|
||||
// 从 URL 中筛选类别
|
||||
// URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
|
||||
queryWrapper.lambda().like(SysFile::getUrl, "/" + sex + "/" + category + "/");
|
||||
}
|
||||
}
|
||||
|
||||
// 查询所有符合条件的记录
|
||||
List<SysFile> sysFiles = sysFileMapper.selectList(queryWrapper);
|
||||
|
||||
if (CollectionUtils.isEmpty(sysFiles)) {
|
||||
log.warn("MySQL 中未找到符合条件的记录");
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// 随机选择
|
||||
Collections.shuffle(sysFiles);
|
||||
|
||||
List<String> urls = sysFiles.stream()
|
||||
.map(SysFile::getUrl)
|
||||
.filter(StringUtils::isNotBlank)
|
||||
.limit(topK)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
log.info("从 MySQL 为新用户查询到 {} 条记录", urls.size());
|
||||
return urls;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("从 MySQL 查询新用户推荐失败", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据类别筛选候选结果
|
||||
*/
|
||||
private List<RecommendationCandidate> filterByCategory(
|
||||
List<RecommendationCandidate> candidates, String category) {
|
||||
// 从 URL 中提取类别信息
|
||||
// URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
|
||||
return candidates.stream()
|
||||
.filter(candidate -> {
|
||||
String url = candidate.getUrl();
|
||||
if (StringUtils.isBlank(url) || !url.contains("/")) {
|
||||
return false;
|
||||
}
|
||||
// 提取类别部分
|
||||
String[] parts = url.split("/");
|
||||
if (parts.length < 4) {
|
||||
return false;
|
||||
}
|
||||
String sex = parts[parts.length - 3]; // female 或 male
|
||||
String cat = parts[parts.length - 2]; // category
|
||||
String expectedCategory = sex + "_" + cat;
|
||||
return expectedCategory.equals(category);
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* 多样性优化:避免推荐过于相似的结果
|
||||
* 使用聚类或相似度阈值来增加多样性
|
||||
*/
|
||||
private List<RecommendationCandidate> diversifyResults(
|
||||
List<RecommendationCandidate> candidates, int targetSize) {
|
||||
if (candidates.size() <= targetSize) {
|
||||
return candidates;
|
||||
}
|
||||
|
||||
// 简单策略:按相似度分数排序,然后每隔几个取一个,增加多样性
|
||||
// 更复杂的策略可以使用聚类算法
|
||||
|
||||
List<RecommendationCandidate> diversified = new ArrayList<>();
|
||||
int step = Math.max(1, candidates.size() / targetSize);
|
||||
|
||||
for (int i = 0; i < candidates.size() && diversified.size() < targetSize; i += step) {
|
||||
diversified.add(candidates.get(i));
|
||||
}
|
||||
|
||||
// 如果还不够,补充剩余的结果
|
||||
if (diversified.size() < targetSize) {
|
||||
for (RecommendationCandidate candidate : candidates) {
|
||||
if (diversified.size() >= targetSize) {
|
||||
break;
|
||||
}
|
||||
if (!diversified.contains(candidate)) {
|
||||
diversified.add(candidate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return diversified;
|
||||
}
|
||||
|
||||
/**
|
||||
* 推荐候选结果内部类
|
||||
*/
|
||||
private static class RecommendationCandidate {
|
||||
private final String url;
|
||||
private final String style;
|
||||
private final Float similarityScore; // L2距离,越小越相似
|
||||
|
||||
public RecommendationCandidate(String url, String style, Float similarityScore) {
|
||||
this.url = url;
|
||||
this.style = style;
|
||||
this.similarityScore = similarityScore;
|
||||
}
|
||||
|
||||
public String getUrl() {
|
||||
return url;
|
||||
}
|
||||
|
||||
public String getStyle() {
|
||||
return style;
|
||||
}
|
||||
|
||||
public Float getSimilarityScore() {
|
||||
return similarityScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
RecommendationCandidate that = (RecommendationCandidate) o;
|
||||
return Objects.equals(url, that.url);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(url);
|
||||
}
|
||||
}
|
||||
|
||||
private String resolveCategory(SysFile sysFile) {
|
||||
String level3 = sysFile.getLevel3Type();
|
||||
String level2 = sysFile.getLevel2Type();
|
||||
if (StringUtils.isNotBlank(level3) && StringUtils.isNotBlank(level2)) {
|
||||
return (level3 + "_" + level2).toLowerCase();
|
||||
}
|
||||
String url = sysFile.getUrl();
|
||||
if (StringUtils.isNotBlank(url)) {
|
||||
String[] parts = url.split("/");
|
||||
if (parts.length >= 4) {
|
||||
String sex = parts[parts.length - 3];
|
||||
String category = parts[parts.length - 2];
|
||||
return (sex + "_" + category).toLowerCase();
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
package com.ai.da.service.impl;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.transform.Normalize;
|
||||
import ai.djl.modality.cv.transform.Resize;
|
||||
import ai.djl.modality.cv.transform.ToTensor;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.Pipeline;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* ResNet 特征提取 Translator
|
||||
* 移除最后的全连接层,只保留特征提取部分(2048维)
|
||||
*/
|
||||
@Slf4j
|
||||
public class ResNetFeatureTranslator implements Translator<Image, float[]> {
|
||||
|
||||
private static final int IMAGE_SIZE = 224;
|
||||
private static final Pipeline PIPELINE = new Pipeline()
|
||||
.add(new Resize(IMAGE_SIZE, IMAGE_SIZE))
|
||||
.add(new ToTensor())
|
||||
.add(new Normalize(
|
||||
new float[]{0.485f, 0.456f, 0.406f},
|
||||
new float[]{0.229f, 0.224f, 0.225f}
|
||||
));
|
||||
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
NDArray array = input.toNDArray(manager);
|
||||
NDList ndList = new NDList(array);
|
||||
return PIPELINE.transform(ndList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] processOutput(TranslatorContext ctx, NDList list) {
|
||||
NDArray output = list.singletonOrThrow();
|
||||
|
||||
// ResNet50 特征层输出形状应该是:
|
||||
// - [1, 2048, 1, 1] - 全局平均池化后的特征(需要 flatten)
|
||||
// - [1, 2048] - 已 flatten 的特征层输出
|
||||
|
||||
long[] shape = output.getShape().getShape();
|
||||
|
||||
// 如果是 4 维,说明是 [1, 2048, 1, 1],需要 flatten
|
||||
if (shape.length == 4) {
|
||||
output = output.flatten();
|
||||
}
|
||||
// 如果是 2 维,检查维度是否正确
|
||||
else if (shape.length == 2) {
|
||||
if (shape[1] == 1000) {
|
||||
// 如果仍然是分类层输出(1000维),说明模型加载有问题
|
||||
log.error("模型输出仍然是分类层(1000维),模型加载可能失败");
|
||||
throw new IllegalStateException("模型输出维度错误:期望 2048 维,实际 1000 维。请检查模型是否正确移除了全连接层。");
|
||||
}
|
||||
// 如果已经是 [1, 2048],直接使用
|
||||
}
|
||||
|
||||
// 转换为 float 数组
|
||||
float[] result = output.toFloatArray();
|
||||
|
||||
// 验证维度必须是 2048
|
||||
if (result.length != 2048) {
|
||||
log.error("特征向量维度错误: 期望 2048, 实际 {}", result.length);
|
||||
throw new IllegalArgumentException(
|
||||
String.format("特征向量维度错误: 期望 2048, 实际 %d", result.length)
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return Batchifier.STACK;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user