437 lines
15 KiB
Python
437 lines
15 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
from PIL import Image
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoProcessor,
|
|
AutoTokenizer,
|
|
Mistral3ForConditionalGeneration,
|
|
pipeline,
|
|
)
|
|
|
|
from .sampling import cap_pixels, concatenate_images
|
|
from .system_messages import (
|
|
PROMPT_IMAGE_INTEGRITY,
|
|
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
|
|
PROMPT_TEXT_INTEGRITY,
|
|
SYSTEM_MESSAGE,
|
|
SYSTEM_MESSAGE_UPSAMPLING_I2I,
|
|
SYSTEM_MESSAGE_UPSAMPLING_T2I,
|
|
SYSTEM_PROMPT_CONTENT_FILTER,
|
|
)
|
|
|
|
OUTPUT_LAYERS_MISTRAL = [10, 20, 30]
|
|
OUTPUT_LAYERS_QWEN3 = [9, 18, 27]
|
|
MAX_LENGTH = 512
|
|
NSFW_THRESHOLD = 0.85
|
|
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
|
|
|
|
|
|
class Mistral3SmallEmbedder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
|
model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
|
torch_dtype: str = "bfloat16",
|
|
):
|
|
super().__init__()
|
|
|
|
self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained(
|
|
model_spec,
|
|
torch_dtype=getattr(torch, torch_dtype),
|
|
)
|
|
self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False)
|
|
self.yes_token, self.no_token = self.processor.tokenizer.encode(
|
|
["yes", "no"], add_special_tokens=False
|
|
)
|
|
|
|
self.max_length = MAX_LENGTH
|
|
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
|
|
|
|
self.nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
|
|
|
|
def _validate_and_process_images(
|
|
self, img: list[list[Image.Image]] | list[Image.Image]
|
|
) -> list[list[Image.Image]]:
|
|
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
|
|
if not img:
|
|
return []
|
|
|
|
# Check if it's a list of lists or a list of images
|
|
if isinstance(img[0], Image.Image):
|
|
# It's a list of images, convert to list of lists
|
|
img = [[im] for im in img]
|
|
|
|
# potentially concatenate multiple images to reduce the size
|
|
img = [[concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in img]
|
|
|
|
# cap the pixels
|
|
img = [[cap_pixels(img_i, self.upsampling_max_image_size) for img_i in img_i] for img_i in img]
|
|
return img
|
|
|
|
def format_input(
|
|
self,
|
|
txt: list[str],
|
|
system_message: str = SYSTEM_MESSAGE,
|
|
img: list[Image.Image] | list[list[Image.Image]] | None = None,
|
|
) -> list[list[dict]]:
|
|
"""
|
|
Format a batch of text prompts into the conversation format expected by apply_chat_template.
|
|
Optionally, add images to the input.
|
|
|
|
Args:
|
|
txt: List of text prompts
|
|
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
|
|
img: List of images to add to the input.
|
|
|
|
Returns:
|
|
List of conversations, where each conversation is a list of message dicts
|
|
"""
|
|
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
|
|
# when truncation is enabled. The processor counts [IMG] tokens and fails
|
|
# if the count changes after truncation.
|
|
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
|
|
|
|
if img is None or len(img) == 0:
|
|
return [
|
|
[
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_message}],
|
|
},
|
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
|
]
|
|
for prompt in cleaned_txt
|
|
]
|
|
else:
|
|
assert len(img) == len(txt), "Number of images must match number of prompts"
|
|
img = self._validate_and_process_images(img)
|
|
|
|
messages = [
|
|
[
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_message}],
|
|
},
|
|
]
|
|
for _ in cleaned_txt
|
|
]
|
|
|
|
for i, (el, images) in enumerate(zip(messages, img)):
|
|
# optionally add the images per batch element.
|
|
if images is not None:
|
|
el.append(
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "image", "image": image_obj} for image_obj in images],
|
|
}
|
|
)
|
|
# add the text.
|
|
el.append(
|
|
{
|
|
"role": "user",
|
|
"content": [{"type": "text", "text": cleaned_txt[i]}],
|
|
}
|
|
)
|
|
|
|
return messages
|
|
|
|
@torch.no_grad()
|
|
def upsample_prompt(
|
|
self,
|
|
txt: list[str],
|
|
img: list[Image.Image] | list[list[Image.Image]] | None = None,
|
|
temperature: float = 0.15,
|
|
) -> list[str]:
|
|
"""
|
|
Upsample prompts using the model's generate method.
|
|
|
|
Args:
|
|
txt: List of input prompts to upsample
|
|
img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.
|
|
|
|
Returns:
|
|
List of upsampled prompts
|
|
"""
|
|
# Set system message based on whether images are provided
|
|
if img is None or len(img) == 0 or img[0] is None:
|
|
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
|
|
else:
|
|
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
|
|
|
|
# Format input messages
|
|
messages_batch = self.format_input(txt=txt, system_message=system_message, img=img)
|
|
|
|
# Process all messages at once
|
|
# with image processing a too short max length can throw an error in here.
|
|
try:
|
|
inputs = self.processor.apply_chat_template(
|
|
messages_batch,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=2048,
|
|
)
|
|
except ValueError as e:
|
|
print(
|
|
f"Error processing input: {e}, your max length is probably too short, when you have images in the input."
|
|
)
|
|
raise e
|
|
|
|
# Move to device
|
|
inputs["input_ids"] = inputs["input_ids"].to(self.model.device)
|
|
inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device)
|
|
|
|
if "pixel_values" in inputs:
|
|
inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype)
|
|
|
|
# Generate text using the model's generate method
|
|
try:
|
|
generated_ids = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=512,
|
|
do_sample=True,
|
|
temperature=temperature,
|
|
use_cache=True,
|
|
)
|
|
|
|
# Decode only the newly generated tokens (skip input tokens)
|
|
# Extract only the generated portion
|
|
input_length = inputs["input_ids"].shape[1]
|
|
generated_tokens = generated_ids[:, input_length:]
|
|
|
|
raw_txt = self.processor.tokenizer.batch_decode(
|
|
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
|
)
|
|
return raw_txt
|
|
except Exception as e:
|
|
print(f"Error generating upsampled prompt: {e}, returning original prompt")
|
|
return txt
|
|
|
|
@torch.no_grad()
|
|
def forward(self, txt: list[str]):
|
|
# Format input messages
|
|
messages_batch = self.format_input(txt=txt)
|
|
|
|
# Process all messages at once
|
|
# with image processing a too short max length can throw an error in here.
|
|
inputs = self.processor.apply_chat_template(
|
|
messages_batch,
|
|
add_generation_prompt=False,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
)
|
|
|
|
# Move to device
|
|
input_ids = inputs["input_ids"].to(self.model.device)
|
|
attention_mask = inputs["attention_mask"].to(self.model.device)
|
|
|
|
# Forward pass through the model
|
|
output = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
output_hidden_states=True,
|
|
use_cache=False,
|
|
)
|
|
|
|
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_MISTRAL], dim=1)
|
|
return rearrange(out, "b c l d -> b l (c d)")
|
|
|
|
def yes_no_logit_processor(
|
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
) -> torch.FloatTensor:
|
|
"""
|
|
Sets all tokens but yes/no to the minimum.
|
|
"""
|
|
scores_yes_token = scores[:, self.yes_token].clone()
|
|
scores_no_token = scores[:, self.no_token].clone()
|
|
scores_min = scores.min()
|
|
scores[:, :] = scores_min - 1
|
|
scores[:, self.yes_token] = scores_yes_token
|
|
scores[:, self.no_token] = scores_no_token
|
|
return scores
|
|
|
|
def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool:
|
|
if isinstance(image, torch.Tensor):
|
|
image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
|
|
image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
|
|
elif isinstance(image, (str, Path)):
|
|
image = Image.open(image)
|
|
|
|
classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
|
|
if classification["score"] > NSFW_THRESHOLD:
|
|
return True
|
|
|
|
# 512^2 pixels are enough for checking
|
|
w, h = image.size
|
|
f = (512**2 / (w * h)) ** 0.5
|
|
image = image.resize((int(f * w), int(f * h)))
|
|
|
|
chat = [
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": SYSTEM_PROMPT_CONTENT_FILTER,
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": PROMPT_IMAGE_INTEGRITY,
|
|
},
|
|
{
|
|
"type": "image",
|
|
"image": image,
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
|
|
},
|
|
],
|
|
},
|
|
]
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
chat,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
).to(self.model.device)
|
|
inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype)
|
|
|
|
generate_ids = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=1,
|
|
logits_processor=[self.yes_no_logit_processor],
|
|
do_sample=False,
|
|
)
|
|
|
|
return generate_ids[0, -1].item() == self.yes_token
|
|
|
|
def test_txt(self, txt: str) -> bool:
|
|
chat = [
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": SYSTEM_PROMPT_CONTENT_FILTER,
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
|
|
},
|
|
],
|
|
},
|
|
]
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
chat,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
).to(self.model.device)
|
|
|
|
generate_ids = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=1,
|
|
logits_processor=[self.yes_no_logit_processor],
|
|
do_sample=False,
|
|
)
|
|
return generate_ids[0, -1].item() == self.yes_token
|
|
|
|
|
|
class Qwen3Embedder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model_spec: str,
|
|
device: str | torch.device = "cuda",
|
|
):
|
|
super().__init__()
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_spec,
|
|
torch_dtype=None,
|
|
device_map=str(device),
|
|
)
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_spec)
|
|
self.max_length = MAX_LENGTH
|
|
|
|
@torch.no_grad()
|
|
def forward(self, txt: list[str]):
|
|
all_input_ids = []
|
|
all_attention_masks = []
|
|
|
|
for prompt in txt:
|
|
messages = [{"role": "user", "content": prompt}]
|
|
text = self.tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
enable_thinking=False,
|
|
)
|
|
|
|
model_inputs = self.tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
)
|
|
|
|
all_input_ids.append(model_inputs["input_ids"])
|
|
all_attention_masks.append(model_inputs["attention_mask"])
|
|
|
|
input_ids = torch.cat(all_input_ids, dim=0).to(self.model.device)
|
|
attention_mask = torch.cat(all_attention_masks, dim=0).to(self.model.device)
|
|
|
|
output = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
output_hidden_states=True,
|
|
use_cache=False,
|
|
)
|
|
|
|
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_QWEN3], dim=1)
|
|
return rearrange(out, "b c l d -> b l (c d)")
|
|
|
|
def test_txt(self, txt: str) -> bool:
|
|
raise NotImplementedError("Qwen3Embedder does not support text testing")
|
|
|
|
def test_image(self, image) -> bool:
|
|
raise NotImplementedError("Qwen3Embedder does not support image testing")
|
|
|
|
def upsample_prompt(self, txt: list[str], img=None, **kwargs) -> list[str]:
|
|
raise NotImplementedError("Qwen3Embedder does not support upsampling")
|
|
|
|
|
|
def load_mistral_small_embedder(device: str | torch.device = "cuda") -> Mistral3SmallEmbedder:
|
|
return Mistral3SmallEmbedder().to(device)
|
|
|
|
|
|
def load_qwen3_embedder(variant: str, device: str | torch.device = "cuda"):
|
|
return Qwen3Embedder(model_spec=f"Qwen/Qwen3-{variant}-FP8", device=device)
|