特征推荐初次提交

This commit is contained in:
litianxiang
2025-11-28 09:36:04 +08:00
parent 1c782f8fd7
commit ca416fed9d
10 changed files with 1220 additions and 3 deletions

43
pom.xml
View File

@@ -422,6 +422,49 @@
<artifactId>bcpkix-jdk18on</artifactId>
<version>1.78.1</version>
</dependency>
<!-- Milvus Java SDK -->
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>2.3.4</version>
</dependency>
<!-- DJL (Deep Java Library) for Deep Learning -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.24.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.24.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>2.0.1</version>
<classifier>linux-x86_64</classifier>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>2.0.1</version>
<classifier>win-x86_64</classifier>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>2.0.1</version>
<classifier>osx-x86_64</classifier>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>2.0.1</version>
<classifier>osx-aarch64</classifier>
</dependency>
</dependencies>
<build>

View File

@@ -0,0 +1,45 @@
package com.ai.da.common.config;
import io.milvus.client.MilvusServiceClient;
import io.milvus.param.ConnectParam;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* Milvus 向量数据库配置
*/
@Slf4j
@Configuration
public class MilvusConfig {
@Value("${milvus.host:localhost}")
private String host;
@Value("${milvus.port:19530}")
private Integer port;
@Value("${milvus.username:}")
private String username;
@Value("${milvus.password:}")
private String password;
@Bean
public MilvusServiceClient milvusClient() {
ConnectParam.Builder builder = ConnectParam.newBuilder()
.withHost(host)
.withPort(port);
if (StringUtils.isNotBlank(username) && StringUtils.isNotBlank(password)) {
builder.withAuthorization(username, password);
}
MilvusServiceClient client = new MilvusServiceClient(builder.build());
log.info("Milvus client initialized: {}:{}", host, port);
return client;
}
}

View File

