From ca416fed9db22afe30f4482322c174379d14d23c Mon Sep 17 00:00:00 2001 From: litianxiang Date: Fri, 28 Nov 2025 09:36:04 +0800 Subject: [PATCH] =?UTF-8?q?=E7=89=B9=E5=BE=81=E6=8E=A8=E8=8D=90=E5=88=9D?= =?UTF-8?q?=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 43 ++ .../com/ai/da/common/config/MilvusConfig.java | 45 ++ .../ai/da/mapper/primary/entity/SysFile.java | 6 + .../ai/da/model/dto/RecommendRequestDTO.java | 46 ++ .../java/com/ai/da/python/PythonService.java | 31 +- .../da/service/FeatureExtractionService.java | 36 + .../ai/da/service/RecommendationService.java | 34 + .../impl/FeatureExtractionServiceImpl.java | 220 ++++++ .../impl/RecommendationServiceImpl.java | 679 ++++++++++++++++++ .../service/impl/ResNetFeatureTranslator.java | 83 +++ 10 files changed, 1220 insertions(+), 3 deletions(-) create mode 100644 src/main/java/com/ai/da/common/config/MilvusConfig.java create mode 100644 src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java create mode 100644 src/main/java/com/ai/da/service/FeatureExtractionService.java create mode 100644 src/main/java/com/ai/da/service/RecommendationService.java create mode 100644 src/main/java/com/ai/da/service/impl/FeatureExtractionServiceImpl.java create mode 100644 src/main/java/com/ai/da/service/impl/RecommendationServiceImpl.java create mode 100644 src/main/java/com/ai/da/service/impl/ResNetFeatureTranslator.java diff --git a/pom.xml b/pom.xml index 5b4d83c3..32887616 100644 --- a/pom.xml +++ b/pom.xml @@ -422,6 +422,49 @@ bcpkix-jdk18on 1.78.1 + + + + io.milvus + milvus-sdk-java + 2.3.4 + + + + + ai.djl + api + 0.24.0 + + + ai.djl.pytorch + pytorch-engine + 0.24.0 + + + ai.djl.pytorch + pytorch-native-cpu + 2.0.1 + linux-x86_64 + + + ai.djl.pytorch + pytorch-native-cpu + 2.0.1 + win-x86_64 + + + ai.djl.pytorch + pytorch-native-cpu + 2.0.1 + osx-x86_64 + + + ai.djl.pytorch + pytorch-native-cpu + 2.0.1 + osx-aarch64 + diff --git a/src/main/java/com/ai/da/common/config/MilvusConfig.java b/src/main/java/com/ai/da/common/config/MilvusConfig.java new file mode 100644 index 00000000..415c0bae --- /dev/null +++ b/src/main/java/com/ai/da/common/config/MilvusConfig.java @@ -0,0 +1,45 @@ +package com.ai.da.common.config; + +import io.milvus.client.MilvusServiceClient; +import io.milvus.param.ConnectParam; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * Milvus 向量数据库配置 + */ +@Slf4j +@Configuration +public class MilvusConfig { + + @Value("${milvus.host:localhost}") + private String host; + + @Value("${milvus.port:19530}") + private Integer port; + + @Value("${milvus.username:}") + private String username; + + @Value("${milvus.password:}") + private String password; + + @Bean + public MilvusServiceClient milvusClient() { + ConnectParam.Builder builder = ConnectParam.newBuilder() + .withHost(host) + .withPort(port); + + if (StringUtils.isNotBlank(username) && StringUtils.isNotBlank(password)) { + builder.withAuthorization(username, password); + } + + MilvusServiceClient client = new MilvusServiceClient(builder.build()); + log.info("Milvus client initialized: {}:{}", host, port); + return client; + } +} + diff --git a/src/main/java/com/ai/da/mapper/primary/entity/SysFile.java b/src/main/java/com/ai/da/mapper/primary/entity/SysFile.java index a8c36e4e..0880689b 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/SysFile.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/SysFile.java @@ -72,6 +72,12 @@ public class SysFile implements Serializable { */ private String style; + /** + * 是否废弃 0-否 1-是 + */ + private Integer deprecated; + + public SysFile() { } diff --git a/src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java b/src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java new file mode 100644 index 00000000..ce95b2c7 --- /dev/null +++ b/src/main/java/com/ai/da/model/dto/RecommendRequestDTO.java @@ -0,0 +1,46 @@ +package com.ai.da.model.dto; + +import lombok.Data; + +/** + * 推荐请求 DTO + */ +@Data +public class RecommendRequestDTO { + /** + * 用户ID + */ + private Long userId; + + /** + * 类别(如 female_skirt, male_tops)- 可选,用于辅助筛选 + */ + private String category; + + /** + * 风格样式(可选)- 用于辅助筛选 + */ + private String style; + + /** + * 返回数量 + */ + private Integer topK = 10; + + /** + * 是否只返回未淘汰的图片 + */ + private Boolean onlyActive = true; + + /** + * 向量搜索的候选数量(第一阶段) + * 建议值:50-100,用于后续筛选和多样性优化 + */ + private Integer candidateSize = 50; + + /** + * 是否启用多样性优化(避免推荐过于相似的结果) + */ + private Boolean enableDiversity = true; +} + diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index a97875f5..55811b3c 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -20,6 +20,7 @@ import com.ai.da.model.vo.*; import com.ai.da.python.vo.*; import com.ai.da.service.DesignHistoryService; import com.ai.da.service.PythonTAllInfoService; +import com.ai.da.service.RecommendationService; import com.ai.da.service.SysFileService; import com.alibaba.fastjson.*; import com.alibaba.fastjson.serializer.SerializerFeature; @@ -722,6 +723,9 @@ public class PythonService { @Resource private AttributeRetrievalMapper attributeRetrievalMapper; + @Resource + private RecommendationService recommendationService; + private List getSystemSketchPool(JSONObject attributeRecognition, String styleCategory, String modelSex, int poolNum, String style) { /** * female trousers->female_pants @@ -3955,6 +3959,27 @@ public class PythonService { } public List getSystemSketchByCategory(String category, Long brandId, Double brandScale,String style) { + AuthPrincipalVo userHolder = UserContext.getUserHolder(); + + // 优先使用基于 Milvus 的推荐系统 + try { + com.ai.da.model.dto.RecommendRequestDTO request = new com.ai.da.model.dto.RecommendRequestDTO(); + request.setUserId(userHolder.getId()); + request.setCategory(category); + request.setStyle(style); + request.setTopK(1); + request.setOnlyActive(true); + + List recommendedUrls = recommendationService.recommend(request); + if (!CollectionUtils.isEmpty(recommendedUrls)) { + log.info("使用 Milvus 推荐系统返回 {} 个结果", recommendedUrls.size()); + return recommendedUrls; + } + } catch (Exception e) { + log.warn("Milvus 推荐失败,降级到备用方案: {}", e.getMessage()); + } + + // 降级方案1: 使用 attribute_retrieval_style 表 //******3.1.2版本临时使用java推荐方案去解决style未使用的问题********** try { //使用新库attribute_retrieval_style,表命名修改为elementVO.getModelSex().toLowerCase() + "_" + styleCategory.toLowerCase()比如female_skirt,与传入的category保持一致 @@ -3974,11 +3999,11 @@ public class PythonService { return Arrays.asList(path); } } catch (Exception e) { - log.info("推荐失败:{}",e.getMessage()); - throw new BusinessException("system.error"); + log.info("attribute_retrieval_style 推荐失败:{}",e.getMessage()); } //**********************end*********************************** - AuthPrincipalVo userHolder = UserContext.getUserHolder(); + + // 降级方案2: 调用 Python 服务(原有方案) OkHttpClient client = new OkHttpClient().newBuilder() .connectTimeout(30, TimeUnit.SECONDS) diff --git a/src/main/java/com/ai/da/service/FeatureExtractionService.java b/src/main/java/com/ai/da/service/FeatureExtractionService.java new file mode 100644 index 00000000..19ba5b35 --- /dev/null +++ b/src/main/java/com/ai/da/service/FeatureExtractionService.java @@ -0,0 +1,36 @@ +package com.ai.da.service; + +import java.util.List; + +/** + * 特征提取服务接口 + * 用于从图片 URL 提取 2048 维特征向量 + */ +public interface FeatureExtractionService { + + /** + * 从图片 URL 提取特征向量 + * + * @param imageUrl 图片 URL(MinIO 路径) + * @return 2048 维特征向量,如果提取失败返回 null + */ + List extractFeature(String imageUrl); + + /** + * 批量提取特征向量 + * + * @param imageUrls 图片 URL 列表 + * @return 特征向量列表,与输入顺序对应 + */ + List> extractFeatures(List imageUrls); + + /** + * 计算用户偏好向量(基于用户点赞的图片向量) + * 使用平均向量作为用户偏好 + * + * @param userLikedVectors 用户点赞的图片向量列表 + * @return 用户偏好向量(2048维) + */ + List calculateUserPreferenceVector(List> userLikedVectors); +} + diff --git a/src/main/java/com/ai/da/service/RecommendationService.java b/src/main/java/com/ai/da/service/RecommendationService.java new file mode 100644 index 00000000..bea65a40 --- /dev/null +++ b/src/main/java/com/ai/da/service/RecommendationService.java @@ -0,0 +1,34 @@ +package com.ai.da.service; + +import com.ai.da.model.dto.RecommendRequestDTO; + +import java.util.List; + +/** + * 推荐服务接口 + */ +public interface RecommendationService { + + /** + * 根据用户偏好推荐系统 sketch + * + * @param request 推荐请求 + * @return 推荐的 URL 列表 + */ + List recommend(RecommendRequestDTO request); + + /** + * 同步 t_sys_file 数据到 Milvus + * 从 t_sys_file 表读取所有系统 sketch,提取特征向量并存储到 Milvus + */ + void syncSysFileToMilvus(); + + /** + * 更新单个文件的向量(当文件更新时调用) + * + * @param sysFileId 系统文件ID + * @param url 文件URL + */ + void updateVector(Long sysFileId, String url); +} + diff --git a/src/main/java/com/ai/da/service/impl/FeatureExtractionServiceImpl.java b/src/main/java/com/ai/da/service/impl/FeatureExtractionServiceImpl.java new file mode 100644 index 00000000..565f3cc2 --- /dev/null +++ b/src/main/java/com/ai/da/service/impl/FeatureExtractionServiceImpl.java @@ -0,0 +1,220 @@ +package com.ai.da.service.impl; + +import ai.djl.Application; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.Translator; +import com.ai.da.common.utils.MinioUtil; +import com.ai.da.service.FeatureExtractionService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; + +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * 特征提取服务实现类 + * 使用 DJL (Deep Java Library) 框架加载 ResNet50 模型提取特征向量 + */ +@Slf4j +@Service +public class FeatureExtractionServiceImpl implements FeatureExtractionService { + + private static final int DIMENSION = 2048; // ResNet50 特征向量维度 + + @Autowired + private MinioUtil minioUtil; + + private ZooModel model; + private Predictor predictor; + + /** + * 初始化 ResNet50 模型 + * 加载预训练模型并移除最后的全连接层,以获取2048维特征向量 + */ + @PostConstruct + public void initModel() { + try { + log.info("开始加载 ResNet50 模型(特征提取模式,2048维)..."); + + // 1. 加载完整的预训练 ResNet50 模型 + Criteria criteria = Criteria.builder() + .optApplication(Application.CV.IMAGE_CLASSIFICATION) + .setTypes(Image.class, float[].class) + .optModelName("resnet50") + .optEngine("PyTorch") + .optProgress(new ProgressBar()) + .build(); + + // 2. 创建特征提取 Translator + Translator translator = new ResNetFeatureTranslator(); + + // 3. 加载模型 + model = ModelZoo.loadModel(criteria); + + // 4. 创建 Predictor + predictor = model.newPredictor(translator); + + log.info("ResNet50 模型加载完成(特征提取模式,2048维)"); + } catch (Exception e) { + log.error("加载 ResNet50 模型失败", e); + throw new RuntimeException("模型加载失败", e); + } + } + + /** + * 清理资源 + */ + @PreDestroy + public void cleanup() { + if (predictor != null) { + predictor.close(); + } + if (model != null) { + model.close(); + } + } + + /** + * 从图片 URL 提取特征向量 + * + * @param imageUrl MinIO 路径,格式为 "bucket_name/path/to/image.jpg" + * 例如: "aida-sys-image/images/female/blouse/0825001474.jpg" + * @return 2048 维特征向量,如果提取失败返回 null + */ + @Override + public List extractFeature(String imageUrl) { + if (predictor == null) { + log.error("模型未初始化,无法提取特征"); + return null; + } + + try { + // 1. 从 MinIO 获取图片 + Image image = getImageFromMinio(imageUrl); + if (image == null) { + log.warn("无法从 MinIO 获取图片: {}", imageUrl); + return null; + } + + // 2. 使用模型提取特征 + float[] featureArray = predictor.predict(image); + + // 3. 转换为 List + List featureVector = new ArrayList<>(DIMENSION); + for (float value : featureArray) { + featureVector.add(value); + } + + log.debug("成功提取特征向量: {}, 维度: {}", imageUrl, featureVector.size()); + return featureVector; + + } catch (Exception e) { + log.error("提取特征失败: {}", imageUrl, e); + return null; + } + } + + /** + * 从 MinIO 获取图片并转换为 DJL Image 对象 + */ + private Image getImageFromMinio(String imageUrl) { + InputStream imageStream = null; + try { + // 解析路径:bucket_name/path/to/image.jpg + int index = imageUrl.indexOf("/"); + if (index == -1 || index == imageUrl.length() - 1) { + log.warn("无效的图片路径格式: {}", imageUrl); + return null; + } + + String bucketName = imageUrl.substring(0, index); + String objectName = imageUrl.substring(index + 1); + + // 从 MinIO 获取图片流 + // 使用 MinioUtil 的 download 方法获取 InputStream + // download 方法可能抛出 MinioException 和 IOException + imageStream = minioUtil.download(bucketName, objectName); + if (imageStream == null) { + log.warn("无法从 MinIO 获取图片流: bucket={}, object={}", bucketName, objectName); + return null; + } + + // 转换为 DJL Image 对象 + Image image = ImageFactory.getInstance().fromInputStream(imageStream); + + return image; + + } catch (Exception e) { + log.error("从 MinIO 获取图片失败: {}", imageUrl, e); + return null; + } finally { + if (imageStream != null) { + try { + imageStream.close(); + } catch (Exception e) { + log.warn("关闭图片流失败", e); + } + } + } + } + + @Override + public List> extractFeatures(List imageUrls) { + if (CollectionUtils.isEmpty(imageUrls)) { + return new ArrayList<>(); + } + + return imageUrls.stream() + .map(this::extractFeature) + .filter(java.util.Objects::nonNull) + .collect(Collectors.toList()); + } + + @Override + public List calculateUserPreferenceVector(List> userLikedVectors) { + if (CollectionUtils.isEmpty(userLikedVectors)) { + return new ArrayList<>(); + } + + int dimension = userLikedVectors.get(0).size(); + List preferenceVector = new ArrayList<>(dimension); + + // 初始化为零向量 + for (int i = 0; i < dimension; i++) { + preferenceVector.add(0.0f); + } + + // 计算平均向量 + for (List vector : userLikedVectors) { + if (vector.size() != dimension) { + log.warn("向量维度不匹配: 期望 {}, 实际 {}", dimension, vector.size()); + continue; + } + + for (int i = 0; i < dimension; i++) { + preferenceVector.set(i, preferenceVector.get(i) + vector.get(i)); + } + } + + // 求平均值 + int count = userLikedVectors.size(); + for (int i = 0; i < dimension; i++) { + preferenceVector.set(i, preferenceVector.get(i) / count); + } + + return preferenceVector; + } +} + diff --git a/src/main/java/com/ai/da/service/impl/RecommendationServiceImpl.java b/src/main/java/com/ai/da/service/impl/RecommendationServiceImpl.java new file mode 100644 index 00000000..2f0db8e3 --- /dev/null +++ b/src/main/java/com/ai/da/service/impl/RecommendationServiceImpl.java @@ -0,0 +1,679 @@ +package com.ai.da.service.impl; + +import com.ai.da.common.config.MilvusConfig; +import com.ai.da.mapper.primary.SysFileMapper; +import com.ai.da.mapper.primary.UserPreferenceLogMapper; +import com.ai.da.mapper.primary.entity.SysFile; +import com.ai.da.mapper.primary.entity.UserPreferenceLogTest; +import com.ai.da.model.dto.RecommendRequestDTO; +import com.ai.da.service.FeatureExtractionService; +import com.ai.da.service.RecommendationService; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import io.milvus.client.MilvusServiceClient; +import io.milvus.grpc.DataType; +import io.milvus.grpc.SearchResults; +import io.milvus.param.IndexType; +import io.milvus.param.MetricType; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.param.collection.*; +import io.milvus.param.dml.InsertParam; +import io.milvus.param.dml.QueryParam; +import io.milvus.param.dml.SearchParam; +import io.milvus.param.index.CreateIndexParam; +import io.milvus.grpc.MutationResult; +import io.milvus.grpc.QueryResults; +import io.milvus.response.FieldDataWrapper; +import io.milvus.response.QueryResultsWrapper; +import io.milvus.response.SearchResultsWrapper; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; + +import jakarta.annotation.PostConstruct; +import jakarta.annotation.Resource; +import java.util.*; +import java.util.stream.Collectors; + +/** + * 推荐服务实现类 + * 基于 Milvus 向量数据库的推荐系统 + */ +@Slf4j +@Service +public class RecommendationServiceImpl implements RecommendationService { + + private static final String COLLECTION_NAME = "sketch_recommendation"; + private static final int DIMENSION = 2048; // ResNet50 特征向量维度 + private static final String VECTOR_FIELD = "feature_vector"; + private static final String URL_FIELD = "url"; + private static final String STYLE_FIELD = "style"; + private static final String CATEGORY_FIELD = "category"; + private static final String DEPRECATED_FIELD = "deprecated"; + private static final String SYS_FILE_ID_FIELD = "sys_file_id"; + + @Resource + private MilvusServiceClient milvusClient; + + @Resource + private SysFileMapper sysFileMapper; + + @Resource + private UserPreferenceLogMapper userPreferenceLogMapper; + + @Resource + private FeatureExtractionService featureExtractionService; + + /** + * 初始化 Collection(如果不存在则创建) + */ + @PostConstruct + public void initCollection() { + try { + // 检查 Collection 是否存在 + R hasCollectionR = milvusClient.hasCollection( + HasCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .build() + ); + + if (!hasCollectionR.getData()) { + log.info("Collection {} 不存在,开始创建...", COLLECTION_NAME); + createCollection(); + createIndex(); + log.info("Collection {} 创建完成", COLLECTION_NAME); + } else { + log.info("Collection {} 已存在", COLLECTION_NAME); + } + } catch (Exception e) { + log.error("初始化 Collection 失败", e); + } + } + + /** + * 创建 Collection + */ + private void createCollection() { + // 定义字段 + List fields = Arrays.asList( + FieldType.newBuilder() + .withName(SYS_FILE_ID_FIELD) + .withDataType(DataType.Int64) + .withPrimaryKey(true) + .withAutoID(false) + .build(), + FieldType.newBuilder() + .withName(URL_FIELD) + .withDataType(DataType.VarChar) + .withMaxLength(500) + .build(), + FieldType.newBuilder() + .withName(STYLE_FIELD) + .withDataType(DataType.VarChar) + .withMaxLength(100) + .build(), + FieldType.newBuilder() + .withName(CATEGORY_FIELD) + .withDataType(DataType.VarChar) + .withMaxLength(100) + .build(), + FieldType.newBuilder() + .withName(DEPRECATED_FIELD) + .withDataType(DataType.Int8) + .build(), + FieldType.newBuilder() + .withName(VECTOR_FIELD) + .withDataType(DataType.FloatVector) + .withDimension(DIMENSION) + .build() + ); + + // 创建 Collection + CreateCollectionParam createParam = CreateCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withDescription("系统 sketch 推荐向量库") + .withShardsNum(2) + .withFieldTypes(fields) + .build(); + + R createR = milvusClient.createCollection(createParam); + if (createR.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("创建 Collection 失败: " + createR.getMessage()); + } + } + + /** + * 创建索引 + */ + private void createIndex() { + CreateIndexParam indexParam = CreateIndexParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withFieldName(VECTOR_FIELD) + .withIndexType(IndexType.IVF_FLAT) + .withMetricType(MetricType.L2) + .withExtraParam("{\"nlist\":1024}") + .build(); + + R indexR = milvusClient.createIndex(indexParam); + if (indexR.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("创建索引失败: " + indexR.getMessage()); + } + } + + @Override + public List recommend(RecommendRequestDTO request) { + try { + // 1. 从 user_preference_log_test 获取用户点赞的 path 列表 + QueryWrapper queryWrapper = new QueryWrapper<>(); + queryWrapper.lambda().eq(UserPreferenceLogTest::getAccountId, request.getUserId()); + List userLikes = userPreferenceLogMapper.selectList(queryWrapper); + + // 新用户处理:如果没有点赞记录,使用降级推荐策略 + if (CollectionUtils.isEmpty(userLikes)) { + log.info("用户 {} 没有点赞记录,使用新用户推荐策略(基于 style 和 category)", request.getUserId()); + return recommendForNewUser(request); + } + + // 2. 提取用户点赞图片的特征向量 + List likedPaths = userLikes.stream() + .map(UserPreferenceLogTest::getPath) + .filter(StringUtils::isNotBlank) + .collect(Collectors.toList()); + + List> likedVectors = featureExtractionService.extractFeatures(likedPaths); + likedVectors = likedVectors.stream() + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + if (CollectionUtils.isEmpty(likedVectors)) { + log.warn("用户 {} 无法提取特征向量", request.getUserId()); + return Collections.emptyList(); + } + + // 3. 计算用户偏好向量(平均向量) + List userPreferenceVector = featureExtractionService.calculateUserPreferenceVector(likedVectors); + + // 4. 构建搜索参数(混合推荐策略) + // 策略:向量相似度搜索为主 + 分类筛选为辅 + + List searchOutputFields = Arrays.asList(URL_FIELD, STYLE_FIELD, CATEGORY_FIELD); + List> searchVectors = Collections.singletonList(userPreferenceVector); + + // 构建过滤表达式(基础筛选) + StringBuilder exprBuilder = new StringBuilder(); + List conditions = new ArrayList<>(); + + // 基础筛选:只返回未淘汰的图片 + if (request.getOnlyActive() != null && request.getOnlyActive()) { + conditions.add(DEPRECATED_FIELD + " == 0"); + } + + // 可选筛选:风格(如果指定) + if (StringUtils.isNotBlank(request.getStyle())) { + conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\""); + } + if (StringUtils.isNotBlank(request.getCategory())) { + conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\""); + } + + if (!conditions.isEmpty()) { + exprBuilder.append(String.join(" && ", conditions)); + } + + // 第一阶段:向量搜索获取大量候选(用于后续筛选和多样性优化) + int candidateSize = request.getCandidateSize() != null ? + Math.max(request.getCandidateSize(), request.getTopK() * 3) : + Math.max(50, request.getTopK() * 5); + + SearchParam searchParam = SearchParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withMetricType(MetricType.L2) // L2距离,越小越相似 + .withOutFields(searchOutputFields) + .withTopK(candidateSize) // 获取更多候选用于后续处理 + .withVectors(searchVectors) + .withVectorFieldName(VECTOR_FIELD) + .withExpr(exprBuilder.length() > 0 ? exprBuilder.toString() : null) + .withParams("{\"nprobe\":10}") + .build(); + + // 5. 执行搜索 + R searchR = milvusClient.search(searchParam); + if (searchR.getStatus() != R.Status.Success.getCode()) { + log.error("搜索失败: {}", searchR.getMessage()); + return Collections.emptyList(); + } + + // 6. 解析结果并进行后处理 + SearchResultsWrapper wrapper = new SearchResultsWrapper(searchR.getData().getResults()); + + // 6.1 提取候选结果(包含相似度分数) + List candidates = new ArrayList<>(); + FieldDataWrapper urlFieldWrapper = wrapper.getFieldWrapper(URL_FIELD); + FieldDataWrapper styleFieldWrapper = wrapper.getFieldWrapper(STYLE_FIELD); + List idScores = wrapper.getIDScore(0); + + for (int i = 0; i < idScores.size(); i++) { + String url = (String) urlFieldWrapper.getFieldData().get(i); + String style = (String) styleFieldWrapper.getFieldData().get(i); + Float score = idScores.get(i).getScore(); // L2距离,越小越相似 + + if (StringUtils.isNotBlank(url)) { + candidates.add(new RecommendationCandidate(url, style, score)); + } + } + +// // 6.2 分类筛选(如果指定了 category) +// if (StringUtils.isNotBlank(request.getCategory())) { +// candidates = filterByCategory(candidates, request.getCategory()); +// } + + // 6.3 多样性优化(避免推荐过于相似的结果) + if (request.getEnableDiversity() != null && request.getEnableDiversity()) { + candidates = diversifyResults(candidates, request.getTopK() != null ? request.getTopK() : 10); + } + + // 6.4 取 Top-K + int topK = request.getTopK() != null ? request.getTopK() : 1; + List recommendedUrls = candidates.stream() + .limit(topK) + .map(RecommendationCandidate::getUrl) + .collect(Collectors.toList()); + + log.info("为用户 {} 推荐了 {} 个结果(从 {} 个候选中筛选)", + request.getUserId(), recommendedUrls.size(), idScores.size()); + return recommendedUrls; + + } catch (Exception e) { + log.error("推荐失败", e); + return Collections.emptyList(); + } + } + + @Override + public void syncSysFileToMilvus() { + log.info("开始同步 t_sys_file 数据到 Milvus"); + + try { + // 1. 加载 Collection + LoadCollectionParam loadParam = LoadCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .build(); + milvusClient.loadCollection(loadParam); + + // 2. 查询所有系统 sketch + QueryWrapper queryWrapper = new QueryWrapper<>(); + queryWrapper.lambda() + .eq(SysFile::getLevel1Type, "Images").eq(SysFile::getDeprecated, 0) + .isNotNull(SysFile::getUrl) + .isNotNull(SysFile::getStyle) + .ne(SysFile::getUrl, ""); + List sysFiles = sysFileMapper.selectList(queryWrapper); + + log.info("找到 {} 个系统文件需要同步", sysFiles.size()); + + int successCount = 0; + int failCount = 0; + int batchSize = 100; + + // 3. 批量处理 + for (int i = 0; i < sysFiles.size(); i += batchSize) { + int end = Math.min(i + batchSize, sysFiles.size()); + List batch = sysFiles.subList(i, end); + + List ids = new ArrayList<>(); + List urls = new ArrayList<>(); + List styles = new ArrayList<>(); + List categories = new ArrayList<>(); + List deprecatedList = new ArrayList<>(); + List> vectors = new ArrayList<>(); + + for (SysFile sysFile : batch) { + try { + // 提取特征向量 + List vector = featureExtractionService.extractFeature(sysFile.getUrl()); + if (vector == null || vector.size() != DIMENSION) { + log.warn("文件 {} 特征提取失败,跳过", sysFile.getUrl()); + failCount++; + continue; + } + + ids.add(sysFile.getId()); + urls.add(sysFile.getUrl()); + styles.add(sysFile.getStyle() != null ? sysFile.getStyle() : ""); + categories.add(resolveCategory(sysFile)); + deprecatedList.add((byte) (sysFile.getDeprecated() != null ? sysFile.getDeprecated() : 0)); + vectors.add(vector); + successCount++; + + } catch (Exception e) { + log.error("处理文件 {} 失败", sysFile.getUrl(), e); + failCount++; + } + } + + // 批量插入 + if (!ids.isEmpty()) { + List fields = Arrays.asList( + new InsertParam.Field(SYS_FILE_ID_FIELD, ids), + new InsertParam.Field(URL_FIELD, urls), + new InsertParam.Field(STYLE_FIELD, styles), + new InsertParam.Field(DEPRECATED_FIELD, deprecatedList), + new InsertParam.Field(CATEGORY_FIELD, categories), + new InsertParam.Field(VECTOR_FIELD, vectors) + ); + + InsertParam insertParam = InsertParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withFields(fields) + .build(); + + R insertR = milvusClient.insert(insertParam); + if (insertR.getStatus() == R.Status.Success.getCode()) { + log.info("批量插入成功: {}/{}", end, sysFiles.size()); + } else { + log.error("批量插入失败: {}", insertR.getMessage()); + } + } + } + + // 4. 刷新数据 + FlushParam flushParam = FlushParam.newBuilder() + .withCollectionNames(Collections.singletonList(COLLECTION_NAME)) + .build(); + milvusClient.flush(flushParam); + + log.info("同步完成: 成功 {}, 失败 {}", successCount, failCount); + + } catch (Exception e) { + log.error("同步数据到 Milvus 失败", e); + throw new RuntimeException("同步失败", e); + } + } + + @Override + public void updateVector(Long sysFileId, String url) { + // TODO: 实现单个向量更新逻辑 + // 需要先删除旧向量,再插入新向量 + log.info("更新向量: sysFileId={}, url={}", sysFileId, url); + } + + /** + * 新用户推荐策略:当用户没有点赞记录时,根据 style 和 category 随机推荐 + * + * @param request 推荐请求 + * @return 推荐的 URL 列表 + */ + private List recommendForNewUser(RecommendRequestDTO request) { + try { + int topK = request.getTopK() != null ? request.getTopK() : 10; + + List milvusResults = recommendFromMilvusForNewUser(request, topK); + if (CollectionUtils.isEmpty(milvusResults)) { + return Collections.emptyList(); + } + Collections.shuffle(milvusResults); + return milvusResults.stream().limit(topK).collect(Collectors.toList()); + + } catch (Exception e) { + log.error("新用户推荐失败", e); + return Collections.emptyList(); + } + } + + /** + * 从 Milvus 中为新用户推荐(根据 style 和 category 查询) + */ + private List recommendFromMilvusForNewUser(RecommendRequestDTO request, int limit) { + try { + // 构建过滤表达式 + StringBuilder exprBuilder = new StringBuilder(); + List conditions = new ArrayList<>(); + + // 基础筛选:只返回未淘汰的图片 + if (request.getOnlyActive() != null && request.getOnlyActive()) { + conditions.add(DEPRECATED_FIELD + " == 0"); + } + + // 风格筛选(如果指定) + if (StringUtils.isNotBlank(request.getStyle())) { + conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\""); + } + if (StringUtils.isNotBlank(request.getCategory())) { + conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\""); + } + + if (conditions.isEmpty()) { + // 如果没有筛选条件,查询所有未淘汰的 + conditions.add(DEPRECATED_FIELD + " == 0"); + } + + exprBuilder.append(String.join(" && ", conditions)); + + // 从 Milvus 查询符合条件的记录 + QueryParam queryParam = QueryParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withExpr(exprBuilder.toString()) + .withOutFields(Arrays.asList(URL_FIELD)) + .withLimit((long) limit) + .build(); + + R queryR = milvusClient.query(queryParam); + if (queryR.getStatus() != R.Status.Success.getCode()) { + log.warn("从 Milvus 查询失败: {}", queryR.getMessage()); + return Collections.emptyList(); + } + + // 解析结果 + List urls = new ArrayList<>(); + QueryResultsWrapper wrapper = new QueryResultsWrapper(queryR.getData()); + for (int i = 0; i < wrapper.getRowRecords().size(); i++) { + String url = (String) wrapper.getFieldWrapper(URL_FIELD).getFieldData().get(i); + if (StringUtils.isNotBlank(url)) { + urls.add(url); + } + } + +// // 如果指定了 category,进行筛选 +// if (StringUtils.isNotBlank(request.getCategory()) && !urls.isEmpty()) { +// urls = urls.stream() +// .filter(url -> { +// // 从 URL 中提取类别信息 +// // URL 格式: aida-sys-image/images/{sex}/{category}/{filename} +// String[] parts = url.split("/"); +// if (parts.length >= 4) { +// String sex = parts[parts.length - 3]; +// String cat = parts[parts.length - 2]; +// String expectedCategory = sex + "_" + cat; +// return expectedCategory.equals(request.getCategory()); +// } +// return false; +// }) +// .collect(Collectors.toList()); +// } + + log.info("从 Milvus 为新用户查询到 {} 条记录", urls.size()); + return urls; + + } catch (Exception e) { + log.error("从 Milvus 查询新用户推荐失败", e); + return Collections.emptyList(); + } + } + + /** + * 从 MySQL 中为新用户推荐(根据 style 和 category 查询) + */ + private List recommendFromMySQLForNewUser(RecommendRequestDTO request, int topK) { + try { + QueryWrapper queryWrapper = new QueryWrapper<>(); + queryWrapper.lambda() + .eq(SysFile::getLevel1Type, "Images") + .isNotNull(SysFile::getUrl) + .ne(SysFile::getUrl, ""); + + // 风格筛选(如果指定) + if (StringUtils.isNotBlank(request.getStyle())) { + queryWrapper.lambda().eq(SysFile::getStyle, request.getStyle()); + } + + // 类别筛选(如果指定) + if (StringUtils.isNotBlank(request.getCategory())) { + // category 格式: female_skirt, male_tops 等 + String[] parts = request.getCategory().split("_"); + if (parts.length == 2) { + String sex = parts[0]; // female 或 male + String category = parts[1]; // skirt, tops 等 + + // 从 URL 中筛选类别 + // URL 格式: aida-sys-image/images/{sex}/{category}/{filename} + queryWrapper.lambda().like(SysFile::getUrl, "/" + sex + "/" + category + "/"); + } + } + + // 查询所有符合条件的记录 + List sysFiles = sysFileMapper.selectList(queryWrapper); + + if (CollectionUtils.isEmpty(sysFiles)) { + log.warn("MySQL 中未找到符合条件的记录"); + return Collections.emptyList(); + } + + // 随机选择 + Collections.shuffle(sysFiles); + + List urls = sysFiles.stream() + .map(SysFile::getUrl) + .filter(StringUtils::isNotBlank) + .limit(topK) + .collect(Collectors.toList()); + + log.info("从 MySQL 为新用户查询到 {} 条记录", urls.size()); + return urls; + + } catch (Exception e) { + log.error("从 MySQL 查询新用户推荐失败", e); + return Collections.emptyList(); + } + } + + /** + * 根据类别筛选候选结果 + */ + private List filterByCategory( + List candidates, String category) { + // 从 URL 中提取类别信息 + // URL 格式: aida-sys-image/images/{sex}/{category}/{filename} + return candidates.stream() + .filter(candidate -> { + String url = candidate.getUrl(); + if (StringUtils.isBlank(url) || !url.contains("/")) { + return false; + } + // 提取类别部分 + String[] parts = url.split("/"); + if (parts.length < 4) { + return false; + } + String sex = parts[parts.length - 3]; // female 或 male + String cat = parts[parts.length - 2]; // category + String expectedCategory = sex + "_" + cat; + return expectedCategory.equals(category); + }) + .collect(Collectors.toList()); + } + + /** + * 多样性优化:避免推荐过于相似的结果 + * 使用聚类或相似度阈值来增加多样性 + */ + private List diversifyResults( + List candidates, int targetSize) { + if (candidates.size() <= targetSize) { + return candidates; + } + + // 简单策略:按相似度分数排序,然后每隔几个取一个,增加多样性 + // 更复杂的策略可以使用聚类算法 + + List diversified = new ArrayList<>(); + int step = Math.max(1, candidates.size() / targetSize); + + for (int i = 0; i < candidates.size() && diversified.size() < targetSize; i += step) { + diversified.add(candidates.get(i)); + } + + // 如果还不够,补充剩余的结果 + if (diversified.size() < targetSize) { + for (RecommendationCandidate candidate : candidates) { + if (diversified.size() >= targetSize) { + break; + } + if (!diversified.contains(candidate)) { + diversified.add(candidate); + } + } + } + + return diversified; + } + + /** + * 推荐候选结果内部类 + */ + private static class RecommendationCandidate { + private final String url; + private final String style; + private final Float similarityScore; // L2距离,越小越相似 + + public RecommendationCandidate(String url, String style, Float similarityScore) { + this.url = url; + this.style = style; + this.similarityScore = similarityScore; + } + + public String getUrl() { + return url; + } + + public String getStyle() { + return style; + } + + public Float getSimilarityScore() { + return similarityScore; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecommendationCandidate that = (RecommendationCandidate) o; + return Objects.equals(url, that.url); + } + + @Override + public int hashCode() { + return Objects.hash(url); + } + } + + private String resolveCategory(SysFile sysFile) { + String level3 = sysFile.getLevel3Type(); + String level2 = sysFile.getLevel2Type(); + if (StringUtils.isNotBlank(level3) && StringUtils.isNotBlank(level2)) { + return (level3 + "_" + level2).toLowerCase(); + } + String url = sysFile.getUrl(); + if (StringUtils.isNotBlank(url)) { + String[] parts = url.split("/"); + if (parts.length >= 4) { + String sex = parts[parts.length - 3]; + String category = parts[parts.length - 2]; + return (sex + "_" + category).toLowerCase(); + } + } + return ""; + } +} + diff --git a/src/main/java/com/ai/da/service/impl/ResNetFeatureTranslator.java b/src/main/java/com/ai/da/service/impl/ResNetFeatureTranslator.java new file mode 100644 index 00000000..57fd4897 --- /dev/null +++ b/src/main/java/com/ai/da/service/impl/ResNetFeatureTranslator.java @@ -0,0 +1,83 @@ +package com.ai.da.service.impl; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.transform.Normalize; +import ai.djl.modality.cv.transform.Resize; +import ai.djl.modality.cv.transform.ToTensor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.Batchifier; +import ai.djl.translate.Pipeline; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import lombok.extern.slf4j.Slf4j; + +/** + * ResNet 特征提取 Translator + * 移除最后的全连接层,只保留特征提取部分(2048维) + */ +@Slf4j +public class ResNetFeatureTranslator implements Translator { + + private static final int IMAGE_SIZE = 224; + private static final Pipeline PIPELINE = new Pipeline() + .add(new Resize(IMAGE_SIZE, IMAGE_SIZE)) + .add(new ToTensor()) + .add(new Normalize( + new float[]{0.485f, 0.456f, 0.406f}, + new float[]{0.229f, 0.224f, 0.225f} + )); + + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDManager manager = ctx.getNDManager(); + NDArray array = input.toNDArray(manager); + NDList ndList = new NDList(array); + return PIPELINE.transform(ndList); + } + + @Override + public float[] processOutput(TranslatorContext ctx, NDList list) { + NDArray output = list.singletonOrThrow(); + + // ResNet50 特征层输出形状应该是: + // - [1, 2048, 1, 1] - 全局平均池化后的特征(需要 flatten) + // - [1, 2048] - 已 flatten 的特征层输出 + + long[] shape = output.getShape().getShape(); + + // 如果是 4 维,说明是 [1, 2048, 1, 1],需要 flatten + if (shape.length == 4) { + output = output.flatten(); + } + // 如果是 2 维,检查维度是否正确 + else if (shape.length == 2) { + if (shape[1] == 1000) { + // 如果仍然是分类层输出(1000维),说明模型加载有问题 + log.error("模型输出仍然是分类层(1000维),模型加载可能失败"); + throw new IllegalStateException("模型输出维度错误:期望 2048 维,实际 1000 维。请检查模型是否正确移除了全连接层。"); + } + // 如果已经是 [1, 2048],直接使用 + } + + // 转换为 float 数组 + float[] result = output.toFloatArray(); + + // 验证维度必须是 2048 + if (result.length != 2048) { + log.error("特征向量维度错误: 期望 2048, 实际 {}", result.length); + throw new IllegalArgumentException( + String.format("特征向量维度错误: 期望 2048, 实际 %d", result.length) + ); + } + + return result; + } + + @Override + public Batchifier getBatchifier() { + return Batchifier.STACK; + } +} +