66 lines
2.4 KiB
Python
66 lines
2.4 KiB
Python
import io
|
|
import torch
|
|
import litserve as ls
|
|
from diffusers import Flux2KleinPipeline
|
|
from minio import Minio
|
|
|
|
from app.config.config import settings
|
|
from app.server.utils.minio_client import oss_get_image, oss_upload_image
|
|
|
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
|
|
|
|
|
class Flux2KleinServer(ls.LitAPI):
|
|
def setup(self, device):
|
|
# Load the model
|
|
dtype = torch.bfloat16
|
|
self.device = device
|
|
self.model = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-4B", torch_dtype=dtype, is_distilled=False)
|
|
self.model.to(device) # save some VRAM by offloading the model to CPU
|
|
|
|
def decode_request(self, request):
|
|
return request
|
|
|
|
def predict(self, request_data):
|
|
image_path = request_data.get("image_path", "")
|
|
prompt = request_data.get("prompt", "")
|
|
height = request_data.get("height", 768)
|
|
width = request_data.get("width", 512)
|
|
infer_step = request_data.get("infer_step", 4)
|
|
tasks_id = request_data.get("tasks_id", "test")
|
|
|
|
user_id = tasks_id[tasks_id.rfind('-') + 1:]
|
|
input_image = oss_get_image(oss_client=minio_client, path=image_path, data_type='pil')
|
|
with torch.no_grad():
|
|
images = self.model(
|
|
image=input_image,
|
|
prompt=prompt,
|
|
height=height,
|
|
width=width,
|
|
guidance_scale=1.0,
|
|
num_inference_steps=infer_step,
|
|
# generator=torch.Generator(device='cuda').manual_seed(3)
|
|
)[0]
|
|
|
|
# save image to minio
|
|
image = images[0] # Assuming you want to retrieve the first image
|
|
image_data = io.BytesIO()
|
|
image.save(image_data, format='PNG')
|
|
image_data.seek(0)
|
|
image_bytes = image_data.read()
|
|
object_name = f'{user_id}/product_image/{tasks_id}.png'
|
|
req = oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
|
|
image_url = f"aida-users/{object_name}"
|
|
return image_url
|
|
|
|
def encode_response(self, image_url):
|
|
return {"image_url": image_url}
|
|
|
|
|
|
# Starting the server
|
|
if __name__ == "__main__":
|
|
# Assume that an appropriate device (e.g., 'cuda', 'cpu') is specified
|
|
api = Flux2KleinServer()
|
|
server = ls.LitServer(api)
|
|
server.run(port=8011)
|