@@ -72,6 +72,12 @@ public class SysFile implements Serializable {
*/
private String style;
/**
* 是否废弃 0-否 1-是
*/
private Integer deprecated;
public SysFile() {
}

View File

@@ -0,0 +1,46 @@
package com.ai.da.model.dto;
import lombok.Data;
/**
* 推荐请求 DTO
*/
@Data
public class RecommendRequestDTO {
/**
* 用户ID
*/
private Long userId;
/**
* 类别(如 female_skirt, male_tops- 可选,用于辅助筛选
*/
private String category;
/**
* 风格样式(可选)- 用于辅助筛选
*/
private String style;
/**
* 返回数量
*/
private Integer topK = 10;
/**
* 是否只返回未淘汰的图片
*/
private Boolean onlyActive = true;
/**
* 向量搜索的候选数量(第一阶段)
* 建议值50-100用于后续筛选和多样性优化
*/
private Integer candidateSize = 50;
/**
* 是否启用多样性优化(避免推荐过于相似的结果)
*/
private Boolean enableDiversity = true;
}

View File

@@ -20,6 +20,7 @@ import com.ai.da.model.vo.*;
import com.ai.da.python.vo.*;
import com.ai.da.service.DesignHistoryService;
import com.ai.da.service.PythonTAllInfoService;
import com.ai.da.service.RecommendationService;
import com.ai.da.service.SysFileService;
import com.alibaba.fastjson.*;
import com.alibaba.fastjson.serializer.SerializerFeature;
@@ -722,6 +723,9 @@ public class PythonService {
@Resource
private AttributeRetrievalMapper attributeRetrievalMapper;
@Resource
private RecommendationService recommendationService;
private List<CollectionElement> getSystemSketchPool(JSONObject attributeRecognition, String styleCategory, String modelSex, int poolNum, String style) {
/**
* female trousers->female_pants
@@ -3955,6 +3959,27 @@ public class PythonService {
}
public List<String> getSystemSketchByCategory(String category, Long brandId, Double brandScale,String style) {
AuthPrincipalVo userHolder = UserContext.getUserHolder();
// 优先使用基于 Milvus 的推荐系统
try {
com.ai.da.model.dto.RecommendRequestDTO request = new com.ai.da.model.dto.RecommendRequestDTO();
request.setUserId(userHolder.getId());
request.setCategory(category);
request.setStyle(style);
request.setTopK(1);
request.setOnlyActive(true);
List<String> recommendedUrls = recommendationService.recommend(request);
if (!CollectionUtils.isEmpty(recommendedUrls)) {
log.info("使用 Milvus 推荐系统返回 {} 个结果", recommendedUrls.size());
return recommendedUrls;
}
} catch (Exception e) {
log.warn("Milvus 推荐失败,降级到备用方案: {}", e.getMessage());
}
// 降级方案1: 使用 attribute_retrieval_style 表
//******3.1.2版本临时使用java推荐方案去解决style未使用的问题**********
try {
//使用新库attribute_retrieval_style表命名修改为elementVO.getModelSex().toLowerCase() + "_" + styleCategory.toLowerCase()比如female_skirt,与传入的category保持一致
@@ -3974,11 +3999,11 @@ public class PythonService {
return Arrays.asList(path);
}
} catch (Exception e) {
log.info("推荐失败:{}",e.getMessage());
throw new BusinessException("system.error");
log.info("attribute_retrieval_style 推荐失败:{}",e.getMessage());
}
//**********************end***********************************
AuthPrincipalVo userHolder = UserContext.getUserHolder();
// 降级方案2: 调用 Python 服务(原有方案)
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)

View File

@@ -0,0 +1,36 @@
package com.ai.da.service;
import java.util.List;
/**
* 特征提取服务接口
* 用于从图片 URL 提取 2048 维特征向量
*/
public interface FeatureExtractionService {
/**
* 从图片 URL 提取特征向量
*
* @param imageUrl 图片 URLMinIO 路径)
* @return 2048 维特征向量,如果提取失败返回 null
*/
List<Float> extractFeature(String imageUrl);
/**
* 批量提取特征向量
*
* @param imageUrls 图片 URL 列表
* @return 特征向量列表,与输入顺序对应
*/
List<List<Float>> extractFeatures(List<String> imageUrls);
/**
* 计算用户偏好向量(基于用户点赞的图片向量)
* 使用平均向量作为用户偏好
*
* @param userLikedVectors 用户点赞的图片向量列表
* @return 用户偏好向量2048维
*/
List<Float> calculateUserPreferenceVector(List<List<Float>> userLikedVectors);
}

View File

@@ -0,0 +1,34 @@
package com.ai.da.service;
import com.ai.da.model.dto.RecommendRequestDTO;
import java.util.List;
/**
* 推荐服务接口
*/
public interface RecommendationService {
/**
* 根据用户偏好推荐系统 sketch
*
* @param request 推荐请求
* @return 推荐的 URL 列表
*/
List<String> recommend(RecommendRequestDTO request);
/**
* 同步 t_sys_file 数据到 Milvus
* 从 t_sys_file 表读取所有系统 sketch提取特征向量并存储到 Milvus
*/
void syncSysFileToMilvus();
/**
* 更新单个文件的向量(当文件更新时调用)
*
* @param sysFileId 系统文件ID
* @param url 文件URL
*/
void updateVector(Long sysFileId, String url);
}

View File

@@ -0,0 +1,220 @@
package com.ai.da.service.impl;
import ai.djl.Application;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Translator;
import com.ai.da.common.utils.MinioUtil;
import com.ai.da.service.FeatureExtractionService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* 特征提取服务实现类
* 使用 DJL (Deep Java Library) 框架加载 ResNet50 模型提取特征向量
*/
@Slf4j
@Service
public class FeatureExtractionServiceImpl implements FeatureExtractionService {
private static final int DIMENSION = 2048; // ResNet50 特征向量维度
@Autowired
private MinioUtil minioUtil;
private ZooModel<Image, float[]> model;
private Predictor<Image, float[]> predictor;
/**
* 初始化 ResNet50 模型
* 加载预训练模型并移除最后的全连接层以获取2048维特征向量
*/
@PostConstruct
public void initModel() {
try {
log.info("开始加载 ResNet50 模型特征提取模式2048维...");
// 1. 加载完整的预训练 ResNet50 模型
Criteria<Image, float[]> criteria = Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, float[].class)
.optModelName("resnet50")
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.build();
// 2. 创建特征提取 Translator
Translator<Image, float[]> translator = new ResNetFeatureTranslator();
// 3. 加载模型
model = ModelZoo.loadModel(criteria);
// 4. 创建 Predictor
predictor = model.newPredictor(translator);
log.info("ResNet50 模型加载完成特征提取模式2048维");
} catch (Exception e) {
log.error("加载 ResNet50 模型失败", e);
throw new RuntimeException("模型加载失败", e);
}
}
/**
* 清理资源
*/
@PreDestroy
public void cleanup() {
if (predictor != null) {
predictor.close();
}
if (model != null) {
model.close();
}
}
/**
* 从图片 URL 提取特征向量
*
* @param imageUrl MinIO 路径,格式为 "bucket_name/path/to/image.jpg"
* 例如: "aida-sys-image/images/female/blouse/0825001474.jpg"
* @return 2048 维特征向量,如果提取失败返回 null
*/
@Override
public List<Float> extractFeature(String imageUrl) {
if (predictor == null) {
log.error("模型未初始化,无法提取特征");
return null;
}
try {
// 1. 从 MinIO 获取图片
Image image = getImageFromMinio(imageUrl);
if (image == null) {
log.warn("无法从 MinIO 获取图片: {}", imageUrl);
return null;
}
// 2. 使用模型提取特征
float[] featureArray = predictor.predict(image);
// 3. 转换为 List<Float>
List<Float> featureVector = new ArrayList<>(DIMENSION);
for (float value : featureArray) {
featureVector.add(value);
}
log.debug("成功提取特征向量: {}, 维度: {}", imageUrl, featureVector.size());
return featureVector;
} catch (Exception e) {
log.error("提取特征失败: {}", imageUrl, e);
return null;
}
}
/**
* 从 MinIO 获取图片并转换为 DJL Image 对象
*/
private Image getImageFromMinio(String imageUrl) {
InputStream imageStream = null;
try {
// 解析路径bucket_name/path/to/image.jpg
int index = imageUrl.indexOf("/");
if (index == -1 || index == imageUrl.length() - 1) {
log.warn("无效的图片路径格式: {}", imageUrl);
return null;
}
String bucketName = imageUrl.substring(0, index);
String objectName = imageUrl.substring(index + 1);
// 从 MinIO 获取图片流
// 使用 MinioUtil 的 download 方法获取 InputStream
// download 方法可能抛出 MinioException 和 IOException
imageStream = minioUtil.download(bucketName, objectName);
if (imageStream == null) {
log.warn("无法从 MinIO 获取图片流: bucket={}, object={}", bucketName, objectName);
return null;
}
// 转换为 DJL Image 对象
Image image = ImageFactory.getInstance().fromInputStream(imageStream);
return image;
} catch (Exception e) {
log.error("从 MinIO 获取图片失败: {}", imageUrl, e);
return null;
} finally {
if (imageStream != null) {
try {
imageStream.close();
} catch (Exception e) {
log.warn("关闭图片流失败", e);
}
}
}
}
@Override
public List<List<Float>> extractFeatures(List<String> imageUrls) {
if (CollectionUtils.isEmpty(imageUrls)) {
return new ArrayList<>();
}
return imageUrls.stream()
.map(this::extractFeature)
.filter(java.util.Objects::nonNull)
.collect(Collectors.toList());
}
@Override
public List<Float> calculateUserPreferenceVector(List<List<Float>> userLikedVectors) {
if (CollectionUtils.isEmpty(userLikedVectors)) {
return new ArrayList<>();
}
int dimension = userLikedVectors.get(0).size();
List<Float> preferenceVector = new ArrayList<>(dimension);
// 初始化为零向量
for (int i = 0; i < dimension; i++) {
preferenceVector.add(0.0f);
}
// 计算平均向量
for (List<Float> vector : userLikedVectors) {
if (vector.size() != dimension) {
log.warn("向量维度不匹配: 期望 {}, 实际 {}", dimension, vector.size());
continue;
}
for (int i = 0; i < dimension; i++) {
preferenceVector.set(i, preferenceVector.get(i) + vector.get(i));
}
}
// 求平均值
int count = userLikedVectors.size();
for (int i = 0; i < dimension; i++) {
preferenceVector.set(i, preferenceVector.get(i) / count);
}
return preferenceVector;
}
}

View File

@@ -0,0 +1,679 @@
package com.ai.da.service.impl;
import com.ai.da.common.config.MilvusConfig;
import com.ai.da.mapper.primary.SysFileMapper;
import com.ai.da.mapper.primary.UserPreferenceLogMapper;
import com.ai.da.mapper.primary.entity.SysFile;
import com.ai.da.mapper.primary.entity.UserPreferenceLogTest;
import com.ai.da.model.dto.RecommendRequestDTO;
import com.ai.da.service.FeatureExtractionService;
import com.ai.da.service.RecommendationService;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.DataType;
import io.milvus.grpc.SearchResults;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.RpcStatus;
import io.milvus.param.collection.*;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.QueryParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.grpc.MutationResult;
import io.milvus.grpc.QueryResults;
import io.milvus.response.FieldDataWrapper;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.response.SearchResultsWrapper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import java.util.*;
import java.util.stream.Collectors;
/**
* 推荐服务实现类
* 基于 Milvus 向量数据库的推荐系统
*/
@Slf4j
@Service
public class RecommendationServiceImpl implements RecommendationService {
private static final String COLLECTION_NAME = "sketch_recommendation";
private static final int DIMENSION = 2048; // ResNet50 特征向量维度
private static final String VECTOR_FIELD = "feature_vector";
private static final String URL_FIELD = "url";
private static final String STYLE_FIELD = "style";
private static final String CATEGORY_FIELD = "category";
private static final String DEPRECATED_FIELD = "deprecated";
private static final String SYS_FILE_ID_FIELD = "sys_file_id";
@Resource
private MilvusServiceClient milvusClient;
@Resource
private SysFileMapper sysFileMapper;
@Resource
private UserPreferenceLogMapper userPreferenceLogMapper;
@Resource
private FeatureExtractionService featureExtractionService;
/**
* 初始化 Collection如果不存在则创建
*/
@PostConstruct
public void initCollection() {
try {
// 检查 Collection 是否存在
R<Boolean> hasCollectionR = milvusClient.hasCollection(
HasCollectionParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.build()
);
if (!hasCollectionR.getData()) {
log.info("Collection {} 不存在,开始创建...", COLLECTION_NAME);
createCollection();
createIndex();
log.info("Collection {} 创建完成", COLLECTION_NAME);
} else {
log.info("Collection {} 已存在", COLLECTION_NAME);
}
} catch (Exception e) {
log.error("初始化 Collection 失败", e);
}
}
/**
* 创建 Collection
*/
private void createCollection() {
// 定义字段
List<FieldType> fields = Arrays.asList(
FieldType.newBuilder()
.withName(SYS_FILE_ID_FIELD)
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(false)
.build(),
FieldType.newBuilder()
.withName(URL_FIELD)
.withDataType(DataType.VarChar)
.withMaxLength(500)
.build(),
FieldType.newBuilder()
.withName(STYLE_FIELD)
.withDataType(DataType.VarChar)
.withMaxLength(100)
.build(),
FieldType.newBuilder()
.withName(CATEGORY_FIELD)
.withDataType(DataType.VarChar)
.withMaxLength(100)
.build(),
FieldType.newBuilder()
.withName(DEPRECATED_FIELD)
.withDataType(DataType.Int8)
.build(),
FieldType.newBuilder()
.withName(VECTOR_FIELD)
.withDataType(DataType.FloatVector)
.withDimension(DIMENSION)
.build()
);
// 创建 Collection
CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withDescription("系统 sketch 推荐向量库")
.withShardsNum(2)
.withFieldTypes(fields)
.build();
R<RpcStatus> createR = milvusClient.createCollection(createParam);
if (createR.getStatus() != R.Status.Success.getCode()) {
throw new RuntimeException("创建 Collection 失败: " + createR.getMessage());
}
}
/**
* 创建索引
*/
private void createIndex() {
CreateIndexParam indexParam = CreateIndexParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withFieldName(VECTOR_FIELD)
.withIndexType(IndexType.IVF_FLAT)
.withMetricType(MetricType.L2)
.withExtraParam("{\"nlist\":1024}")
.build();
R<RpcStatus> indexR = milvusClient.createIndex(indexParam);
if (indexR.getStatus() != R.Status.Success.getCode()) {
throw new RuntimeException("创建索引失败: " + indexR.getMessage());
}
}
@Override
public List<String> recommend(RecommendRequestDTO request) {
try {
// 1. 从 user_preference_log_test 获取用户点赞的 path 列表
QueryWrapper<UserPreferenceLogTest> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(UserPreferenceLogTest::getAccountId, request.getUserId());
List<UserPreferenceLogTest> userLikes = userPreferenceLogMapper.selectList(queryWrapper);
// 新用户处理:如果没有点赞记录,使用降级推荐策略
if (CollectionUtils.isEmpty(userLikes)) {
log.info("用户 {} 没有点赞记录,使用新用户推荐策略(基于 style 和 category", request.getUserId());
return recommendForNewUser(request);
}
// 2. 提取用户点赞图片的特征向量
List<String> likedPaths = userLikes.stream()
.map(UserPreferenceLogTest::getPath)
.filter(StringUtils::isNotBlank)
.collect(Collectors.toList());
List<List<Float>> likedVectors = featureExtractionService.extractFeatures(likedPaths);
likedVectors = likedVectors.stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(likedVectors)) {
log.warn("用户 {} 无法提取特征向量", request.getUserId());
return Collections.emptyList();
}
// 3. 计算用户偏好向量(平均向量)
List<Float> userPreferenceVector = featureExtractionService.calculateUserPreferenceVector(likedVectors);
// 4. 构建搜索参数(混合推荐策略)
// 策略:向量相似度搜索为主 + 分类筛选为辅
List<String> searchOutputFields = Arrays.asList(URL_FIELD, STYLE_FIELD, CATEGORY_FIELD);
List<List<Float>> searchVectors = Collections.singletonList(userPreferenceVector);
// 构建过滤表达式(基础筛选)
StringBuilder exprBuilder = new StringBuilder();
List<String> conditions = new ArrayList<>();
// 基础筛选:只返回未淘汰的图片
if (request.getOnlyActive() != null && request.getOnlyActive()) {
conditions.add(DEPRECATED_FIELD + " == 0");
}
// 可选筛选:风格(如果指定)
if (StringUtils.isNotBlank(request.getStyle())) {
conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\"");
}
if (StringUtils.isNotBlank(request.getCategory())) {
conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\"");
}
if (!conditions.isEmpty()) {
exprBuilder.append(String.join(" && ", conditions));
}
// 第一阶段:向量搜索获取大量候选(用于后续筛选和多样性优化)
int candidateSize = request.getCandidateSize() != null ?
Math.max(request.getCandidateSize(), request.getTopK() * 3) :
Math.max(50, request.getTopK() * 5);
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withMetricType(MetricType.L2) // L2距离越小越相似
.withOutFields(searchOutputFields)
.withTopK(candidateSize) // 获取更多候选用于后续处理
.withVectors(searchVectors)
.withVectorFieldName(VECTOR_FIELD)
.withExpr(exprBuilder.length() > 0 ? exprBuilder.toString() : null)
.withParams("{\"nprobe\":10}")
.build();
// 5. 执行搜索
R<SearchResults> searchR = milvusClient.search(searchParam);
if (searchR.getStatus() != R.Status.Success.getCode()) {
log.error("搜索失败: {}", searchR.getMessage());
return Collections.emptyList();
}
// 6. 解析结果并进行后处理
SearchResultsWrapper wrapper = new SearchResultsWrapper(searchR.getData().getResults());
// 6.1 提取候选结果(包含相似度分数)
List<RecommendationCandidate> candidates = new ArrayList<>();
FieldDataWrapper urlFieldWrapper = wrapper.getFieldWrapper(URL_FIELD);
FieldDataWrapper styleFieldWrapper = wrapper.getFieldWrapper(STYLE_FIELD);
List<SearchResultsWrapper.IDScore> idScores = wrapper.getIDScore(0);
for (int i = 0; i < idScores.size(); i++) {
String url = (String) urlFieldWrapper.getFieldData().get(i);
String style = (String) styleFieldWrapper.getFieldData().get(i);
Float score = idScores.get(i).getScore(); // L2距离越小越相似
if (StringUtils.isNotBlank(url)) {
candidates.add(new RecommendationCandidate(url, style, score));
}
}
// // 6.2 分类筛选(如果指定了 category
// if (StringUtils.isNotBlank(request.getCategory())) {
// candidates = filterByCategory(candidates, request.getCategory());
// }
// 6.3 多样性优化(避免推荐过于相似的结果)
if (request.getEnableDiversity() != null && request.getEnableDiversity()) {
candidates = diversifyResults(candidates, request.getTopK() != null ? request.getTopK() : 10);
}
// 6.4 取 Top-K
int topK = request.getTopK() != null ? request.getTopK() : 1;
List<String> recommendedUrls = candidates.stream()
.limit(topK)
.map(RecommendationCandidate::getUrl)
.collect(Collectors.toList());
log.info("为用户 {} 推荐了 {} 个结果(从 {} 个候选中筛选)",
request.getUserId(), recommendedUrls.size(), idScores.size());
return recommendedUrls;
} catch (Exception e) {
log.error("推荐失败", e);
return Collections.emptyList();
}
}
@Override
public void syncSysFileToMilvus() {
log.info("开始同步 t_sys_file 数据到 Milvus");
try {
// 1. 加载 Collection
LoadCollectionParam loadParam = LoadCollectionParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.build();
milvusClient.loadCollection(loadParam);
// 2. 查询所有系统 sketch
QueryWrapper<SysFile> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda()
.eq(SysFile::getLevel1Type, "Images").eq(SysFile::getDeprecated, 0)
.isNotNull(SysFile::getUrl)
.isNotNull(SysFile::getStyle)
.ne(SysFile::getUrl, "");
List<SysFile> sysFiles = sysFileMapper.selectList(queryWrapper);
log.info("找到 {} 个系统文件需要同步", sysFiles.size());
int successCount = 0;
int failCount = 0;
int batchSize = 100;
// 3. 批量处理
for (int i = 0; i < sysFiles.size(); i += batchSize) {
int end = Math.min(i + batchSize, sysFiles.size());
List<SysFile> batch = sysFiles.subList(i, end);
List<Long> ids = new ArrayList<>();
List<String> urls = new ArrayList<>();
List<String> styles = new ArrayList<>();
List<String> categories = new ArrayList<>();
List<Byte> deprecatedList = new ArrayList<>();
List<List<Float>> vectors = new ArrayList<>();
for (SysFile sysFile : batch) {
try {
// 提取特征向量
List<Float> vector = featureExtractionService.extractFeature(sysFile.getUrl());
if (vector == null || vector.size() != DIMENSION) {
log.warn("文件 {} 特征提取失败,跳过", sysFile.getUrl());
failCount++;
continue;
}
ids.add(sysFile.getId());
urls.add(sysFile.getUrl());
styles.add(sysFile.getStyle() != null ? sysFile.getStyle() : "");
categories.add(resolveCategory(sysFile));
deprecatedList.add((byte) (sysFile.getDeprecated() != null ? sysFile.getDeprecated() : 0));
vectors.add(vector);
successCount++;
} catch (Exception e) {
log.error("处理文件 {} 失败", sysFile.getUrl(), e);
failCount++;
}
}
// 批量插入
if (!ids.isEmpty()) {
List<InsertParam.Field> fields = Arrays.asList(
new InsertParam.Field(SYS_FILE_ID_FIELD, ids),
new InsertParam.Field(URL_FIELD, urls),
new InsertParam.Field(STYLE_FIELD, styles),
new InsertParam.Field(DEPRECATED_FIELD, deprecatedList),
new InsertParam.Field(CATEGORY_FIELD, categories),
new InsertParam.Field(VECTOR_FIELD, vectors)
);
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withFields(fields)
.build();
R<MutationResult> insertR = milvusClient.insert(insertParam);
if (insertR.getStatus() == R.Status.Success.getCode()) {
log.info("批量插入成功: {}/{}", end, sysFiles.size());
} else {
log.error("批量插入失败: {}", insertR.getMessage());
}
}
}
// 4. 刷新数据
FlushParam flushParam = FlushParam.newBuilder()
.withCollectionNames(Collections.singletonList(COLLECTION_NAME))
.build();
milvusClient.flush(flushParam);
log.info("同步完成: 成功 {}, 失败 {}", successCount, failCount);
} catch (Exception e) {
log.error("同步数据到 Milvus 失败", e);
throw new RuntimeException("同步失败", e);
}
}
@Override
public void updateVector(Long sysFileId, String url) {
// TODO: 实现单个向量更新逻辑
// 需要先删除旧向量,再插入新向量
log.info("更新向量: sysFileId={}, url={}", sysFileId, url);
}
/**
* 新用户推荐策略:当用户没有点赞记录时,根据 style 和 category 随机推荐
*
* @param request 推荐请求
* @return 推荐的 URL 列表
*/
private List<String> recommendForNewUser(RecommendRequestDTO request) {
try {
int topK = request.getTopK() != null ? request.getTopK() : 10;
List<String> milvusResults = recommendFromMilvusForNewUser(request, topK);
if (CollectionUtils.isEmpty(milvusResults)) {
return Collections.emptyList();
}
Collections.shuffle(milvusResults);
return milvusResults.stream().limit(topK).collect(Collectors.toList());
} catch (Exception e) {
log.error("新用户推荐失败", e);
return Collections.emptyList();
}
}
/**
* 从 Milvus 中为新用户推荐(根据 style 和 category 查询)
*/
private List<String> recommendFromMilvusForNewUser(RecommendRequestDTO request, int limit) {
try {
// 构建过滤表达式
StringBuilder exprBuilder = new StringBuilder();
List<String> conditions = new ArrayList<>();
// 基础筛选:只返回未淘汰的图片
if (request.getOnlyActive() != null && request.getOnlyActive()) {
conditions.add(DEPRECATED_FIELD + " == 0");
}
// 风格筛选(如果指定)
if (StringUtils.isNotBlank(request.getStyle())) {
conditions.add(STYLE_FIELD + " == \"" + request.getStyle() + "\"");
}
if (StringUtils.isNotBlank(request.getCategory())) {
conditions.add(CATEGORY_FIELD + " == \"" + request.getCategory().toLowerCase() + "\"");
}
if (conditions.isEmpty()) {
// 如果没有筛选条件,查询所有未淘汰的
conditions.add(DEPRECATED_FIELD + " == 0");
}
exprBuilder.append(String.join(" && ", conditions));
// 从 Milvus 查询符合条件的记录
QueryParam queryParam = QueryParam.newBuilder()
.withCollectionName(COLLECTION_NAME)
.withExpr(exprBuilder.toString())
.withOutFields(Arrays.asList(URL_FIELD))
.withLimit((long) limit)
.build();
R<QueryResults> queryR = milvusClient.query(queryParam);
if (queryR.getStatus() != R.Status.Success.getCode()) {
log.warn("从 Milvus 查询失败: {}", queryR.getMessage());
return Collections.emptyList();
}
// 解析结果
List<String> urls = new ArrayList<>();
QueryResultsWrapper wrapper = new QueryResultsWrapper(queryR.getData());
for (int i = 0; i < wrapper.getRowRecords().size(); i++) {
String url = (String) wrapper.getFieldWrapper(URL_FIELD).getFieldData().get(i);
if (StringUtils.isNotBlank(url)) {
urls.add(url);
}
}
// // 如果指定了 category进行筛选
// if (StringUtils.isNotBlank(request.getCategory()) && !urls.isEmpty()) {
// urls = urls.stream()
// .filter(url -> {
// // 从 URL 中提取类别信息
// // URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
// String[] parts = url.split("/");
// if (parts.length >= 4) {
// String sex = parts[parts.length - 3];
// String cat = parts[parts.length - 2];
// String expectedCategory = sex + "_" + cat;
// return expectedCategory.equals(request.getCategory());
// }
// return false;
// })
// .collect(Collectors.toList());
// }
log.info("从 Milvus 为新用户查询到 {} 条记录", urls.size());
return urls;
} catch (Exception e) {
log.error("从 Milvus 查询新用户推荐失败", e);
return Collections.emptyList();
}
}
/**
* 从 MySQL 中为新用户推荐(根据 style 和 category 查询)
*/
private List<String> recommendFromMySQLForNewUser(RecommendRequestDTO request, int topK) {
try {
QueryWrapper<SysFile> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda()
.eq(SysFile::getLevel1Type, "Images")
.isNotNull(SysFile::getUrl)
.ne(SysFile::getUrl, "");
// 风格筛选(如果指定)
if (StringUtils.isNotBlank(request.getStyle())) {
queryWrapper.lambda().eq(SysFile::getStyle, request.getStyle());
}
// 类别筛选(如果指定)
if (StringUtils.isNotBlank(request.getCategory())) {
// category 格式: female_skirt, male_tops 等
String[] parts = request.getCategory().split("_");
if (parts.length == 2) {
String sex = parts[0]; // female 或 male
String category = parts[1]; // skirt, tops 等
// 从 URL 中筛选类别
// URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
queryWrapper.lambda().like(SysFile::getUrl, "/" + sex + "/" + category + "/");
}
}
// 查询所有符合条件的记录
List<SysFile> sysFiles = sysFileMapper.selectList(queryWrapper);
if (CollectionUtils.isEmpty(sysFiles)) {
log.warn("MySQL 中未找到符合条件的记录");
return Collections.emptyList();
}
// 随机选择
Collections.shuffle(sysFiles);
List<String> urls = sysFiles.stream()
.map(SysFile::getUrl)
.filter(StringUtils::isNotBlank)
.limit(topK)
.collect(Collectors.toList());
log.info("从 MySQL 为新用户查询到 {} 条记录", urls.size());
return urls;
} catch (Exception e) {
log.error("从 MySQL 查询新用户推荐失败", e);
return Collections.emptyList();
}
}
/**
* 根据类别筛选候选结果
*/
private List<RecommendationCandidate> filterByCategory(
List<RecommendationCandidate> candidates, String category) {
// 从 URL 中提取类别信息
// URL 格式: aida-sys-image/images/{sex}/{category}/{filename}
return candidates.stream()
.filter(candidate -> {
String url = candidate.getUrl();
if (StringUtils.isBlank(url) || !url.contains("/")) {
return false;
}
// 提取类别部分
String[] parts = url.split("/");
if (parts.length < 4) {
return false;
}
String sex = parts[parts.length - 3]; // female 或 male
String cat = parts[parts.length - 2]; // category
String expectedCategory = sex + "_" + cat;
return expectedCategory.equals(category);
})
.collect(Collectors.toList());
}
/**
* 多样性优化:避免推荐过于相似的结果
* 使用聚类或相似度阈值来增加多样性
*/
private List<RecommendationCandidate> diversifyResults(
List<RecommendationCandidate> candidates, int targetSize) {
if (candidates.size() <= targetSize) {
return candidates;
}
// 简单策略:按相似度分数排序,然后每隔几个取一个,增加多样性
// 更复杂的策略可以使用聚类算法
List<RecommendationCandidate> diversified = new ArrayList<>();
int step = Math.max(1, candidates.size() / targetSize);
for (int i = 0; i < candidates.size() && diversified.size() < targetSize; i += step) {
diversified.add(candidates.get(i));
}
// 如果还不够,补充剩余的结果
if (diversified.size() < targetSize) {
for (RecommendationCandidate candidate : candidates) {
if (diversified.size() >= targetSize) {
break;
}
if (!diversified.contains(candidate)) {
diversified.add(candidate);
}
}
}
return diversified;
}
/**
* 推荐候选结果内部类
*/
private static class RecommendationCandidate {
private final String url;
private final String style;
private final Float similarityScore; // L2距离越小越相似
public RecommendationCandidate(String url, String style, Float similarityScore) {
this.url = url;
this.style = style;
this.similarityScore = similarityScore;
}
public String getUrl() {
return url;
}
public String getStyle() {
return style;
}
public Float getSimilarityScore() {
return similarityScore;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RecommendationCandidate that = (RecommendationCandidate) o;
return Objects.equals(url, that.url);
}
@Override
public int hashCode() {
return Objects.hash(url);
}
}
private String resolveCategory(SysFile sysFile) {
String level3 = sysFile.getLevel3Type();
String level2 = sysFile.getLevel2Type();
if (StringUtils.isNotBlank(level3) && StringUtils.isNotBlank(level2)) {
return (level3 + "_" + level2).toLowerCase();
}
String url = sysFile.getUrl();
if (StringUtils.isNotBlank(url)) {
String[] parts = url.split("/");
if (parts.length >= 4) {
String sex = parts[parts.length - 3];
String category = parts[parts.length - 2];
return (sex + "_" + category).toLowerCase();
}
}
return "";
}
}

