Files
design2garmentcode-impl/lmm_utils/fintuned_qwen2vl_model.py

117 lines
5.4 KiB
Python
Raw Permalink Normal View History

2025-07-03 17:03:00 +08:00
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from peft import PeftModel, PeftConfig
import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
# Qwen2VLForConditionalGeneration,
AutoProcessor,
PreTrainedModel
)
from lmm_utils.Qwen.qwen2vl_lora_mlp.qwen2vl_modify_modeling_qwen2_vl import Qwen2VLForConditionalGeneration
import json
# Qwen2VLForConditionalGeneration
class LoRAWithMLP(nn.Module):
def __init__(self, base_model_name, mlp_hidden_size=512, num_mlp_layers=2,device='cuda:0'):
super().__init__()
self.device=device
self.base_model = Qwen2VLForConditionalGeneration.from_pretrained("./lmm_utils/Qwen/Qwen2-VL-2B-Instruct/", device_map=device,
torch_dtype=torch.bfloat16, trust_remote_code=True, )
self.base_model.enable_input_require_grads() # This method is performed when gradient checkpoints are turned on
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
# target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
inference_mode=False, # Training mode
r=64,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
)
# Get the LoRA model
self.lora_model = get_peft_model(self.base_model, config)
mlp_layers = []
input_dim = self.lora_model.config.hidden_size # Inherit large model hidden_size
for _ in range(num_mlp_layers):
mlp_layers.append(nn.Linear(input_dim, mlp_hidden_size,dtype=torch.bfloat16))
mlp_layers.append(nn.ReLU())
input_dim = mlp_hidden_size
mlp_layers.append(nn.Linear(mlp_hidden_size, 123,dtype=torch.bfloat16)) #Output size = hidden_size
self.mlp = nn.Sequential(*mlp_layers)
def forward(self, input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task_ids=None,
**kwargs,):
# Calculate the output of the large model after LoRA adaptation
lora_output = self.lora_model(input_ids=input_ids,attention_mask=attention_mask,inputs_embeds=inputs_embeds,labels=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
task_ids=task_ids,
)
# Calculate the output of MLP additional processing
mlp_output = self.mlp(lora_output.hidden_states[:,-1])
return mlp_output # Let MLP adjust the output of LoRA
def save_checkpoint(self, path,epoch,optimizer,scheduler,best_valid_loss,avg_train_loss):
filtered_dict = {name: param for name, param in self.state_dict().items() if
'base_model' not in name or 'lora' in name}
checkpoint_dict = {
'epoch': epoch,
'model_state_dict': filtered_dict,
'optimizer_state_dict': optimizer.state_dict(),
'best_valid_loss': best_valid_loss,
'avg_train_loss': avg_train_loss,
"scheduler_state_dict":scheduler.state_dict(),
}
torch.save(checkpoint_dict, path)
def load_checkpoint(self, path, optimizer,scheduler, device):
"""
Load the checkpoint and restore the model and optimizer state
:p aram model: The model that needs to be restored
:p aram optimizer: The optimizer that needs to be restored
:p aram path: checkpoint file path
:p aram device: Runtime device (default GPU)
:return: epoch of training, best validation loss, training loss
"""
checkpoint = torch.load(path, map_location=device,optimizer=None,) # Load checkpoint
self.load_state_dict(checkpoint['model_state_dict'], strict=False) # Load the model parameters
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load optimizer parameters
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint.get('epoch', 0) # Get the epoch
best_valid_loss = checkpoint.get('best_valid_loss', float('inf')) # Get the best verification loss
avg_train_loss = checkpoint.get('avg_train_loss', float('inf')) # Get the average training loss
return epoch, best_valid_loss
def save_weights(self, path):
""" Only LoRA + MLP weights are saved, and the original Qwen2VL model is not included """
# for name, param in self.state_dict().items():
# print(name, param.requires_grad)
filtered_dict = {name: param for name, param in self.state_dict().items() if 'base_model' not in name or 'lora' in name}
torch.save(filtered_dict, path)
# torch.save(self.state_dict(), path)
def load_weights(self, path):
""" Load LoRA + MLP weights (need to initialize the model first) """
state_dict = torch.load(path, map_location=self.device)
self.load_state_dict(state_dict, strict=False) # strict=False Some layers are allowed to be missing