diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index dacb92f..831c2ce 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -23,7 +23,7 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.adjust_contrast import adjust_contrast -from app.service.generate_image.utils.image_processing import remove_background, stain_detection +from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition from app.service.generate_image.utils.upload_sd_image import upload_png_sd logger = logging.getLogger() @@ -52,7 +52,7 @@ class GenerateImage: self.batch_size = 1 self.category = request_data.category self.index = 0 - self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''} + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': {"image_url": "", "category": ""}} self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.redis_client.expire(self.tasks_id, 600) @@ -73,7 +73,7 @@ class GenerateImage: if error: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(error) - self.generate_data['data'] = str(error) + # self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: image_result = result.as_numpy("generated_image")[0] @@ -83,6 +83,9 @@ class GenerateImage: remove_bg_image = remove_background(np.asarray(image_result)) # 污点检测 is_smudge, not_smudge_image = stain_detection(remove_bg_image) + # 类型识别 + category, scores, not_smudge_image = generate_category_recognition(image_result) + self.generate_data['data']['category'] = str(category) image_result = not_smudge_image if is_smudge: # 无污点 image_result = adjust_contrast(image_result) @@ -90,12 +93,12 @@ class GenerateImage: # logger.info(f"upload image SUCCESS : {image_url}") self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" - self.generate_data['data'] = str(image_url) + self.generate_data['data']['image_url'] = str(image_url) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: # 有污点 self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" - self.generate_data['data'] = str(GI_SYS_IMAGE_URL) + self.generate_data['data']['image_url'] = str(GI_SYS_IMAGE_URL) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) # logger.info(f"stain_detection result : {self.generate_data}") @@ -146,14 +149,14 @@ class GenerateImage: return generate_data except Exception as e: self.generate_data['status'] = "FAILURE" - self.generate_data['message'] = "failure" - self.generate_data['data'] = str(e) + self.generate_data['message'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) raise Exception(str(e)) finally: dict_generate_data, str_generate_data = self.read_tasks_status() if DEBUG is False: self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") diff --git a/app/service/generate_image/utils/image_processing.py b/app/service/generate_image/utils/image_processing.py index f15ae9a..808d9b5 100644 --- a/app/service/generate_image/utils/image_processing.py +++ b/app/service/generate_image/utils/image_processing.py @@ -160,3 +160,41 @@ def stain_detection(image, spot_size=200): cv2.rectangle(image, corner_coords, (corner_coords[0] + spot_size, corner_coords[1] + spot_size), (0, 0, 255), 2) return True, image + + +def generate_category_recognition(image): + def preprocess(img): + img = mmcv.imread(img) + # ori_shape = img.shape[:2] + img_scale = (224, 224) + scale_factor = [] + img, x, y = mmcv.imresize(img, img_scale, return_scale=True) + scale_factor.append(x) + scale_factor.append(y) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img + + preprocessed_img = preprocess(image) + triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) + + inputs = [ + httpclient.InferInput("input__0", preprocessed_img.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(preprocessed_img, binary_data=True) + results = triton_client.infer(model_name="attr_retrieve_category", inputs=inputs) + inference_output = torch.from_numpy(results.as_numpy(f'output__0')) + + scores = inference_output.detach().numpy() + import pandas as pd + + attr_type = pd.read_csv(CATEGORY_PATH) + colattr = list(attr_type['labelName']) + + task = attr_type['taskName'][0] + + maxsc = np.max(scores[0][:5]) + indexs = np.argwhere(scores == maxsc)[:, 1] + category = colattr[indexs[0]] + return category, scores, image +