View File

@@ -0,0 +1,83 @@
package com.ai.da.service.impl;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import lombok.extern.slf4j.Slf4j;
/**
* ResNet 特征提取 Translator
* 移除最后的全连接层只保留特征提取部分2048维
*/
@Slf4j
public class ResNetFeatureTranslator implements Translator<Image, float[]> {
private static final int IMAGE_SIZE = 224;
private static final Pipeline PIPELINE = new Pipeline()
.add(new Resize(IMAGE_SIZE, IMAGE_SIZE))
.add(new ToTensor())
.add(new Normalize(
new float[]{0.485f, 0.456f, 0.406f},
new float[]{0.229f, 0.224f, 0.225f}
));
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
NDArray array = input.toNDArray(manager);
NDList ndList = new NDList(array);
return PIPELINE.transform(ndList);
}
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.singletonOrThrow();
// ResNet50 特征层输出形状应该是:
// - [1, 2048, 1, 1] - 全局平均池化后的特征(需要 flatten
// - [1, 2048] - 已 flatten 的特征层输出
long[] shape = output.getShape().getShape();
// 如果是 4 维,说明是 [1, 2048, 1, 1],需要 flatten
if (shape.length == 4) {
output = output.flatten();
}
// 如果是 2 维,检查维度是否正确
else if (shape.length == 2) {
if (shape[1] == 1000) {
// 如果仍然是分类层输出1000维说明模型加载有问题
log.error("模型输出仍然是分类层1000维模型加载可能失败");
throw new IllegalStateException("模型输出维度错误:期望 2048 维,实际 1000 维。请检查模型是否正确移除了全连接层。");
}
// 如果已经是 [1, 2048],直接使用
}
// 转换为 float 数组
float[] result = output.toFloatArray();
// 验证维度必须是 2048
if (result.length != 2048) {
log.error("特征向量维度错误: 期望 2048, 实际 {}", result.length);
throw new IllegalArgumentException(
String.format("特征向量维度错误: 期望 2048, 实际 %d", result.length)
);
}
return result;
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}