From 13e3f8ac3d1bd67f4352e85dea5b30b741aeb203 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 8 Nov 2024 14:05:09 +0800 Subject: [PATCH 01/52] =?UTF-8?q?feat=20=20=20=20design=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=BF=81=E7=A7=BB4090=E6=B5=8B=E8=AF=95=20fi?= =?UTF-8?q?x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 6 +----- app/service/attribute/service_att_recognition.py | 2 +- app/service/attribute/service_category_recognition.py | 2 +- app/service/generate_image/utils/image_processing.py | 4 ++-- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 35c12b7..5909a3a 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -93,9 +93,6 @@ OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613", "gpt-4-0613", "gpt-4-32k-0613", } -# attribute service config -ATT_TRITON_URL = "10.1.1.240:10000" - # SR service config SR_MODEL_NAME = "super_resolution" SR_TRITON_URL = "10.1.1.240:10031" @@ -132,7 +129,6 @@ GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_URL = '10.1.1.240:10051' # SEG service config -SEG_MODEL_URL = '10.1.1.240:10000' SEGMENTATION = { "new_model_name": "seg_knet", "name": "seg_ocrnet_hr18", @@ -141,7 +137,7 @@ SEGMENTATION = { } # DESIGN config -DESIGN_MODEL_URL = '10.1.1.240:10000' +DESIGN_MODEL_URL = '10.1.1.243:10000' AIDA_CLOTHING = "aida-clothing" KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') diff --git a/app/service/attribute/service_att_recognition.py b/app/service/attribute/service_att_recognition.py index 1251891..f93146e 100644 --- a/app/service/attribute/service_att_recognition.py +++ b/app/service/attribute/service_att_recognition.py @@ -28,7 +28,7 @@ class AttributeRecognition: } ) self.const = const - self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}") + self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") def get_result(self): for sketch in self.request_data: diff --git a/app/service/attribute/service_category_recognition.py b/app/service/attribute/service_category_recognition.py index f917af2..7c277c9 100644 --- a/app/service/attribute/service_category_recognition.py +++ b/app/service/attribute/service_category_recognition.py @@ -26,7 +26,7 @@ class CategoryRecognition: self.attr_type = pd.read_csv(CATEGORY_PATH) # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.request_data = [] - self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) + self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) for sketch in request_data: self.request_data.append( { diff --git a/app/service/generate_image/utils/image_processing.py b/app/service/generate_image/utils/image_processing.py index af36188..02d8bee 100644 --- a/app/service/generate_image/utils/image_processing.py +++ b/app/service/generate_image/utils/image_processing.py @@ -81,7 +81,7 @@ def get_contours(image): def seg_infer_image(image_obj): image, ori_shape = seg_preprocess(image_obj) - client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}") + client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}") transformed_img = image.astype(np.float32) # 输入集 inputs = [ @@ -250,7 +250,7 @@ def generate_category_recognition(image, gender): return preprocessed_img preprocessed_img = preprocess(image) - triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) + triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) inputs = [ httpclient.InferInput("input__0", preprocessed_img.shape, datatype="FP32") From 6e621038f6faa2678500bf37cce68a92679c03c7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 19:52:33 +0800 Subject: [PATCH 02/52] =?UTF-8?q?feat=20=20flux=20=E5=8F=96=E6=B6=88?= =?UTF-8?q?=E6=B1=A1=E7=82=B9=E6=A3=80=E6=B5=8B=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E7=B1=BB=E5=88=AB=E5=88=A4=E6=96=AD=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_image/service_generate_image.py | 61 ++++++++++++------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index dac211c..86912f8 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -35,7 +35,12 @@ class GenerateImage: # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.version = request_data.version + if request_data.version == "fast": + self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL) + else: + self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": # cv2 读图片是BGR PIL读图片是RGB @@ -87,23 +92,28 @@ class GenerateImage: image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) is_smudge = True if self.category == "sketch": - # 色阶调整 - cutoff = 1 - levels_img = autoLevels(image_result, cutoff) - # 亮度调整 - luminance = luminance_adjust(0.3, levels_img) - # 去背景 - remove_bg_image = remove_background(luminance) - # 人脸检测 - # if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: - # is_smudge = False - # else: - # 污点/ - is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) - # 类型识别 - category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) - self.generate_data['category'] = str(category) - image_result = not_smudge_image + if self.version == "fast": + # 色阶调整 + cutoff = 1 + levels_img = autoLevels(image_result, cutoff) + # 亮度调整 + luminance = luminance_adjust(0.3, levels_img) + # 去背景 + remove_bg_image = remove_background(luminance) + # 人脸检测 + # if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: + # is_smudge = False + # else: + # 污点/ + is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) + # 类型识别 + category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) + self.generate_data['category'] = str(category) + image_result = not_smudge_image + else: + category, scores, not_smudge_image = generate_category_recognition(image=image_result, gender=self.gender) + self.generate_data['category'] = str(category) + image_result = not_smudge_image if is_smudge: # 无污点 # image_result = adjust_contrast(image_result) image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -134,15 +144,19 @@ class GenerateImage: image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") - input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) + input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] - ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) + if self.version == "fast": + ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback) + else: + ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) + time_out = 600 generate_data = None while time_out > 0: @@ -181,11 +195,12 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", - prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', + prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background', image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', category="test", - gender="male" + gender="male", + version="high" ) server = GenerateImage(rd) print(server.get_result()) From 926593826a32e37fb85d9c7699e0b5ffb4b35c15 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 18:51:53 +0800 Subject: [PATCH 03/52] =?UTF-8?q?feat=20=20product=20=E5=90=8E=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=9B=BE=E7=89=87size=E6=94=B9=E4=B8=BA320*700=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 69 ++++++++++--------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 5ea6f83..1c20a13 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -15,7 +15,7 @@ import cv2 import numpy as np import redis import tritonclient.grpc as grpcclient -from PIL import Image, ImageOps +from PIL import Image from tritonclient.utils import np_to_triton_dtype from app.core.config import * @@ -41,7 +41,7 @@ class GenerateProductImage: self.batch_size = 1 self.product_type = request_data.product_type self.prompt = request_data.prompt - self.image, self.image_size = pre_processing_image(request_data.image_url) + self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} @@ -55,12 +55,10 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - if self.product_type == "single": - image = result.as_numpy("generated_cnet_image") - else: - image = result.as_numpy("generated_inpaint_image") + image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) - image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + cropped_image = post_processing_image(image_result, self.left, self.top) + image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -74,16 +72,16 @@ class GenerateProductImage: try: prompts = [self.prompt] * self.batch_size self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) - self.image = cv2.resize(self.image, (512, 768)) + self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size if self.product_type == "single": text_obj = np.array(prompts, dtype="object").reshape(-1, 1) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) else: - text_obj = np.array(prompts, dtype="object").reshape(1) - image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) # 假设 prompts、images 和 self.image_strength 已经定义 @@ -94,11 +92,12 @@ class GenerateProductImage: input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) - inputs = [input_text, input_image, input_image_strength] input_image_strength.set_data_from_numpy(image_strength_obj) + inputs = [input_text, input_image, input_image_strength] + if self.product_type == "single": - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) else: ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) @@ -136,22 +135,13 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + # resize 原图至1024*1024 + image = image.resize((int(1024 / image.height * image.width), 1024)) + # 原始图片的尺寸 width, height = image.size - # 计算长宽比为 3:2 的新尺寸 - desired_ratio = 2 / 3 - current_ratio = width / height - - if current_ratio > desired_ratio: - # 原始图片更宽,需要在上下添加 padding - new_width = width - new_height = int(width / desired_ratio) - else: - # 原始图片更高或者长宽比已经为 3:2 - new_height = height - new_width = int(height * desired_ratio) - + new_height, new_width = 1024, 1024 # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) @@ -160,9 +150,9 @@ def pre_processing_image(image_url): top = (new_height - height) // 2 pad_image.paste(image, (left, top)) - # 将画布 resize 成宽度 500,长度 750 - resized_image = pad_image.resize((500, 750)) - image_size = (512, 768) + # 将画布 resize 成宽度 1024,长度 1024 + resized_image = pad_image.resize((1024, 1024)) + image_size = (1024, 1024) if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): # 创建白色背景 @@ -171,16 +161,29 @@ def pre_processing_image(image_url): background.paste(resized_image, mask=resized_image.split()[3]) image = np.array(background) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image, image_size + return image, image_size, left, top + + +def post_processing_image(image, left, top): + resized_image = image.resize((int(image.width * (768 / image.height)), 768)) + # 计算裁剪的坐标 + left = (resized_image.width - 512) // 2 + upper = 0 + right = left + 512 + lower = 768 + + # 进行裁剪 + cropped_image = resized_image.crop((left, upper, right, lower)) + return cropped_image if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", # prompt="", - image_strength=0.9, - prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", - image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", + image_strength=0.7, + prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", + image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", product_type="overall" ) server = GenerateProductImage(rd) From 58bf9341def24c4622d6acb2ab99de8785366339 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 17:54:42 +0800 Subject: [PATCH 04/52] =?UTF-8?q?feat=20=20design=20triton=20=E6=9B=B4?= =?UTF-8?q?=E6=8D=A2=E4=B8=BAA6000=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index 5909a3a..0bbe79e 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -137,7 +137,7 @@ SEGMENTATION = { } # DESIGN config -DESIGN_MODEL_URL = '10.1.1.243:10000' +DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing" KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') From 135828b40607d5bb8ebdb81abe102da8f64aa88e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 14:24:48 +0800 Subject: [PATCH 05/52] =?UTF-8?q?feat=20generate=20img=20schemas=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9Eversion=E5=AD=97=E6=AE=B5=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/schemas/generate_image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 3dd7cf8..11e295f 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -8,6 +8,7 @@ class GenerateImageModel(BaseModel): mode: str category: str gender: str + version: str class GenerateSingleLogoImageModel(BaseModel): From 8f0e5919c5f199ad93e2e0b7d02491e8e52191c7 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 20:44:45 +0800 Subject: [PATCH 06/52] =?UTF-8?q?feat=20=20generate=20img=20api=20?= =?UTF-8?q?=E6=B3=A8=E9=87=8A=E4=BF=AE=E6=94=B9=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_generate_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 3dee667..b021158 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -34,7 +34,8 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", "mode": "img2img", "category": "sketch", - "gender": "male" + "gender": "male", + "version": "fast" } """ try: From 3a59010d947d8560ebfafb7c21f93ce7ef1cc244 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 17:54:42 +0800 Subject: [PATCH 07/52] feat fix --- app/core/config.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 0bbe79e..dd56258 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -100,8 +100,12 @@ SR_MINIO_BUCKET = "aida-users" SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ENV}") # GenerateImage service config -GI_MODEL_NAME = 'stable_diffusion_xl' -GI_MODEL_URL = '10.1.1.240:10041' +FAST_GI_MODEL_URL = '10.1.1.243:10011' +FAST_GI_MODEL_NAME = 'stable_diffusion_xl' + +GI_MODEL_URL = '10.1.1.240:10061' +GI_MODEL_NAME = 'flux' + GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" @@ -110,17 +114,15 @@ GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") # Generate Single Logo service config -GSL_MODEL_URL = '10.1.1.240:10041' +GSL_MODEL_URL = '10.1.1.243:10041' GSL_MINIO_BUCKET = "aida-users" GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") # Generate Product service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") -GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' -GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' - -GPI_MODEL_URL = '10.1.1.240:10041' +GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all' +GPI_MODEL_URL = '10.1.1.243:10051' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") @@ -135,9 +137,10 @@ SEGMENTATION = { "input": "seg_input__0", "output": "seg_output__0", } - +# ollama config +OLLAMA_URL = "http://10.1.1.243:11434/api/embeddings" # DESIGN config -DESIGN_MODEL_URL = '10.1.1.240:10000' +DESIGN_MODEL_URL = '10.1.1.243:10000' AIDA_CLOTHING = "aida-clothing" KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') From dbea3bc975241e46cf70739a2cb5402d3be4e81b Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Sun, 1 Dec 2024 20:51:44 +0800 Subject: [PATCH 08/52] feat fix --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index dd56258..d369ff2 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -140,7 +140,7 @@ SEGMENTATION = { # ollama config OLLAMA_URL = "http://10.1.1.243:11434/api/embeddings" # DESIGN config -DESIGN_MODEL_URL = '10.1.1.243:10000' +DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing" KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right') From bf8b3b417bb874c6fb4a4facae53a5fb391a203e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 20:31:46 +0800 Subject: [PATCH 09/52] =?UTF-8?q?feat=20=20translator=20=E5=88=87=E6=8D=A2?= =?UTF-8?q?ollama=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_prompt_generation.py | 4 +- .../chatgpt_for_translation.py | 94 ++++++++++++++----- 2 files changed, 75 insertions(+), 23 deletions(-) diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index 59e5779..11733e8 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, HTTPException from app.schemas.prompt_generation import PromptGenerationImageModel from app.schemas.response_template import ResponseModel -from app.service.prompt_generation.chatgpt_for_translation import translate_to_en +from app.service.prompt_generation.chatgpt_for_translation import translate_to_en, get_translation_from_llama3 router = APIRouter() logger = logging.getLogger() @@ -26,7 +26,7 @@ def prompt_generation(request_data: PromptGenerationImageModel): """ try: logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") - data = translate_to_en("[" + request_data.text + "]") + data = get_translation_from_llama3("[" + request_data.text + "]") logger.info(f"prompt_generation response @@@@@@:{data}") except Exception as e: logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index 193bcfc..e541781 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -1,11 +1,16 @@ +import json import logging +import time +import requests from dashscope import Generation from requests import RequestException from retry import retry from app.core.config import QWEN_API_KEY +logger = logging.getLogger(__name__) + # os.environ["http_proxy"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] = "http://127.0.0.1:7890" @@ -15,26 +20,35 @@ from app.core.config import QWEN_API_KEY # openai_api_key=OPENAI_API_KEY, # temperature=0) +# prefix_for_llama = ( +# """ +# Translate everything within the brackets [] into English. +# Never translate or modify any English input. +# The input must be fully translated into coherent English sentences. +# Please only output the translated result.\n +# """ +# ) + def translate_to_en(text): - template = ( - """You are a translation expert, proficient in various languages. - And can translate various languages into English. - Please translate to grammatically correct English regardless of the input language. - If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", - output the input text exactly as it is without any modifications or additions. - If there are grammatical errors, correct them and then output the sentence.""" - ) - - prefix = ( - """ - Translate everything within the brackets [] into English. - Never translate or modify any English input. - The input must be fully translated into coherent English sentences. - Never present the translation results in the format - "The translation of \"Material suave\" into English would be \"Smooth material.\"". Instead, directly output "Smooth material". - """ - ) + # template = ( + # """You are a translation expert, proficient in various languages. + # And can translate various languages into English. + # Please translate to grammatically correct English regardless of the input language. + # If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", + # output the input text exactly as it is without any modifications or additions. + # If there are grammatical errors, correct them and then output the sentence.""" + # ) + # + # prefix = ( + # """ + # Translate everything within the brackets [] into English. + # Never translate or modify any English input. + # The input must be fully translated into coherent English sentences. + # Never present the translation results in the format + # "The translation of \"Material suave\" into English would be \"Smooth material.\"". Instead, directly output "Smooth material". + # """ + # ) messages = [ # { # Translate the entire text and ensure the output is a complete and coherent sentence in English. @@ -43,7 +57,7 @@ def translate_to_en(text): # }, { # "content": input('请输入:'), # 用户message - "content": prefix + text, # 用户message + "content": text, # 用户message "role": "user" } ] @@ -52,12 +66,18 @@ def translate_to_en(text): print("input : {}, translate result : {}".format(text, assistant_output.content)) return assistant_output.content + # llama3专用 + # data = get_translation_from_llama3(text) + # translation = data + # # print("Response from llama3 : " + translation) + # return translation + @retry(exceptions=RequestException, tries=3, delay=1) def get_response(messages): response = Generation.call( model='qwen-turbo', - api_key= QWEN_API_KEY, + api_key=QWEN_API_KEY, messages=messages, # seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234 result_format='message', # 将输出设置为message形式 @@ -65,9 +85,41 @@ def get_response(messages): ) return response + +def get_translation_from_llama3(text): + start_time = time.time() + url = "http://10.1.1.240:11434/api/generate" + # url = "http://10.1.1.240:1143/api/generate" + + # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" + + # 创建请求的负载 + payload = { + "model": "translator", + "prompt": f"[{text}]", + "stream": False + } + + # 将负载转换为 JSON 格式 + headers = {'Content-Type': 'application/json'} + response = requests.post(url, data=json.dumps(payload), headers=headers) + # 处理响应 + if response.status_code == 200: + # print("Response from server:") + # print(response.json()) + resp = json.loads(response.content).get("response") + logger.info(f"translation server runtime is {time.time() - start_time} , response is {resp}") + print("input : {}, translate result : {}".format(text, resp)) + return resp + else: + logger.info(f"translation server runtime is {time.time() - start_time} , response is {response.content}") + print(f"Request failed with status code {response.status_code}") + print(response.text) + + def main(): """Main function""" - text = translate_to_en("fire") + text = get_translation_from_llama3("[火焰]") print(text) From 9d11e995dd7230917b61f93d0d6ca8bac2929315 Mon Sep 17 00:00:00 2001 From: xupei Date: Tue, 29 Oct 2024 16:50:46 +0800 Subject: [PATCH 10/52] =?UTF-8?q?=E4=BB=8E=E5=90=91=E9=87=8F=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E4=B8=AD=E6=A3=80=E7=B4=A2=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E5=B9=B6=E9=9B=86=E6=88=90=E5=88=B0chat-robot?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_query_image.py | 36 ++++++++ app/api/api_route.py | 3 +- app/schemas/query_image.py | 6 ++ app/service/chat_robot/script/main.py | 2 +- app/service/chat_robot/script/prompt.py | 49 ++++++---- .../chat_robot/script/service/CallQWen.py | 57 ++++++++++-- app/service/search_image_with_text/service.py | 89 +++++++++++++++++++ 7 files changed, 217 insertions(+), 25 deletions(-) create mode 100644 app/api/api_query_image.py create mode 100644 app/schemas/query_image.py create mode 100644 app/service/search_image_with_text/service.py diff --git a/app/api/api_query_image.py b/app/api/api_query_image.py new file mode 100644 index 0000000..d27c67b --- /dev/null +++ b/app/api/api_query_image.py @@ -0,0 +1,36 @@ +import json +import logging +from http.client import HTTPException + +from fastapi import APIRouter + +from app.schemas.query_image import QueryImageModel +from app.schemas.response_template import ResponseModel +from app.service.search_image_with_text.service import query + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/query_image") +def query_image(request_data: QueryImageModel): + """ + 对话机器人 + 创建一个具有以下参数的请求体: + - **gender**: 性别 + - **content**: 用户输入的内容 + + 示例参数: + { + "gender": "male", + "content": "give me a long sleeve blouse", + } + """ + try: + logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}") + data = query(request_data.gender, request_data.content) + logger.info(f"query_image response @@@@@@:{json.dumps(data)}") + except Exception as e: + logger.warning(f"query_image Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_route.py b/app/api/api_route.py index 7ee774d..0da3a66 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from app.api import api_attribute_retrieve +from app.api import api_attribute_retrieve, api_query_image from app.api import api_brighten from app.api import api_chat_robot from app.api import api_design @@ -23,3 +23,4 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'], router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") +router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") \ No newline at end of file diff --git a/app/schemas/query_image.py b/app/schemas/query_image.py new file mode 100644 index 0000000..147603f --- /dev/null +++ b/app/schemas/query_image.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class QueryImageModel(BaseModel): + gender: str + content: str diff --git a/app/service/chat_robot/script/main.py b/app/service/chat_robot/script/main.py index cabe372..3342a5c 100644 --- a/app/service/chat_robot/script/main.py +++ b/app/service/chat_robot/script/main.py @@ -100,7 +100,7 @@ def chat(post_data): # session_key=f"buffer:{user_id}:{session_id}", # ) - final_outputs = CallQWen.call_with_messages(input_message) + final_outputs = CallQWen.call_with_messages(input_message, gender) # api_response = { # 'user_id': user_id, # 'session_id': session_id, diff --git a/app/service/chat_robot/script/prompt.py b/app/service/chat_robot/script/prompt.py index a88044d..ad6ac9e 100644 --- a/app/service/chat_robot/script/prompt.py +++ b/app/service/chat_robot/script/prompt.py @@ -1,16 +1,31 @@ +# FASHION_CHAT_BOT_PREFIX = """ +# You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. +# The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. +# Remember your answer should be very precise and the final output answer should not exceed 20 words. +# +# You may encounter the following types of questions: +# 1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. +# Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. +# Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. +# Never query for all the columns from a specific table, only ask for the relevant columns given the question. +# You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. +# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. +# If the question does not seem related to the database, just return "I don't know" as the answer. +# +# 2) If the query related to current events, you should use internet_search to seek help from the internet. +# +# 3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant. +# +# Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential. +# """ + FASHION_CHAT_BOT_PREFIX = """ You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. Remember your answer should be very precise and the final output answer should not exceed 20 words. You may encounter the following types of questions: -1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. -Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. -Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. -Never query for all the columns from a specific table, only ask for the relevant columns given the question. -You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. -If the question does not seem related to the database, just return "I don't know" as the answer. +1) If you need to query information related to clothing retrieval, please use the get_image_from_vector_db tool. 2) If the query related to current events, you should use internet_search to seek help from the internet. @@ -37,15 +52,19 @@ ANSWER_FORMAT_SUFFIX = """ My final answer are limited to 20 words and be as much precise as possible. """ +# TOOLS_FUNCTIONS_SUFFIX = ( +# "If the input involves clothing queries," +# "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." +# "All SQL statements must use 'ORDER BY RAND()', for example:" +# "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" +# "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" +# "If the input does not involve clothing queries, " +# "I should engage in conversation as an assistant or search from internet with internet_search tool." +# "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" +# "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " +# ) TOOLS_FUNCTIONS_SUFFIX = ( - "If the input involves clothing queries," - "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." - "All SQL statements must use 'ORDER BY RAND()', for example:" - "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" - "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" - "If the input does not involve clothing queries, " - "I should engage in conversation as an assistant or search from internet with internet_search tool." - "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" + "If the input involves clothing queries,please use the get_image_from_vector_db tool." "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " ) diff --git a/app/service/chat_robot/script/service/CallQWen.py b/app/service/chat_robot/script/service/CallQWen.py index d2e2c06..33dcd04 100644 --- a/app/service/chat_robot/script/service/CallQWen.py +++ b/app/service/chat_robot/script/service/CallQWen.py @@ -8,6 +8,7 @@ from app.core.config import * from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.database import CustomDatabase from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN +from app.service.search_image_with_text.service import query get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." @@ -32,6 +33,12 @@ query_database_description = ( "order by rand() LIMIT 2'" ) +query_vector_db_description = ( + "Use this tool to find the clothing images that users need. " + "If the user's input includes clothing types such as blouse, skirt, dress, outerwear, pants, or trousers, please use this tool. " + "The input for the tool is the string provided by the user." +) + tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials." "Input is an empty string") @@ -105,15 +112,37 @@ tools = [ "function": { "name": "tutorial_tool", "description": tutorial_description, + # "parameters": { + # "type": "object", + # "properties": { + # "sql_string": { + # "type": "string", + # "description": "由模型生成的sql语句" + # } + # } + # }, + } + }, + { + "type": "function", + "function": { + "name": "get_image_from_vector_db", + "description": query_vector_db_description, "parameters": { - "type": "object", - "properties": { - "sql_string": { - "type": "string", - "description": "由模型生成的sql语句" + "parameters": { + "type": "object", + "properties": { + "gender": { + "type": "string", + "description": "性别" + }, + "content": { + "type": "string", + "description": "用户描述" + } } - } - }, + }, + } } } ] @@ -150,6 +179,10 @@ def query_database(sql_string): return CustomDatabase.run(db, sql_string) +def get_image_from_vector_db(gender, content): + return query(gender, content) + + @retry(exceptions=NewConnectionError, tries=3, delay=1) def get_response(messages): response = Generation.call( @@ -164,7 +197,8 @@ def get_response(messages): return response -def call_with_messages(message): +def call_with_messages(message, gender): + user_input = message print('\n') # messages = [ # { @@ -235,6 +269,12 @@ def call_with_messages(message): tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()} flag = False result_content = tool_info['content'] + elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db': + tool_info = {"name": "get_image_from_vector_db", "role": "tool", + 'content': get_image_from_vector_db(gender, user_input)} + flag = False + result_content = tool_info['content'] + response_type = "image" print(f"工具输出信息:{tool_info['content']}\n") messages.append(tool_info) @@ -257,5 +297,6 @@ def call_with_messages(message): def tutorial_tool(): return TUTORIAL_TOOL_RETURN + if __name__ == '__main__': call_with_messages() diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py new file mode 100644 index 0000000..98f6ac4 --- /dev/null +++ b/app/service/search_image_with_text/service.py @@ -0,0 +1,89 @@ +import chromadb +import hashlib + +import pandas as pd +from chromadb.config import Settings +from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction +from tqdm import tqdm + +# 读取 csv 文件 +csv_file_path = r'D:/Files/csv/output/output.csv' +image_path = r'D:/images-clean' + +df = pd.read_csv(csv_file_path, encoding='Windows-1252') + +# 创建 Chroma 客户端 +client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) +# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) +# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) +# 创建集合 +embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") + + +def create_collection(): + collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) + + # 存储数据,包括自定义属性 + images_description = [] + images_metadata = [] + ids = [] + batch_size = 41666 # 最大批量大小 + for index, row in tqdm(df.iterrows()): + # 将图片的md5作为id + with open(image_path + row['path'], 'rb') as f: + image_data = f.read() + md5_value = hashlib.md5(image_data).hexdigest() + ids.append(md5_value) + images_description.append(row['description']) + images_metadata.append({ + "gender": row['gender'], + "path": row['path'] + }) + + # 将数据添加到集合 + # 每达到 batch_size 就执行一次 upsert + if len(ids) >= batch_size: + collection.upsert( + ids=list(ids), + documents=images_description, + metadatas=images_metadata # 添加自定义属性 + ) + # 清空列表以准备下一批数据 + ids.clear() + images_description.clear() + images_metadata.clear() + + if ids: + collection.upsert( + ids=list(ids), + documents=images_description, + metadatas=images_metadata # 添加自定义属性 + ) + + print("Data successfully stored in the vector database.") + + +def query(gender, content): + collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn) + # 6. 查询相似内容 + user_gender = gender # 用户输入的性别 + user_content = content # 用户输入的内容 + + results = collection.query( + query_texts=user_content, + n_results=5, # 返回前 5 个结果 + where={"gender": user_gender} # 根据性别过滤 + ) + + # 输出结果 + resp = [] + for document, result in zip(results['documents'][0], results['metadatas'][0]): + # print("Path:", result['path']) + # print("Content:", document) + resp.append(result['path']) + return resp + + +if __name__ == '__main__': + # create_collection() + query("female", "I need a long sleeve dress") From 6b8b24de896e236617bda6f989fb5ee85458f906 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 22:42:03 +0800 Subject: [PATCH 11/52] =?UTF-8?q?feat=20=20OLLAMA=5FURL=20=E5=88=87?= =?UTF-8?q?=E6=8D=A2=E5=88=B0A6000=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/core/config.py b/app/core/config.py index d369ff2..7629429 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -138,7 +138,7 @@ SEGMENTATION = { "output": "seg_output__0", } # ollama config -OLLAMA_URL = "http://10.1.1.243:11434/api/embeddings" +OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings" # DESIGN config DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing" From 5812c3eaaba61e8c1f4514a968c9c294b3b1867e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 29 Oct 2024 17:44:10 +0800 Subject: [PATCH 12/52] feat fix 1 --- app/service/search_image_with_text/service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 98f6ac4..5ac9cef 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -17,7 +17,8 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 -embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") def create_collection(): From 6a62d0844694e4cee921fccc5526e4cd77a8ca85 Mon Sep 17 00:00:00 2001 From: xupei Date: Tue, 29 Oct 2024 16:50:46 +0800 Subject: [PATCH 13/52] =?UTF-8?q?=E4=BB=8E=E5=90=91=E9=87=8F=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E4=B8=AD=E6=A3=80=E7=B4=A2=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E5=B9=B6=E9=9B=86=E6=88=90=E5=88=B0chat-robot?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/search_image_with_text/service.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 5ac9cef..98f6ac4 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -17,8 +17,7 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 -# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") def create_collection(): From 31d7f55402d789a884bc7d0094fc31631869f6a2 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 29 Oct 2024 16:58:37 +0800 Subject: [PATCH 14/52] =?UTF-8?q?feat=20=20=20dockerfile=20=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 1828 -> 1860 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6c9e38f1ded86de71e2126d5c357903ea0d08a05..73507145f0a1adf6986bae737597a22a911f640e 100644 GIT binary patch delta 44 ycmZ3&cZ6@lELQnsh75)xhJ1!xhD3%Gh9rhM23rOL20aE-AU0$$-8_@En-Ku{m Date: Tue, 29 Oct 2024 17:17:30 +0800 Subject: [PATCH 15/52] =?UTF-8?q?feat=20=20=20dockerfile=20=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/search_image_with_text/service.py | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 98f6ac4..47a9dde 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -7,10 +7,10 @@ from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaE from tqdm import tqdm # 读取 csv 文件 -csv_file_path = r'D:/Files/csv/output/output.csv' -image_path = r'D:/images-clean' +# csv_file_path = r'D:/Files/csv/output/output.csv' +# image_path = r'D:/images-clean' -df = pd.read_csv(csv_file_path, encoding='Windows-1252') +# df = pd.read_csv(csv_file_path, encoding='Windows-1252') # 创建 Chroma 客户端 client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) @@ -20,47 +20,47 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -def create_collection(): - collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) - - # 存储数据,包括自定义属性 - images_description = [] - images_metadata = [] - ids = [] - batch_size = 41666 # 最大批量大小 - for index, row in tqdm(df.iterrows()): - # 将图片的md5作为id - with open(image_path + row['path'], 'rb') as f: - image_data = f.read() - md5_value = hashlib.md5(image_data).hexdigest() - ids.append(md5_value) - images_description.append(row['description']) - images_metadata.append({ - "gender": row['gender'], - "path": row['path'] - }) - - # 将数据添加到集合 - # 每达到 batch_size 就执行一次 upsert - if len(ids) >= batch_size: - collection.upsert( - ids=list(ids), - documents=images_description, - metadatas=images_metadata # 添加自定义属性 - ) - # 清空列表以准备下一批数据 - ids.clear() - images_description.clear() - images_metadata.clear() - - if ids: - collection.upsert( - ids=list(ids), - documents=images_description, - metadatas=images_metadata # 添加自定义属性 - ) - - print("Data successfully stored in the vector database.") +# def create_collection(): +# collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) +# +# # 存储数据,包括自定义属性 +# images_description = [] +# images_metadata = [] +# ids = [] +# batch_size = 41666 # 最大批量大小 +# for index, row in tqdm(df.iterrows()): +# # 将图片的md5作为id +# with open(image_path + row['path'], 'rb') as f: +# image_data = f.read() +# md5_value = hashlib.md5(image_data).hexdigest() +# ids.append(md5_value) +# images_description.append(row['description']) +# images_metadata.append({ +# "gender": row['gender'], +# "path": row['path'] +# }) +# +# # 将数据添加到集合 +# # 每达到 batch_size 就执行一次 upsert +# if len(ids) >= batch_size: +# collection.upsert( +# ids=list(ids), +# documents=images_description, +# metadatas=images_metadata # 添加自定义属性 +# ) +# # 清空列表以准备下一批数据 +# ids.clear() +# images_description.clear() +# images_metadata.clear() +# +# if ids: +# collection.upsert( +# ids=list(ids), +# documents=images_description, +# metadatas=images_metadata # 添加自定义属性 +# ) +# +# print("Data successfully stored in the vector database.") def query(gender, content): From 55cd1b27bdb6d1e39b46a882243eb5c65cc0c299 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 29 Oct 2024 17:27:05 +0800 Subject: [PATCH 16/52] feat fix 1 --- app/api/api_query_image.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/api/api_query_image.py b/app/api/api_query_image.py index d27c67b..ca0dbe6 100644 --- a/app/api/api_query_image.py +++ b/app/api/api_query_image.py @@ -1,8 +1,7 @@ import json import logging -from http.client import HTTPException -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.schemas.query_image import QueryImageModel from app.schemas.response_template import ResponseModel From be7e12103395166b80eb5f56977b3283e9673846 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 29 Oct 2024 17:44:10 +0800 Subject: [PATCH 17/52] feat fix 1 --- app/service/search_image_with_text/service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 47a9dde..36a86a8 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -17,7 +17,8 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 -embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") # def create_collection(): From aa0db5006d7dab9be273f3db6943d1fc39846053 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 8 Nov 2024 14:35:23 +0800 Subject: [PATCH 18/52] =?UTF-8?q?feat=20=20OLLAMA=5FURL=20=E5=88=87?= =?UTF-8?q?=E6=8D=A2=E5=88=B0A6000=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/search_image_with_text/service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 36a86a8..edd4d93 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -6,6 +6,8 @@ from chromadb.config import Settings from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction from tqdm import tqdm +from app.core.config import OLLAMA_URL + # 读取 csv 文件 # csv_file_path = r'D:/Files/csv/output/output.csv' # image_path = r'D:/images-clean' @@ -18,7 +20,7 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 # embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url=OLLAMA_URL, model_name="mxbai-embed-large") # def create_collection(): @@ -67,7 +69,7 @@ embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddin def query(gender, content): collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn) # 6. 查询相似内容 - user_gender = gender # 用户输入的性别 + user_gender = gender.lower() # 用户输入的性别 user_content = content # 用户输入的内容 results = collection.query( From 1a12ed2201eaeebed4d8fdf8190bb8deee7ec7c9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:17:59 +0800 Subject: [PATCH 19/52] =?UTF-8?q?feat=20=20rabbitt=20env=20=E5=88=87?= =?UTF-8?q?=E6=8D=A2=E5=88=B0=E7=94=9F=E4=BA=A7=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 7629429..2575c97 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -30,8 +30,8 @@ else: CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" SEG_CACHE_PATH = "/seg_cache/" -# RABBITMQ_ENV = "" # 生产环境 -RABBITMQ_ENV = "-dev" # 开发环境 +RABBITMQ_ENV = "" # 生产环境 +# RABBITMQ_ENV = "-dev" # 开发环境 # RABBITMQ_ENV = "-local" # 本地测试环境 settings = Settings() From f6c166b24d02747f2a19f092d9a6869affe898b8 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:21:22 +0800 Subject: [PATCH 20/52] =?UTF-8?q?Revert=20"feat=20=20rabbitt=20env=20?= =?UTF-8?q?=E5=88=87=E6=8D=A2=E5=88=B0=E7=94=9F=E4=BA=A7"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 1a12ed2201eaeebed4d8fdf8190bb8deee7ec7c9. --- app/core/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 2575c97..7629429 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -30,8 +30,8 @@ else: CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv" SEG_CACHE_PATH = "/seg_cache/" -RABBITMQ_ENV = "" # 生产环境 -# RABBITMQ_ENV = "-dev" # 开发环境 +# RABBITMQ_ENV = "" # 生产环境 +RABBITMQ_ENV = "-dev" # 开发环境 # RABBITMQ_ENV = "-local" # 本地测试环境 settings = Settings() From 7b7dad636c39feb137e687332ce16c2d4964a919 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:21:30 +0800 Subject: [PATCH 21/52] =?UTF-8?q?Revert=20"feat=20=20OLLAMA=5FURL=20?= =?UTF-8?q?=E5=88=87=E6=8D=A2=E5=88=B0A6000"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit aa0db5006d7dab9be273f3db6943d1fc39846053. --- app/service/search_image_with_text/service.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index edd4d93..36a86a8 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -6,8 +6,6 @@ from chromadb.config import Settings from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction from tqdm import tqdm -from app.core.config import OLLAMA_URL - # 读取 csv 文件 # csv_file_path = r'D:/Files/csv/output/output.csv' # image_path = r'D:/images-clean' @@ -20,7 +18,7 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 # embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -embedding_fn = OllamaEmbeddingFunction(url=OLLAMA_URL, model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") # def create_collection(): @@ -69,7 +67,7 @@ embedding_fn = OllamaEmbeddingFunction(url=OLLAMA_URL, model_name="mxbai-embed-l def query(gender, content): collection = client.get_collection("sub_sketches_description", embedding_function=embedding_fn) # 6. 查询相似内容 - user_gender = gender.lower() # 用户输入的性别 + user_gender = gender # 用户输入的性别 user_content = content # 用户输入的内容 results = collection.query( From bb0ac9046b8bf228114b0fc94d91ef5a43c7a456 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:21:45 +0800 Subject: [PATCH 22/52] Revert "feat" This reverts commit be7e12103395166b80eb5f56977b3283e9673846. --- app/service/search_image_with_text/service.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 36a86a8..47a9dde 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -17,8 +17,7 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 -# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") # def create_collection(): From 85145ba2c9a54cc4f8899aaedfb853f5978d2249 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:21:46 +0800 Subject: [PATCH 23/52] Revert "feat" This reverts commit 55cd1b27bdb6d1e39b46a882243eb5c65cc0c299. --- app/api/api_query_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/api/api_query_image.py b/app/api/api_query_image.py index ca0dbe6..d27c67b 100644 --- a/app/api/api_query_image.py +++ b/app/api/api_query_image.py @@ -1,7 +1,8 @@ import json import logging +from http.client import HTTPException -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter from app.schemas.query_image import QueryImageModel from app.schemas.response_template import ResponseModel From e134453976d519b9cc443865e8b1cdc9a3200816 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:21:48 +0800 Subject: [PATCH 24/52] =?UTF-8?q?Revert=20"feat=20=20=20dockerfile=20?= =?UTF-8?q?=E4=BF=AE=E6=94=B9"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 1ba67d0bf72c400a405437ee67591ab5750a4d9a. --- app/service/search_image_with_text/service.py | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 47a9dde..98f6ac4 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -7,10 +7,10 @@ from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaE from tqdm import tqdm # 读取 csv 文件 -# csv_file_path = r'D:/Files/csv/output/output.csv' -# image_path = r'D:/images-clean' +csv_file_path = r'D:/Files/csv/output/output.csv' +image_path = r'D:/images-clean' -# df = pd.read_csv(csv_file_path, encoding='Windows-1252') +df = pd.read_csv(csv_file_path, encoding='Windows-1252') # 创建 Chroma 客户端 client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) @@ -20,47 +20,47 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -# def create_collection(): -# collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) -# -# # 存储数据,包括自定义属性 -# images_description = [] -# images_metadata = [] -# ids = [] -# batch_size = 41666 # 最大批量大小 -# for index, row in tqdm(df.iterrows()): -# # 将图片的md5作为id -# with open(image_path + row['path'], 'rb') as f: -# image_data = f.read() -# md5_value = hashlib.md5(image_data).hexdigest() -# ids.append(md5_value) -# images_description.append(row['description']) -# images_metadata.append({ -# "gender": row['gender'], -# "path": row['path'] -# }) -# -# # 将数据添加到集合 -# # 每达到 batch_size 就执行一次 upsert -# if len(ids) >= batch_size: -# collection.upsert( -# ids=list(ids), -# documents=images_description, -# metadatas=images_metadata # 添加自定义属性 -# ) -# # 清空列表以准备下一批数据 -# ids.clear() -# images_description.clear() -# images_metadata.clear() -# -# if ids: -# collection.upsert( -# ids=list(ids), -# documents=images_description, -# metadatas=images_metadata # 添加自定义属性 -# ) -# -# print("Data successfully stored in the vector database.") +def create_collection(): + collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) + + # 存储数据,包括自定义属性 + images_description = [] + images_metadata = [] + ids = [] + batch_size = 41666 # 最大批量大小 + for index, row in tqdm(df.iterrows()): + # 将图片的md5作为id + with open(image_path + row['path'], 'rb') as f: + image_data = f.read() + md5_value = hashlib.md5(image_data).hexdigest() + ids.append(md5_value) + images_description.append(row['description']) + images_metadata.append({ + "gender": row['gender'], + "path": row['path'] + }) + + # 将数据添加到集合 + # 每达到 batch_size 就执行一次 upsert + if len(ids) >= batch_size: + collection.upsert( + ids=list(ids), + documents=images_description, + metadatas=images_metadata # 添加自定义属性 + ) + # 清空列表以准备下一批数据 + ids.clear() + images_description.clear() + images_metadata.clear() + + if ids: + collection.upsert( + ids=list(ids), + documents=images_description, + metadatas=images_metadata # 添加自定义属性 + ) + + print("Data successfully stored in the vector database.") def query(gender, content): From e097e485b82771f6948b8e36abb072c8b81b2b09 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:21:49 +0800 Subject: [PATCH 25/52] =?UTF-8?q?Revert=20"feat=20=20=20dockerfile=20?= =?UTF-8?q?=E4=BF=AE=E6=94=B9"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 31d7f55402d789a884bc7d0094fc31631869f6a2. --- requirements.txt | Bin 1860 -> 1828 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 73507145f0a1adf6986bae737597a22a911f640e..6c9e38f1ded86de71e2126d5c357903ea0d08a05 100644 GIT binary patch delta 12 TcmX@Yw}fxQEY{7tSi2YjA{zwJ delta 44 ycmZ3&cZ6@lELQnsh75)xhJ1!xhD3%Gh9rhM23rOL20aE-AU0$$-8_@En-Ku{m Date: Mon, 2 Dec 2024 23:21:51 +0800 Subject: [PATCH 26/52] =?UTF-8?q?Revert=20"=E4=BB=8E=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E4=B8=AD=E6=A3=80=E7=B4=A2=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E5=B9=B6=E9=9B=86=E6=88=90=E5=88=B0chat-robot"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 6a62d0844694e4cee921fccc5526e4cd77a8ca85. --- app/service/search_image_with_text/service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 98f6ac4..5ac9cef 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -17,7 +17,8 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 -embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") def create_collection(): From 68c95eec0c0480415982f77e59310ee578354bd2 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:21:53 +0800 Subject: [PATCH 27/52] Revert "feat" This reverts commit 5812c3eaaba61e8c1f4514a968c9c294b3b1867e. --- app/service/search_image_with_text/service.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 5ac9cef..98f6ac4 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -17,8 +17,7 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 -# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") -embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") def create_collection(): From 2102b712300f89bb4b2820ae415b08da0aeecc20 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:26:19 +0800 Subject: [PATCH 28/52] =?UTF-8?q?feat=20=20=E4=BF=AE=E5=A4=8Dchatroboot=20?= =?UTF-8?q?fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/search_image_with_text/service.py | 90 +++++++++---------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/app/service/search_image_with_text/service.py b/app/service/search_image_with_text/service.py index 98f6ac4..712050f 100644 --- a/app/service/search_image_with_text/service.py +++ b/app/service/search_image_with_text/service.py @@ -7,60 +7,60 @@ from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaE from tqdm import tqdm # 读取 csv 文件 -csv_file_path = r'D:/Files/csv/output/output.csv' -image_path = r'D:/images-clean' +# csv_file_path = r'D:/Files/csv/output/output.csv' +# image_path = r'D:/images-clean' -df = pd.read_csv(csv_file_path, encoding='Windows-1252') +# df = pd.read_csv(csv_file_path, encoding='Windows-1252') # 创建 Chroma 客户端 client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # 创建集合 -embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") +embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") -def create_collection(): - collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) - - # 存储数据,包括自定义属性 - images_description = [] - images_metadata = [] - ids = [] - batch_size = 41666 # 最大批量大小 - for index, row in tqdm(df.iterrows()): - # 将图片的md5作为id - with open(image_path + row['path'], 'rb') as f: - image_data = f.read() - md5_value = hashlib.md5(image_data).hexdigest() - ids.append(md5_value) - images_description.append(row['description']) - images_metadata.append({ - "gender": row['gender'], - "path": row['path'] - }) - - # 将数据添加到集合 - # 每达到 batch_size 就执行一次 upsert - if len(ids) >= batch_size: - collection.upsert( - ids=list(ids), - documents=images_description, - metadatas=images_metadata # 添加自定义属性 - ) - # 清空列表以准备下一批数据 - ids.clear() - images_description.clear() - images_metadata.clear() - - if ids: - collection.upsert( - ids=list(ids), - documents=images_description, - metadatas=images_metadata # 添加自定义属性 - ) - - print("Data successfully stored in the vector database.") +# def create_collection(): +# collection = client.get_or_create_collection("sub_sketches_description", embedding_function=embedding_fn) +# +# # 存储数据,包括自定义属性 +# images_description = [] +# images_metadata = [] +# ids = [] +# batch_size = 41666 # 最大批量大小 +# for index, row in tqdm(df.iterrows()): +# # 将图片的md5作为id +# with open(image_path + row['path'], 'rb') as f: +# image_data = f.read() +# md5_value = hashlib.md5(image_data).hexdigest() +# ids.append(md5_value) +# images_description.append(row['description']) +# images_metadata.append({ +# "gender": row['gender'], +# "path": row['path'] +# }) +# +# # 将数据添加到集合 +# # 每达到 batch_size 就执行一次 upsert +# if len(ids) >= batch_size: +# collection.upsert( +# ids=list(ids), +# documents=images_description, +# metadatas=images_metadata # 添加自定义属性 +# ) +# # 清空列表以准备下一批数据 +# ids.clear() +# images_description.clear() +# images_metadata.clear() +# +# if ids: +# collection.upsert( +# ids=list(ids), +# documents=images_description, +# metadatas=images_metadata # 添加自定义属性 +# ) +# +# print("Data successfully stored in the vector database.") def query(gender, content): From 7ff1603583d653cde043fc49425b46546d39a596 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 23:38:06 +0800 Subject: [PATCH 29/52] =?UTF-8?q?feat=20=20=E4=BF=AE=E5=A4=8Dchatroboot=20?= =?UTF-8?q?fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_route.py | 3 +- app/core/config.py | 2 +- app/service/chat_robot/script/main.py | 2 +- app/service/chat_robot/script/prompt.py | 49 +++++++++++----- .../chat_robot/script/service/CallQWen.py | 57 ++++++++++++++++--- 5 files changed, 87 insertions(+), 26 deletions(-) diff --git a/app/api/api_route.py b/app/api/api_route.py index 7ee774d..0da3a66 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from app.api import api_attribute_retrieve +from app.api import api_attribute_retrieve, api_query_image from app.api import api_brighten from app.api import api_chat_robot from app.api import api_design @@ -23,3 +23,4 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'], router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") +router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py index d369ff2..7629429 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -138,7 +138,7 @@ SEGMENTATION = { "output": "seg_output__0", } # ollama config -OLLAMA_URL = "http://10.1.1.243:11434/api/embeddings" +OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings" # DESIGN config DESIGN_MODEL_URL = '10.1.1.240:10000' AIDA_CLOTHING = "aida-clothing" diff --git a/app/service/chat_robot/script/main.py b/app/service/chat_robot/script/main.py index cabe372..3342a5c 100644 --- a/app/service/chat_robot/script/main.py +++ b/app/service/chat_robot/script/main.py @@ -100,7 +100,7 @@ def chat(post_data): # session_key=f"buffer:{user_id}:{session_id}", # ) - final_outputs = CallQWen.call_with_messages(input_message) + final_outputs = CallQWen.call_with_messages(input_message, gender) # api_response = { # 'user_id': user_id, # 'session_id': session_id, diff --git a/app/service/chat_robot/script/prompt.py b/app/service/chat_robot/script/prompt.py index a88044d..ad6ac9e 100644 --- a/app/service/chat_robot/script/prompt.py +++ b/app/service/chat_robot/script/prompt.py @@ -1,16 +1,31 @@ +# FASHION_CHAT_BOT_PREFIX = """ +# You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. +# The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. +# Remember your answer should be very precise and the final output answer should not exceed 20 words. +# +# You may encounter the following types of questions: +# 1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. +# Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. +# Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. +# Never query for all the columns from a specific table, only ask for the relevant columns given the question. +# You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. +# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. +# If the question does not seem related to the database, just return "I don't know" as the answer. +# +# 2) If the query related to current events, you should use internet_search to seek help from the internet. +# +# 3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant. +# +# Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential. +# """ + FASHION_CHAT_BOT_PREFIX = """ You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can. The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database. Remember your answer should be very precise and the final output answer should not exceed 20 words. You may encounter the following types of questions: -1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database. -Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables. -Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results. -Never query for all the columns from a specific table, only ask for the relevant columns given the question. -You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. -If the question does not seem related to the database, just return "I don't know" as the answer. +1) If you need to query information related to clothing retrieval, please use the get_image_from_vector_db tool. 2) If the query related to current events, you should use internet_search to seek help from the internet. @@ -37,15 +52,19 @@ ANSWER_FORMAT_SUFFIX = """ My final answer are limited to 20 words and be as much precise as possible. """ +# TOOLS_FUNCTIONS_SUFFIX = ( +# "If the input involves clothing queries," +# "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." +# "All SQL statements must use 'ORDER BY RAND()', for example:" +# "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" +# "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" +# "If the input does not involve clothing queries, " +# "I should engage in conversation as an assistant or search from internet with internet_search tool." +# "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" +# "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " +# ) TOOLS_FUNCTIONS_SUFFIX = ( - "If the input involves clothing queries," - "I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables." - "All SQL statements must use 'ORDER BY RAND()', for example:" - "Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'" - "Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'" - "If the input does not involve clothing queries, " - "I should engage in conversation as an assistant or search from internet with internet_search tool." - "If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'" + "If the input involves clothing queries,please use the get_image_from_vector_db tool." "Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool " ) diff --git a/app/service/chat_robot/script/service/CallQWen.py b/app/service/chat_robot/script/service/CallQWen.py index d2e2c06..33dcd04 100644 --- a/app/service/chat_robot/script/service/CallQWen.py +++ b/app/service/chat_robot/script/service/CallQWen.py @@ -8,6 +8,7 @@ from app.core.config import * from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.database import CustomDatabase from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN +from app.service.search_image_with_text.service import query get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." @@ -32,6 +33,12 @@ query_database_description = ( "order by rand() LIMIT 2'" ) +query_vector_db_description = ( + "Use this tool to find the clothing images that users need. " + "If the user's input includes clothing types such as blouse, skirt, dress, outerwear, pants, or trousers, please use this tool. " + "The input for the tool is the string provided by the user." +) + tutorial_description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials." "Input is an empty string") @@ -105,15 +112,37 @@ tools = [ "function": { "name": "tutorial_tool", "description": tutorial_description, + # "parameters": { + # "type": "object", + # "properties": { + # "sql_string": { + # "type": "string", + # "description": "由模型生成的sql语句" + # } + # } + # }, + } + }, + { + "type": "function", + "function": { + "name": "get_image_from_vector_db", + "description": query_vector_db_description, "parameters": { - "type": "object", - "properties": { - "sql_string": { - "type": "string", - "description": "由模型生成的sql语句" + "parameters": { + "type": "object", + "properties": { + "gender": { + "type": "string", + "description": "性别" + }, + "content": { + "type": "string", + "description": "用户描述" + } } - } - }, + }, + } } } ] @@ -150,6 +179,10 @@ def query_database(sql_string): return CustomDatabase.run(db, sql_string) +def get_image_from_vector_db(gender, content): + return query(gender, content) + + @retry(exceptions=NewConnectionError, tries=3, delay=1) def get_response(messages): response = Generation.call( @@ -164,7 +197,8 @@ def get_response(messages): return response -def call_with_messages(message): +def call_with_messages(message, gender): + user_input = message print('\n') # messages = [ # { @@ -235,6 +269,12 @@ def call_with_messages(message): tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()} flag = False result_content = tool_info['content'] + elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db': + tool_info = {"name": "get_image_from_vector_db", "role": "tool", + 'content': get_image_from_vector_db(gender, user_input)} + flag = False + result_content = tool_info['content'] + response_type = "image" print(f"工具输出信息:{tool_info['content']}\n") messages.append(tool_info) @@ -257,5 +297,6 @@ def call_with_messages(message): def tutorial_tool(): return TUTORIAL_TOOL_RETURN + if __name__ == '__main__': call_with_messages() From 600ca77d9092ecae0e2a13685c49b99b41a901b7 Mon Sep 17 00:00:00 2001 From: xupei Date: Mon, 2 Dec 2024 18:20:45 +0800 Subject: [PATCH 30/52] =?UTF-8?q?=E7=BF=BB=E8=AF=91=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E4=BD=BF=E7=94=A8llama3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit bc9aa034458be245ad7cc8d7be6258da14e83e60) --- .../chatgpt_for_translation.py | 75 +++++++++---------- 1 file changed, 35 insertions(+), 40 deletions(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index e541781..05d85fb 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -1,6 +1,5 @@ import json -import logging -import time + import requests from dashscope import Generation @@ -9,8 +8,6 @@ from retry import retry from app.core.config import QWEN_API_KEY -logger = logging.getLogger(__name__) - # os.environ["http_proxy"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] = "http://127.0.0.1:7890" @@ -20,35 +17,35 @@ logger = logging.getLogger(__name__) # openai_api_key=OPENAI_API_KEY, # temperature=0) -# prefix_for_llama = ( -# """ -# Translate everything within the brackets [] into English. -# Never translate or modify any English input. -# The input must be fully translated into coherent English sentences. -# Please only output the translated result.\n -# """ -# ) +prefix_for_llama = ( + """ + Translate everything within the brackets [] into English. + Never translate or modify any English input. + The input must be fully translated into coherent English sentences. + Please only output the translated result.\n + """ + ) def translate_to_en(text): - # template = ( - # """You are a translation expert, proficient in various languages. - # And can translate various languages into English. - # Please translate to grammatically correct English regardless of the input language. - # If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", - # output the input text exactly as it is without any modifications or additions. - # If there are grammatical errors, correct them and then output the sentence.""" - # ) - # - # prefix = ( - # """ - # Translate everything within the brackets [] into English. - # Never translate or modify any English input. - # The input must be fully translated into coherent English sentences. - # Never present the translation results in the format - # "The translation of \"Material suave\" into English would be \"Smooth material.\"". Instead, directly output "Smooth material". - # """ - # ) + template = ( + """You are a translation expert, proficient in various languages. + And can translate various languages into English. + Please translate to grammatically correct English regardless of the input language. + If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", + output the input text exactly as it is without any modifications or additions. + If there are grammatical errors, correct them and then output the sentence.""" + ) + + prefix = ( + """ + Translate everything within the brackets [] into English. + Never translate or modify any English input. + The input must be fully translated into coherent English sentences. + Never present the translation results in the format + "The translation of \"Material suave\" into English would be \"Smooth material.\"". Instead, directly output "Smooth material". + """ + ) messages = [ # { # Translate the entire text and ensure the output is a complete and coherent sentence in English. @@ -57,7 +54,7 @@ def translate_to_en(text): # }, { # "content": input('请输入:'), # 用户message - "content": text, # 用户message + "content": prefix + text, # 用户message "role": "user" } ] @@ -77,7 +74,7 @@ def translate_to_en(text): def get_response(messages): response = Generation.call( model='qwen-turbo', - api_key=QWEN_API_KEY, + api_key= QWEN_API_KEY, messages=messages, # seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234 result_format='message', # 将输出设置为message形式 @@ -87,39 +84,37 @@ def get_response(messages): def get_translation_from_llama3(text): - start_time = time.time() - url = "http://10.1.1.240:11434/api/generate" + url = "http://localhost:11434/api/generate" # url = "http://10.1.1.240:1143/api/generate" - # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" + prompt = f"System: {prefix_for_llama}\nUser:[{text}]" # 创建请求的负载 payload = { - "model": "translator", - "prompt": f"[{text}]", + "model": "llama3.2", + "prompt": prompt, "stream": False } # 将负载转换为 JSON 格式 headers = {'Content-Type': 'application/json'} response = requests.post(url, data=json.dumps(payload), headers=headers) + # 处理响应 if response.status_code == 200: # print("Response from server:") # print(response.json()) resp = json.loads(response.content).get("response") - logger.info(f"translation server runtime is {time.time() - start_time} , response is {resp}") print("input : {}, translate result : {}".format(text, resp)) return resp else: - logger.info(f"translation server runtime is {time.time() - start_time} , response is {response.content}") print(f"Request failed with status code {response.status_code}") print(response.text) def main(): """Main function""" - text = get_translation_from_llama3("[火焰]") + text = translate_to_en("fire") print(text) From a5ceabba047b338a71499896ecac8d47c908ad27 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 20:13:33 +0800 Subject: [PATCH 31/52] =?UTF-8?q?feat=20=20translator=20=E5=88=87=E6=8D=A2?= =?UTF-8?q?ollama=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit ea3e6667a051c2003da5055032ea486f0b338b6a) --- .../chatgpt_for_translation.py | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index 05d85fb..5d720b9 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -1,6 +1,5 @@ import json - import requests from dashscope import Generation from requests import RequestException @@ -17,35 +16,35 @@ from app.core.config import QWEN_API_KEY # openai_api_key=OPENAI_API_KEY, # temperature=0) -prefix_for_llama = ( - """ - Translate everything within the brackets [] into English. - Never translate or modify any English input. - The input must be fully translated into coherent English sentences. - Please only output the translated result.\n - """ - ) +# prefix_for_llama = ( +# """ +# Translate everything within the brackets [] into English. +# Never translate or modify any English input. +# The input must be fully translated into coherent English sentences. +# Please only output the translated result.\n +# """ +# ) def translate_to_en(text): - template = ( - """You are a translation expert, proficient in various languages. - And can translate various languages into English. - Please translate to grammatically correct English regardless of the input language. - If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", - output the input text exactly as it is without any modifications or additions. - If there are grammatical errors, correct them and then output the sentence.""" - ) - - prefix = ( - """ - Translate everything within the brackets [] into English. - Never translate or modify any English input. - The input must be fully translated into coherent English sentences. - Never present the translation results in the format - "The translation of \"Material suave\" into English would be \"Smooth material.\"". Instead, directly output "Smooth material". - """ - ) + # template = ( + # """You are a translation expert, proficient in various languages. + # And can translate various languages into English. + # Please translate to grammatically correct English regardless of the input language. + # If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1", + # output the input text exactly as it is without any modifications or additions. + # If there are grammatical errors, correct them and then output the sentence.""" + # ) + # + # prefix = ( + # """ + # Translate everything within the brackets [] into English. + # Never translate or modify any English input. + # The input must be fully translated into coherent English sentences. + # Never present the translation results in the format + # "The translation of \"Material suave\" into English would be \"Smooth material.\"". Instead, directly output "Smooth material". + # """ + # ) messages = [ # { # Translate the entire text and ensure the output is a complete and coherent sentence in English. @@ -54,7 +53,7 @@ def translate_to_en(text): # }, { # "content": input('请输入:'), # 用户message - "content": prefix + text, # 用户message + "content": text, # 用户message "role": "user" } ] @@ -74,7 +73,7 @@ def translate_to_en(text): def get_response(messages): response = Generation.call( model='qwen-turbo', - api_key= QWEN_API_KEY, + api_key=QWEN_API_KEY, messages=messages, # seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234 result_format='message', # 将输出设置为message形式 @@ -84,15 +83,15 @@ def get_response(messages): def get_translation_from_llama3(text): - url = "http://localhost:11434/api/generate" + url = "http://10.1.1.240:11434/api/generate" # url = "http://10.1.1.240:1143/api/generate" - prompt = f"System: {prefix_for_llama}\nUser:[{text}]" + # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" # 创建请求的负载 payload = { - "model": "llama3.2", - "prompt": prompt, + "model": "translator", + "prompt": f"[{text}]", "stream": False } @@ -114,7 +113,7 @@ def get_translation_from_llama3(text): def main(): """Main function""" - text = translate_to_en("fire") + text = get_translation_from_llama3("[火焰]") print(text) From f07d1b0822e65522f135fadef40f3bb0e601ac02 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 2 Dec 2024 20:31:46 +0800 Subject: [PATCH 32/52] =?UTF-8?q?feat=20=20translator=20=E5=88=87=E6=8D=A2?= =?UTF-8?q?ollama=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 5491c54bda681448fe93632e4898de6af82c58d4) --- app/service/prompt_generation/chatgpt_for_translation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index 5d720b9..e541781 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -1,4 +1,6 @@ import json +import logging +import time import requests from dashscope import Generation @@ -7,6 +9,8 @@ from retry import retry from app.core.config import QWEN_API_KEY +logger = logging.getLogger(__name__) + # os.environ["http_proxy"] = "http://127.0.0.1:7890" # os.environ["https_proxy"] = "http://127.0.0.1:7890" @@ -83,6 +87,7 @@ def get_response(messages): def get_translation_from_llama3(text): + start_time = time.time() url = "http://10.1.1.240:11434/api/generate" # url = "http://10.1.1.240:1143/api/generate" @@ -98,15 +103,16 @@ def get_translation_from_llama3(text): # 将负载转换为 JSON 格式 headers = {'Content-Type': 'application/json'} response = requests.post(url, data=json.dumps(payload), headers=headers) - # 处理响应 if response.status_code == 200: # print("Response from server:") # print(response.json()) resp = json.loads(response.content).get("response") + logger.info(f"translation server runtime is {time.time() - start_time} , response is {resp}") print("input : {}, translate result : {}".format(text, resp)) return resp else: + logger.info(f"translation server runtime is {time.time() - start_time} , response is {response.content}") print(f"Request failed with status code {response.status_code}") print(response.text) From 6a12dcba578f8869780137f222d14f163f5246d1 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 13 Jan 2025 15:41:37 +0800 Subject: [PATCH 33/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20desi?= =?UTF-8?q?gn=20=E5=88=86=E5=89=B2=E9=A2=84=E5=A4=84=E7=90=86=E6=96=B0?= =?UTF-8?q?=E5=A2=9E25padding=EF=BC=8C=E5=90=8E=E5=A4=84=E7=90=86=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=8F=92=E5=80=BC=E5=A4=84=E7=90=86=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/segmentation.py | 10 +++++----- app/service/design_fast/utils/design_ensemble.py | 10 ++++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index ebf02b4..9cc53a3 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -36,11 +36,11 @@ class Segmentation: # preview 过模型 不缓存 if "preview_submit" in result.keys() and result['preview_submit'] == "preview": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) # submit 过模型 缓存 elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) self.save_seg_result(seg_result, result['image_id']) # null 正常流程 加载本地缓存 无缓存则过模型 else: @@ -49,14 +49,14 @@ class Segmentation: # 判断缓存和实际图片size是否相同 if not _ or result["image"].shape[:2] != seg_result.shape: # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) self.save_seg_result(seg_result, result['image_id']) result['seg_result'] = seg_result # 处理前片后片 - temp_front = seg_result == 1.0 + temp_front = seg_result == 1 result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) - temp_back = seg_result == 2.0 + temp_back = seg_result == 2 result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) result['mask'] = result['front_mask'] + result['back_mask'] return result diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index 267ea00..9f30d0c 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -13,7 +13,6 @@ import cv2 import mmcv import numpy as np import torch -import torch.nn.functional as F import tritonclient.http as httpclient from app.core.config import * @@ -85,6 +84,9 @@ def seg_preprocess(img_path): if ori_shape != (img_scale_w, img_scale_h): # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) + + # 扩充25的白边 + img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape @@ -114,9 +116,9 @@ def get_seg_result(image_id, image): # no cache def seg_postprocess(image_id, output, ori_shape): - seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) - seg_pred = seg_logit.cpu().numpy() - return seg_pred[0] + seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50)) + seg_logit = seg_logit[25: - 25, 25: - 25] + return seg_logit def key_point_show(image_path, key_point_result=None): From 238b3dc7af4d87416cbbea27d8169ce3d75ed3fe Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 13 Jan 2025 15:58:58 +0800 Subject: [PATCH 34/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20desi?= =?UTF-8?q?gn=20=E5=88=86=E5=89=B2=E9=A2=84=E5=A4=84=E7=90=86=E6=96=B0?= =?UTF-8?q?=E5=A2=9E25padding=EF=BC=8C=E5=90=8E=E5=A4=84=E7=90=86=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=8F=92=E5=80=BC=E5=A4=84=E7=90=86=20docs=EF=BC=88?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor?= =?UTF-8?q?=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/utils/design_ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index 9f30d0c..bfc50c6 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -87,7 +87,7 @@ def seg_preprocess(img_path): # 扩充25的白边 img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) - # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape From f21b7203bb1a7e7c9a6ca2280413a75f3dd8b756 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 20 Jan 2025 11:26:42 +0800 Subject: [PATCH 35/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20?= =?UTF-8?q?=E5=85=B3=E9=94=AE=E7=82=B9=E6=A8=A1=E5=9E=8B=E9=A2=84=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=90=8E=E5=A4=84=E7=90=86=E5=A2=9E=E5=8A=A0=E7=99=BD?= =?UTF-8?q?=E8=BE=B9=E6=A1=86=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98?= =?UTF-8?q?=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84?= =?UTF-8?q?=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/utils/design_ensemble.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index bfc50c6..8eef4f2 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -25,6 +25,7 @@ from app.core.config import * def keypoint_preprocess(img_path): img = mmcv.imread(img_path) + img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) img_scale = (256, 256) h, w = img.shape[:2] img = cv2.resize(img, img_scale) @@ -62,7 +63,11 @@ def keypoint_postprocess(output, scale_factor): scale_matrix = np.diag(scale_factor) nan = np.isinf(scale_matrix) scale_matrix[nan] = 0 - return np.ceil(np.dot(segment_result, scale_matrix) * 4) + # 应用缩放因子 + scaled_result = np.ceil(np.dot(segment_result, scale_matrix) * 4) + # 补偿边框偏移 + compensated_result = scaled_result - 25 + return compensated_result """ From 07e09dc99d58ce9d8c816561da9d166f39e5237e Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Feb 2025 14:33:06 +0800 Subject: [PATCH 36/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 246 +++++++++++++++--- 1 file changed, 215 insertions(+), 31 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 1c20a13..681e2b5 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -1,4 +1,196 @@ -#!/usr/bin/env python +# #!/usr/bin/env python +# # -*- coding: UTF-8 -*- +# """ +# @Project :trinity_client +# @File :service_att_recognition.py +# @Author :周成融 +# @Date :2023/7/26 12:01:05 +# @detail : +# """ +# import json +# import logging +# import time +# +# import cv2 +# import numpy as np +# import redis +# import tritonclient.grpc as grpcclient +# from PIL import Image +# from tritonclient.utils import np_to_triton_dtype +# +# from app.core.config import * +# from app.schemas.generate_image import GenerateProductImageModel +# from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +# from app.service.utils.oss_client import oss_get_image +# +# logger = logging.getLogger() +# +# +# class GenerateProductImage: +# def __init__(self, request_data): +# if DEBUG is False: +# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) +# self.channel = self.connection.channel() +# # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) +# # self.channel = self.connection.channel() +# # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) +# self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) +# self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) +# self.category = "product_image" +# self.image_strength = request_data.image_strength +# self.batch_size = 1 +# self.product_type = request_data.product_type +# self.prompt = request_data.prompt +# self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) +# self.tasks_id = request_data.tasks_id +# self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] +# self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# self.redis_client.expire(self.tasks_id, 600) +# +# def callback(self, result, error): +# if error: +# self.gen_product_data['status'] = "FAILURE" +# self.gen_product_data['message'] = str(error) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# else: +# # pil图像转成numpy数组 +# image = result.as_numpy("generated_inpaint_image") +# image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) +# cropped_image = post_processing_image(image_result, self.left, self.top) +# image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") +# self.gen_product_data['status'] = "SUCCESS" +# self.gen_product_data['message'] = "success" +# self.gen_product_data['image_url'] = str(image_url) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# +# def read_tasks_status(self): +# status_data = self.redis_client.get(self.tasks_id) +# return json.loads(status_data), status_data +# +# def get_result(self): +# try: +# prompts = [self.prompt] * self.batch_size +# self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) +# self.image = cv2.resize(self.image, (1024, 1024)) +# images = [self.image.astype(np.uint8)] * self.batch_size +# +# if self.product_type == "single": +# text_obj = np.array(prompts, dtype="object").reshape(-1, 1) +# image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) +# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) +# else: +# text_obj = np.array(prompts, dtype="object").reshape((1)) +# image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) +# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) +# +# # 假设 prompts、images 和 self.image_strength 已经定义 +# +# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) +# input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") +# input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) +# +# input_text.set_data_from_numpy(text_obj) +# input_image.set_data_from_numpy(image_obj) +# input_image_strength.set_data_from_numpy(image_strength_obj) +# +# inputs = [input_text, input_image, input_image_strength] +# +# if self.product_type == "single": +# ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) +# else: +# ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) +# +# time_out = 600 +# while time_out > 0: +# gen_product_data, _ = self.read_tasks_status() +# if gen_product_data['status'] in ["REVOKED", "FAILURE"]: +# ctx.cancel() +# break +# elif gen_product_data['status'] == "SUCCESS": +# break +# time_out -= 1 +# time.sleep(0.1) +# gen_product_data, _ = self.read_tasks_status() +# return gen_product_data +# except Exception as e: +# self.gen_product_data['status'] = "FAILURE" +# self.gen_product_data['message'] = str(e) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# raise Exception(str(e)) +# finally: +# dict_gen_product_data, str_gen_product_data = self.read_tasks_status() +# if DEBUG is False: +# self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) +# logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") +# +# +# def infer_cancel(tasks_id): +# redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) +# data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} +# gen_product_data = json.dumps(data) +# redis_client.set(tasks_id, gen_product_data) +# return data +# +# +# def pre_processing_image(image_url): +# image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") +# # resize 原图至1024*1024 +# image = image.resize((int(1024 / image.height * image.width), 1024)) +# +# # 原始图片的尺寸 +# width, height = image.size +# +# new_height, new_width = 1024, 1024 +# # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 +# pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) +# +# # 将原始图片粘贴到新的画布中心 +# left = (new_width - width) // 2 +# top = (new_height - height) // 2 +# pad_image.paste(image, (left, top)) +# +# # 将画布 resize 成宽度 1024,长度 1024 +# resized_image = pad_image.resize((1024, 1024)) +# image_size = (1024, 1024) +# +# if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): +# # 创建白色背景 +# background = Image.new("RGB", image_size, (255, 255, 255)) +# # 将图片粘贴到白色背景上 +# background.paste(resized_image, mask=resized_image.split()[3]) +# image = np.array(background) +# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +# return image, image_size, left, top +# +# +# def post_processing_image(image, left, top): +# resized_image = image.resize((int(image.width * (768 / image.height)), 768)) +# # 计算裁剪的坐标 +# left = (resized_image.width - 512) // 2 +# upper = 0 +# right = left + 512 +# lower = 768 +# +# # 进行裁剪 +# cropped_image = resized_image.crop((left, upper, right, lower)) +# return cropped_image +# +# +# if __name__ == '__main__': +# rd = GenerateProductImageModel( +# tasks_id="123-89", +# # prompt="", +# image_strength=0.7, +# prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", +# image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", +# product_type="overall" +# ) +# server = GenerateProductImage(rd) +# print(server.get_result()) + +# 旧版product +# !/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @@ -34,14 +226,14 @@ class GenerateProductImage: # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) + self.grpc_client = grpcclient.InferenceServerClient(url="10.1.1.243:18001") self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" self.image_strength = request_data.image_strength self.batch_size = 1 self.product_type = request_data.product_type self.prompt = request_data.prompt - self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) + self.image = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} @@ -56,9 +248,9 @@ class GenerateProductImage: else: # pil图像转成numpy数组 image = result.as_numpy("generated_inpaint_image") - image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) - cropped_image = post_processing_image(image_result, self.left, self.top) - image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + # cropped_image = post_processing_image(image_result, self.left, self.top) + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -72,7 +264,7 @@ class GenerateProductImage: try: prompts = [self.prompt] * self.batch_size self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) - self.image = cv2.resize(self.image, (1024, 1024)) + # self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size if self.product_type == "single": @@ -81,7 +273,7 @@ class GenerateProductImage: image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) else: text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) # 假设 prompts、images 和 self.image_strength 已经定义 @@ -99,7 +291,7 @@ class GenerateProductImage: if self.product_type == "single": ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) else: - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: @@ -135,33 +327,25 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") - # resize 原图至1024*1024 - image = image.resize((int(1024 / image.height * image.width), 1024)) - - # 原始图片的尺寸 + # 调整图片高度为768像素,保持宽高比 width, height = image.size + new_height = 768 + new_width = int(width * (new_height / height)) + resized_image = image.resize((new_width, new_height)) - new_height, new_width = 1024, 1024 - # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 - pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) + # 创建一个512x768的透明图片 + result_image = Image.new("RGBA", (512, 768), (0, 0, 0, 0)) - # 将原始图片粘贴到新的画布中心 - left = (new_width - width) // 2 - top = (new_height - height) // 2 - pad_image.paste(image, (left, top)) + # 计算需要粘贴的位置,使图片居中 + x_offset = (512 - new_width) // 2 + y_offset = 0 - # 将画布 resize 成宽度 1024,长度 1024 - resized_image = pad_image.resize((1024, 1024)) - image_size = (1024, 1024) + # 将调整大小后的图片粘贴到透明图片上 + result_image.paste(resized_image, (x_offset, y_offset)) - if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): - # 创建白色背景 - background = Image.new("RGB", image_size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(resized_image, mask=resized_image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image, image_size, left, top + image = np.array(result_image) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return image def post_processing_image(image, left, top): From e1c00bbd669a00028003e0f0388786714399ae64 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Feb 2025 19:06:49 +0800 Subject: [PATCH 37/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_image/service_generate_product_image.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 681e2b5..0507953 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -334,17 +334,18 @@ def pre_processing_image(image_url): resized_image = image.resize((new_width, new_height)) # 创建一个512x768的透明图片 - result_image = Image.new("RGBA", (512, 768), (0, 0, 0, 0)) + result_image = Image.new("RGBA", (512, 768), (255, 255, 255, 255)) # 计算需要粘贴的位置,使图片居中 x_offset = (512 - new_width) // 2 y_offset = 0 # 将调整大小后的图片粘贴到透明图片上 - result_image.paste(resized_image, (x_offset, y_offset)) + result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3]) image = np.array(result_image) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) return image @@ -366,8 +367,8 @@ if __name__ == '__main__': tasks_id="123-89", # prompt="", image_strength=0.7, - prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", - image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", + prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR", + image_url="aida-users/11633/toProductImageElement/46166c36-c584-4e0f-b9fe-50615ec03ef3.png", product_type="overall" ) server = GenerateProductImage(rd) From 74008c4586d6d486584d4f409f86672470bbb970 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Feb 2025 19:13:48 +0800 Subject: [PATCH 38/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service_generate_product_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 0507953..22f7306 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -263,7 +263,8 @@ class GenerateProductImage: def get_result(self): try: prompts = [self.prompt] * self.batch_size - self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + + self.image = cv2.cvtColor(self.image, cv2.COLOR_RGBA2RGB) # self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size From 48ae1cfb75ec5329dc22c0aafc46c2fcf08f799c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 3 Feb 2025 19:36:56 +0800 Subject: [PATCH 39/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 22f7306..287a983 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -247,7 +247,10 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") + if self.product_type == "single": + image = result.as_numpy("generated_cnet_image") + else: + image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) # cropped_image = post_processing_image(image_result, self.left, self.top) image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -269,9 +272,9 @@ class GenerateProductImage: images = [self.image.astype(np.uint8)] * self.batch_size if self.product_type == "single": - text_obj = np.array(prompts, dtype="object").reshape(-1, 1) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1)) else: text_obj = np.array(prompts, dtype="object").reshape((1)) image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) @@ -290,7 +293,7 @@ class GenerateProductImage: inputs = [input_text, input_image, input_image_strength] if self.product_type == "single": - ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name="stable_diffusion_1_5_cnet", inputs=inputs, callback=self.callback) else: ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback) @@ -369,8 +372,8 @@ if __name__ == '__main__': # prompt="", image_strength=0.7, prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR", - image_url="aida-users/11633/toProductImageElement/46166c36-c584-4e0f-b9fe-50615ec03ef3.png", - product_type="overall" + image_url="aida-results/result_40c7924e-e220-11ef-8ea2-0242ac150003.png", + product_type="single" ) server = GenerateProductImage(rd) print(server.get_result()) From 092de8c43d02ef4580eab861869a55856a0bfc05 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 4 Feb 2025 09:54:38 +0800 Subject: [PATCH 40/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/config.py | 8 +++++++- .../generate_image/service_generate_product_image.py | 6 +++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 7456912..6a4ad23 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -129,8 +129,14 @@ GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") # Generate Product service config +# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") +# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all' +# GPI_MODEL_URL = '10.1.1.243:10051' + +# Generate Product service config 旧版product img 模型 GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") -GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all' +GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' +GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' GPI_MODEL_URL = '10.1.1.243:10051' # Generate Single Logo service config diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 287a983..3663643 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -226,7 +226,7 @@ class GenerateProductImage: # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - self.grpc_client = grpcclient.InferenceServerClient(url="10.1.1.243:18001") + self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" self.image_strength = request_data.image_strength @@ -293,9 +293,9 @@ class GenerateProductImage: inputs = [input_text, input_image, input_image_strength] if self.product_type == "single": - ctx = self.grpc_client.async_infer(model_name="stable_diffusion_1_5_cnet", inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) else: - ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: From 3a0c730e9cfb7987431562153379c57459be2e1f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Tue, 4 Feb 2025 10:05:15 +0800 Subject: [PATCH 41/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20=E6=97=A7=E7=89=88product=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 3663643..a575f07 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -331,18 +331,33 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") - # 调整图片高度为768像素,保持宽高比 - width, height = image.size - new_height = 768 - new_width = int(width * (new_height / height)) + # 目标图片的尺寸 + target_width = 512 + target_height = 768 + + # 原始图片的尺寸 + original_width, original_height = image.size + + # 计算宽度和高度的缩放比例 + width_ratio = target_width / original_width + height_ratio = target_height / original_height + + # 选择较小的缩放比例,确保图片能完整放入目标图片中 + scale_ratio = min(width_ratio, height_ratio) + + # 计算调整后的尺寸 + new_width = int(original_width * scale_ratio) + new_height = int(original_height * scale_ratio) + + # 调整图片大小 resized_image = image.resize((new_width, new_height)) - # 创建一个512x768的透明图片 - result_image = Image.new("RGBA", (512, 768), (255, 255, 255, 255)) + # 创建一个 512x768 的透明图片 + result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0)) # 计算需要粘贴的位置,使图片居中 - x_offset = (512 - new_width) // 2 - y_offset = 0 + x_offset = (target_width - new_width) // 2 + y_offset = (target_height - new_height) // 2 # 将调整大小后的图片粘贴到透明图片上 result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3]) @@ -371,8 +386,8 @@ if __name__ == '__main__': tasks_id="123-89", # prompt="", image_strength=0.7, - prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR", - image_url="aida-results/result_40c7924e-e220-11ef-8ea2-0242ac150003.png", + prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR", + image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png", product_type="single" ) server = GenerateProductImage(rd) From b3ba9c13eeb6818a09b7e0ecc5212827205434ba Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 5 Feb 2025 17:12:28 +0800 Subject: [PATCH 42/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20p?= =?UTF-8?q?rint=20overall=20=E6=97=8B=E8=BD=AC=E6=8A=A5=E9=94=99=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/print_painting.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/app/service/design_fast/pipeline/print_painting.py b/app/service/design_fast/pipeline/print_painting.py index 6fe40d8..f03cdb1 100644 --- a/app/service/design_fast/pipeline/print_painting.py +++ b/app/service/design_fast/pipeline/print_painting.py @@ -442,8 +442,11 @@ class PrintPainting: angle: 旋转的角度 crop: 是否需要进行裁剪,布尔向量 """ + if not isinstance(crop, bool): + raise ValueError("The 'crop' parameter must be a boolean.") + crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w] - w, h = img.shape[:2] + h, w = img.shape[:2] # 旋转角度的周期是360° angle %= 360 # 计算仿射变换矩阵 @@ -455,7 +458,7 @@ class PrintPainting: if crop: # 裁剪角度的等效周期是180° angle_crop = angle % 180 - if angle > 90: + if angle_crop > 90: angle_crop = 180 - angle_crop # 转化角度为弧度 theta = angle_crop * np.pi / 180 From 9261e2cde694540ea3e4f9af07e9f42da4d62de9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 6 Feb 2025 16:22:43 +0800 Subject: [PATCH 43/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20p?= =?UTF-8?q?rint=20overall=20=E6=97=8B=E8=BD=AC=E6=8A=A5=E9=94=99=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20te?= =?UTF-8?q?st(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/print_painting.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/app/service/design_fast/pipeline/print_painting.py b/app/service/design_fast/pipeline/print_painting.py index f03cdb1..6fe40d8 100644 --- a/app/service/design_fast/pipeline/print_painting.py +++ b/app/service/design_fast/pipeline/print_painting.py @@ -442,11 +442,8 @@ class PrintPainting: angle: 旋转的角度 crop: 是否需要进行裁剪,布尔向量 """ - if not isinstance(crop, bool): - raise ValueError("The 'crop' parameter must be a boolean.") - crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w] - h, w = img.shape[:2] + w, h = img.shape[:2] # 旋转角度的周期是360° angle %= 360 # 计算仿射变换矩阵 @@ -458,7 +455,7 @@ class PrintPainting: if crop: # 裁剪角度的等效周期是180° angle_crop = angle % 180 - if angle_crop > 90: + if angle > 90: angle_crop = 180 - angle_crop # 转化角度为弧度 theta = angle_crop * np.pi / 180 From bba91b46719b8ebd39680584912d46cd3e2019a9 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 6 Feb 2025 17:42:02 +0800 Subject: [PATCH 44/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20p?= =?UTF-8?q?rint=20=E4=BA=8E=20sketch=E6=8B=89=E4=BC=B8=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E7=9A=84print=E6=AF=94=E4=BE=8B=E4=B8=8D=E6=AD=A3=E7=A1=AE?= =?UTF-8?q?=E9=97=AE=E9=A2=98=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98?= =?UTF-8?q?=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88=E9=87=8D=E6=9E=84?= =?UTF-8?q?=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/loading.py | 15 +++++++-- .../design_fast/pipeline/segmentation.py | 33 ++++++++++++++++--- app/service/design_fast/pipeline/split.py | 4 +-- app/service/utils/new_oss_client.py | 2 +- 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/app/service/design_fast/pipeline/loading.py b/app/service/design_fast/pipeline/loading.py index 5a55d9d..85d1fb1 100644 --- a/app/service/design_fast/pipeline/loading.py +++ b/app/service/design_fast/pipeline/loading.py @@ -1,9 +1,6 @@ -import io import logging import cv2 -import numpy as np -from PIL import Image from app.service.utils.new_oss_client import oss_get_image @@ -38,6 +35,18 @@ class LoadImage: def __call__(self, result): result['image'], result['pre_mask'] = self.read_image(result['path']) + + # 判断是否resize sketch 保留ori image 用于模型输入 + result['ori_image'] = result['image'] + if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: + height, width = result['image'].shape[:2] + # 计算新的宽度和高度 + new_width = int(width * result['resize_scale'][0]) + new_height = int(height * result['resize_scale'][1]) + # 使用cv2.resize()函数进行缩放 + result['image'] = cv2.resize(result['image'], (new_width, new_height)) + if result['pre_mask'] is not None: + result['pre_mask'] = cv2.resize(result['pre_mask'], (new_width, new_height)) result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) result['keypoint'] = self.get_keypoint(result['name']) result['img_shape'] = result['image'].shape diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 0c9c51e..2ad1a57 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -36,12 +36,27 @@ class Segmentation: # preview 过模型 不缓存 if "preview_submit" in result.keys() and result['preview_submit'] == "preview": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image']) + seg_result = get_seg_result(result["image_id"], result['ori_image']) + if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: + height, width = seg_result.shape[:2] + # 计算新的宽度和高度 + new_width = int(width * result['resize_scale'][0]) + new_height = int(height * result['resize_scale'][1]) + # 使用cv2.resize()函数进行缩放 + seg_result = cv2.resize(seg_result, (new_width, new_height)) # submit 过模型 缓存 elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image']) - self.save_seg_result(seg_result, result['image_id']) + seg_result = get_seg_result(result["image_id"], result['ori_image']) + seg_result_save = seg_result + if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: + height, width = seg_result.shape[:2] + # 计算新的宽度和高度 + new_width = int(width * result['resize_scale'][0]) + new_height = int(height * result['resize_scale'][1]) + # 使用cv2.resize()函数进行缩放 + seg_result = cv2.resize(seg_result, (new_width, new_height)) + self.save_seg_result(seg_result_save, result['image_id']) # null 正常流程 加载本地缓存 无缓存则过模型 else: # 本地查询seg 缓存是否存在 @@ -49,8 +64,16 @@ class Segmentation: # 判断缓存和实际图片size是否相同 if not _ or result["image"].shape[:2] != seg_result.shape: # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image']) - self.save_seg_result(seg_result, result['image_id']) + seg_result = get_seg_result(result["image_id"], result['ori_image']) + seg_result_save = seg_result + if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: + height, width = seg_result.shape[:2] + # 计算新的宽度和高度 + new_width = int(width * result['resize_scale'][0]) + new_height = int(height * result['resize_scale'][1]) + # 使用cv2.resize()函数进行缩放 + seg_result = cv2.resize(seg_result, (new_width, new_height)) + self.save_seg_result(seg_result_save, result['image_id']) result['seg_result'] = seg_result # 处理前片后片 diff --git a/app/service/design_fast/pipeline/split.py b/app/service/design_fast/pipeline/split.py index 344c5c5..f3374a1 100644 --- a/app/service/design_fast/pipeline/split.py +++ b/app/service/design_fast/pipeline/split.py @@ -21,11 +21,11 @@ class Split(object): def __call__(self, result): try: - if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms','accessories'): + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'): front_mask = result['front_mask'] back_mask = result['back_mask'] rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) - new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1])) + new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"])) rgba_image = cv2.resize(rgba_image, new_size) result_front_image = np.zeros_like(rgba_image) front_mask = cv2.resize(front_mask, new_size) diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 7939333..cf6f861 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/89/123-89.png" + url = "aida-results/result_40e527bf-e46d-11ef-813d-0826ae3ad6b3.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From eed0ac13a1547f25c74b44a9a213e26b3c39a107 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 6 Feb 2025 17:42:24 +0800 Subject: [PATCH 45/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20s?= =?UTF-8?q?ketch=E6=8B=89=E4=BC=B8=E5=AF=BC=E8=87=B4=E7=9A=84print?= =?UTF-8?q?=E6=AF=94=E4=BE=8B=E4=B8=8D=E6=AD=A3=E7=A1=AE=E9=97=AE=E9=A2=98?= =?UTF-8?q?=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:?= =?UTF-8?q?=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/print_painting.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/app/service/design_fast/pipeline/print_painting.py b/app/service/design_fast/pipeline/print_painting.py index 6fe40d8..f03cdb1 100644 --- a/app/service/design_fast/pipeline/print_painting.py +++ b/app/service/design_fast/pipeline/print_painting.py @@ -442,8 +442,11 @@ class PrintPainting: angle: 旋转的角度 crop: 是否需要进行裁剪,布尔向量 """ + if not isinstance(crop, bool): + raise ValueError("The 'crop' parameter must be a boolean.") + crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w] - w, h = img.shape[:2] + h, w = img.shape[:2] # 旋转角度的周期是360° angle %= 360 # 计算仿射变换矩阵 @@ -455,7 +458,7 @@ class PrintPainting: if crop: # 裁剪角度的等效周期是180° angle_crop = angle % 180 - if angle > 90: + if angle_crop > 90: angle_crop = 180 - angle_crop # 转化角度为弧度 theta = angle_crop * np.pi / 180 From f4225bcb2c78558f2a9ee6b9ff737cfe9a3ef422 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 6 Feb 2025 18:37:16 +0800 Subject: [PATCH 46/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20s?= =?UTF-8?q?ketch=E6=8B=89=E4=BC=B8=E5=AF=BC=E8=87=B4=E7=9A=84print?= =?UTF-8?q?=E6=AF=94=E4=BE=8B=E4=B8=8D=E6=AD=A3=E7=A1=AE=E9=97=AE=E9=A2=98?= =?UTF-8?q?=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:?= =?UTF-8?q?=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../design_fast/pipeline/print_painting.py | 18 ++++++++++++++++++ app/service/design_fast/pipeline/split.py | 18 ++++++++++++++---- app/service/utils/new_oss_client.py | 2 +- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/app/service/design_fast/pipeline/print_painting.py b/app/service/design_fast/pipeline/print_painting.py index 6fe40d8..7120e69 100644 --- a/app/service/design_fast/pipeline/print_painting.py +++ b/app/service/design_fast/pipeline/print_painting.py @@ -17,6 +17,24 @@ class PrintPainting: element_print = result['print']['element'] result['single_image'] = None result['print_image'] = None + # TODO 给result['pattern_image'] resize 到resize_scale的大小 + # TODO 给result['mask'] resize 到resize_scale的大小 + + if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0: + pass + else: + height, width = result['pattern_image'].shape[:2] + new_width = int(width * result['resize_scale'][0]) + new_height = int(height * result['resize_scale'][1]) + + result['pattern_image'] = cv2.resize(result['pattern_image'], (new_width, new_height)) + result['mask'] = cv2.resize(result['mask'], (new_width, new_height)) + result['gray'] = cv2.resize(result['gray'], (new_width, new_height)) + + + + + print(1) if overall_print['print_path_list']: painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} result['print_image'] = result['pattern_image'] diff --git a/app/service/design_fast/pipeline/split.py b/app/service/design_fast/pipeline/split.py index 344c5c5..115f814 100644 --- a/app/service/design_fast/pipeline/split.py +++ b/app/service/design_fast/pipeline/split.py @@ -21,11 +21,21 @@ class Split(object): def __call__(self, result): try: - if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms','accessories'): - front_mask = result['front_mask'] - back_mask = result['back_mask'] + if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'): + + if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0: + front_mask = result['front_mask'] + back_mask = result['back_mask'] + else: + height, width = result['front_mask'].shape[:2] + new_width = int(width * result['resize_scale'][0]) + new_height = int(height * result['resize_scale'][1]) + + front_mask = cv2.resize(result['front_mask'], (new_width, new_height)) + back_mask = cv2.resize(result['back_mask'], (new_width, new_height)) + rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) - new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1])) + new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"])) rgba_image = cv2.resize(rgba_image, new_size) result_front_image = np.zeros_like(rgba_image) front_mask = cv2.resize(front_mask, new_size) diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 7939333..92b41fa 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/89/123-89.png" + url = "aida-results/result_4185cc1c-e476-11ef-b8e1-0826ae3ad6b3.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2" From 0fe0ca41838f0c74fa2168b40aedd52cac09818c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 6 Feb 2025 18:39:05 +0800 Subject: [PATCH 47/52] =?UTF-8?q?Revert=20"feat=EF=BC=88=E6=96=B0=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug?= =?UTF-8?q?=EF=BC=89:=20=20print=20=E4=BA=8E=20sketch=E6=8B=89=E4=BC=B8?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=E7=9A=84print=E6=AF=94=E4=BE=8B=E4=B8=8D?= =?UTF-8?q?=E6=AD=A3=E7=A1=AE=E9=97=AE=E9=A2=98=20docs=EF=BC=88=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refactor=EF=BC=88?= =?UTF-8?q?=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E=E5=8A=A0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95):"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit bba91b46 --- app/service/design_fast/pipeline/loading.py | 15 ++------- .../design_fast/pipeline/segmentation.py | 33 +++---------------- 2 files changed, 8 insertions(+), 40 deletions(-) diff --git a/app/service/design_fast/pipeline/loading.py b/app/service/design_fast/pipeline/loading.py index 85d1fb1..5a55d9d 100644 --- a/app/service/design_fast/pipeline/loading.py +++ b/app/service/design_fast/pipeline/loading.py @@ -1,6 +1,9 @@ +import io import logging import cv2 +import numpy as np +from PIL import Image from app.service.utils.new_oss_client import oss_get_image @@ -35,18 +38,6 @@ class LoadImage: def __call__(self, result): result['image'], result['pre_mask'] = self.read_image(result['path']) - - # 判断是否resize sketch 保留ori image 用于模型输入 - result['ori_image'] = result['image'] - if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: - height, width = result['image'].shape[:2] - # 计算新的宽度和高度 - new_width = int(width * result['resize_scale'][0]) - new_height = int(height * result['resize_scale'][1]) - # 使用cv2.resize()函数进行缩放 - result['image'] = cv2.resize(result['image'], (new_width, new_height)) - if result['pre_mask'] is not None: - result['pre_mask'] = cv2.resize(result['pre_mask'], (new_width, new_height)) result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY) result['keypoint'] = self.get_keypoint(result['name']) result['img_shape'] = result['image'].shape diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 2ad1a57..0c9c51e 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -36,27 +36,12 @@ class Segmentation: # preview 过模型 不缓存 if "preview_submit" in result.keys() and result['preview_submit'] == "preview": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['ori_image']) - if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: - height, width = seg_result.shape[:2] - # 计算新的宽度和高度 - new_width = int(width * result['resize_scale'][0]) - new_height = int(height * result['resize_scale'][1]) - # 使用cv2.resize()函数进行缩放 - seg_result = cv2.resize(seg_result, (new_width, new_height)) + seg_result = get_seg_result(result["image_id"], result['image']) # submit 过模型 缓存 elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['ori_image']) - seg_result_save = seg_result - if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: - height, width = seg_result.shape[:2] - # 计算新的宽度和高度 - new_width = int(width * result['resize_scale'][0]) - new_height = int(height * result['resize_scale'][1]) - # 使用cv2.resize()函数进行缩放 - seg_result = cv2.resize(seg_result, (new_width, new_height)) - self.save_seg_result(seg_result_save, result['image_id']) + seg_result = get_seg_result(result["image_id"], result['image']) + self.save_seg_result(seg_result, result['image_id']) # null 正常流程 加载本地缓存 无缓存则过模型 else: # 本地查询seg 缓存是否存在 @@ -64,16 +49,8 @@ class Segmentation: # 判断缓存和实际图片size是否相同 if not _ or result["image"].shape[:2] != seg_result.shape: # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['ori_image']) - seg_result_save = seg_result - if result['resize_scale'][0] != 0 and result['resize_scale'][1] != 0: - height, width = seg_result.shape[:2] - # 计算新的宽度和高度 - new_width = int(width * result['resize_scale'][0]) - new_height = int(height * result['resize_scale'][1]) - # 使用cv2.resize()函数进行缩放 - seg_result = cv2.resize(seg_result, (new_width, new_height)) - self.save_seg_result(seg_result_save, result['image_id']) + seg_result = get_seg_result(result["image_id"], result['image']) + self.save_seg_result(seg_result, result['image_id']) result['seg_result'] = seg_result # 处理前片后片 From 29d9e02d3e4718a05ab2f62b5b73987f1fc5595f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 6 Feb 2025 18:39:46 +0800 Subject: [PATCH 48/52] =?UTF-8?q?Revert=20"feat=EF=BC=88=E6=96=B0=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=89:"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit b3ba9c13eeb6818a09b7e0ecc5212827205434ba. --- app/service/design_fast/pipeline/print_painting.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/app/service/design_fast/pipeline/print_painting.py b/app/service/design_fast/pipeline/print_painting.py index 9e8e1dc..7120e69 100644 --- a/app/service/design_fast/pipeline/print_painting.py +++ b/app/service/design_fast/pipeline/print_painting.py @@ -460,11 +460,8 @@ class PrintPainting: angle: 旋转的角度 crop: 是否需要进行裁剪,布尔向量 """ - if not isinstance(crop, bool): - raise ValueError("The 'crop' parameter must be a boolean.") - crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w] - h, w = img.shape[:2] + w, h = img.shape[:2] # 旋转角度的周期是360° angle %= 360 # 计算仿射变换矩阵 @@ -476,7 +473,7 @@ class PrintPainting: if crop: # 裁剪角度的等效周期是180° angle_crop = angle % 180 - if angle_crop > 90: + if angle > 90: angle_crop = 180 - angle_crop # 转化角度为弧度 theta = angle_crop * np.pi / 180 From dd9f091a1c7403c8ca8187d8b7bb683fe8482cea Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 6 Feb 2025 18:47:00 +0800 Subject: [PATCH 49/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20s?= =?UTF-8?q?ketch=E6=8B=89=E4=BC=B8=E5=AF=BC=E8=87=B4=E7=9A=84print?= =?UTF-8?q?=E6=AF=94=E4=BE=8B=E4=B8=8D=E6=AD=A3=E7=A1=AE=E9=97=AE=E9=A2=98?= =?UTF-8?q?=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:?= =?UTF-8?q?=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/design_fast/pipeline/print_painting.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/app/service/design_fast/pipeline/print_painting.py b/app/service/design_fast/pipeline/print_painting.py index 7120e69..417572f 100644 --- a/app/service/design_fast/pipeline/print_painting.py +++ b/app/service/design_fast/pipeline/print_painting.py @@ -28,12 +28,10 @@ class PrintPainting: new_height = int(height * result['resize_scale'][1]) result['pattern_image'] = cv2.resize(result['pattern_image'], (new_width, new_height)) + result['final_image'] = cv2.resize(result['final_image'], (new_width, new_height)) result['mask'] = cv2.resize(result['mask'], (new_width, new_height)) result['gray'] = cv2.resize(result['gray'], (new_width, new_height)) - - - print(1) if overall_print['print_path_list']: painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} From 5e9101e77dcd3fd3d983b768e2771ec9a90fbfaf Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 7 Feb 2025 14:21:46 +0800 Subject: [PATCH 50/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20s?= =?UTF-8?q?ketch=E6=8B=89=E4=BC=B8=E5=AF=BC=E8=87=B4=E7=9A=84print?= =?UTF-8?q?=E6=AF=94=E4=BE=8B=E4=B8=8D=E6=AD=A3=E7=A1=AE=E9=97=AE=E9=A2=98?= =?UTF-8?q?=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:?= =?UTF-8?q?=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../design_fast/pipeline/print_painting.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/app/service/design_fast/pipeline/print_painting.py b/app/service/design_fast/pipeline/print_painting.py index 417572f..fd5f910 100644 --- a/app/service/design_fast/pipeline/print_painting.py +++ b/app/service/design_fast/pipeline/print_painting.py @@ -53,9 +53,12 @@ class PrintPainting: print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) for i in range(len(single_print['print_path_list'])): + if not (result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0): + single_print['location'][i] = (int(single_print['location'][i][0] * result['resize_scale'][0]), int(single_print['location'][i][1] * result['resize_scale'][1])) + image, image_mode = self.read_image(single_print['print_path_list'][i]) if image_mode == "RGBA": - new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i])) + new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1])) mask = image.split()[3] resized_source = image.resize(new_size) @@ -78,9 +81,12 @@ class PrintPainting: mask = np.expand_dims(mask, axis=2) mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) mask = cv2.bitwise_not(mask) + + mask = cv2.resize(mask, (int(result['final_image'].shape[0] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[1] * single_print['print_scale_list'][i][1]))) + image = cv2.resize(image, (int(result['final_image'].shape[0] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[1] * single_print['print_scale_list'][i][1]))) # 旋转后的坐标需要重新算 - rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) - rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i]) + rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i]) # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1]) @@ -157,9 +163,11 @@ class PrintPainting: print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) for i in range(len(element_print['element_path_list'])): + if not (result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0): + element_print['location'][i] = (int(element_print['location'][i][0] * result['resize_scale'][0]), int(element_print['location'][i][1] * result['resize_scale'][1])) image, image_mode = self.read_image(element_print['element_path_list'][i]) if image_mode == "RGBA": - new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i])) + new_size = (int(result['final_image'].shape[1] * element_print['element_scale_list'][i][0]), int(result['final_image'].shape[0] * element_print['element_scale_list'][i][1])) mask = image.split()[3] resized_source = image.resize(new_size) @@ -181,9 +189,11 @@ class PrintPainting: mask = np.expand_dims(mask, axis=2) mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) mask = cv2.bitwise_not(mask) + mask = cv2.resize(mask, (int(result['final_image'].shape[0] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[1] * single_print['print_scale_list'][i][1]))) + image = cv2.resize(image, (int(result['final_image'].shape[0] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[1] * single_print['print_scale_list'][i][1]))) # 旋转后的坐标需要重新算 - rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) - rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i]) + rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i]) + rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i]) # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2) x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1]) @@ -425,7 +435,7 @@ class PrintPainting: return high, low @staticmethod - def img_rotate(image, angel, scale): + def img_rotate(image, angel): """顺时针旋转图像任意角度 Args: @@ -440,7 +450,7 @@ class PrintPainting: center = (w // 2, h // 2) # if type(angel) is not int: # angel = 0 - M = cv2.getRotationMatrix2D(center, -angel, scale) + M = cv2.getRotationMatrix2D(center, -angel, 1) # 调整旋转后的图像长宽 rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0])))) rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0])))) @@ -449,7 +459,7 @@ class PrintPainting: # 旋转图像 rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h)) - return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2) + return rotated_img, ((rotated_img.shape[1] - image.shape[1]) // 2, (rotated_img.shape[0] - image.shape[0]) // 2) # return rotated_img, (0, 0) @staticmethod From 89e39e37ae3b0a5e05856ba51235515dfdd23c96 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 7 Feb 2025 14:46:07 +0800 Subject: [PATCH 51/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20s?= =?UTF-8?q?ketch=E6=8B=89=E4=BC=B8=E5=AF=BC=E8=87=B4=E7=9A=84print?= =?UTF-8?q?=E6=AF=94=E4=BE=8B=E4=B8=8D=E6=AD=A3=E7=A1=AE=E9=97=AE=E9=A2=98?= =?UTF-8?q?=20docs=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:?= =?UTF-8?q?=20refactor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/prompt_generation/chatgpt_for_translation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index e668500..e86a6bf 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -106,7 +106,7 @@ def get_translation_from_llama3(text): "prompt": f"[{text}]", "stream": False } - + logger.info(f"translation start ********************* {text}") # 将负载转换为 JSON 格式 headers = {'Content-Type': 'application/json'} response = requests.post(url, data=json.dumps(payload), headers=headers) From 2e4c43178d74c3551e56082b28ed86de5d73b768 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 7 Feb 2025 15:04:59 +0800 Subject: [PATCH 52/52] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20=20?= =?UTF-8?q?=E8=AF=AD=E8=A8=80=E5=88=A4=E6=96=AD=E9=83=A8=E5=88=86=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=20=E5=85=B3=E5=81=9C=E8=AF=A5=E9=80=BB=E8=BE=91=20doc?= =?UTF-8?q?s=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refa?= =?UTF-8?q?ctor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/prompt_generation/chatgpt_for_translation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index e86a6bf..833eda0 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -95,10 +95,10 @@ def get_translation_from_llama3(text): # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" # 先获取用户输入文本的语言 - language = get_language(text) + # language = get_language(text) - if 'English' in language: - return text + # if 'English' in language: + # return text # 创建请求的负载 translator是自定义的翻译模型 payload = { @@ -106,7 +106,6 @@ def get_translation_from_llama3(text): "prompt": f"[{text}]", "stream": False } - logger.info(f"translation start ********************* {text}") # 将负载转换为 JSON 格式 headers = {'Content-Type': 'application/json'} response = requests.post(url, data=json.dumps(payload), headers=headers)