Vision-Language Models (VLMs) have emerged as powerful tools for understanding and generating content that involves both images and text. While these models demonstrate impressive capabilities on general domain tasks, specialized fields like medical imaging require domain-specific adaptation to achieve optimal performance.
Fine-tuning Qwen2.5-VL for Medical Image Analysis
In this comprehensive guide, we'll walk through the process of fine-tuning the Qwen2.5-VL (Vision-Language) model for medical imaging analysis. Qwen2.5-VL is a powerful multimodal model that can process both images and text, making it ideal for medical applications where visual analysis and textual explanations are equally important.
By the end of this guide, you'll have a comprehensive understanding of how to adapt large vision-language models for specialized medical imaging tasks, even with limited computational resources.
Before diving into fine-tuning, let's explore the architecture of the Qwen2.5-VL model and why it's suitable for medical imaging tasks.
Qwen2.5-VL is a multimodal model based on the Qwen family developed by Alibaba Cloud. The 7B version we're using contains approximately 7 billion parameters and incorporates both vision and language capabilities:
Qwen2.5-VL offers several advantages for medical imaging applications:
Why Qwen2.5-VL for Medical Imaging? While there are dedicated medical imaging models, adapting general-purpose VLMs like Qwen2.5-VL offers the advantage of retaining broader world knowledge while gaining domain-specific capabilities. This enables more comprehensive analysis and explanation of medical images in context.
The first step in our fine-tuning journey is to set up the proper environment with all necessary dependencies. This includes installing the latest versions of transformer libraries, quantization tools, and utilities specific to the Qwen2.5-VL model.
!pip install -U -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git datasets bitsandbytes peft qwen-vl-utils wandb accelerate
# Tested with transformers==4.47.0.dev0, trl==0.12.0.dev0, datasets==3.0.2, bitsandbytes==0.44.1, peft==0.13.2, qwen-vl-utils==0.0.8, wandb==0.18.5, accelerate==1.0.1
!pip install -q torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121
from huggingface_hub import notebook_login
notebook_login()
Each component in our installation serves a specific purpose:
We're using the latest PyTorch version with CUDA 12.1 support to ensure compatibility with modern GPUs and to take advantage of the latest optimizations for large model training.
To fine-tune our model for medical applications, we need appropriate medical imaging datasets. In this section, we'll download and prepare a specialized dataset containing medical images paired with reports.
data_url="https://www.kaggle.com/datasets/ahmedsta/qa-vlm-med"
import os
os.makedirs('/root/.kaggle', exist_ok=True) # Make the .kaggle directory
os.rename('kaggle.json', '/root/.kaggle/kaggle.json') # Move the file to the correct path
os.chmod('/root/.kaggle/kaggle.json', 600) # Set the correct file permissions
!kaggle datasets download ahmedsta/report-ray
!unzip report-ray.zip
Medical images often come in various sizes and formats. To ensure consistency, we resize all images to a standard dimension (320x320 pixels) that's appropriate for our model's input requirements:
import json
from tqdm import tqdm
from PIL import Image
with open(r"/content/Report_json.json", 'r',encoding='utf-8') as file:
med = json.load(file)
def resize(image_path):
try:
if os.path.exists(image_path):
image = Image.open(image_path)
resized_image = image.resize((320, 320))
resized_image.save(image_path)
except Exception as e:
print(f"Error processing image: {e}")
med_folder="/content/Report_Med/Report_Med"
for e in tqdm(med):
for i in range(len(e[1]['content'])-1):
image=f"""{med_folder}/{e[1]['content'][i]['image']}"""
resize(image)
e[1]['content'][i]['image']=image
print(f"""lenght of med:{len(med)}""")
We divide our dataset into training and evaluation sets to properly measure our model's performance:
import random
Data_model = random.sample(med, k=len(med)//0.9)
Data_test = random.sample(med, k=len(med)//0.1)
# Creating train/eval split
train_data = random.sample(med, k=len(Data_model)*0.8)
eval_data = random.sample(med, k=len(Data_model)*0.2)
Dataset Structure: The medical dataset consists of radiology images paired with expert reports. Each example contains an image path and corresponding medical interpretation or analysis. This paired data is ideal for teaching the model to associate visual patterns in medical images with appropriate clinical descriptions.
Before proceeding with fine-tuning, it's important to load and test the base model to understand its initial capabilities on medical images. This gives us a baseline for comparison after fine-tuning.
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
# default: Load the model on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
To test the model, we need a function that can process our medical images and generate descriptive text:
from qwen_vl_utils import process_vision_info
def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cuda"):
# Prepare the text input by applying the chat template
text_input = processor.apply_chat_template(
sample[1:2], # Use the sample without the system message
tokenize=False,
add_generation_prompt=True
)
# Process the visual input from the sample
image_inputs, _ = process_vision_info(sample)
# Prepare the inputs for the model
model_inputs = processor(
text=[text_input],
images=image_inputs,
return_tensors="pt",
).to(device) # Move inputs to the specified device
# Generate text with the model
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
# Trim the generated ids to remove the input ids
trimmed_generated_ids = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
]
# Decode the output text
output_text = processor.batch_decode(
trimmed_generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return output_text[0] # Return the first decoded output text
Working with large multimodal models requires careful memory management, especially when fine-tuning. We implement a utility function to clear memory between different stages:
import gc
import time
def clear_memory():
# Delete variables if they exist in the current global scope
if 'inputs' in globals(): del globals()['inputs']
if 'model' in globals(): del globals()['model']
if 'processor' in globals(): del globals()['processor']
if 'trainer' in globals(): del globals()['trainer']
if 'peft_model' in globals(): del globals()['peft_model']
if 'bnb_config' in globals(): del globals()['bnb_config']
time.sleep(2)
# Garbage collection and clearing CUDA memory
gc.collect()
time.sleep(2)
torch.cuda.empty_cache()
torch.cuda.synchronize()
time.sleep(2)
gc.collect()
time.sleep(2)
print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
clear_memory()
GPU Memory Challenge: The Qwen2.5-VL-7B model with 7 billion parameters requires significant GPU memory, especially when handling both images and text. Without quantization, it would need over 14GB of VRAM in full precision (FP32) or around 7GB in half precision (BF16). The memory management strategy is crucial for successfully fine-tuning on consumer-grade hardware.
To fine-tune our large vision-language model efficiently, we need to implement memory optimization techniques like quantization and parameter-efficient fine-tuning methods.
import torch
from transformers import BitsAndBytesConfig
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Load model and tokenizer
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
Quantization reduces the precision of model weights from 32-bit floating point to just 4 bits, significantly decreasing memory requirements. Our quantization configuration uses:
Memory Reduction: 4-bit quantization reduces the memory footprint of the Qwen2.5-VL-7B model from approximately 14GB (in FP32) to around 2GB, making it possible to fine-tune on GPUs with limited VRAM. This is particularly important for vision-language models that need to process both large images and text sequences simultaneously.
Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning technique that adds small trainable matrices to the model while keeping most weights frozen. When combined with quantization, this approach is known as QLoRA:
from peft import LoraConfig, get_peft_model
# Configure LoRA
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=8,
bias="none",
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
# Apply PEFT model adaptation
peft_model = get_peft_model(model, peft_config)
# Print trainable parameters
peft_model.print_trainable_parameters()
Our LoRA configuration is optimized for vision-language models with the following parameters:
Vision-language models require special handling for processing batches that contain both images and text. We implement a custom data collator function:
from qwen_vl_utils import process_vision_info
# Create a data collator to encode text and image pairs
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example, tokenize=False) for example in examples] # Prepare texts for processing
image_inputs = []
for example in examples:
if isinstance(example, dict) and "image" in example:
try:
image_inputs.append(process_vision_info([example])[0])
except KeyError:
# Handle missing 'content' key, for example:
print(f"Warning: Example missing 'content' key: {example}")
# Skip this example or provide a default value
image_inputs.append(None) # or your preferred default
else:
image_inputs.append(process_vision_info(example)[0])
# Tokenize the texts and process the images
batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True) # Encode texts and images into tensors
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone() # Clone input IDs for labels
labels[labels == processor.tokenizer.pad_token_id] = -100 # Mask padding tokens in labels
image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] # Convert image token to ID
# Mask image token IDs in the labels
for image_token_id in image_tokens:
labels[labels == image_token_id] = -100 # Mask image token IDs in labels
batch["labels"] = labels # Add labels to the batch
return batch # Return the prepared batch
Our custom collator performs several key functions for handling multimodal data:
Handling Special Tokens: The collator sets labels for padding tokens and image tokens to -100, which tells the model to ignore these positions when computing the loss. This is crucial for vision-language models because we don't want the model to predict the raw image tokens—only the text that should come after or around the images.
With all components in place, we can now set up the SFTTrainer and begin the fine-tuning process:
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=eval_data,
data_collator=collate_fn,
peft_config=peft_config,
#tokenizer=processor.tokenizer,
)
trainer.train()
trainer.save_model(training_args.output_dir)
During training, the SFTTrainer will:
Training Efficiency: With our QLoRA setup, we're only training about 0.1% of the total parameters in the model. This translates to approximately 7 million trainable parameters out of the 7 billion total parameters, dramatically reducing memory requirements while still achieving significant adaptation to the medical domain.
After training, we can clear memory and load our fine-tuned model to evaluate its performance:
clear_memory()
import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, AutoConfig
# Instantiate the model with the correct configuration
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
processor = Qwen2VLProcessor.from_pretrained(model_id)
adapter_path = "StaAhmed/qwen2.5-7b-med"
model.load_adapter(adapter_path)
output = generate_text_from_sample(model, processor, data[500])
output
The fine-tuned model should now demonstrate improved capabilities specific to medical imaging:
Model Integration: One key advantage of our approach is that the fine-tuned adapter is only about 7MB in size, compared to the full model which is several gigabytes. This small adapter can be easily distributed and applied to the base model, making it practical to deploy specialized medical capabilities without duplicating the entire model.