feat
fix 超分异常
This commit is contained in:
@@ -85,39 +85,6 @@ class Painting(object):
|
|||||||
pattern[0, 0, 2] = int(R)
|
pattern[0, 0, 2] = int(R)
|
||||||
return pattern
|
return pattern
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gradient(image, angle_degrees, start_color, end_color):
|
|
||||||
height, width = image.shape[0], image.shape[1]
|
|
||||||
|
|
||||||
# 创建一个空白的图像
|
|
||||||
gradient_image = np.zeros((height, width, 3), dtype=np.uint8)
|
|
||||||
|
|
||||||
# 将角度限制在 0 到 360 度之间
|
|
||||||
angle_degrees = np.clip(angle_degrees, 0, 360)
|
|
||||||
|
|
||||||
# 将角度转换为弧度
|
|
||||||
angle_radians = np.radians(angle_degrees)
|
|
||||||
|
|
||||||
# 计算渐变的方向
|
|
||||||
dx = np.cos(angle_radians)
|
|
||||||
dy = np.sin(angle_radians)
|
|
||||||
|
|
||||||
# 创建网格
|
|
||||||
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
|
|
||||||
|
|
||||||
# 计算每个像素在渐变方向上的位置
|
|
||||||
distance_along_gradient = (x_grid * dx + y_grid * dy) / np.sqrt(dx ** 2 + dy ** 2)
|
|
||||||
|
|
||||||
# 计算渐变的权重
|
|
||||||
weight = np.clip(distance_along_gradient / max(width, height), 0, 1)
|
|
||||||
|
|
||||||
# 计算渐变的颜色
|
|
||||||
gradient_image[:, :, 0] = (1 - weight) * start_color[0] + weight * end_color[0]
|
|
||||||
gradient_image[:, :, 1] = (1 - weight) * start_color[1] + weight * end_color[1]
|
|
||||||
gradient_image[:, :, 2] = (1 - weight) * start_color[2] + weight * end_color[2]
|
|
||||||
|
|
||||||
return gradient_image
|
|
||||||
|
|
||||||
|
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class PrintPainting(object):
|
class PrintPainting(object):
|
||||||
@@ -147,8 +114,8 @@ class PrintPainting(object):
|
|||||||
resized_source = image.resize(new_size)
|
resized_source = image.resize(new_size)
|
||||||
resized_source_mask = mask.resize(new_size)
|
resized_source_mask = mask.resize(new_size)
|
||||||
|
|
||||||
rotated_resized_source = resized_source.rotate(result['print']['print_angle_list'][i])
|
rotated_resized_source = resized_source.rotate(-result['print']['print_angle_list'][i])
|
||||||
rotated_resized_source_mask = resized_source_mask.rotate(result['print']['print_angle_list'][i])
|
rotated_resized_source_mask = resized_source_mask.rotate(-result['print']['print_angle_list'][i])
|
||||||
|
|
||||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||||
@@ -268,8 +235,8 @@ class PrintPainting(object):
|
|||||||
resized_source = image.resize(new_size)
|
resized_source = image.resize(new_size)
|
||||||
resized_source_mask = mask.resize(new_size)
|
resized_source_mask = mask.resize(new_size)
|
||||||
|
|
||||||
rotated_resized_source = resized_source.rotate(result['element']['element_angle_list'][i])
|
rotated_resized_source = resized_source.rotate(-result['element']['element_angle_list'][i])
|
||||||
rotated_resized_source_mask = resized_source_mask.rotate(result['element']['element_angle_list'][i])
|
rotated_resized_source_mask = resized_source_mask.rotate(-result['element']['element_angle_list'][i])
|
||||||
|
|
||||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||||
|
|||||||
@@ -58,14 +58,6 @@ class SuperResolution:
|
|||||||
logging.info(f"{self.tasks_id} ===> {status_data}")
|
logging.info(f"{self.tasks_id} ===> {status_data}")
|
||||||
return status_data
|
return status_data
|
||||||
|
|
||||||
# @RunTime
|
|
||||||
def infer(self, inputs):
|
|
||||||
return self.triton_client.async_infer(
|
|
||||||
model_name=SR_MODEL_NAME,
|
|
||||||
inputs=inputs,
|
|
||||||
callback=self.callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# @RunTime
|
# @RunTime
|
||||||
def sr_result(self):
|
def sr_result(self):
|
||||||
sample = self.read_image()
|
sample = self.read_image()
|
||||||
@@ -82,13 +74,16 @@ class SuperResolution:
|
|||||||
# , binary_data=True
|
# , binary_data=True
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx = self.infer(inputs)
|
ctx = self.triton_client.async_infer(
|
||||||
|
model_name=SR_MODEL_NAME,
|
||||||
|
inputs=inputs,
|
||||||
|
callback=self.callback
|
||||||
|
)
|
||||||
time_out = 60
|
time_out = 60
|
||||||
while time_out > 0:
|
while time_out > 0:
|
||||||
generate_data = self.read_tasks_status()
|
generate_data = self.read_tasks_status()
|
||||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||||
ctx.cancel()
|
ctx.cancel()
|
||||||
# noinspection PyTypeChecker
|
|
||||||
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data))
|
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data))
|
||||||
logger.info(f" [x] Sent {generate_data}")
|
logger.info(f" [x] Sent {generate_data}")
|
||||||
break
|
break
|
||||||
@@ -98,16 +93,6 @@ class SuperResolution:
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return self.read_tasks_status()
|
return self.read_tasks_status()
|
||||||
|
|
||||||
# results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
|
|
||||||
|
|
||||||
# sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
|
||||||
# output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
# if output.ndim == 3:
|
|
||||||
# output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
|
||||||
# output = (output * 255.0).round().astype(np.uint8)
|
|
||||||
# output_url = self.upload_img_sr(output, generate_uuid())
|
|
||||||
# return output_url
|
|
||||||
|
|
||||||
def upload_img_sr(self, image):
|
def upload_img_sr(self, image):
|
||||||
try:
|
try:
|
||||||
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||||
@@ -121,7 +106,6 @@ class SuperResolution:
|
|||||||
|
|
||||||
def callback(self, result, error):
|
def callback(self, result, error):
|
||||||
if error:
|
if error:
|
||||||
print(error)
|
|
||||||
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
|
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
|
||||||
self.redis_client.set(self.tasks_id, sr_info_data)
|
self.redis_client.set(self.tasks_id, sr_info_data)
|
||||||
else:
|
else:
|
||||||
@@ -147,6 +131,6 @@ def infer_cancel(tasks_id):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="12341556")
|
request_data = SuperResolutionModel(sr_image_url="aida-users/83/print/b77bf4ca-6ca2-44a1-9040-505f359a974c-3-83.png", sr_xn=2, sr_tasks_id="12341556")
|
||||||
service = SuperResolution(request_data)
|
service = SuperResolution(request_data)
|
||||||
result_url = service.sr_result()
|
result_url = service.sr_result()
|
||||||
|
|||||||
Reference in New Issue
Block a user