feat generate sketch 新增 服装类别识别
This commit is contained in:
@@ -23,7 +23,7 @@ from tritonclient.utils import np_to_triton_dtype
|
|||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateImageModel
|
from app.schemas.generate_image import GenerateImageModel
|
||||||
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
|
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
|
||||||
from app.service.generate_image.utils.image_processing import remove_background, stain_detection
|
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -52,7 +52,7 @@ class GenerateImage:
|
|||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.category = request_data.category
|
self.category = request_data.category
|
||||||
self.index = 0
|
self.index = 0
|
||||||
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''}
|
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': {"image_url": "", "category": ""}}
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
self.redis_client.expire(self.tasks_id, 600)
|
self.redis_client.expire(self.tasks_id, 600)
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ class GenerateImage:
|
|||||||
if error:
|
if error:
|
||||||
self.generate_data['status'] = "FAILURE"
|
self.generate_data['status'] = "FAILURE"
|
||||||
self.generate_data['message'] = str(error)
|
self.generate_data['message'] = str(error)
|
||||||
self.generate_data['data'] = str(error)
|
# self.generate_data['data'] = str(error)
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
else:
|
else:
|
||||||
image_result = result.as_numpy("generated_image")[0]
|
image_result = result.as_numpy("generated_image")[0]
|
||||||
@@ -83,6 +83,9 @@ class GenerateImage:
|
|||||||
remove_bg_image = remove_background(np.asarray(image_result))
|
remove_bg_image = remove_background(np.asarray(image_result))
|
||||||
# 污点检测
|
# 污点检测
|
||||||
is_smudge, not_smudge_image = stain_detection(remove_bg_image)
|
is_smudge, not_smudge_image = stain_detection(remove_bg_image)
|
||||||
|
# 类型识别
|
||||||
|
category, scores, not_smudge_image = generate_category_recognition(image_result)
|
||||||
|
self.generate_data['data']['category'] = str(category)
|
||||||
image_result = not_smudge_image
|
image_result = not_smudge_image
|
||||||
if is_smudge: # 无污点
|
if is_smudge: # 无污点
|
||||||
image_result = adjust_contrast(image_result)
|
image_result = adjust_contrast(image_result)
|
||||||
@@ -90,12 +93,12 @@ class GenerateImage:
|
|||||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||||
self.generate_data['status'] = "SUCCESS"
|
self.generate_data['status'] = "SUCCESS"
|
||||||
self.generate_data['message'] = "success"
|
self.generate_data['message'] = "success"
|
||||||
self.generate_data['data'] = str(image_url)
|
self.generate_data['data']['image_url'] = str(image_url)
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
else: # 有污点
|
else: # 有污点
|
||||||
self.generate_data['status'] = "SUCCESS"
|
self.generate_data['status'] = "SUCCESS"
|
||||||
self.generate_data['message'] = "success"
|
self.generate_data['message'] = "success"
|
||||||
self.generate_data['data'] = str(GI_SYS_IMAGE_URL)
|
self.generate_data['data']['image_url'] = str(GI_SYS_IMAGE_URL)
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
# logger.info(f"stain_detection result : {self.generate_data}")
|
# logger.info(f"stain_detection result : {self.generate_data}")
|
||||||
|
|
||||||
@@ -146,14 +149,14 @@ class GenerateImage:
|
|||||||
return generate_data
|
return generate_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.generate_data['status'] = "FAILURE"
|
self.generate_data['status'] = "FAILURE"
|
||||||
self.generate_data['message'] = "failure"
|
self.generate_data['message'] = str(e)
|
||||||
self.generate_data['data'] = str(e)
|
|
||||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
raise Exception(str(e))
|
raise Exception(str(e))
|
||||||
finally:
|
finally:
|
||||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||||
if DEBUG is False:
|
if DEBUG is False:
|
||||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||||
|
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -160,3 +160,41 @@ def stain_detection(image, spot_size=200):
|
|||||||
cv2.rectangle(image, corner_coords, (corner_coords[0] + spot_size, corner_coords[1] + spot_size), (0, 0, 255), 2)
|
cv2.rectangle(image, corner_coords, (corner_coords[0] + spot_size, corner_coords[1] + spot_size), (0, 0, 255), 2)
|
||||||
|
|
||||||
return True, image
|
return True, image
|
||||||
|
|
||||||
|
|
||||||
|
def generate_category_recognition(image):
|
||||||
|
def preprocess(img):
|
||||||
|
img = mmcv.imread(img)
|
||||||
|
# ori_shape = img.shape[:2]
|
||||||
|
img_scale = (224, 224)
|
||||||
|
scale_factor = []
|
||||||
|
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
|
||||||
|
scale_factor.append(x)
|
||||||
|
scale_factor.append(y)
|
||||||
|
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||||
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||||
|
return preprocessed_img
|
||||||
|
|
||||||
|
preprocessed_img = preprocess(image)
|
||||||
|
triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
httpclient.InferInput("input__0", preprocessed_img.shape, datatype="FP32")
|
||||||
|
]
|
||||||
|
inputs[0].set_data_from_numpy(preprocessed_img, binary_data=True)
|
||||||
|
results = triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
|
||||||
|
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
|
||||||
|
|
||||||
|
scores = inference_output.detach().numpy()
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
attr_type = pd.read_csv(CATEGORY_PATH)
|
||||||
|
colattr = list(attr_type['labelName'])
|
||||||
|
|
||||||
|
task = attr_type['taskName'][0]
|
||||||
|
|
||||||
|
maxsc = np.max(scores[0][:5])
|
||||||
|
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||||
|
category = colattr[indexs[0]]
|
||||||
|
return category, scores, image
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user