From 123fecbf72ee25a64762232ef58fe518cfd5b74f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 24 Apr 2024 13:25:17 +0800 Subject: [PATCH] =?UTF-8?q?feat=20generate=20sketch=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=20=E6=9C=8D=E8=A3=85=E7=B1=BB=E5=88=AB=E8=AF=86=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/schemas/generate_image.py | 1 + app/service/generate_image/service.py | 5 +++-- app/service/generate_image/utils/image_processing.py | 11 +++++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 1e7617c..b8f5441 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -7,3 +7,4 @@ class GenerateImageModel(BaseModel): image_url: str mode: str category: str + gender: str diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index 04c3734..0292223 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -52,6 +52,7 @@ class GenerateImage: self.batch_size = 1 self.category = request_data.category self.index = 0 + self.gender = request_data.gender self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.redis_client.expire(self.tasks_id, 600) @@ -84,11 +85,11 @@ class GenerateImage: # 污点检测 is_smudge, not_smudge_image = stain_detection(remove_bg_image) # 类型识别 - category, scores, not_smudge_image = generate_category_recognition(image_result) + category, scores, not_smudge_image = generate_category_recognition(image=image_result, gender=self.gender) self.generate_data['category'] = str(category) image_result = not_smudge_image if is_smudge: # 无污点 - image_result = adjust_contrast(image_result) + # image_result = adjust_contrast(image_result) image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") self.generate_data['status'] = "SUCCESS" diff --git a/app/service/generate_image/utils/image_processing.py b/app/service/generate_image/utils/image_processing.py index 808d9b5..f76c26c 100644 --- a/app/service/generate_image/utils/image_processing.py +++ b/app/service/generate_image/utils/image_processing.py @@ -162,7 +162,7 @@ def stain_detection(image, spot_size=200): return True, image -def generate_category_recognition(image): +def generate_category_recognition(image, gender): def preprocess(img): img = mmcv.imread(img) # ori_shape = img.shape[:2] @@ -196,5 +196,12 @@ def generate_category_recognition(image): maxsc = np.max(scores[0][:5]) indexs = np.argwhere(scores == maxsc)[:, 1] category = colattr[indexs[0]] - return category, scores, image + if gender == "Male": + if category == 'Trousers' or category == 'Skirt': + category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + category = 'Tops' + else: + category = 'Outwear' + return category, scores, image