feat generate sketch 新增 服装类别识别
This commit is contained in:
@@ -7,3 +7,4 @@ class GenerateImageModel(BaseModel):
|
|||||||
image_url: str
|
image_url: str
|
||||||
mode: str
|
mode: str
|
||||||
category: str
|
category: str
|
||||||
|
gender: str
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class GenerateImage:
|
|||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.category = request_data.category
|
self.category = request_data.category
|
||||||
self.index = 0
|
self.index = 0
|
||||||
|
self.gender = request_data.gender
|
||||||
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''}
|
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''}
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
self.redis_client.expire(self.tasks_id, 600)
|
self.redis_client.expire(self.tasks_id, 600)
|
||||||
@@ -84,11 +85,11 @@ class GenerateImage:
|
|||||||
# 污点检测
|
# 污点检测
|
||||||
is_smudge, not_smudge_image = stain_detection(remove_bg_image)
|
is_smudge, not_smudge_image = stain_detection(remove_bg_image)
|
||||||
# 类型识别
|
# 类型识别
|
||||||
category, scores, not_smudge_image = generate_category_recognition(image_result)
|
category, scores, not_smudge_image = generate_category_recognition(image=image_result, gender=self.gender)
|
||||||
self.generate_data['category'] = str(category)
|
self.generate_data['category'] = str(category)
|
||||||
image_result = not_smudge_image
|
image_result = not_smudge_image
|
||||||
if is_smudge: # 无污点
|
if is_smudge: # 无污点
|
||||||
image_result = adjust_contrast(image_result)
|
# image_result = adjust_contrast(image_result)
|
||||||
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
|
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
|
||||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||||
self.generate_data['status'] = "SUCCESS"
|
self.generate_data['status'] = "SUCCESS"
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ def stain_detection(image, spot_size=200):
|
|||||||
return True, image
|
return True, image
|
||||||
|
|
||||||
|
|
||||||
def generate_category_recognition(image):
|
def generate_category_recognition(image, gender):
|
||||||
def preprocess(img):
|
def preprocess(img):
|
||||||
img = mmcv.imread(img)
|
img = mmcv.imread(img)
|
||||||
# ori_shape = img.shape[:2]
|
# ori_shape = img.shape[:2]
|
||||||
@@ -196,5 +196,12 @@ def generate_category_recognition(image):
|
|||||||
maxsc = np.max(scores[0][:5])
|
maxsc = np.max(scores[0][:5])
|
||||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||||
category = colattr[indexs[0]]
|
category = colattr[indexs[0]]
|
||||||
return category, scores, image
|
|
||||||
|
|
||||||
|
if gender == "Male":
|
||||||
|
if category == 'Trousers' or category == 'Skirt':
|
||||||
|
category = 'Bottoms'
|
||||||
|
elif category == 'Blouse' or category == 'Dress':
|
||||||
|
category = 'Tops'
|
||||||
|
else:
|
||||||
|
category = 'Outwear'
|
||||||
|
return category, scores, image
|
||||||
|
|||||||
Reference in New Issue
Block a user