FakeLLaVA — Building a Lightweight LLaVa style VLM with Qwen2.5 and CLIP
After the last blog article the next logical step is implementing a full blown VLM using Image Encoders. For this task I will implement the training and model structure indicated on the LLaVa 1.5 paper with some modifications. For simplicity’s sake and because I’m GPU poor I will implement a simpler version of the paper without the Q&A post-training. On the paper the training takes around 6hrs on a cluster of 8xA100 which I simply do not own, so we’re going to deploy some lightweight techniques, first of all we’re going to use a much smaller model Qwen2.5-0.5B instead of the Vicuna-7B, secondly we’re going to use a 224x224 image encoder not the 336x336 used on the paper.
So how does it work?
The model is structured the following way x_v our image is first resized to 336x336 then its embedded using the encoder \(z_v = g(x_v)\) where g() is the encoders function, then our embeddings are linearly projected into the models input space either through a projection layer \(H_v = W \cdot z_v\) or through a MLP layer \(H_v = h(g(x_v))\) where \(h()\) is the MLP layer. Alongside the image \(x_v\) also the textual input x_q is passed to the model.
The way this model works is not that much different from the CLIPCap paper’s implementation, even the projection layers used are almost identical. The main difference comes down to the usage of all the tokens generated by the vision encoder, instead of the single CLS token, and also the LLaVa1.5 paper introduces a new method of embedding the images without resizing them to the Vision encoders input space, which is achieved by splitting the image to chunks of 336x336 and embedding each piece individually and feeding all the different embeddings sequentially.
One of the major differences between ClipCap and LLaVA is the instruction-tuning dataset used, while the ClipCap paper uses pairs of encodings and images in the training process LLaVa uses a Q&A dataset obtained using a GPT4 model (so the og LLaVa models are a kind of distilled gpt4 models?!?!).
The training process
The training process happens in 2 phases, the first phase where the LLM and the vision encoder are frozen and the MLP layer is pretrained to generate tokens “comprehensible” by the model. In the second phase the LLM is finetuned alongside the MLP layer in a set of image instructions.
The dataset used on the first phase is a subset of the liuhaotian/LLaVA-CC3M-Pretrain-595K dataset. If you want to follow this tutorial in order to save some compute you can use the Martingkc/LLaVa-CC3M-Pretrain-clip-vit-base-patch32 dataset which contains the pre-computed patch tokens extracted using CLIP ViT Base Patch 32 for the first stage of the training and Martingkc/LLaVa-Instruct-150K-clip-vit-base-patch32 which was computed using liuhaotian/LLaVA-Instruct-150K for the second stage.
The reason I’ve precomputed the patches and stored them on a dataset is because the embedding process takes too long and I want this process to be as easily reproducible as possible and I do not have enough colab credits left :(.
If you want to learn more about how vision encoders and CLIP models work check my article on CLIP models: Clip Models and Image Captioning using CLIP embeddings.
Defining the model
In order to save the model on HF we need to define the torch model. As stated before our LLaVa model is composed of the following components:
- QWEN 2.5 0,5B
- 2 Layer MLP
- CLIP embedder
The input dimension of the Qwen model is 896 meanwhile the dimension of our clip patch tokens are 768, our MLP layer has to project our clip patches into the LLMs input space, so that the transformer can process both modalities together.
projection = nn.Sequential(
nn.Linear(768, 896),
nn.GELU(),
nn.Linear(896, 896),
).cuda().to(torch.float32)
Meanwhile our full model is defined the following way, where we define how the visual tokens will be appended before our text tokens in the LLMs input space. Furthermore since our model doesn’t respect the OG LLaVa structure we need to also define the repo and model structure in this torch model.
So a quick explanation of the parts which are not so explicit:
- our no of visual tokens is 49 because the clip model has a patch size of 32x32 over an image sized 224x224 thus producing
224^2/32^2 = 49patches. - the input embeddings are shaped the following way
inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)where the visual tokens are prepended to the input embeddings the same way it’s done on the LLaVa paper.
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
SYSTEM_PROMPT = (
"You are a helpful vision assistant. "
"You are given an image and a question about it. "
"Answer concisely and accurately."
)
@dataclass
class VLMOutput:
loss: torch.Tensor = None
logits: torch.Tensor = None
class VLMConfig:
def __init__(self, vision_hidden=768, llm_hidden=896, n_visual=49):
self.vision_hidden = vision_hidden
self.llm_hidden = llm_hidden
self.n_visual = n_visual
class VLMModel(nn.Module):
"""
LLaVA-style VLM.
Visual tokens are projected and prepended before the ChatML text sequence.
Loss is computed only on assistant turn tokens.
"""
def __init__(self, config, llm, projection):
super().__init__()
self.config = config
self.llm = llm
self.projection = projection
def forward(self, patch_tokens, input_ids, attention_mask, labels=None):
# [B, n_visual, 768] -> [B, n_visual, 896]
visual_tokens = self.projection(patch_tokens.cuda().float())
# [B, seq_len, 896]
text_embeds = self.llm.get_input_embeddings()(input_ids.cuda())
# match LLM dtype
visual_tokens = visual_tokens.to(text_embeds.dtype)
# [B, n_visual+seq_len, 896]
inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
n_visual = patch_tokens.shape[1]
# extend attention mask
visual_mask = torch.ones(
patch_tokens.shape[0], n_visual,
device="cuda", dtype=attention_mask.dtype
)
full_mask = torch.cat([visual_mask, attention_mask.cuda()], dim=1)
# extend labels ignore visual positions
if labels is not None:
visual_labels = torch.full(
(patch_tokens.shape[0], n_visual), -100,
dtype=torch.long, device="cuda"
)
labels = torch.cat([visual_labels, labels.cuda()], dim=1)
output = self.llm(inputs_embeds=inputs_embeds, attention_mask=full_mask)
loss = None
if labels is not None:
shift_logits = output.logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
return VLMOutput(loss=loss, logits=output.logits)
def generate_caption(self, patch_tokens, question, tokenizer, max_new_tokens=128):
self.eval()
with torch.no_grad():
proj_dtype = next(self.projection.parameters()).dtype
patch_tokens = patch_tokens.unsqueeze(0).cuda().to(proj_dtype)
visual_tokens = self.projection(patch_tokens)
text_embeds = self.llm.get_input_embeddings()(
tokenizer(
f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{question}<|im_end|>\n"
f"<|im_start|>assistant\n",
return_tensors="pt"
)["input_ids"].cuda()
)
# cast visual to match LLM
visual_tokens = visual_tokens.to(text_embeds.dtype)
inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
attention_mask = torch.ones(1, inputs_embeds.shape[1], device="cuda")
output_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
repetition_penalty=1.3,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
def save_pretrained(self, path):
os.makedirs(path, exist_ok=True)
torch.save({
"projection": self.projection.state_dict(),
"llm": self.llm.state_dict(),
"config": self.config.__dict__
}, os.path.join(path, "vlm.pt"))
print(f"saved to {path}")
@classmethod
def load_pretrained(cls, path, llm, projection):
checkpoint = torch.load(os.path.join(path, "vlm.pt"), map_location="cuda")
config = VLMConfig(**checkpoint["config"])
model = cls(config, llm, projection)
model.llm.load_state_dict(checkpoint["llm"])
model.projection.load_state_dict(checkpoint["projection"])
return model.cuda()
def push_to_hub(self, repo_id):
from huggingface_hub import HfApi
import tempfile
api = HfApi()
api.create_repo(repo_id, exist_ok=True, repo_type="model")
with tempfile.TemporaryDirectory() as tmp:
self.save_pretrained(tmp)
api.upload_folder(folder_path=tmp, repo_id=repo_id)
print(f"pushed to hub: {repo_id}")
Stage 1 Training
As stated before on the first stage of training we freeze the LLM and we train only the projection layer. This stage is done with only 1 Epoch of training through our CC3M Dataset subset.
for param in vlm.llm.parameters():
param.requires_grad = False
for param in vlm.projection.parameters():
param.requires_grad = True
train_loader_s1 = DataLoader(
train_dataset_s1,
batch_size=16,
shuffle=True,
num_workers=2,
pin_memory=True
)
optimizer_s1 = AdamW(vlm.projection.parameters(), lr=1e-4)
scheduler_s1 = CosineAnnealingLR(optimizer_s1, T_max=len(train_loader_s1))
EPOCHS_S1 = 1
LOG_EVERY = 100
SAVE_EVERY = 1000
vlm.train()
for epoch in range(EPOCHS_S1):
total_loss = 0
for step, batch in enumerate(tqdm(train_loader_s1)):
optimizer_s1.zero_grad()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = vlm(
patch_tokens=batch["patch_tokens"],
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"]
)
loss = output.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(vlm.projection.parameters(), 1.0)
optimizer_s1.step()
scheduler_s1.step()
total_loss += loss.item()
if step % LOG_EVERY == 0:
print(f"[S1] epoch {epoch} step {step} | loss {loss.item():.4f} | avg {total_loss/(step+1):.4f}")
if step % SAVE_EVERY == 0 and step > 0:
torch.save(vlm.projection.state_dict(), f"./projection_s1_step_{step}.pt")
print(f"checkpoint saved at step {step}")
print(f"[S1] epoch {epoch} done | avg loss {total_loss/len(train_loader_s1):.4f}")
torch.save(vlm.projection.state_dict(), "./projection_stage1_final.pt")
print("stage 1 complete")
We can see that just with this first stage of training our model starts to recognise the image tokens although incorrectly, it was the picture of an olive oil bottle.
sample = train_dataset_s1[0]
response = vlm.generate_caption(
sample["patch_tokens"],
"Describe the image.",
tokenizer
)
Stage 1 output: a picture of a large glass jar with an open lid .
Stage 2 - Where the magic happens
In this stage we unfreeze both the llm and the projection layer, and we train them alongside for 3 Epochs.
vlm.projection.load_state_dict(torch.load("./projection_stage1_final.pt")) # load stage 1 projection weights
print("stage 1 weights loaded")
for param in vlm.llm.parameters():
param.requires_grad = True
for param in vlm.projection.parameters():
param.requires_grad = True
ACCUMULATION_STEPS = 8 # batch size = 8 * 8 = 64
train_loader_s2 = DataLoader(
train_dataset_s2,
batch_size=8,
shuffle=True,
num_workers=4,
pin_memory=True
)
optimizer_s2 = AdamW([
{"params": vlm.projection.parameters(), "lr": 1e-4},
{"params": vlm.llm.parameters(), "lr": 2e-5}
])
scheduler_s2 = CosineAnnealingLR(optimizer_s2, T_max=len(train_loader_s2) * 3)
EPOCHS_S2 = 3
LOG_EVERY = 100
SAVE_EVERY = 1000
vlm.train()
for epoch in range(EPOCHS_S2):
total_loss = 0
optimizer_s2.zero_grad()
for step, batch in enumerate(tqdm(train_loader_s2)):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = vlm(
patch_tokens=batch["patch_tokens"],
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"]
)
loss = output.loss / ACCUMULATION_STEPS
loss.backward()
if (step + 1) % ACCUMULATION_STEPS == 0:
torch.nn.utils.clip_grad_norm_(vlm.parameters(), 1.0)
optimizer_s2.step()
scheduler_s2.step()
optimizer_s2.zero_grad()
total_loss += loss.item() * ACCUMULATION_STEPS
if step % LOG_EVERY == 0:
print(f"[S2] epoch {epoch} step {step} | loss {loss.item() * ACCUMULATION_STEPS:.4f} | avg {total_loss/(step+1):.4f}")
if step % SAVE_EVERY == 0 and step > 0:
vlm.save_pretrained(f"./vlm_s2_step_{step}")
tokenizer.save_pretrained(f"./vlm_s2_step_{step}")
print(f"checkpoint saved at step {step}")
prev = f"./vlm_s2_step_{step - SAVE_EVERY}"
if os.path.exists(prev):
shutil.rmtree(prev)
print(f"deleted checkpoint at step {step - SAVE_EVERY}")
sample = train_dataset_s2[0]
response = vlm.generate_caption(
sample["patch_tokens"],
"What is in this image?",
tokenizer
)
print(f"[S2] epoch {epoch} sample: {response}")
print(f"[S2] epoch {epoch} done | avg loss {total_loss/len(train_loader_s2):.4f}")
vlm.train()
print("stage 2 training complete")
Running the model
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from transformers import CLIPVisionModel
from huggingface_hub import hf_hub_download
from PIL import Image
import requests
from io import BytesIO
HUB_REPO = "Martingkc/qwen-05B-vlm-LLaVa-v2"
# download vlm.pt from hub and load weights
ckpt_path = hf_hub_download(repo_id=HUB_REPO, filename="vlm.pt")
checkpoint = torch.load(ckpt_path, map_location="cuda")
llm.load_state_dict(checkpoint["llm"])
projection.load_state_dict(checkpoint["projection"])
print("weights loaded")
config = VLMConfig(vision_hidden=768, llm_hidden=896, n_visual=49)
vlm = VLMModel(config, llm, projection).cuda()
vlm.eval()
clip_vision = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").cuda()
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
def encode_image(image: Image.Image):
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = clip_vision(pixel_values=inputs["pixel_values"].cuda())
# drop CLS token -> [49, 768] we use the patch tokens
return outputs.last_hidden_state[0, 1:, :].cpu()
def ask(image: Image.Image, question: str):
patch_tokens = encode_image(image)
return vlm.generate_caption(patch_tokens, question, tokenizer)
url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRn_SnbOBt9UoNy8Yft_ZGSr6kBDugy0O0jXQ&s"
image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
print(ask(image, "What is in this image?"))
In this image, there is a group of people sitting at an outdoor dining table. They are engaged in conversation and enjoying each other's company while holding up their cell phones to take pictures or record videos together.
print(ask(image, "How many people are there in this image?"))
There are three people in this image.
Eh in this example it behaved ok I guess, definitely not SOTA level performance but I mean what can you expect with 6 hours of training on a Colab notebook on a model that barely manages to generate coherent text…
url2 = "https://cdn.prod.website-files.com/62d84e447b4f9e7263d31e94/637627ca9eebde45ae5f394c_Underwater-Nun.jpeg"
image = Image.open(BytesIO(requests.get(url2).content)).convert("RGB")
print(ask(image, "What is in this image?"))
In this image, there is a person wearing glasses. The scene also features fish swimming in the water behind them on top of some rocks or coral reefs.
well on this example it performed better than I expected.
url3 = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTLxctgOX9u23xPYCAnWddKXUZZg0X5VwdXJA&s"
image = Image.open(BytesIO(requests.get(url3).content)).convert("RGB")
print(ask(image, "What is in this image?"))
In this image, there is a group of people standing together and holding wine glasses. They are posing for the camera while enjoying their time at Christmas dinner or holiday celebration with festive decorations like lights on trees in the background.
Same in this one, it managed to recognise the trees, decorations and the glasses. It’s quite impressive for a model that weighs only 1GB!!!!!
Summary
The model definitely isn’t a SOTA model however given the size of it, it excedeed my expectations, most likely if I had more compute and had gotten the chance to train it with a larger dataset or by using a more granular CLIP model it would’ve performed much better. So if anyone that isnt as GPU poor as me gets the chance to mess with it LMK!!
| Component | Value |
|---|---|
| LLM | Qwen2.5-0.5B |
| Vision Encoder | CLIP ViT-B/32 |
| Resolution | 224×224 |
| Visual Tokens | 49 |
| Training Time | ~6h |
| GPU | Colab A100 |
| Params Trained Stage 1 | Projection only |
| Params Trained Stage 2 | Full model |
| Dataset Stage 1 | Martingkc/LLaVa-CC3M-Pretrain-clip-vit-base-patch32 |
| Dataset Stage 2 | Martingkc/LLaVa-Instruct-150K-clip-vit-base-patch32 |
Notebook
Enjoy Reading This Article?
Here are some more articles you might like to read next: