Skip to content

vllm.model_executor.models.colmodernvbert

ColModernVBERT: multimodal late-interaction retrieval model.

Combines SigLIP vision encoder + ModernBERT text encoder with a pixel shuffle connector and ColBERT-style 128-dim per-token embeddings.

Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged

ColModernVBertConnector

Bases: Module

Pixel shuffle spatial reduction followed by a linear projection.

Reduces the vision encoder's token count by factor^2 via pixel-shuffle spatial rearrangement, then projects the concatenated channels to the text encoder's hidden size with a single bias-free linear layer.

Source code in vllm/model_executor/models/colmodernvbert.py
class ColModernVBertConnector(nn.Module):
    """Pixel shuffle spatial reduction followed by a linear projection.

    Reduces the vision encoder's token count by ``factor^2`` via pixel-shuffle
    spatial rearrangement, then projects the concatenated channels to the text
    encoder's hidden size with a single bias-free linear layer.
    """

    def __init__(self, config: ColModernVBertConfig):
        super().__init__()
        self.pixel_shuffle_factor = config.pixel_shuffle_factor
        vision_hidden_size = config.vision_config.hidden_size
        input_size = vision_hidden_size * (self.pixel_shuffle_factor**2)
        output_size = config.hidden_size
        self.proj = nn.Linear(input_size, output_size, bias=False)

    def pixel_shuffle(self, features: torch.Tensor) -> torch.Tensor:
        """Spatial rearrangement that reduces seq length by factor^2."""
        batch_size, seq_length, hidden_size = features.shape
        height = width = int(seq_length**0.5)
        factor = self.pixel_shuffle_factor

        # Reshape to (B, H, W, C)
        features = features.view(batch_size, height, width, hidden_size)

        # Reshape to (B, H/f, f, W/f, f, C)
        features = features.view(
            batch_size, height // factor, factor, width // factor, factor, hidden_size
        )

        # Permute to (B, H/f, W/f, f, f, C)
        features = features.permute(0, 1, 3, 2, 4, 5)

        # Reshape to (B, H/f, W/f, C * f^2)
        new_hidden_size = hidden_size * (factor**2)
        features = features.reshape(
            batch_size, height // factor, width // factor, new_hidden_size
        )

        return features

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        features = self.pixel_shuffle(features)
        batch_size = features.shape[0]
        features = features.reshape(batch_size, -1, features.shape[-1])
        return self.proj(features)

pixel_shuffle

pixel_shuffle(features: Tensor) -> Tensor

Spatial rearrangement that reduces seq length by factor^2.

Source code in vllm/model_executor/models/colmodernvbert.py
def pixel_shuffle(self, features: torch.Tensor) -> torch.Tensor:
    """Spatial rearrangement that reduces seq length by factor^2."""
    batch_size, seq_length, hidden_size = features.shape
    height = width = int(seq_length**0.5)
    factor = self.pixel_shuffle_factor

    # Reshape to (B, H, W, C)
    features = features.view(batch_size, height, width, hidden_size)

    # Reshape to (B, H/f, f, W/f, f, C)
    features = features.view(
        batch_size, height // factor, factor, width // factor, factor, hidden_size
    )

    # Permute to (B, H/f, W/f, f, f, C)
    features = features.permute(0, 1, 3, 2, 4, 5)

    # Reshape to (B, H/f, W/f, C * f^2)
    new_hidden_size = hidden_size * (factor**2)
    features = features.reshape(
        batch_size, height // factor, width // factor, new_hidden_size
    )

    return features

ColModernVBertForRetrieval

Bases: Module, SupportsMultiModal

ColModernVBERT multimodal late-interaction retrieval model.

Architecture

Image -> SiglipVisionModel -> ColModernVBertConnector ↓ Text -> ModernBertEmbeddings → [merge] → ModernBertLayers → norm ↓ custom_text_proj → L2 norm ↓ per-token 128-d embeddings

