117 lines
5.4 KiB
Python
117 lines
5.4 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
from transformers import AutoModel, AutoConfig
|
||
|
|
from peft import PeftModel, PeftConfig
|
||
|
|
import torch
|
||
|
|
from datasets import Dataset
|
||
|
|
from modelscope import snapshot_download, AutoTokenizer
|
||
|
|
from qwen_vl_utils import process_vision_info
|
||
|
|
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
|
||
|
|
from transformers import (
|
||
|
|
TrainingArguments,
|
||
|
|
Trainer,
|
||
|
|
DataCollatorForSeq2Seq,
|
||
|
|
# Qwen2VLForConditionalGeneration,
|
||
|
|
AutoProcessor,
|
||
|
|
PreTrainedModel
|
||
|
|
)
|
||
|
|
from lmm_utils.Qwen.qwen2vl_lora_mlp.qwen2vl_modify_modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
||
|
|
import json
|
||
|
|
# Qwen2VLForConditionalGeneration
|
||
|
|
class LoRAWithMLP(nn.Module):
|
||
|
|
def __init__(self, base_model_name, mlp_hidden_size=512, num_mlp_layers=2,device='cuda:0'):
|
||
|
|
super().__init__()
|
||
|
|
self.device=device
|
||
|
|
self.base_model = Qwen2VLForConditionalGeneration.from_pretrained("./lmm_utils/Qwen/Qwen2-VL-2B-Instruct/", device_map=device,
|
||
|
|
torch_dtype=torch.bfloat16, trust_remote_code=True, )
|
||
|
|
self.base_model.enable_input_require_grads() # This method is performed when gradient checkpoints are turned on
|
||
|
|
config = LoraConfig(
|
||
|
|
task_type=TaskType.CAUSAL_LM,
|
||
|
|
# target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||
|
|
inference_mode=False, # Training mode
|
||
|
|
r=64,
|
||
|
|
lora_alpha=16,
|
||
|
|
lora_dropout=0.05,
|
||
|
|
bias="none",
|
||
|
|
)
|
||
|
|
# Get the LoRA model
|
||
|
|
self.lora_model = get_peft_model(self.base_model, config)
|
||
|
|
mlp_layers = []
|
||
|
|
input_dim = self.lora_model.config.hidden_size # Inherit large model hidden_size
|
||
|
|
for _ in range(num_mlp_layers):
|
||
|
|
mlp_layers.append(nn.Linear(input_dim, mlp_hidden_size,dtype=torch.bfloat16))
|
||
|
|
mlp_layers.append(nn.ReLU())
|
||
|
|
input_dim = mlp_hidden_size
|
||
|
|
mlp_layers.append(nn.Linear(mlp_hidden_size, 123,dtype=torch.bfloat16)) #Output size = hidden_size
|
||
|
|
self.mlp = nn.Sequential(*mlp_layers)
|
||
|
|
|
||
|
|
def forward(self, input_ids=None,
|
||
|
|
attention_mask=None,
|
||
|
|
inputs_embeds=None,
|
||
|
|
labels=None,
|
||
|
|
output_attentions=None,
|
||
|
|
output_hidden_states=None,
|
||
|
|
return_dict=None,
|
||
|
|
task_ids=None,
|
||
|
|
**kwargs,):
|
||
|
|
# Calculate the output of the large model after LoRA adaptation
|
||
|
|
|
||
|
|
lora_output = self.lora_model(input_ids=input_ids,attention_mask=attention_mask,inputs_embeds=inputs_embeds,labels=None,
|
||
|
|
output_attentions=output_attentions,
|
||
|
|
output_hidden_states=output_hidden_states,
|
||
|
|
return_dict=return_dict,
|
||
|
|
task_ids=task_ids,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Calculate the output of MLP additional processing
|
||
|
|
mlp_output = self.mlp(lora_output.hidden_states[:,-1])
|
||
|
|
|
||
|
|
return mlp_output # Let MLP adjust the output of LoRA
|
||
|
|
|
||
|
|
|
||
|
|
def save_checkpoint(self, path,epoch,optimizer,scheduler,best_valid_loss,avg_train_loss):
|
||
|
|
filtered_dict = {name: param for name, param in self.state_dict().items() if
|
||
|
|
'base_model' not in name or 'lora' in name}
|
||
|
|
checkpoint_dict = {
|
||
|
|
'epoch': epoch,
|
||
|
|
'model_state_dict': filtered_dict,
|
||
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
||
|
|
'best_valid_loss': best_valid_loss,
|
||
|
|
'avg_train_loss': avg_train_loss,
|
||
|
|
"scheduler_state_dict":scheduler.state_dict(),
|
||
|
|
}
|
||
|
|
torch.save(checkpoint_dict, path)
|
||
|
|
def load_checkpoint(self, path, optimizer,scheduler, device):
|
||
|
|
"""
|
||
|
|
Load the checkpoint and restore the model and optimizer state
|
||
|
|
:p aram model: The model that needs to be restored
|
||
|
|
:p aram optimizer: The optimizer that needs to be restored
|
||
|
|
:p aram path: checkpoint file path
|
||
|
|
:p aram device: Runtime device (default GPU)
|
||
|
|
:return: epoch of training, best validation loss, training loss
|
||
|
|
"""
|
||
|
|
checkpoint = torch.load(path, map_location=device,optimizer=None,) # Load checkpoint
|
||
|
|
|
||
|
|
self.load_state_dict(checkpoint['model_state_dict'], strict=False) # Load the model parameters
|
||
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load optimizer parameters
|
||
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||
|
|
epoch = checkpoint.get('epoch', 0) # Get the epoch
|
||
|
|
best_valid_loss = checkpoint.get('best_valid_loss', float('inf')) # Get the best verification loss
|
||
|
|
avg_train_loss = checkpoint.get('avg_train_loss', float('inf')) # Get the average training loss
|
||
|
|
|
||
|
|
return epoch, best_valid_loss
|
||
|
|
|
||
|
|
def save_weights(self, path):
|
||
|
|
""" Only LoRA + MLP weights are saved, and the original Qwen2VL model is not included """
|
||
|
|
# for name, param in self.state_dict().items():
|
||
|
|
# print(name, param.requires_grad)
|
||
|
|
filtered_dict = {name: param for name, param in self.state_dict().items() if 'base_model' not in name or 'lora' in name}
|
||
|
|
torch.save(filtered_dict, path)
|
||
|
|
# torch.save(self.state_dict(), path)
|
||
|
|
|
||
|
|
def load_weights(self, path):
|
||
|
|
""" Load LoRA + MLP weights (need to initialize the model first) """
|
||
|
|
state_dict = torch.load(path, map_location=self.device)
|
||
|
|
self.load_state_dict(state_dict, strict=False) # strict=False Some layers are allowed to be missing
|