FLUX.2 [klein]

This commit is contained in:
timudk
2026-01-15 15:12:38 +01:00
parent ab7cca6801
commit b56ac61450
12 changed files with 530 additions and 119 deletions

View File

@@ -4,7 +4,13 @@ import torch
import torch.nn as nn
from einops import rearrange
from PIL import Image
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, pipeline
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
Mistral3ForConditionalGeneration,
pipeline,
)
from .sampling import cap_pixels, concatenate_images
from .system_messages import (
@@ -17,7 +23,8 @@ from .system_messages import (
SYSTEM_PROMPT_CONTENT_FILTER,
)
OUTPUT_LAYERS = [10, 20, 30]
OUTPUT_LAYERS_MISTRAL = [10, 20, 30]
OUTPUT_LAYERS_QWEN3 = [9, 18, 27]
MAX_LENGTH = 512
NSFW_THRESHOLD = 0.85
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
@@ -237,7 +244,7 @@ class Mistral3SmallEmbedder(nn.Module):
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_MISTRAL], dim=1)
return rearrange(out, "b c l d -> b l (c d)")
def yes_no_logit_processor(
@@ -354,3 +361,76 @@ class Mistral3SmallEmbedder(nn.Module):
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token
class Qwen3Embedder(nn.Module):
def __init__(
self,
model_spec: str,
device: str | torch.device = "cuda",
):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
model_spec,
torch_dtype=None,
device_map=str(device),
)
self.tokenizer = AutoTokenizer.from_pretrained(model_spec)
self.max_length = MAX_LENGTH
@torch.no_grad()
def forward(self, txt: list[str]):
all_input_ids = []
all_attention_masks = []
for prompt in txt:
messages = [{"role": "user", "content": prompt}]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
model_inputs = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length,
)
all_input_ids.append(model_inputs["input_ids"])
all_attention_masks.append(model_inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(self.model.device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(self.model.device)
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_QWEN3], dim=1)
return rearrange(out, "b c l d -> b l (c d)")
def test_txt(self, txt: str) -> bool:
raise NotImplementedError("Qwen3Embedder does not support text testing")
def test_image(self, image) -> bool:
raise NotImplementedError("Qwen3Embedder does not support image testing")
def upsample_prompt(self, txt: list[str], img=None, **kwargs) -> list[str]:
raise NotImplementedError("Qwen3Embedder does not support upsampling")
def load_mistral_small_embedder(device: str | torch.device = "cuda") -> Mistral3SmallEmbedder:
return Mistral3SmallEmbedder().to(device)
def load_qwen3_embedder(variant: str, device: str | torch.device = "cuda"):
return Qwen3Embedder(model_spec=f"Qwen/Qwen3-{variant}-FP8", device=device)