Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 70 additions & 119 deletions janus/models/modeling_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from janus.models.projector import MlpProjector


class vision_head(torch.nn.Module):
def __init__(self, params):
class VisionHead(torch.nn.Module):
def __init__(self, params: AttrDict):
super().__init__()
self.output_mlp_projector = torch.nn.Linear(
params.n_embed, params.image_token_embed
Expand All @@ -44,135 +44,82 @@ def __init__(self, params):
params.image_token_embed, params.image_token_size
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.output_mlp_projector(x)
x = self.vision_activation(x)
x = self.vision_head(x)
return x


def model_name_to_cls(cls_name):
def model_name_to_cls(cls_name: str) -> type:
if "MlpProjector" in cls_name:
cls = MlpProjector

elif "CLIPVisionTower" in cls_name:
cls = CLIPVisionTower

elif "VQ" in cls_name:
from janus.models.vq_model import VQ_models

cls = VQ_models[cls_name]
elif "vision_head" in cls_name:
cls = vision_head
# Maintain backward compatibility with existing configs using "vision_head"
cls = VisionHead
else:
raise ValueError(f"class_name {cls_name} is invalid.")
raise ValueError(f"Invalid class name: {cls_name}")

return cls


class VisionConfig(PretrainedConfig):
model_type = "vision"
class BaseSubConfig(PretrainedConfig):
model_type: str = "base"
cls: str = ""
params: AttrDict = {}
params: AttrDict = AttrDict()

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))


class AlignerConfig(PretrainedConfig):
model_type = "aligner"
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)
class VisionConfig(BaseSubConfig):
model_type = "vision"

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))
class AlignerConfig(BaseSubConfig):
model_type = "aligner"


class GenVisionConfig(PretrainedConfig):
class GenVisionConfig(BaseSubConfig):
model_type = "gen_vision"
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))


class GenAlignerConfig(PretrainedConfig):
class GenAlignerConfig(BaseSubConfig):
model_type = "gen_aligner"
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))


class GenHeadConfig(PretrainedConfig):
class GenHeadConfig(BaseSubConfig):
model_type = "gen_head"
cls: str = ""
params: AttrDict = {}

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__

self.params = AttrDict(kwargs.get("params", {}))


class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig

gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig

language_config: LlamaConfig

def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionConfig(**vision_config)

aligner_config = kwargs.get("aligner_config", {})
self.aligner_config = AlignerConfig(**aligner_config)

gen_vision_config = kwargs.get("gen_vision_config", {})
self.gen_vision_config = GenVisionConfig(**gen_vision_config)

gen_aligner_config = kwargs.get("gen_aligner_config", {})
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)

gen_head_config = kwargs.get("gen_head_config", {})
self.gen_head_config = GenHeadConfig(**gen_head_config)

self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
self.gen_aligner_config = GenAlignerConfig(
**kwargs.get("gen_aligner_config", {})
)
self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))
language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
Expand All @@ -191,32 +138,31 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def __init__(self, config: MultiModalityConfig):
super().__init__(config)

vision_config = config.vision_config
vision_cls = model_name_to_cls(vision_config.cls)
self.vision_model = vision_cls(**vision_config.params)

aligner_config = config.aligner_config
aligner_cls = model_name_to_cls(aligner_config.cls)
self.aligner = aligner_cls(aligner_config.params)

gen_vision_config = config.gen_vision_config
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
self.gen_vision_model = gen_vision_cls()

gen_aligner_config = config.gen_aligner_config
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
# Initialize vision components
self.vision_model = model_name_to_cls(config.vision_config.cls)(
**config.vision_config.params
)
self.aligner = model_name_to_cls(config.aligner_config.cls)(
config.aligner_config.params
)

gen_head_config = config.gen_head_config
gen_head_cls = model_name_to_cls(gen_head_config.cls)
self.gen_head = gen_head_cls(gen_head_config.params)
# Initialize generation components
self.gen_vision_model = model_name_to_cls(config.gen_vision_config.cls)(
**config.gen_vision_config.params
)
self.gen_aligner = model_name_to_cls(config.gen_aligner_config.cls)(
config.gen_aligner_config.params
)
self.gen_head = model_name_to_cls(config.gen_head_config.cls)(
config.gen_head_config.params
)

# Initialize embeddings and language model
self.gen_embed = torch.nn.Embedding(
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
config.gen_vision_config.params.image_token_size,
config.gen_vision_config.params.n_embed,
)

language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config)
self.language_model = LlamaForCausalLM(config.language_config)

def prepare_inputs_embeds(
self,
Expand All @@ -225,44 +171,49 @@ def prepare_inputs_embeds(
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.LongTensor,
**kwargs,
):
) -> torch.Tensor:
"""
Prepares combined text and image embeddings for the language model.

Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]

assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
input_ids: Token IDs for text inputs, shape [batch_size, seq_len]
pixel_values: Image tensors, shape [batch_size, num_images, channels, height, width]
images_seq_mask: Boolean mask indicating image positions in the text sequence,
shape [batch_size, seq_len]
images_emb_mask: Boolean mask for valid image tokens per image,
shape [batch_size, num_images, tokens_per_image]

Returns:
input_embeds (torch.Tensor): [b, T, D]
Combined embeddings tensor of shape [batch_size, seq_len, embedding_dim]
"""

bs, n = pixel_values.shape[0:2]
batch_size, num_images = pixel_values.shape[:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
images_embeds = self.aligner(self.vision_model(images)) # [(b n), tokens, dim]

# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
# Reshape embeddings and masks
images_embeds = rearrange(
images_embeds, "(b n) t d -> b (n t) d", b=batch_size, n=num_images
)
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")

# [b, T, D]
input_ids[input_ids < 0] = 0 # ignore the image embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
# Validate mask compatibility
assert torch.all(
images_seq_mask.sum(dim=1) == images_emb_mask.sum(dim=1)
), "Masks must have matching number of image tokens"

# replace with the image embeddings
# Get text embeddings and replace image positions
input_ids = input_ids.masked_fill(input_ids < 0, 0) # Replace negatives
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]

return inputs_embeds

def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor) -> torch.Tensor:
"""Generates image embeddings from token IDs for generation."""
return self.gen_aligner(self.gen_embed(image_ids))


# Configuration registration
AutoConfig.register("vision", VisionConfig)
AutoConfig.register("aligner", AlignerConfig)
AutoConfig.register("gen_vision", GenVisionConfig)
Expand Down