Source code in vllm/model_executor/models/colmodernvbert.py
@MULTIMODAL_REGISTRY.register_processor(
    ColModernVBertMultiModalProcessor,
    info=ColModernVBertProcessingInfo,
    dummy_inputs=ColModernVBertDummyInputsBuilder,
)
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal):
    """ColModernVBERT multimodal late-interaction retrieval model.

    Architecture:
        Image -> SiglipVisionModel -> ColModernVBertConnector

        Text  -> ModernBertEmbeddings → [merge] → ModernBertLayers → norm

                                              custom_text_proj → L2 norm

                                          per-token 128-d embeddings
    """

    is_pooling_model = True
    supports_late_interaction: ClassVar[Literal[True]] = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: ColModernVBertConfig = vllm_config.model_config.hf_config
        self.config = config
        text_config = config.text_config
        quant_config = vllm_config.quant_config

        # --- Vision encoder (reuses SiglipVisionModel from siglip.py) ---
        self.vision_model = SiglipVisionModel(
            config.vision_config,
            quant_config,
            prefix=maybe_prefix(prefix, "vision_model"),
        )

        # --- Connector (pixel shuffle + linear projection) ---
        self.connector = ColModernVBertConnector(config)

        # --- Text encoder (built from ModernBERT components directly) ---
        # We build the components individually rather than wrapping
        # ``ModernBertModel`` because ``ModernBertEncoderLayer`` reads
        # ``vllm_config.model_config.hf_config`` which would be
        # ``ColModernVBertConfig``, not ``ModernBertConfig``.
        self.text_embeddings = ModernBertEmbeddings(text_config)
        self.text_layers = nn.ModuleList(
            [
                ModernBertLayer(
                    config=text_config,
                    layer_id=i,
                    prefix=f"{prefix}.text_layers.{i}",
                )
                for i in range(text_config.num_hidden_layers)
            ]
        )
        self.text_final_norm = nn.LayerNorm(
            text_config.hidden_size,
            eps=text_config.norm_eps,
            bias=text_config.norm_bias,
        )

        # --- ColBERT projection (768 -> 128, with bias) ---
        self.custom_text_proj = nn.Linear(
            text_config.hidden_size,
            config.embedding_dim,
            bias=True,
            dtype=vllm_config.model_config.head_dtype,
        )

        # --- Pooler (applies projection + L2 normalize) ---
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler = pooler_for_token_embed(
            pooler_config,
            projector=self.custom_text_proj,
        )

    # ---- multimodal ---------------------------------------------------------

    def _get_image_features(
        self,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        # Idefics3ImageProcessor may return (batch, tiles, C, H, W);
        # flatten to (batch*tiles, C, H, W) for SiglipVisionModel.
        if pixel_values.dim() == 5:
            b, t, c, h, w = pixel_values.shape
            pixel_values = pixel_values.reshape(b * t, c, h, w)
        vision_outputs = self.vision_model(
            pixel_values.to(dtype=self.vision_model.dtype),
        )
        return self.connector(vision_outputs)

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return []
        assert isinstance(pixel_values, torch.Tensor)
        image_features = self._get_image_features(pixel_values)
        return list(image_features)

    # ---- forward ------------------------------------------------------------

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        hidden_states = self.text_embeddings(input_ids, inputs_embeds=inputs_embeds)

        for layer in self.text_layers:
            hidden_states = layer(hidden_states, positions)

        return self.text_final_norm(hidden_states)

    # ---- weight loading -----------------------------------------------------

    # Checkpoint prefix → vLLM param prefix.
    # More-specific prefixes must appear before shorter ones.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.text_model.layers.": "text_layers.",
            "model.text_model.embeddings.": "text_embeddings.",
            "model.text_model.final_norm.": "text_final_norm.",
            "model.connector.modality_projection.": "connector.",
            "model.custom_text_proj.": "custom_text_proj.",
            "model.vision_model.": "vision_model.vision_model.",
            "model.": "",
        },
    )

    # Checkpoint names for DecoupledEmbedding parts
    _BASE_EMB = "model.text_model.embeddings.tok_embeddings.weight"
    _EXTRA_EMB = (
        "model.text_model.embeddings.tok_embeddings.additional_embedding.weight"
    )

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        # DecoupledEmbedding requires concatenating base + additional
        # embedding tensors before loading, so we extract them first.
        base_embedding_weight: torch.Tensor | None = None
        additional_embedding_weight: torch.Tensor | None = None
        remaining: list[tuple[str, torch.Tensor]] = []

        for name, tensor in weights:
            if name == self._BASE_EMB:
                base_embedding_weight = tensor
            elif name == self._EXTRA_EMB:
                additional_embedding_weight = tensor
            else:
                remaining.append((name, tensor))

        # Load all non-embedding weights via AutoWeightsLoader
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(
            remaining,
            mapper=self.hf_to_vllm_mapper,
        )

        # Concatenate and load DecoupledEmbedding weights
        if base_embedding_weight is not None:
            combined = base_embedding_weight
            if additional_embedding_weight is not None:
                combined = torch.cat(
                    [base_embedding_weight, additional_embedding_weight],
                    dim=0,
                )
            param_name = "text_embeddings.tok_embeddings.weight"
            params_dict = dict(self.named_parameters())
            if param_name in params_dict:
                param = params_dict[param_name]
                weight_loader = getattr(
                    param,
                    "weight_loader",
                    default_weight_loader,
                )
                weight_loader(param, combined)
                loaded_params.add(param_name)
        elif additional_embedding_weight is not None:
            raise ValueError(
                "Found 'text_model.embeddings.tok_embeddings"
                ".additional_embedding.weight' but not "
                "'text_model.embeddings.tok_embeddings.weight'"
            )

        # The pooler wraps ``custom_text_proj`` as its head projector.
        # Mark those params as loaded under the pooler path too.
        if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
            head = self.pooler.head
            projector = getattr(head, "projector", None)
            if projector is not None and isinstance(projector, nn.Module):
                for pname, _ in projector.named_parameters():
                    loaded_params.add(f"pooler.head.projector.{pname}")

        return loaded_params