diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 1bc1c91..dac211c 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -40,16 +40,17 @@ class GenerateImage: if request_data.mode == "img2img": # cv2 读图片是BGR PIL读图片是RGB self.image = self.get_image(request_data.image_url) - self.prompt = request_data.prompt else: self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - self.prompt = request_data.prompt + self.prompt = request_data.prompt self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category + if self.category == "sketch": + self.prompt = f"{self.category},{self.prompt}" self.index = 0 self.gender = request_data.gender self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''}