1
This commit is contained in:
110
app/litserve_app/app.py
Executable file
110
app/litserve_app/app.py
Executable file
@@ -0,0 +1,110 @@
|
||||
import torch
|
||||
import litserve as ls
|
||||
from PIL import Image
|
||||
import io
|
||||
import base64
|
||||
from diffusers import Flux2KleinPipeline
|
||||
|
||||
|
||||
# 保持原有的辅助函数
|
||||
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
|
||||
w_str, h_str = aspect_ratio.split(":")
|
||||
w, h = float(w_str), float(h_str)
|
||||
if w >= h:
|
||||
width = base_long_edge
|
||||
height = int(round(base_long_edge * (h / w)))
|
||||
else:
|
||||
height = base_long_edge
|
||||
width = int(round(base_long_edge * (w / h)))
|
||||
width = max(64, (width // 8) * 8)
|
||||
height = max(64, (height // 8) * 8)
|
||||
return width, height
|
||||
|
||||
|
||||
class FluxKleinAPI(ls.LitAPI):
|
||||
def setup(self, device):
|
||||
# 1. 模型初始化
|
||||
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
||||
self.device = device
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
self.pipe = Flux2KleinPipeline.from_pretrained(
|
||||
self.repo_id,
|
||||
torch_dtype=self.dtype
|
||||
)
|
||||
self.pipe.to(device)
|
||||
|
||||
self.default_system_prompt = (
|
||||
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment... "
|
||||
"Realistic scale, high material fidelity, sharp focus, 4K quality."
|
||||
)
|
||||
|
||||
def decode_request(self, request):
|
||||
# 1. 获取 Base64 字符串
|
||||
image_data = request.get("image", None)
|
||||
if not image_data:
|
||||
raise ValueError("Request must contain 'image' field with base64 data.")
|
||||
|
||||
# 2. 如果带有 data:image/png;base64, 前缀,去掉它
|
||||
if "," in image_data:
|
||||
image_data = image_data.split(",")[1]
|
||||
|
||||
# 3. 解码并转换为 PIL 对象 (修复点)
|
||||
img_bytes = base64.b64decode(image_data)
|
||||
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
|
||||
return {
|
||||
"image": img,
|
||||
"prompt": request.get("prompt", ""),
|
||||
"steps": request.get("steps", 28),
|
||||
"guidance": request.get("guidance", 4.0),
|
||||
"seed": request.get("seed", 42),
|
||||
"aspect_ratio": request.get("aspect_ratio", "4:3"),
|
||||
"base_long_edge": request.get("base_long_edge", 1024)
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def predict(self, payload):
|
||||
# 3. 执行推理逻辑
|
||||
img = payload.get("image", None)
|
||||
prompt = payload.get("prompt", "")
|
||||
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
|
||||
gen = torch.Generator(device=self.device)
|
||||
|
||||
if img:
|
||||
# 这里简化了 resize 逻辑,实际可调用你原有的 resize_pad_to_target
|
||||
img = img.convert("RGB").resize((W, H), Image.Resampling.LANCZOS)
|
||||
output = self.pipe(
|
||||
image=img,
|
||||
prompt=prompt,
|
||||
height=H,
|
||||
width=W,
|
||||
guidance_scale=payload["guidance"],
|
||||
num_inference_steps=payload["steps"],
|
||||
generator=gen,
|
||||
).images[0]
|
||||
else:
|
||||
output = self.pipe(
|
||||
image=img,
|
||||
prompt=prompt,
|
||||
height=H,
|
||||
width=W,
|
||||
guidance_scale=payload["guidance"],
|
||||
num_inference_steps=payload["steps"],
|
||||
generator=gen,
|
||||
).images[0]
|
||||
return output
|
||||
|
||||
def encode_response(self, output):
|
||||
# 4. 将 PIL 图像转回 Base64 返回
|
||||
buffered = io.BytesIO()
|
||||
output.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return {"image": img_str}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动服务器
|
||||
api = FluxKleinAPI()
|
||||
server = ls.LitServer(api, accelerator="cuda", devices=1)
|
||||
server.run(port=8451)
|
||||
389
app/litserve_app/cli.py
Executable file
389
app/litserve_app/cli.py
Executable file
@@ -0,0 +1,389 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
FLUX.2 [klein] image-to-image editing with CLI args (single or batch).
|
||||
|
||||
Prereqs:
|
||||
pip install -U "diffusers>=0.30.0" transformers accelerate safetensors pillow
|
||||
# and you have CUDA set up for torch.
|
||||
|
||||
Examples:
|
||||
# Single image
|
||||
python flux2_klein_cli.py \
|
||||
--image test_sketch/sofa_2.png \
|
||||
--user_input "caramel brown leather, matte black legs" \
|
||||
--out outputs/sofa_2_to_real.png
|
||||
|
||||
# Batch: one prompt applied to all images
|
||||
python flux2_klein_cli.py \
|
||||
--image_dir test_sketch \
|
||||
--user_input "pure white background, photorealistic materials" \
|
||||
--out_dir outputs/batch
|
||||
|
||||
# Batch: a prompt per image using a prompt file (mapping or list)
|
||||
python flux2_klein_cli.py \
|
||||
--image_dir test_sketch \
|
||||
--prompt_file prompt.txt \
|
||||
--out_dir outputs/batch
|
||||
|
||||
# Optional knobs
|
||||
python flux2_klein_cli.py \
|
||||
--image test_sketch/a.png \
|
||||
--user_input "..." \
|
||||
--out outputs/a.png \
|
||||
--aspect_ratio 4:3 \
|
||||
--base_long_edge 1024 \
|
||||
--steps 28 \
|
||||
--guidance 4.0 \
|
||||
--seed 0 \
|
||||
--cpu_offload
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from diffusers import Flux2KleinPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
DEFAULT_REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
|
||||
DEFAULT_DEVICE = "cuda"
|
||||
DEFAULT_TORCH_DTYPE = "bfloat16"
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment, no interior, "
|
||||
"no floor texture, and no visible floor/wall lines. Studio product photography lighting with soft "
|
||||
"diffused light; ultra-clean industrial design presentation. Render a single furniture object only, "
|
||||
"isolated and perfectly centered on a borderless canvas. Strictly preserve the geometry, silhouette, "
|
||||
"proportions, and key structural details of the reference sketch (do not change shape, do not add/remove parts). "
|
||||
"Realistic scale, high material fidelity, sharp focus, 4K quality. No people, no props, no text, no logos, no watermark."
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Prompt parsing (optional)
|
||||
# -----------------------------
|
||||
def load_prompts(prompt_txt: Path) -> Tuple[Dict[str, str], List[str]]:
|
||||
"""
|
||||
Supports:
|
||||
A) "sofa_2.png: caramel brown leather ..." (filename: user_input)
|
||||
B) "caramel brown leather ..." (user_input only; matched by sorted image filenames)
|
||||
"""
|
||||
if not prompt_txt.exists():
|
||||
raise FileNotFoundError(f"prompt file not found: {prompt_txt}")
|
||||
|
||||
lines: List[str] = []
|
||||
for raw in prompt_txt.read_text(encoding="utf-8").splitlines():
|
||||
s = raw.strip()
|
||||
if not s or s.startswith("#"):
|
||||
continue
|
||||
lines.append(s)
|
||||
|
||||
mapping: Dict[str, str] = {}
|
||||
list_only: List[str] = []
|
||||
|
||||
looks_like_mapping = any(
|
||||
(":" in ln and ln.split(":", 1)[0].strip().lower().endswith((".png", ".jpg", ".jpeg", ".webp")))
|
||||
for ln in lines
|
||||
)
|
||||
|
||||
if looks_like_mapping:
|
||||
for ln in lines:
|
||||
if ":" not in ln:
|
||||
continue
|
||||
k, v = ln.split(":", 1)
|
||||
key = k.strip()
|
||||
val = v.strip()
|
||||
if key and val:
|
||||
mapping[key] = val
|
||||
else:
|
||||
list_only = lines
|
||||
|
||||
return mapping, list_only
|
||||
|
||||
|
||||
def list_images(image_dir: Path) -> List[Path]:
|
||||
exts = {".png", ".jpg", ".jpeg", ".webp"}
|
||||
imgs = [p for p in sorted(image_dir.iterdir()) if p.is_file() and p.suffix.lower() in exts]
|
||||
if not imgs:
|
||||
raise RuntimeError(f"No images found in: {image_dir}")
|
||||
return imgs
|
||||
|
||||
|
||||
def build_final_prompt(system_prompt: str, user_input: str) -> str:
|
||||
user_input = (user_input or "").strip()
|
||||
if not user_input:
|
||||
return system_prompt
|
||||
return f"{system_prompt} User requirements: {user_input}"
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Aspect ratio utils (resize+pad)
|
||||
# -----------------------------
|
||||
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
|
||||
w_str, h_str = aspect_ratio.split(":")
|
||||
w, h = float(w_str), float(h_str)
|
||||
if w >= h:
|
||||
width = base_long_edge
|
||||
height = int(round(base_long_edge * (h / w)))
|
||||
else:
|
||||
height = base_long_edge
|
||||
width = int(round(base_long_edge * (w / h)))
|
||||
|
||||
width = max(64, (width // 8) * 8)
|
||||
height = max(64, (height // 8) * 8)
|
||||
return width, height
|
||||
|
||||
|
||||
def resize_pad_to_target(img: Image.Image, width: int, height: int, fill=(255, 255, 255)) -> Image.Image:
|
||||
img = img.convert("RGB")
|
||||
contained = ImageOps.contain(img, (width, height), method=Image.Resampling.LANCZOS)
|
||||
pad_w = width - contained.size[0]
|
||||
pad_h = height - contained.size[1]
|
||||
padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)
|
||||
return ImageOps.expand(contained, border=padding, fill=fill)
|
||||
|
||||
|
||||
def resize_to_long_edge(img: Image.Image, base_long_edge: int) -> Image.Image:
|
||||
img = img.convert("RGB")
|
||||
w, h = img.size
|
||||
if max(w, h) == base_long_edge:
|
||||
return img
|
||||
if w >= h:
|
||||
nw, nh = base_long_edge, int(round(base_long_edge * h / w))
|
||||
else:
|
||||
nw, nh = int(round(base_long_edge * w / h)), base_long_edge
|
||||
nw = max(64, (nw // 8) * 8)
|
||||
nh = max(64, (nh // 8) * 8)
|
||||
return img.resize((nw, nh), Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Run one
|
||||
# -----------------------------
|
||||
@torch.inference_mode()
|
||||
def run_one(
|
||||
pipe: Flux2KleinPipeline,
|
||||
device: str,
|
||||
image_path: Path,
|
||||
final_prompt: str,
|
||||
out_path: Path,
|
||||
aspect_ratio: Optional[str],
|
||||
base_long_edge: int,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed: int,
|
||||
) -> None:
|
||||
ref = load_image(str(image_path)) # PIL.Image
|
||||
|
||||
if aspect_ratio:
|
||||
W, H = aspect_to_wh(aspect_ratio, base_long_edge)
|
||||
ref = resize_pad_to_target(ref, W, H)
|
||||
else:
|
||||
ref = resize_to_long_edge(ref, base_long_edge)
|
||||
W, H = ref.size
|
||||
|
||||
gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
result = pipe(
|
||||
image=ref,
|
||||
prompt=final_prompt,
|
||||
height=H,
|
||||
width=W,
|
||||
guidance_scale=guidance,
|
||||
num_inference_steps=steps,
|
||||
generator=gen,
|
||||
).images[0]
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
result.save(out_path)
|
||||
print(f"[OK] saved: {out_path}")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
ap = argparse.ArgumentParser(description="FLUX.2 [klein] img2img editing with CLI args.")
|
||||
|
||||
io = ap.add_argument_group("I/O")
|
||||
io.add_argument("--image", type=str, default="", help="Single input image path.")
|
||||
io.add_argument("--image_dir", type=str, default="", help="Directory of images for batch mode.")
|
||||
io.add_argument("--out", type=str, default="", help="Single output image path (for --image).")
|
||||
io.add_argument("--out_dir", type=str, default="", help="Output directory (for --image_dir).")
|
||||
|
||||
prompt = ap.add_argument_group("Prompt")
|
||||
prompt.add_argument("--user_input", type=str, default="", help="User prompt (applied to single image or all images).")
|
||||
prompt.add_argument(
|
||||
"--prompt_file",
|
||||
type=str,
|
||||
default="",
|
||||
help='Optional prompt file for batch mode. Supports "filename: prompt" mapping or one-prompt-per-line list.',
|
||||
)
|
||||
prompt.add_argument("--system_prompt", type=str, default=DEFAULT_SYSTEM_PROMPT, help="System prompt string.")
|
||||
|
||||
model = ap.add_argument_group("Model/Runtime")
|
||||
model.add_argument("--repo_id", type=str, default=DEFAULT_REPO_ID, help=f"HF repo id (default: {DEFAULT_REPO_ID}).")
|
||||
model.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='Device, e.g. "cuda" or "cpu".')
|
||||
model.add_argument(
|
||||
"--torch_dtype",
|
||||
type=str,
|
||||
default=DEFAULT_TORCH_DTYPE,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
help="torch dtype for loading weights.",
|
||||
)
|
||||
model.add_argument("--cpu_offload", action="store_true", help="Enable model CPU offload (saves VRAM).")
|
||||
|
||||
gen = ap.add_argument_group("Generation")
|
||||
gen.add_argument("--aspect_ratio", type=str, default="4:3", help='e.g. "16:9", "1:1". Set "" to disable.')
|
||||
gen.add_argument("--base_long_edge", type=int, default=1024, help="Target long edge for output/resized reference.")
|
||||
gen.add_argument("--steps", type=int, default=28, help="num_inference_steps.")
|
||||
gen.add_argument("--guidance", type=float, default=4.0, help="guidance_scale.")
|
||||
gen.add_argument("--seed", type=int, default=0, help="Random seed.")
|
||||
|
||||
args = ap.parse_args()
|
||||
|
||||
# Validate mode
|
||||
if bool(args.image) == bool(args.image_dir):
|
||||
raise SystemExit("Provide exactly one of --image or --image_dir.")
|
||||
|
||||
if args.image:
|
||||
if not args.out:
|
||||
raise SystemExit("In single-image mode, --out is required.")
|
||||
if not args.user_input and not args.system_prompt:
|
||||
raise SystemExit("Provide --user_input and/or --system_prompt.")
|
||||
else:
|
||||
if not args.out_dir:
|
||||
raise SystemExit("In batch mode, --out_dir is required.")
|
||||
if not args.prompt_file and not args.user_input:
|
||||
raise SystemExit("Batch mode requires either --prompt_file or --user_input.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def resolve_torch_dtype(dtype_str: str) -> torch.dtype:
|
||||
if dtype_str == "float16":
|
||||
return torch.float16
|
||||
if dtype_str == "bfloat16":
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
# Normalize aspect_ratio
|
||||
aspect_ratio = (args.aspect_ratio or "").strip()
|
||||
if aspect_ratio == "":
|
||||
aspect_ratio = None
|
||||
|
||||
torch_dtype = resolve_torch_dtype(args.torch_dtype)
|
||||
|
||||
pipe = Flux2KleinPipeline.from_pretrained(args.repo_id, torch_dtype=torch_dtype)
|
||||
if args.cpu_offload:
|
||||
pipe.enable_model_cpu_offload()
|
||||
else:
|
||||
pipe.to(args.device)
|
||||
|
||||
system_prompt = args.system_prompt
|
||||
|
||||
if args.image:
|
||||
img_path = Path(args.image)
|
||||
if not img_path.exists():
|
||||
raise FileNotFoundError(f"Image not found: {img_path}")
|
||||
out_path = Path(args.out)
|
||||
|
||||
final_prompt = build_final_prompt(system_prompt, args.user_input)
|
||||
run_one(
|
||||
pipe=pipe,
|
||||
device=args.device,
|
||||
image_path=img_path,
|
||||
final_prompt=final_prompt,
|
||||
out_path=out_path,
|
||||
aspect_ratio=aspect_ratio,
|
||||
base_long_edge=args.base_long_edge,
|
||||
steps=args.steps,
|
||||
guidance=args.guidance,
|
||||
seed=args.seed,
|
||||
)
|
||||
return
|
||||
|
||||
# Batch mode
|
||||
image_dir = Path(args.image_dir)
|
||||
if not image_dir.exists():
|
||||
raise FileNotFoundError(f"Image dir not found: {image_dir}")
|
||||
|
||||
images = list_images(image_dir)
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mapping: Dict[str, str] = {}
|
||||
list_only: List[str] = []
|
||||
|
||||
if args.prompt_file:
|
||||
mapping, list_only = load_prompts(Path(args.prompt_file))
|
||||
|
||||
if mapping:
|
||||
# Mapping mode: filename -> prompt
|
||||
for img_path in images:
|
||||
user_input = mapping.get(img_path.name, "")
|
||||
final_prompt = build_final_prompt(system_prompt, user_input)
|
||||
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
||||
run_one(
|
||||
pipe=pipe,
|
||||
device=args.device,
|
||||
image_path=img_path,
|
||||
final_prompt=final_prompt,
|
||||
out_path=out_path,
|
||||
aspect_ratio=aspect_ratio,
|
||||
base_long_edge=args.base_long_edge,
|
||||
steps=args.steps,
|
||||
guidance=args.guidance,
|
||||
seed=args.seed,
|
||||
)
|
||||
elif list_only:
|
||||
# List mode: one prompt per line, matched by sorted filenames
|
||||
n = min(len(list_only), len(images))
|
||||
if len(list_only) != len(images):
|
||||
print(
|
||||
f"[WARN] prompt lines ({len(list_only)}) != images ({len(images)}). "
|
||||
"Will match by min length and leave the rest empty."
|
||||
)
|
||||
for i, img_path in enumerate(images):
|
||||
user_input = list_only[i] if i < n else ""
|
||||
final_prompt = build_final_prompt(system_prompt, user_input)
|
||||
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
||||
run_one(
|
||||
pipe=pipe,
|
||||
device=args.device,
|
||||
image_path=img_path,
|
||||
final_prompt=final_prompt,
|
||||
out_path=out_path,
|
||||
aspect_ratio=aspect_ratio,
|
||||
base_long_edge=args.base_long_edge,
|
||||
steps=args.steps,
|
||||
guidance=args.guidance,
|
||||
seed=args.seed,
|
||||
)
|
||||
else:
|
||||
# No prompt file parsed (or no prompt file). Apply --user_input to all
|
||||
for img_path in images:
|
||||
final_prompt = build_final_prompt(system_prompt, args.user_input)
|
||||
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
||||
run_one(
|
||||
pipe=pipe,
|
||||
device=args.device,
|
||||
image_path=img_path,
|
||||
final_prompt=final_prompt,
|
||||
out_path=out_path,
|
||||
aspect_ratio=aspect_ratio,
|
||||
base_long_edge=args.base_long_edge,
|
||||
steps=args.steps,
|
||||
guidance=args.guidance,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
16
app/litserve_app/client.py
Executable file
16
app/litserve_app/client.py
Executable file
@@ -0,0 +1,16 @@
|
||||
import httpx
|
||||
import asyncio
|
||||
|
||||
|
||||
async def main():
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"http://localhost:8451/predict",
|
||||
json={
|
||||
"prompt": "紫色实木窗帘",
|
||||
}
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
84
app/litserve_app/model.py
Executable file
84
app/litserve_app/model.py
Executable file
@@ -0,0 +1,84 @@
|
||||
# model.py
|
||||
import torch
|
||||
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
|
||||
from PIL import Image
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
class Flux2KleinModel:
|
||||
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
|
||||
print(f"Loading Flux model on {device} ...")
|
||||
self.pipe = FluxPipeline.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
|
||||
# self.pipe.vae.enable_slicing() # 可选,省显存
|
||||
self.device = device
|
||||
|
||||
def preprocess_image(self, image_path: str) -> Image.Image:
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
return img
|
||||
|
||||
def run_inference(
|
||||
self,
|
||||
image: Image.Image,
|
||||
prompt: str,
|
||||
aspect_ratio: str = "1:1",
|
||||
base_long_edge: int = 1024,
|
||||
steps: int = 28,
|
||||
guidance: float = 4.0,
|
||||
seed: int = 0,
|
||||
**kwargs
|
||||
) -> Image.Image:
|
||||
# 根据 aspect_ratio 计算 width/height(示例实现,可按需修改)
|
||||
if ":" in aspect_ratio:
|
||||
w, h = map(int, aspect_ratio.split(":"))
|
||||
ratio = w / h
|
||||
else:
|
||||
ratio = 1.0
|
||||
|
||||
if ratio >= 1:
|
||||
width = base_long_edge
|
||||
height = int(base_long_edge / ratio)
|
||||
else:
|
||||
height = base_long_edge
|
||||
width = int(base_long_edge * ratio)
|
||||
|
||||
# 取整到 8 的倍数(flux 常见要求)
|
||||
width = (width // 8) * 8
|
||||
height = (height // 8) * 8
|
||||
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||
|
||||
result = self.pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance,
|
||||
generator=generator,
|
||||
**kwargs
|
||||
).images[0]
|
||||
|
||||
return result
|
||||
|
||||
def batch_run(
|
||||
self,
|
||||
image_paths: List[str],
|
||||
prompts: List[str],
|
||||
out_paths: List[str],
|
||||
**common_kwargs
|
||||
):
|
||||
assert len(image_paths) == len(prompts) == len(out_paths)
|
||||
|
||||
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
|
||||
try:
|
||||
img = self.preprocess_image(img_path)
|
||||
result = self.run_inference(img, prompt, **common_kwargs)
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
result.save(out_path)
|
||||
print(f"Saved: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed {img_path}: {e}")
|
||||
Reference in New Issue
Block a user