Files
design2garmentcode-impl/lmm_utils/fintuned_qwen2vl_model.py
2025-07-03 17:03:00 +08:00

117 lines
5.4 KiB
Python

import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from peft import PeftModel, PeftConfig
import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
# Qwen2VLForConditionalGeneration,
AutoProcessor,
PreTrainedModel
)
from lmm_utils.Qwen.qwen2vl_lora_mlp.qwen2vl_modify_modeling_qwen2_vl import Qwen2VLForConditionalGeneration
import json
# Qwen2VLForConditionalGeneration
class LoRAWithMLP(nn.Module):
def __init__(self, base_model_name, mlp_hidden_size=512, num_mlp_layers=2,device='cuda:0'):
super().__init__()
self.device=device
self.base_model = Qwen2VLForConditionalGeneration.from_pretrained("./lmm_utils/Qwen/Qwen2-VL-2B-Instruct/", device_map=device,
torch_dtype=torch.bfloat16, trust_remote_code=True, )
self.base_model.enable_input_require_grads() # This method is performed when gradient checkpoints are turned on
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
# target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
inference_mode=False, # Training mode
r=64,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
)
# Get the LoRA model
self.lora_model = get_peft_model(self.base_model, config)
mlp_layers = []
input_dim = self.lora_model.config.hidden_size # Inherit large model hidden_size
for _ in range(num_mlp_layers):
mlp_layers.append(nn.Linear(input_dim, mlp_hidden_size,dtype=torch.bfloat16))
mlp_layers.append(nn.ReLU())
input_dim = mlp_hidden_size
mlp_layers.append(nn.Linear(mlp_hidden_size, 123,dtype=torch.bfloat16)) #Output size = hidden_size
self.mlp = nn.Sequential(*mlp_layers)
def forward(self, input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task_ids=None,
**kwargs,):
# Calculate the output of the large model after LoRA adaptation
lora_output = self.lora_model(input_ids=input_ids,attention_mask=attention_mask,inputs_embeds=inputs_embeds,labels=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
task_ids=task_ids,
)
# Calculate the output of MLP additional processing
mlp_output = self.mlp(lora_output.hidden_states[:,-1])
return mlp_output # Let MLP adjust the output of LoRA
def save_checkpoint(self, path,epoch,optimizer,scheduler,best_valid_loss,avg_train_loss):
filtered_dict = {name: param for name, param in self.state_dict().items() if
'base_model' not in name or 'lora' in name}
checkpoint_dict = {
'epoch': epoch,
'model_state_dict': filtered_dict,
'optimizer_state_dict': optimizer.state_dict(),
'best_valid_loss': best_valid_loss,
'avg_train_loss': avg_train_loss,
"scheduler_state_dict":scheduler.state_dict(),
}
torch.save(checkpoint_dict, path)
def load_checkpoint(self, path, optimizer,scheduler, device):
"""
Load the checkpoint and restore the model and optimizer state
:p aram model: The model that needs to be restored
:p aram optimizer: The optimizer that needs to be restored
:p aram path: checkpoint file path
:p aram device: Runtime device (default GPU)
:return: epoch of training, best validation loss, training loss
"""
checkpoint = torch.load(path, map_location=device,optimizer=None,) # Load checkpoint
self.load_state_dict(checkpoint['model_state_dict'], strict=False) # Load the model parameters
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load optimizer parameters
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint.get('epoch', 0) # Get the epoch
best_valid_loss = checkpoint.get('best_valid_loss', float('inf')) # Get the best verification loss
avg_train_loss = checkpoint.get('avg_train_loss', float('inf')) # Get the average training loss
return epoch, best_valid_loss
def save_weights(self, path):
""" Only LoRA + MLP weights are saved, and the original Qwen2VL model is not included """
# for name, param in self.state_dict().items():
# print(name, param.requires_grad)
filtered_dict = {name: param for name, param in self.state_dict().items() if 'base_model' not in name or 'lora' in name}
torch.save(filtered_dict, path)
# torch.save(self.state_dict(), path)
def load_weights(self, path):
""" Load LoRA + MLP weights (need to initialize the model first) """
state_dict = torch.load(path, map_location=self.device)
self.load_state_dict(state_dict, strict=False) # strict=False Some layers are allowed to be missing