diff --git a/pom.xml b/pom.xml
index 5b4d83c3..32887616 100644
--- a/pom.xml
+++ b/pom.xml
@@ -422,6 +422,49 @@
bcpkix-jdk18on
1.78.1
+
+
+
+ io.milvus
+ milvus-sdk-java
+ 2.3.4
+
+
+
+
+ ai.djl
+ api
+ 0.24.0
+
+
+ ai.djl.pytorch
+ pytorch-engine
+ 0.24.0
+
+
+ ai.djl.pytorch
+ pytorch-native-cpu
+ 2.0.1
+ linux-x86_64
+
+
+ ai.djl.pytorch
+ pytorch-native-cpu
+ 2.0.1
+ win-x86_64
+
+
+ ai.djl.pytorch
+ pytorch-native-cpu
+ 2.0.1
+ osx-x86_64
+
+
+ ai.djl.pytorch
+ pytorch-native-cpu
+ 2.0.1
+ osx-aarch64
+
diff --git a/src/main/java/com/ai/da/common/config/MilvusConfig.java b/src/main/java/com/ai/da/common/config/MilvusConfig.java
new file mode 100644
index 00000000..415c0bae
--- /dev/null
+++ b/src/main/java/com/ai/da/common/config/MilvusConfig.java
@@ -0,0 +1,45 @@
+package com.ai.da.common.config;
+
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.param.ConnectParam;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+/**
+ * Milvus 向量数据库配置
+ */
+@Slf4j
+@Configuration
+public class MilvusConfig {
+
+ @Value("${milvus.host:localhost}")
+ private String host;
+
+ @Value("${milvus.port:19530}")
+ private Integer port;
+
+ @Value("${milvus.username:}")
+ private String username;
+
+ @Value("${milvus.password:}")
+ private String password;
+
+ @Bean
+ public MilvusServiceClient milvusClient() {
+ ConnectParam.Builder builder = ConnectParam.newBuilder()
+ .withHost(host)
+ .withPort(port);
+
+ if (StringUtils.isNotBlank(username) && StringUtils.isNotBlank(password)) {
+ builder.withAuthorization(username, password);
+ }
+
+ MilvusServiceClient client = new MilvusServiceClient(builder.build());
+ log.info("Milvus client initialized: {}:{}", host, port);
+ return client;
+ }
+}
+
diff --git a/src/main/java/com/ai/da/mapper/primary/entity/SysFile.java b/src/main/java/com/ai/da/mapper/primary/entity/SysFile.java
index a8c36e4e..0880689b 100644
--- a/src/main/java/com/ai/da/mapper/primary/entity/SysFile.java
+++ b/src/main/java/com/ai/da/mapper/primary/entity/SysFile.java
@@ -72,6 +72,12 @@ public class SysFile implements Serializable {
*/
private String style;
+ /**
+ * 是否废弃 0-否 1-是
+ */
+ private Integer deprecated;
+
+
public SysFile() {
}
diff --git a/src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java b/src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java
new file mode 100644
index 00000000..ce95b2c7
--- /dev/null
+++ b/src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java
@@ -0,0 +1,46 @@
+package com.ai.da.model.dto;
+
+import lombok.Data;
+
+/**
+ * 推荐请求 DTO
+ */
+@Data
+public class RecommendRequestDTO {
+ /**
+ * 用户ID
+ */
+ private Long userId;
+
+ /**
+ * 类别(如 female_skirt, male_tops)- 可选,用于辅助筛选
+ */
+ private String category;
+
+ /**
+ * 风格样式(可选)- 用于辅助筛选
+ */
+ private String style;
+
+ /**
+ * 返回数量
+ */
+ private Integer topK = 10;
+
+ /**
+ * 是否只返回未淘汰的图片
+ */
+ private Boolean onlyActive = true;
+
+ /**
+ * 向量搜索的候选数量(第一阶段)
+ * 建议值:50-100,用于后续筛选和多样性优化
+ */
+ private Integer candidateSize = 50;
+
+ /**
+ * 是否启用多样性优化(避免推荐过于相似的结果)
+ */
+ private Boolean enableDiversity = true;
+}
+
diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java
index a97875f5..55811b3c 100644
--- a/src/main/java/com/ai/da/python/PythonService.java
+++ b/src/main/java/com/ai/da/python/PythonService.java
@@ -20,6 +20,7 @@ import com.ai.da.model.vo.*;
import com.ai.da.python.vo.*;
import com.ai.da.service.DesignHistoryService;
import com.ai.da.service.PythonTAllInfoService;
+import com.ai.da.service.RecommendationService;
import com.ai.da.service.SysFileService;
import com.alibaba.fastjson.*;
import com.alibaba.fastjson.serializer.SerializerFeature;
@@ -722,6 +723,9 @@ public class PythonService {
@Resource
private AttributeRetrievalMapper attributeRetrievalMapper;
+ @Resource
+ private RecommendationService recommendationService;
+
private List getSystemSketchPool(JSONObject attributeRecognition, String styleCategory, String modelSex, int poolNum, String style) {
/**
* female trousers->female_pants
@@ -3955,6 +3959,27 @@ public class PythonService {
}
public List getSystemSketchByCategory(String category, Long brandId, Double brandScale,String style) {
+ AuthPrincipalVo userHolder = UserContext.getUserHolder();
+
+ // 优先使用基于 Milvus 的推荐系统
+ try {
+ com.ai.da.model.dto.RecommendRequestDTO request = new com.ai.da.model.dto.RecommendRequestDTO();
+ request.setUserId(userHolder.getId());
+ request.setCategory(category);
+ request.setStyle(style);
+ request.setTopK(1);
+ request.setOnlyActive(true);
+
+ List recommendedUrls = recommendationService.recommend(request);
+ if (!CollectionUtils.isEmpty(recommendedUrls)) {
+ log.info("使用 Milvus 推荐系统返回 {} 个结果", recommendedUrls.size());
+ return recommendedUrls;
+ }
+ } catch (Exception e) {
+ log.warn("Milvus 推荐失败,降级到备用方案: {}", e.getMessage());
+ }
+
+ // 降级方案1: 使用 attribute_retrieval_style 表
//******3.1.2版本临时使用java推荐方案去解决style未使用的问题**********
try {
//使用新库attribute_retrieval_style,表命名修改为elementVO.getModelSex().toLowerCase() + "_" + styleCategory.toLowerCase()比如female_skirt,与传入的category保持一致
@@ -3974,11 +3999,11 @@ public class PythonService {
return Arrays.asList(path);
}
} catch (Exception e) {
- log.info("推荐失败:{}",e.getMessage());
- throw new BusinessException("system.error");
+ log.info("attribute_retrieval_style 推荐失败:{}",e.getMessage());
}
//**********************end***********************************
- AuthPrincipalVo userHolder = UserContext.getUserHolder();
+
+ // 降级方案2: 调用 Python 服务(原有方案)
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
diff --git a/src/main/java/com/ai/da/service/FeatureExtractionService.java b/src/main/java/com/ai/da/service/FeatureExtractionService.java
new file mode 100644
index 00000000..19ba5b35
--- /dev/null
+++ b/src/main/java/com/ai/da/service/FeatureExtractionService.java
@@ -0,0 +1,36 @@
+package com.ai.da.service;
+
+import java.util.List;
+
+/**
+ * 特征提取服务接口
+ * 用于从图片 URL 提取 2048 维特征向量
+ */
+public interface FeatureExtractionService {
+
+ /**
+ * 从图片 URL 提取特征向量
+ *
+ * @param imageUrl 图片 URL(MinIO 路径)
+ * @return 2048 维特征向量,如果提取失败返回 null
+ */
+ List extractFeature(String imageUrl);
+
+ /**
+ * 批量提取特征向量
+ *
+ * @param imageUrls 图片 URL 列表
+ * @return 特征向量列表,与输入顺序对应
+ */
+ List> extractFeatures(List imageUrls);
+
+ /**
+ * 计算用户偏好向量(基于用户点赞的图片向量)
+ * 使用平均向量作为用户偏好
+ *
+ * @param userLikedVectors 用户点赞的图片向量列表
+ * @return 用户偏好向量(2048维)
+ */
+ List calculateUserPreferenceVector(List> userLikedVectors);
+}
+
diff --git a/src/main/java/com/ai/da/service/RecommendationService.java b/src/main/java/com/ai/da/service/RecommendationService.java
new file mode 100644
index 00000000..bea65a40
--- /dev/null
+++ b/src/main/java/com/ai/da/service/RecommendationService.java
@@ -0,0 +1,34 @@
+package com.ai.da.service;
+
+import com.ai.da.model.dto.RecommendRequestDTO;
+
+import java.util.List;
+
+/**
+ * 推荐服务接口
+ */
+public interface RecommendationService {
+
+ /**
+ * 根据用户偏好推荐系统 sketch
+ *
+ * @param request 推荐请求
+ * @return 推荐的 URL 列表
+ */
+ List recommend(RecommendRequestDTO request);
+
+ /**
+ * 同步 t_sys_file 数据到 Milvus
+ * 从 t_sys_file 表读取所有系统 sketch,提取特征向量并存储到 Milvus
+ */
+ void syncSysFileToMilvus();
+
+ /**
+ * 更新单个文件的向量(当文件更新时调用)
+ *
+ * @param sysFileId 系统文件ID
+ * @param url 文件URL
+ */
+ void updateVector(Long sysFileId, String url);
+}
+
diff --git a/src/main/java/com/ai/da/service/impl/FeatureExtractionServiceImpl.java b/src/main/java/com/ai/da/service/impl/FeatureExtractionServiceImpl.java
new file mode 100644
index 00000000..565f3cc2
--- /dev/null
+++ b/src/main/java/com/ai/da/service/impl/FeatureExtractionServiceImpl.java
@@ -0,0 +1,220 @@
+package com.ai.da.service.impl;
+
+import ai.djl.Application;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ModelZoo;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.training.util.ProgressBar;
+import ai.djl.translate.Translator;
+import com.ai.da.common.utils.MinioUtil;
+import com.ai.da.service.FeatureExtractionService;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Service;
+import org.springframework.util.CollectionUtils;
+
+import jakarta.annotation.PostConstruct;
+import jakarta.annotation.PreDestroy;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * 特征提取服务实现类
+ * 使用 DJL (Deep Java Library) 框架加载 ResNet50 模型提取特征向量
+ */
+@Slf4j
+@Service
+public class FeatureExtractionServiceImpl implements FeatureExtractionService {
+
+ private static final int DIMENSION = 2048; // ResNet50 特征向量维度
+
+ @Autowired
+ private MinioUtil minioUtil;
+
+ private ZooModel model;
+ private Predictor predictor;
+
+ /**
+ * 初始化 ResNet50 模型
+ * 加载预训练模型并移除最后的全连接层,以获取2048维特征向量
+ */
+ @PostConstruct
+ public void initModel() {
+ try {
+ log.info("开始加载 ResNet50 模型(特征提取模式,2048维)...");
+
+ // 1. 加载完整的预训练 ResNet50 模型
+ Criteria criteria = Criteria.builder()
+ .optApplication(Application.CV.IMAGE_CLASSIFICATION)
+ .setTypes(Image.class, float[].class)
+ .optModelName("resnet50")
+ .optEngine("PyTorch")
+ .optProgress(new ProgressBar())
+ .build();
+
+ // 2. 创建特征提取 Translator
+ Translator translator = new ResNetFeatureTranslator();
+
+ // 3. 加载模型
+ model = ModelZoo.loadModel(criteria);
+
+ // 4. 创建 Predictor
+ predictor = model.newPredictor(translator);
+
+ log.info("ResNet50 模型加载完成(特征提取模式,2048维)");
+ } catch (Exception e) {
+ log.error("加载 ResNet50 模型失败", e);
+ throw new RuntimeException("模型加载失败", e);
+ }
+ }
+
+ /**
+ * 清理资源
+ */
+ @PreDestroy
+ public void cleanup() {
+ if (predictor != null) {
+ predictor.close();
+ }
+ if (model != null) {
+ model.close();
+ }
+ }
+
+ /**
+ * 从图片 URL 提取特征向量
+ *
+ * @param imageUrl MinIO 路径,格式为 "bucket_name/path/to/image.jpg"
+ * 例如: "aida-sys-image/images/female/blouse/0825001474.jpg"
+ * @return 2048 维特征向量,如果提取失败返回 null
+ */
+ @Override
+ public List extractFeature(String imageUrl) {
+ if (predictor == null) {
+ log.error("模型未初始化,无法提取特征");
+ return null;
+ }
+
+ try {
+ // 1. 从 MinIO 获取图片
+ Image image = getImageFromMinio(imageUrl);
+ if (image == null) {
+ log.warn("无法从 MinIO 获取图片: {}", imageUrl);
+ return null;
+ }
+
+ // 2. 使用模型提取特征
+ float[] featureArray = predictor.predict(image);
+
+ // 3. 转换为 List
+ List featureVector = new ArrayList<>(DIMENSION);
+ for (float value : featureArray) {
+ featureVector.add(value);
+ }
+
+ log.debug("成功提取特征向量: {}, 维度: {}", imageUrl, featureVector.size());
+ return featureVector;
+
+ } catch (Exception e) {
+ log.error("提取特征失败: {}", imageUrl, e);
+ return null;
+ }
+ }
+
+ /**
+ * 从 MinIO 获取图片并转换为 DJL Image 对象
+ */
+ private Image getImageFromMinio(String imageUrl) {
+ InputStream imageStream = null;
+ try {
+ // 解析路径:bucket_name/path/to/image.jpg
+ int index = imageUrl.indexOf("/");
+ if (index == -1 || index == imageUrl.length() - 1) {
+ log.warn("无效的图片路径格式: {}", imageUrl);
+ return null;
+ }
+
+ String bucketName = imageUrl.substring(0, index);
+ String objectName = imageUrl.substring(index + 1);
+
+ // 从 MinIO 获取图片流
+ // 使用 MinioUtil 的 download 方法获取 InputStream
+ // download 方法可能抛出 MinioException 和 IOException
+ imageStream = minioUtil.download(bucketName, objectName);
+ if (imageStream == null) {
+ log.warn("无法从 MinIO 获取图片流: bucket={}, object={}", bucketName, objectName);
+ return null;
+ }
+
+ // 转换为 DJL Image 对象
+ Image image = ImageFactory.getInstance().fromInputStream(imageStream);
+
+ return image;
+
+ } catch (Exception e) {
+ log.error("从 MinIO 获取图片失败: {}", imageUrl, e);
+ return null;
+ } finally {
+ if (imageStream != null) {
+ try {
+ imageStream.close();
+ } catch (Exception e) {
+ log.warn("关闭图片流失败", e);
+ }
+ }
+ }
+ }
+
+ @Override
+ public List> extractFeatures(List imageUrls) {
+ if (CollectionUtils.isEmpty(imageUrls)) {
+ return new ArrayList<>();
+ }
+
+ return imageUrls.stream()
+ .map(this::extractFeature)
+ .filter(java.util.Objects::nonNull)
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public List calculateUserPreferenceVector(List> userLikedVectors) {
+ if (CollectionUtils.isEmpty(userLikedVectors)) {
+ return new ArrayList<>();
+ }
+
+ int dimension = userLikedVectors.get(0).size();
+ List preferenceVector = new ArrayList<>(dimension);
+
+ // 初始化为零向量
+ for (int i = 0; i < dimension; i++) {
+ preferenceVector.add(0.0f);
+ }
+
+ // 计算平均向量
+ for (List vector : userLikedVectors) {
+ if (vector.size() != dimension) {
+ log.warn("向量维度不匹配: 期望 {}, 实际 {}", dimension, vector.size());
+ continue;
+ }
+
+ for (int i = 0; i < dimension; i++) {
+ preferenceVector.set(i, preferenceVector.get(i) + vector.get(i));
+ }
+ }
+
+ // 求平均值
+ int count = userLikedVectors.size();
+ for (int i = 0; i < dimension; i++) {
+ preferenceVector.set(i, preferenceVector.get(i) / count);
+ }
+
+ return preferenceVector;
+ }
+}
+
diff --git a/src/main/java/com/ai/da/service/impl/RecommendationServiceImpl.java b/src/main/java/com/ai/da/service/impl/RecommendationServiceImpl.java
new file mode 100644
index 00000000..2f0db8e3
--- /dev/null
+++ b/src/main/java/com/ai/da/service/impl/RecommendationServiceImpl.java
@@ -0,0 +1,679 @@
+package com.ai.da.service.impl;
+
+import com.ai.da.common.config.MilvusConfig;
+import com.ai.da.mapper.primary.SysFileMapper;
+import com.ai.da.mapper.primary.UserPreferenceLogMapper;
+import com.ai.da.mapper.primary.entity.SysFile;
+import com.ai.da.mapper.primary.entity.UserPreferenceLogTest;
+import com.ai.da.model.dto.RecommendRequestDTO;
+import com.ai.da.service.FeatureExtractionService;
+import com.ai.da.service.RecommendationService;
+import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.grpc.DataType;
+import io.milvus.grpc.SearchResults;
+import io.milvus.param.IndexType;
+import io.milvus.param.MetricType;
+import io.milvus.param.R;
+import io.milvus.param.RpcStatus;
+import io.milvus.param.collection.*;
+import io.milvus.param.dml.InsertParam;
+import io.milvus.param.dml.QueryParam;
+import io.milvus.param.dml.SearchParam;
+import io.milvus.param.index.CreateIndexParam;
+import io.milvus.grpc.MutationResult;
+import io.milvus.grpc.QueryResults;
+import io.milvus.response.FieldDataWrapper;
+import io.milvus.response.QueryResultsWrapper;
+import io.milvus.response.SearchResultsWrapper;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.stereotype.Service;
+import org.springframework.util.CollectionUtils;
+
+import jakarta.annotation.PostConstruct;
+import jakarta.annotation.Resource;
+import java.util.*;
+import java.util.stream.Collectors;
+
+/**
+ * 推荐服务实现类
+ * 基于 Milvus 向量数据库的推荐系统
+ */
+@Slf4j
+@Service
+public class RecommendationServiceImpl implements RecommendationService {
+
+ private static final String COLLECTION_NAME = "sketch_recommendation";
+ private static final int DIMENSION = 2048; // ResNet50 特征向量维度
+ private static final String VECTOR_FIELD = "feature_vector";
+ private static final String URL_FIELD = "url";
+ private static final String STYLE_FIELD = "style";
+ private static final String CATEGORY_FIELD = "category";
+ private static final String DEPRECATED_FIELD = "deprecated";
+ private static final String SYS_FILE_ID_FIELD = "sys_file_id";
+
+ @Resource
+ private MilvusServiceClient milvusClient;
+
+ @Resource
+ private SysFileMapper sysFileMapper;
+
+ @Resource
+ private UserPreferenceLogMapper userPreferenceLogMapper;
+
+ @Resource
+ private FeatureExtractionService featureExtractionService;
+
+ /**
+ * 初始化 Collection(如果不存在则创建)
+ */
+ @PostConstruct
+ public void initCollection() {
+ try {
+ // 检查 Collection 是否存在
+ R hasCollectionR = milvusClient.hasCollection(
+ HasCollectionParam.newBuilder()
+ .withCollectionName(COLLECTION_NAME)
+ .build()
+ );
+
+ if (!hasCollectionR.getData()) {
+ log.info("Collection {} 不存在,开始创建...", COLLECTION_NAME);
+ createCollection();
+ createIndex();
+ log.info("Collection {} 创建完成", COLLECTION_NAME);
+ } else {
+ log.info("Collection {} 已存在", COLLECTION_NAME);
+ }
+ } catch (Exception e) {
+ log.error("初始化 Collection 失败", e);
+ }
+ }
+
+ /**
+ * 创建 Collection
+ */
+ private void createCollection() {
+ // 定义字段
+ List fields = Arrays.asList(
+ FieldType.newBuilder()
+ .withName(SYS_FILE_ID_FIELD)
+ .withDataType(DataType.Int64)
+ .withPrimaryKey(true)
+ .withAutoID(false)
+ .build(),
+ FieldType.newBuilder()
+ .withName(URL_FIELD)
+ .withDataType(DataType.VarChar)
+ .withMaxLength(500)
+ .build(),
+ FieldType.newBuilder()
+ .withName(STYLE_FIELD)
+ .withDataType(DataType.VarChar)
+ .withMaxLength(100)
+ .build(),
+ FieldType.newBuilder()
+ .withName(CATEGORY_FIELD)
+ .withDataType(DataType.VarChar)
+ .withMaxLength(100)
+ .build(),
+ FieldType.newBuilder()
+ .withName(DEPRECATED_FIELD)
+ .withDataType(DataType.Int8)
+ .build(),
+ FieldType.newBuilder()
+ .withName(VECTOR_FIELD)
+ .withDataType(DataType.FloatVector)
+ .withDimension(DIMENSION)
+ .build()
+ );
+
+ // 创建 Collection
+ CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
+ .withCollectionName(COLLECTION_NAME)
+ .withDescription("系统 sketch 推荐向量库")
+ .withShardsNum(2)
+ .withFieldTypes(fields)
+ .build();
+
+ R createR = milvusClient.createCollection(createParam);
+ if (createR.getStatus() != R.Status.Success.getCode()) {
+ throw new RuntimeException("创建 Collection 失败: " + createR.getMessage());
+ }
+ }
+
+ /**
+ * 创建索引
+ */
+ private void createIndex() {
+ CreateIndexParam indexParam = CreateIndexParam.newBuilder()
+ .withCollectionName(COLLECTION_NAME)
+ .withFieldName(VECTOR_FIELD)
+ .withIndexType(IndexType.IVF_FLAT)
+ .withMetricType(MetricType.L2)
+ .withExtraParam("{\"nlist\":1024}")
+ .build();
+
+ R indexR = milvusClient.createIndex(indexParam);
+ if (indexR.getStatus() != R.Status.Success.getCode()) {
+ throw new RuntimeException("创建索引失败: " + indexR.getMessage());
+ }
+ }
+
+ @Override
+ public List recommend(RecommendRequestDTO request) {
+ try {
+ // 1. 从 user_preference_log_test 获取用户点赞的 path 列表
+ QueryWrapper queryWrapper = new QueryWrapper<>();
+ queryWrapper.lambda().eq(UserPreferenceLogTest::getAccountId, request.getUserId());
+ List userLikes = userPreferenceLogMapper.selectList(queryWrapper);
+
+ // 新用户处理:如果没有点赞记录,使用降级推荐策略
+ if (CollectionUtils.isEmpty(userLikes)) {
+ log.info("用户 {} 没有点赞记录,使用新用户推荐策略(基于 style 和 category)", request.getUserId());
+ return recommendForNewUser(request);
+ }
+
+ // 2. 提取用户点赞图片的特征向量
+ List likedPaths = userLikes.stream()
+ .map(UserPreferenceLogTest::getPath)
+ .filter(StringUtils::isNotBlank)
+ .collect(Collectors.toList());
+
+ List> likedVectors = featureExtractionService.extractFeatures(likedPaths);
+ likedVectors = likedVectors.stream()
+ .filter(Objects::nonNull)
+ .collect(Collectors.toList());
+
+ if (CollectionUtils.isEmpty(likedVectors)) {
+ log.warn("用户 {} 无法提取特征向量", request.getUserId());
+ return Collections.emptyList();
+ }
+
+ // 3. 计算用户偏好向量(平均向量)
+ List userPreferenceVector = featureExtractionService.calculateUserPreferenceVector(likedVectors);
+
+ // 4. 构建搜索参数(混合推荐策略)
+ // 策略:向量相似度搜索为主 + 分类筛选为辅
+
+ List searchOutputFields = Arrays.asList(URL_FIELD, STYLE_FIELD, CATEGORY_FIELD);
+ List> searchVectors = Collections.singletonList(userPreferenceVector);
+
+ // 构建过滤表达式(基础筛选)
+ StringBuilder exprBuilder = new StringBuilder();
+ List conditions = new ArrayList<>();
+
+ // 基础筛选:只返回未淘汰的图片
+ if (request.getOnlyActive() != null && request.getOnlyActive()) {
+ conditions.add(DEPRECATED_FIELD + " == 0");
+ }
+
+ // 可选筛选:风格(如果指定)
+ if (StringUtils.isNotBlank(request.getStyle())) {
+ conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\"");
+ }
+ if (StringUtils.isNotBlank(request.getCategory())) {
+ conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\"");
+ }
+
+ if (!conditions.isEmpty()) {
+ exprBuilder.append(String.join(" && ", conditions));
+ }
+
+ // 第一阶段:向量搜索获取大量候选(用于后续筛选和多样性优化)
+ int candidateSize = request.getCandidateSize() != null ?
+ Math.max(request.getCandidateSize(), request.getTopK() * 3) :
+ Math.max(50, request.getTopK() * 5);
+
+ SearchParam searchParam = SearchParam.newBuilder()
+ .withCollectionName(COLLECTION_NAME)
+ .withMetricType(MetricType.L2) // L2距离,越小越相似
+ .withOutFields(searchOutputFields)
+ .withTopK(candidateSize) // 获取更多候选用于后续处理
+ .withVectors(searchVectors)
+ .withVectorFieldName(VECTOR_FIELD)
+ .withExpr(exprBuilder.length() > 0 ? exprBuilder.toString() : null)
+ .withParams("{\"nprobe\":10}")
+ .build();
+
+ // 5. 执行搜索
+ R searchR = milvusClient.search(searchParam);
+ if (searchR.getStatus() != R.Status.Success.getCode()) {
+ log.error("搜索失败: {}", searchR.getMessage());
+ return Collections.emptyList();
+ }
+
+ // 6. 解析结果并进行后处理
+ SearchResultsWrapper wrapper = new SearchResultsWrapper(searchR.getData().getResults());
+
+ // 6.1 提取候选结果(包含相似度分数)
+ List candidates = new ArrayList<>();
+ FieldDataWrapper urlFieldWrapper = wrapper.getFieldWrapper(URL_FIELD);
+ FieldDataWrapper styleFieldWrapper = wrapper.getFieldWrapper(STYLE_FIELD);
+ List idScores = wrapper.getIDScore(0);
+
+ for (int i = 0; i < idScores.size(); i++) {
+ String url = (String) urlFieldWrapper.getFieldData().get(i);
+ String style = (String) styleFieldWrapper.getFieldData().get(i);
+ Float score = idScores.get(i).getScore(); // L2距离,越小越相似
+
+ if (StringUtils.isNotBlank(url)) {
+ candidates.add(new RecommendationCandidate(url, style, score));
+ }
+ }
+
+// // 6.2 分类筛选(如果指定了 category)
+// if (StringUtils.isNotBlank(request.getCategory())) {
+// candidates = filterByCategory(candidates, request.getCategory());
+// }
+
+ // 6.3 多样性优化(避免推荐过于相似的结果)
+ if (request.getEnableDiversity() != null && request.getEnableDiversity()) {
+ candidates = diversifyResults(candidates, request.getTopK() != null ? request.getTopK() : 10);
+ }
+
+ // 6.4 取 Top-K
+ int topK = request.getTopK() != null ? request.getTopK() : 1;
+ List recommendedUrls = candidates.stream()
+ .limit(topK)
+ .map(RecommendationCandidate::getUrl)
+ .collect(Collectors.toList());
+
+ log.info("为用户 {} 推荐了 {} 个结果(从 {} 个候选中筛选)",
+ request.getUserId(), recommendedUrls.size(), idScores.size());
+ return recommendedUrls;
+
+ } catch (Exception e) {
+ log.error("推荐失败", e);
+ return Collections.emptyList();
+ }
+ }
+
+ @Override
+ public void syncSysFileToMilvus() {
+ log.info("开始同步 t_sys_file 数据到 Milvus");
+
+ try {
+ // 1. 加载 Collection
+ LoadCollectionParam loadParam = LoadCollectionParam.newBuilder()
+ .withCollectionName(COLLECTION_NAME)
+ .build();
+ milvusClient.loadCollection(loadParam);
+
+ // 2. 查询所有系统 sketch
+ QueryWrapper queryWrapper = new QueryWrapper<>();
+ queryWrapper.lambda()
+ .eq(SysFile::getLevel1Type, "Images").eq(SysFile::getDeprecated, 0)
+ .isNotNull(SysFile::getUrl)
+ .isNotNull(SysFile::getStyle)
+ .ne(SysFile::getUrl, "");
+ List sysFiles = sysFileMapper.selectList(queryWrapper);
+
+ log.info("找到 {} 个系统文件需要同步", sysFiles.size());
+
+ int successCount = 0;
+ int failCount = 0;
+ int batchSize = 100;
+
+ // 3. 批量处理
+ for (int i = 0; i < sysFiles.size(); i += batchSize) {
+ int end = Math.min(i + batchSize, sysFiles.size());
+ List batch = sysFiles.subList(i, end);
+
+ List ids = new ArrayList<>();
+ List urls = new ArrayList<>();
+ List styles = new ArrayList<>();
+ List categories = new ArrayList<>();
+ List deprecatedList = new ArrayList<>();
+ List> vectors = new ArrayList<>();
+
+ for (SysFile sysFile : batch) {
+ try {
+ // 提取特征向量
+ List vector = featureExtractionService.extractFeature(sysFile.getUrl());
+ if (vector == null || vector.size() != DIMENSION) {
+ log.warn("文件 {} 特征提取失败,跳过", sysFile.getUrl());
+ failCount++;
+ continue;
+ }
+
+ ids.add(sysFile.getId());
+ urls.add(sysFile.getUrl());
+ styles.add(sysFile.getStyle() != null ? sysFile.getStyle() : "");
+ categories.add(resolveCategory(sysFile));
+ deprecatedList.add((byte) (sysFile.getDeprecated() != null ? sysFile.getDeprecated() : 0));
+ vectors.add(vector);
+ successCount++;
+
+ } catch (Exception e) {
+ log.error("处理文件 {} 失败", sysFile.getUrl(), e);
+ failCount++;
+ }
+ }
+
+ // 批量插入
+ if (!ids.isEmpty()) {
+ List fields = Arrays.asList(
+ new InsertParam.Field(SYS_FILE_ID_FIELD, ids),
+ new InsertParam.Field(URL_FIELD, urls),
+ new InsertParam.Field(STYLE_FIELD, styles),
+ new InsertParam.Field(DEPRECATED_FIELD, deprecatedList),
+ new InsertParam.Field(CATEGORY_FIELD, categories),
+ new InsertParam.Field(VECTOR_FIELD, vectors)
+ );
+
+ InsertParam insertParam = InsertParam.newBuilder()
+ .withCollectionName(COLLECTION_NAME)
+ .withFields(fields)
+ .build();
+
+ R insertR = milvusClient.insert(insertParam);
+ if (insertR.getStatus() == R.Status.Success.getCode()) {
+ log.info("批量插入成功: {}/{}", end, sysFiles.size());
+ } else {
+ log.error("批量插入失败: {}", insertR.getMessage());
+ }
+ }
+ }
+
+ // 4. 刷新数据
+ FlushParam flushParam = FlushParam.newBuilder()
+ .withCollectionNames(Collections.singletonList(COLLECTION_NAME))
+ .build();
+ milvusClient.flush(flushParam);
+
+ log.info("同步完成: 成功 {}, 失败 {}", successCount, failCount);
+
+ } catch (Exception e) {
+ log.error("同步数据到 Milvus 失败", e);
+ throw new RuntimeException("同步失败", e);
+ }
+ }
+
+ @Override
+ public void updateVector(Long sysFileId, String url) {
+ // TODO: 实现单个向量更新逻辑
+ // 需要先删除旧向量,再插入新向量
+ log.info("更新向量: sysFileId={}, url={}", sysFileId, url);
+ }
+
+ /**
+ * 新用户推荐策略:当用户没有点赞记录时,根据 style 和 category 随机推荐
+ *
+ * @param request 推荐请求
+ * @return 推荐的 URL 列表
+ */
+ private List recommendForNewUser(RecommendRequestDTO request) {
+ try {
+ int topK = request.getTopK() != null ? request.getTopK() : 10;
+
+ List milvusResults = recommendFromMilvusForNewUser(request, topK);
+ if (CollectionUtils.isEmpty(milvusResults)) {
+ return Collections.emptyList();
+ }
+ Collections.shuffle(milvusResults);
+ return milvusResults.stream().limit(topK).collect(Collectors.toList());
+
+ } catch (Exception e) {
+ log.error("新用户推荐失败", e);
+ return Collections.emptyList();
+ }
+ }
+
+ /**
+ * 从 Milvus 中为新用户推荐(根据 style 和 category 查询)
+ */
+ private List recommendFromMilvusForNewUser(RecommendRequestDTO request, int limit) {
+ try {
+ // 构建过滤表达式
+ StringBuilder exprBuilder = new StringBuilder();
+ List conditions = new ArrayList<>();
+
+ // 基础筛选:只返回未淘汰的图片
+ if (request.getOnlyActive() != null && request.getOnlyActive()) {
+ conditions.add(DEPRECATED_FIELD + " == 0");
+ }
+
+ // 风格筛选(如果指定)
+ if (StringUtils.isNotBlank(request.getStyle())) {
+ conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\"");
+ }
+ if (StringUtils.isNotBlank(request.getCategory())) {
+ conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\"");
+ }
+
+ if (conditions.isEmpty()) {
+ // 如果没有筛选条件,查询所有未淘汰的
+ conditions.add(DEPRECATED_FIELD + " == 0");
+ }
+
+ exprBuilder.append(String.join(" && ", conditions));
+
+ // 从 Milvus 查询符合条件的记录
+ QueryParam queryParam = QueryParam.newBuilder()
+ .withCollectionName(COLLECTION_NAME)
+ .withExpr(exprBuilder.toString())
+ .withOutFields(Arrays.asList(URL_FIELD))
+ .withLimit((long) limit)
+ .build();
+
+ R queryR = milvusClient.query(queryParam);
+ if (queryR.getStatus() != R.Status.Success.getCode()) {
+ log.warn("从 Milvus 查询失败: {}", queryR.getMessage());
+ return Collections.emptyList();
+ }
+
+ // 解析结果
+ List urls = new ArrayList<>();
+ QueryResultsWrapper wrapper = new QueryResultsWrapper(queryR.getData());
+ for (int i = 0; i < wrapper.getRowRecords().size(); i++) {
+ String url = (String) wrapper.getFieldWrapper(URL_FIELD).getFieldData().get(i);
+ if (StringUtils.isNotBlank(url)) {
+ urls.add(url);
+ }
+ }
+
+// // 如果指定了 category,进行筛选
+// if (StringUtils.isNotBlank(request.getCategory()) && !urls.isEmpty()) {
+// urls = urls.stream()
+// .filter(url -> {
+// // 从 URL 中提取类别信息
+// // URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
+// String[] parts = url.split("/");
+// if (parts.length >= 4) {
+// String sex = parts[parts.length - 3];
+// String cat = parts[parts.length - 2];
+// String expectedCategory = sex + "_" + cat;
+// return expectedCategory.equals(request.getCategory());
+// }
+// return false;
+// })
+// .collect(Collectors.toList());
+// }
+
+ log.info("从 Milvus 为新用户查询到 {} 条记录", urls.size());
+ return urls;
+
+ } catch (Exception e) {
+ log.error("从 Milvus 查询新用户推荐失败", e);
+ return Collections.emptyList();
+ }
+ }
+
+ /**
+ * 从 MySQL 中为新用户推荐(根据 style 和 category 查询)
+ */
+ private List recommendFromMySQLForNewUser(RecommendRequestDTO request, int topK) {
+ try {
+ QueryWrapper queryWrapper = new QueryWrapper<>();
+ queryWrapper.lambda()
+ .eq(SysFile::getLevel1Type, "Images")
+ .isNotNull(SysFile::getUrl)
+ .ne(SysFile::getUrl, "");
+
+ // 风格筛选(如果指定)
+ if (StringUtils.isNotBlank(request.getStyle())) {
+ queryWrapper.lambda().eq(SysFile::getStyle, request.getStyle());
+ }
+
+ // 类别筛选(如果指定)
+ if (StringUtils.isNotBlank(request.getCategory())) {
+ // category 格式: female_skirt, male_tops 等
+ String[] parts = request.getCategory().split("_");
+ if (parts.length == 2) {
+ String sex = parts[0]; // female 或 male
+ String category = parts[1]; // skirt, tops 等
+
+ // 从 URL 中筛选类别
+ // URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
+ queryWrapper.lambda().like(SysFile::getUrl, "/" + sex + "/" + category + "/");
+ }
+ }
+
+ // 查询所有符合条件的记录
+ List sysFiles = sysFileMapper.selectList(queryWrapper);
+
+ if (CollectionUtils.isEmpty(sysFiles)) {
+ log.warn("MySQL 中未找到符合条件的记录");
+ return Collections.emptyList();
+ }
+
+ // 随机选择
+ Collections.shuffle(sysFiles);
+
+ List urls = sysFiles.stream()
+ .map(SysFile::getUrl)
+ .filter(StringUtils::isNotBlank)
+ .limit(topK)
+ .collect(Collectors.toList());
+
+ log.info("从 MySQL 为新用户查询到 {} 条记录", urls.size());
+ return urls;
+
+ } catch (Exception e) {
+ log.error("从 MySQL 查询新用户推荐失败", e);
+ return Collections.emptyList();
+ }
+ }
+
+ /**
+ * 根据类别筛选候选结果
+ */
+ private List filterByCategory(
+ List candidates, String category) {
+ // 从 URL 中提取类别信息
+ // URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
+ return candidates.stream()
+ .filter(candidate -> {
+ String url = candidate.getUrl();
+ if (StringUtils.isBlank(url) || !url.contains("/")) {
+ return false;
+ }
+ // 提取类别部分
+ String[] parts = url.split("/");
+ if (parts.length < 4) {
+ return false;
+ }
+ String sex = parts[parts.length - 3]; // female 或 male
+ String cat = parts[parts.length - 2]; // category
+ String expectedCategory = sex + "_" + cat;
+ return expectedCategory.equals(category);
+ })
+ .collect(Collectors.toList());
+ }
+
+ /**
+ * 多样性优化:避免推荐过于相似的结果
+ * 使用聚类或相似度阈值来增加多样性
+ */
+ private List diversifyResults(
+ List candidates, int targetSize) {
+ if (candidates.size() <= targetSize) {
+ return candidates;
+ }
+
+ // 简单策略:按相似度分数排序,然后每隔几个取一个,增加多样性
+ // 更复杂的策略可以使用聚类算法
+
+ List diversified = new ArrayList<>();
+ int step = Math.max(1, candidates.size() / targetSize);
+
+ for (int i = 0; i < candidates.size() && diversified.size() < targetSize; i += step) {
+ diversified.add(candidates.get(i));
+ }
+
+ // 如果还不够,补充剩余的结果
+ if (diversified.size() < targetSize) {
+ for (RecommendationCandidate candidate : candidates) {
+ if (diversified.size() >= targetSize) {
+ break;
+ }
+ if (!diversified.contains(candidate)) {
+ diversified.add(candidate);
+ }
+ }
+ }
+
+ return diversified;
+ }
+
+ /**
+ * 推荐候选结果内部类
+ */
+ private static class RecommendationCandidate {
+ private final String url;
+ private final String style;
+ private final Float similarityScore; // L2距离,越小越相似
+
+ public RecommendationCandidate(String url, String style, Float similarityScore) {
+ this.url = url;
+ this.style = style;
+ this.similarityScore = similarityScore;
+ }
+
+ public String getUrl() {
+ return url;
+ }
+
+ public String getStyle() {
+ return style;
+ }
+
+ public Float getSimilarityScore() {
+ return similarityScore;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ RecommendationCandidate that = (RecommendationCandidate) o;
+ return Objects.equals(url, that.url);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(url);
+ }
+ }
+
+ private String resolveCategory(SysFile sysFile) {
+ String level3 = sysFile.getLevel3Type();
+ String level2 = sysFile.getLevel2Type();
+ if (StringUtils.isNotBlank(level3) && StringUtils.isNotBlank(level2)) {
+ return (level3 + "_" + level2).toLowerCase();
+ }
+ String url = sysFile.getUrl();
+ if (StringUtils.isNotBlank(url)) {
+ String[] parts = url.split("/");
+ if (parts.length >= 4) {
+ String sex = parts[parts.length - 3];
+ String category = parts[parts.length - 2];
+ return (sex + "_" + category).toLowerCase();
+ }
+ }
+ return "";
+ }
+}
+
diff --git a/src/main/java/com/ai/da/service/impl/ResNetFeatureTranslator.java b/src/main/java/com/ai/da/service/impl/ResNetFeatureTranslator.java
new file mode 100644
index 00000000..57fd4897
--- /dev/null
+++ b/src/main/java/com/ai/da/service/impl/ResNetFeatureTranslator.java
@@ -0,0 +1,83 @@
+package com.ai.da.service.impl;
+
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.transform.Normalize;
+import ai.djl.modality.cv.transform.Resize;
+import ai.djl.modality.cv.transform.ToTensor;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.translate.Batchifier;
+import ai.djl.translate.Pipeline;
+import ai.djl.translate.Translator;
+import ai.djl.translate.TranslatorContext;
+import lombok.extern.slf4j.Slf4j;
+
+/**
+ * ResNet 特征提取 Translator
+ * 移除最后的全连接层,只保留特征提取部分(2048维)
+ */
+@Slf4j
+public class ResNetFeatureTranslator implements Translator {
+
+ private static final int IMAGE_SIZE = 224;
+ private static final Pipeline PIPELINE = new Pipeline()
+ .add(new Resize(IMAGE_SIZE, IMAGE_SIZE))
+ .add(new ToTensor())
+ .add(new Normalize(
+ new float[]{0.485f, 0.456f, 0.406f},
+ new float[]{0.229f, 0.224f, 0.225f}
+ ));
+
+ @Override
+ public NDList processInput(TranslatorContext ctx, Image input) {
+ NDManager manager = ctx.getNDManager();
+ NDArray array = input.toNDArray(manager);
+ NDList ndList = new NDList(array);
+ return PIPELINE.transform(ndList);
+ }
+
+ @Override
+ public float[] processOutput(TranslatorContext ctx, NDList list) {
+ NDArray output = list.singletonOrThrow();
+
+ // ResNet50 特征层输出形状应该是:
+ // - [1, 2048, 1, 1] - 全局平均池化后的特征(需要 flatten)
+ // - [1, 2048] - 已 flatten 的特征层输出
+
+ long[] shape = output.getShape().getShape();
+
+ // 如果是 4 维,说明是 [1, 2048, 1, 1],需要 flatten
+ if (shape.length == 4) {
+ output = output.flatten();
+ }
+ // 如果是 2 维,检查维度是否正确
+ else if (shape.length == 2) {
+ if (shape[1] == 1000) {
+ // 如果仍然是分类层输出(1000维),说明模型加载有问题
+ log.error("模型输出仍然是分类层(1000维),模型加载可能失败");
+ throw new IllegalStateException("模型输出维度错误:期望 2048 维,实际 1000 维。请检查模型是否正确移除了全连接层。");
+ }
+ // 如果已经是 [1, 2048],直接使用
+ }
+
+ // 转换为 float 数组
+ float[] result = output.toFloatArray();
+
+ // 验证维度必须是 2048
+ if (result.length != 2048) {
+ log.error("特征向量维度错误: 期望 2048, 实际 {}", result.length);
+ throw new IllegalArgumentException(
+ String.format("特征向量维度错误: 期望 2048, 实际 %d", result.length)
+ );
+ }
+
+ return result;
+ }
+
+ @Override
+ public Batchifier getBatchifier() {
+ return Batchifier.STACK;
+ }
+}
+