Skip to content

vllm.model_executor.models.moondream3

Inference-only Moondream3 model implementation.

DetectPointState dataclass

Per-request state for detect/point generation.

Tracks the current step in the state machine, pending embedding information for the next decode step, the object currently being constructed, and the accumulated results.

Source code in vllm/model_executor/models/moondream3.py
@dataclass
class DetectPointState:
    """Per-request state for detect/point generation.

    Tracks the current step in the state machine, pending embedding
    information for the next decode step, the object currently being
    constructed, and the accumulated results.
    """

    mode: Literal["detect", "point"]
    step: Literal["decode_x_or_stop", "decode_y", "decode_size"] = "decode_x_or_stop"

    # Embedding info for the *next* decode step (set in compute_logits,
    # consumed in forward to replace the token embedding).
    pending_embed_type: str | None = None  # "coord" or "size"
    pending_embed_coord: float | None = None
    pending_embed_w: float | None = None
    pending_embed_h: float | None = None

    # Current object being constructed.
    current_x: float | None = None
    current_y: float | None = None

    # Accumulated results.
    objects: list[dict] = field(default_factory=list)
    finished: bool = False
    max_objects: int = _DEFAULT_MAX_OBJECTS

DetectPointStateManager

Manages per-request detect/point state for the model.

Source code in vllm/model_executor/models/moondream3.py
class DetectPointStateManager:
    """Manages per-request detect/point state for the model."""

    def __init__(
        self,
        coord_id: int = 5,
        size_id: int = 6,
        eos_id: int = 0,
    ) -> None:
        self._states: dict[str, DetectPointState] = {}
        self.coord_id = coord_id
        self.size_id = size_id
        self.eos_id = eos_id

    def register_request(
        self,
        req_id: str,
        mode: Literal["detect", "point"],
        max_objects: int = _DEFAULT_MAX_OBJECTS,
    ) -> None:
        self._states[req_id] = DetectPointState(mode=mode, max_objects=max_objects)

    def get_state(self, req_id: str) -> DetectPointState | None:
        return self._states.get(req_id)

    def has_active_requests(self) -> bool:
        return bool(self._states)

    def remove_request(self, req_id: str) -> None:
        self._states.pop(req_id, None)

    def get_json_result(self, req_id: str) -> str | None:
        """Serialize the accumulated objects for a finished request."""
        state = self._states.get(req_id)
        if state is None:
            return None
        if state.mode == "detect":
            return json.dumps({"objects": state.objects})
        return json.dumps({"points": state.objects})

    def update_after_sample(self, req_id: str, sampled_token_id: int) -> None:
        """Transition state after a token is sampled."""
        state = self._states.get(req_id)
        if state is None:
            return

        tok = sampled_token_id

        if state.step == "decode_x_or_stop":
            if tok == self.eos_id:
                state.finished = True
            else:
                # coord token sampled → move to decode_y
                state.step = "decode_y"

        elif state.step == "decode_y":
            if state.mode == "point":
                # For point: y was decoded, object is complete.
                state.objects.append({"x": state.current_x, "y": state.current_y})
                state.step = "decode_x_or_stop"
            else:
                # For detect: y decoded, need size next.
                state.step = "decode_size"

        elif state.step == "decode_size":
            # Size decoded, compute bbox from center + size.
            x, y = state.current_x, state.current_y
            w, h = state.pending_embed_w, state.pending_embed_h
            state.objects.append(
                {
                    "x_min": x - w / 2,
                    "y_min": y - h / 2,
                    "x_max": x + w / 2,
                    "y_max": y + h / 2,
                }
            )
            state.step = "decode_x_or_stop"

        # Check max objects limit.
        if len(state.objects) >= state.max_objects:
            state.finished = True

get_json_result

get_json_result(req_id: str) -> str | None

Serialize the accumulated objects for a finished request.

Source code in vllm/model_executor/models/moondream3.py
def get_json_result(self, req_id: str) -> str | None:
    """Serialize the accumulated objects for a finished request."""
    state = self._states.get(req_id)
    if state is None:
        return None
    if state.mode == "detect":
        return json.dumps({"objects": state.objects})
    return json.dumps({"points": state.objects})

update_after_sample

update_after_sample(
    req_id: str, sampled_token_id: int
) -> None

Transition state after a token is sampled.

Source code in vllm/model_executor/models/moondream3.py
def update_after_sample(self, req_id: str, sampled_token_id: int) -> None:
    """Transition state after a token is sampled."""
    state = self._states.get(req_id)
    if state is None:
        return

    tok = sampled_token_id

    if state.step == "decode_x_or_stop":
        if tok == self.eos_id:
            state.finished = True
        else:
            # coord token sampled → move to decode_y
            state.step = "decode_y"

    elif state.step == "decode_y":
        if state.mode == "point":
            # For point: y was decoded, object is complete.
            state.objects.append({"x": state.current_x, "y": state.current_y})
            state.step = "decode_x_or_stop"
        else:
            # For detect: y decoded, need size next.
            state.step = "decode_size"

    elif state.step == "decode_size":
        # Size decoded, compute bbox from center + size.
        x, y = state.current_x, state.current_y
        w, h = state.pending_embed_w, state.pending_embed_h
        state.objects.append(
            {
                "x_min": x - w / 2,
                "y_min": y - h / 2,
                "x_max": x + w / 2,
                "y_max": y + h / 2,
            }
        )
        state.step = "decode_x_or_stop"

    # Check max objects limit.
    if len(state.objects) >= state.max_objects:
        state.finished = True

Moondream3Attention

Bases: Module

Decoder attention with RoPE and tau scaling.

Moondream3 uses a tau attention mechanism that scales Q and V based on both token content and position.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3Attention(nn.Module):
    """Decoder attention with RoPE and tau scaling.

    Moondream3 uses a tau attention mechanism that scales Q and V
    based on both token content and position.
    """

    def __init__(
        self,
        config: Moondream3TextConfig,
        layer_idx: int,
        cache_config=None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = config.dim
        self.num_heads = config.n_heads
        self.num_kv_heads = config.n_kv_heads
        self.head_dim = config.dim // config.n_heads

        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = self.num_heads // tp_size
        self.num_kv_heads_per_partition = max(1, self.num_kv_heads // tp_size)

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            total_num_kv_heads=self.num_kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
        )

        self.out_proj = RowParallelLinear(
            input_size=self.hidden_size,
            output_size=self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
        )

        # Moondream uses 32-dim rotation out of 64-dim head (partial_rotary_factor=0.5)
        # HF Moondream uses non-interleaved RoPE (split by half)
        # In vLLM, is_neox_style=True means split by half (GPT-NeoX style)
        rope_parameters = {
            "rope_theta": config.rope_theta,
            "partial_rotary_factor": 32 / self.head_dim,  # 32/64 = 0.5
        }
        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            max_position=config.max_context,
            rope_parameters=rope_parameters,
            is_neox_style=True,  # Moondream uses split-by-half (GPT-NeoX) style
        )

        self.scaling = self.head_dim**-0.5
        self.attn = Attention(
            num_heads=self.num_heads_per_partition,
            head_size=self.head_dim,
            scale=self.scaling,
            num_kv_heads=self.num_kv_heads_per_partition,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

        # Tau scaling parameters for position-dependent attention
        # These are learned during training to modulate attention based on position
        # tau_wq and tau_wv need full qkv_dim for correct computation
        # Only heads are partitioned, qkv dimension is kept full for all-gather
        qkv_dim = self.hidden_size * 3  # Q + K + V dimension (full)
        self.tau_alpha = nn.Parameter(torch.zeros(self.num_heads_per_partition))
        self.tau_wq = nn.Parameter(torch.zeros(self.num_heads_per_partition, qkv_dim))
        self.tau_wv = nn.Parameter(torch.zeros(self.num_heads_per_partition, qkv_dim))
        self.tp_size = tp_size

        # Prefix-LM attention length: BOS (1) + vision patches (729) = 730
        self._prefix_attn_len = config.prefix_attn  # 730
        if self._prefix_attn_len > config.max_context:
            raise ValueError(
                "prefix_attn must be <= max_context, "
                f"got {self._prefix_attn_len} > {config.max_context}."
            )
        # Build once and slice per prefill call to avoid allocating an N x N
        # mask on every forward pass.
        prefill_mask = torch.tril(
            torch.ones(
                1,
                1,
                config.max_context,
                config.max_context,
                dtype=torch.bool,
            )
        )
        prefill_mask[:, :, : self._prefix_attn_len, : self._prefix_attn_len] = True
        self.register_buffer("_prefill_mask", prefill_mask, persistent=False)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)

        q, k, v = qkv.split(
            [
                self.num_heads_per_partition * self.head_dim,
                self.num_kv_heads_per_partition * self.head_dim,
                self.num_kv_heads_per_partition * self.head_dim,
            ],
            dim=-1,
        )

        # Apply tau scaling to Q and V
        # Tau scaling has two components:
        # 1. Token-based: tok_q = tanh(gelu(qkv) @ tau_wq.T)
        # 2. Position-based: tau_pos = 1 + (sigmoid(alpha * log(pos+1)) - 0.5)
        # Final: tau = tok + tau_pos
        #
        # For TP, tau weights are sharded by head, but qkv_dim is kept full

        # Get full qkv for tau computation
        # With TP, reconstruct qkv in correct layout [q_full, k_full, v_full]
        # (all-gather would produce [q_0, k_0, v_0, q_1, k_1, v_1] - wrong)
        if self.tp_size > 1:
            from vllm.distributed import tensor_model_parallel_all_gather

            # All-gather once, then reconstruct [q_full, k_full, v_full].
            qkv_full_sharded = tensor_model_parallel_all_gather(qkv.contiguous())
            q_local_dim = q.shape[-1]
            kv_local_dim = k.shape[-1]
            qkv_full_sharded = qkv_full_sharded.view(
                qkv.shape[0],
                self.tp_size,
                q_local_dim + 2 * kv_local_dim,
            )
            q_full = qkv_full_sharded[:, :, :q_local_dim].reshape(qkv.shape[0], -1)
            k_full = qkv_full_sharded[
                :, :, q_local_dim : q_local_dim + kv_local_dim
            ].reshape(qkv.shape[0], -1)
            v_full = qkv_full_sharded[:, :, q_local_dim + kv_local_dim :].reshape(
                qkv.shape[0], -1
            )
            qkv_full = torch.cat([q_full, k_full, v_full], dim=-1).contiguous()
        else:
            qkv_full = qkv

        # Compute tau scaling factors matching HF implementation exactly:
        # tok_feat = gelu(qkv)
        # tok_q = tanh(tok_feat @ tau_wq.T)  # [num_tokens, num_heads]
        # tau_pos = 1 + (sigmoid(alpha * log(pos+1)) - 0.5)  # [num_heads, num_tokens]
        # tau = (tok_q.T + tau_pos).T  # [num_tokens, num_heads]
        num_tokens = qkv_full.shape[0]
        orig_dtype = q.dtype

        # Token-based component
        tok_feat = F.gelu(qkv_full)  # Apply GELU activation
        tok_q = torch.tanh(tok_feat @ self.tau_wq.t())  # [N, H_per_partition]
        tok_v = torch.tanh(tok_feat @ self.tau_wv.t())  # [N, H_per_partition]

        # Position-based component
        # tau_pos = 1 + (sigmoid(alpha * log(pos+1)) - 0.5)
        # positions is [num_tokens], need to compute for each head
        # tau_alpha: [num_heads_per_partition]
        pos_float = (positions.to(orig_dtype) + 1.0).clamp(min=1e-6)
        pos_log = pos_float.log()  # [num_tokens]
        # alpha[:, None] * pos_log[None, :] -> [num_heads, num_tokens]
        tau_pos = 1.0 + (
            torch.sigmoid(self.tau_alpha[:, None] * pos_log[None, :]) - 0.5
        )  # [H_per_partition, N]

        # Combine token and position components
        tau_q = (tok_q + tau_pos.t()).to(orig_dtype)  # [N, H_per_partition]
        tau_v = (tok_v + tau_pos.t()).to(orig_dtype)  # [N, H_per_partition]

        # Reshape q and v to apply per-head tau scaling
        q = q.view(num_tokens, self.num_heads_per_partition, self.head_dim)
        v = v.view(num_tokens, self.num_kv_heads_per_partition, self.head_dim)

        # Apply tau scaling
        q = q * tau_q.unsqueeze(-1)
        v = v * tau_v[:, : self.num_kv_heads_per_partition].unsqueeze(-1)

        # Reshape back
        q = q.view(num_tokens, -1)
        v = v.view(num_tokens, -1)

        q, k = self.rotary_emb(positions, q, k)

        # ---- SDPA prefill override ----
        # During prefill (num_tokens >= prefix_attn_len), use PyTorch SDPA
        # with an explicit bidirectional mask for the prefix-LM region.
        # This produces hidden states that closely match the HF reference
        # (which also uses SDPA for prefill).  Decode tokens and short
        # sequences fall through to the configured attention backend (e.g.
        # FLEX_ATTENTION) for KV-cache-aware decoding.
        H = self.num_heads_per_partition
        Hkv = self.num_kv_heads_per_partition
        D = self.head_dim
        P = self._prefix_attn_len  # 730

        if num_tokens > 1 and num_tokens >= P:
            q_4d = q.view(num_tokens, H, D).transpose(0, 1).unsqueeze(0)
            k_4d = k.view(num_tokens, Hkv, D).transpose(0, 1).unsqueeze(0)
            v_4d = v.view(num_tokens, Hkv, D).transpose(0, 1).unsqueeze(0)

            # Causal mask with bidirectional prefix region. Reuse the prebuilt
            # mask when possible; some profiling paths can exceed max_context.
            if num_tokens <= self._prefill_mask.shape[-1]:
                bool_mask = self._prefill_mask[:, :, :num_tokens, :num_tokens]
            else:
                bool_mask = torch.tril(
                    torch.ones(
                        1,
                        1,
                        num_tokens,
                        num_tokens,
                        dtype=torch.bool,
                        device=q.device,
                    )
                )
                bool_mask[:, :, :P, :P] = True

            attn_output = F.scaled_dot_product_attention(
                q_4d,
                k_4d,
                v_4d,
                attn_mask=bool_mask,
                scale=self.scaling,
            )
            attn_output = (
                attn_output.squeeze(0)
                .transpose(0, 1)
                .contiguous()
                .view(num_tokens, H * D)
            )

            # Still call self.attn to populate the KV cache
            _ = self.attn(q, k, v)
        else:
            attn_output = self.attn(q, k, v)

        output, _ = self.out_proj(attn_output)
        return output

Moondream3Config dataclass

Combined configuration for Moondream3 model.

Source code in vllm/model_executor/models/moondream3.py
@dataclass
class Moondream3Config:
    """Combined configuration for Moondream3 model."""

    text: Moondream3TextConfig
    vision: Moondream3VisionConfig
    region: Moondream3RegionConfig

    @classmethod
    def from_dict(cls, d: dict) -> "Moondream3Config":
        return cls(
            text=Moondream3TextConfig.from_dict(d),
            vision=Moondream3VisionConfig.from_dict(d),
            region=Moondream3RegionConfig.from_dict(d),
        )

Moondream3DecoderLayer

Bases: Module

Decoder layer with attention + MLP/MoE.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3DecoderLayer(nn.Module):
    """Decoder layer with attention + MLP/MoE."""

    def __init__(
        self,
        config: Moondream3TextConfig,
        cache_config=None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        layer_idx = extract_layer_index(prefix)
        self.layer_idx = layer_idx

        self.ln = nn.LayerNorm(config.dim, eps=1e-5, bias=True)

        self.attn = Moondream3Attention(
            config=config,
            layer_idx=layer_idx,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

        # Use MoE for layers >= moe_start_layer, standard MLP otherwise
        if layer_idx >= config.moe_start_layer:
            self.mlp = Moondream3TextMoE(
                hidden_size=config.dim,
                expert_inner_dim=config.moe_expert_inner_dim,
                num_experts=config.moe_num_experts,
                experts_per_token=config.moe_experts_per_token,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        else:
            self.mlp = Moondream3TextMLP(
                hidden_size=config.dim,
                intermediate_size=config.ff_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Pre-norm architecture
        normed = self.ln(hidden_states)
        attn_out = self.attn(positions, normed)
        mlp_out = self.mlp(normed)
        hidden_states = hidden_states + attn_out + mlp_out
        return hidden_states

Moondream3DummyInputsBuilder

Bases: BaseDummyInputsBuilder[Moondream3ProcessingInfo]

Dummy inputs builder for profiling.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3DummyInputsBuilder(BaseDummyInputsBuilder[Moondream3ProcessingInfo]):
    """Dummy inputs builder for profiling."""

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return (
            "<|endoftext|><image><|md_reserved_0|>query<|md_reserved_1|>"
            "What is this image?<|md_reserved_2|>"
        )

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
        mm_processor_kwargs: Mapping[str, object] | None = None,
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        return {
            "image": self._get_dummy_images(
                width=378,
                height=378,
                num_images=num_images,
            )
        }

Moondream3ForCausalLM

Bases: Module, SupportsMultiModal, SupportsPP

Moondream3 multimodal model for causal language modeling.

Moondream3 has four capabilities:

  • query: Visual QA.
  • caption: Image description.
  • detect: Object detection (bounding boxes).
  • point: Object pointing (x, y coordinates).

All four capabilities are supported. Query and caption use standard autoregressive generation. Detect and point use a custom state machine that intercepts compute_logits and forward to decode coordinates from hidden states and feed Fourier-encoded coordinate embeddings back as the next input.

Detect/point mode is activated by setting SamplingParams(extra_args={"moondream3_task": "detect"}) (or "point").

Source code in vllm/model_executor/models/moondream3.py
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
@MULTIMODAL_REGISTRY.register_processor(
    Moondream3MultiModalProcessor,
    info=Moondream3ProcessingInfo,
    dummy_inputs=Moondream3DummyInputsBuilder,
)
class Moondream3ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
    """Moondream3 multimodal model for causal language modeling.

    Moondream3 has four capabilities:

    - **query**: Visual QA.
    - **caption**: Image description.
    - **detect**: Object detection (bounding boxes).
    - **point**: Object pointing (x, y coordinates).

    All four capabilities are supported. Query and caption use standard
    autoregressive generation. Detect and point use a custom state machine
    that intercepts ``compute_logits`` and ``forward`` to decode
    coordinates from hidden states and feed Fourier-encoded coordinate
    embeddings back as the next input.

    Detect/point mode is activated by setting
    ``SamplingParams(extra_args={"moondream3_task": "detect"})``
    (or ``"point"``).
    """

    supports_multimodal = True
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    }

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()

        hf_config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        cache_config = vllm_config.cache_config

        # Parse config from HuggingFace config
        config_dict = hf_config.config if hasattr(hf_config, "config") else {}

        self.config = Moondream3Config.from_dict(config_dict)

        with self._mark_tower_model(vllm_config, "image"):
            # Vision encoder
            self.vision = Moondream3VisionEncoder(
                config=self.config.vision,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision"),
            )

            # Vision projection
            self.vision_proj = Moondream3VisionProjection(
                input_dim=self.config.vision.enc_dim,
                inner_dim=self.config.vision.proj_inner_dim,
                output_dim=self.config.text.dim,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_proj"),
            )

        with self._mark_language_model(vllm_config):
            # Text decoder
            self.text = Moondream3TextModel(
                config=self.config.text,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "text"),
            )

            # LM head (with bias - Moondream3 has lm_head bias)
            self.lm_head = ParallelLMHead(
                self.config.text.vocab_size,
                self.config.text.dim,
                bias=True,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )

        # Region module for point/detect coordinate encoding/decoding.
        self.region = Moondream3RegionModule(
            config=self.config.region,
            prefix=maybe_prefix(prefix, "region"),
        )
        self.logits_processor = LogitsProcessor(self.config.text.vocab_size)
        self.make_empty_intermediate_tensors = self.text.make_empty_intermediate_tensors

        # Detect/point token IDs (from config, with defaults).
        self._coord_id = getattr(hf_config, "coord_token_id", 5)
        self._size_id = getattr(hf_config, "size_token_id", 6)
        self._eos_id = getattr(hf_config, "region_eos_token_id", 0)

        # Detect/point state management and per-step scratch buffers.
        # These buffers are populated by model hooks invoked by the v1 runner.
        self.detect_point_manager = DetectPointStateManager(
            coord_id=self._coord_id,
            size_id=self._size_id,
            eos_id=self._eos_id,
        )
        self._dp_row_states: list[DetectPointState | None] | None = None
        self._dp_embed_data: dict[int, dict] | None = None

    def on_new_request(
        self,
        *,
        req_id: str,
        sampling_params: object | None,
    ) -> None:
        """Register detect/point requests from per-request extra args."""
        if sampling_params is None:
            return
        extra = getattr(sampling_params, "extra_args", None) or {}
        dp_task = extra.get("moondream3_task")
        if dp_task == "detect":
            mode: Literal["detect", "point"] = "detect"
        elif dp_task == "point":
            mode = "point"
        else:
            return

        raw_max = extra.get("moondream3_max_objects", _DEFAULT_MAX_OBJECTS)
        try:
            max_obj = int(raw_max)
        except (TypeError, ValueError):
            raise ValueError(
                "moondream3_max_objects must be an integer, "
                f"got {type(raw_max).__name__}"
            ) from None
        if max_obj < 1:
            raise ValueError(f"moondream3_max_objects must be >= 1, got {max_obj}")
        self.detect_point_manager.register_request(req_id, mode, max_obj)

    def on_requests_finished(self, req_ids: Iterable[str]) -> None:
        for req_id in req_ids:
            self.detect_point_manager.remove_request(req_id)

    def on_before_model_forward(
        self,
        *,
        req_ids: list[str],
        logits_indices: torch.Tensor,
        device: torch.device,
    ) -> None:
        """Prepare pending coordinate/size embed replacements for forward."""
        dp_mgr = self.detect_point_manager
        if not dp_mgr.has_active_requests():
            self._dp_embed_data = None
            return

        pp = get_pp_group()
        if pp.world_size > 1:
            self._sync_dp_pending_embeds(req_ids=req_ids, device=device, pp=pp)

        dp_embed_data: dict[int, dict[str, Any]] = {}
        for i, rid in enumerate(req_ids):
            dp_st = dp_mgr.get_state(rid)
            if dp_st is None or dp_st.pending_embed_type is None:
                continue
            pos = int(logits_indices[i].item())
            if dp_st.pending_embed_type == "coord":
                dp_embed_data[pos] = {
                    "type": "coord",
                    "value": dp_st.pending_embed_coord,
                }
            elif dp_st.pending_embed_type == "size":
                dp_embed_data[pos] = {
                    "type": "size",
                    "w": dp_st.pending_embed_w,
                    "h": dp_st.pending_embed_h,
                }
        self._dp_embed_data = dp_embed_data or None

    def on_before_compute_logits(self, *, req_ids: list[str]) -> None:
        dp_mgr = self.detect_point_manager
        if not dp_mgr.has_active_requests():
            self._dp_row_states = None
            return
        self._dp_row_states = [dp_mgr.get_state(req_id) for req_id in req_ids]

    def on_after_sample(
        self,
        *,
        req_ids: list[str],
        sampled_token_ids: torch.Tensor,
    ) -> None:
        dp_mgr = self.detect_point_manager
        if not dp_mgr.has_active_requests():
            return
        for i, req_id in enumerate(req_ids):
            if dp_mgr.get_state(req_id) is None:
                continue
            tok = int(sampled_token_ids[i, 0].item())
            dp_mgr.update_after_sample(req_id, tok)

    def get_per_request_extra_output(
        self,
        *,
        req_ids: list[str],
    ) -> dict[str, dict[str, Any]] | None:
        """Return per-request final text overrides for active requests."""
        dp_mgr = self.detect_point_manager
        if not get_pp_group().is_last_rank or not dp_mgr.has_active_requests():
            return None

        per_request_extra: dict[str, dict[str, Any]] = {}
        for req_id in req_ids:
            state = dp_mgr.get_state(req_id)
            if state is None:
                continue
            json_str = dp_mgr.get_json_result(req_id)
            if json_str is not None:
                per_request_extra[req_id] = {"text_override": json_str}
        return per_request_extra or None

    def _sync_dp_pending_embeds(
        self,
        *,
        req_ids: list[str],
        device: torch.device,
        pp: Any,
    ) -> None:
        """Broadcast detect/point pending embed data across PP ranks."""
        num_reqs = len(req_ids)
        sync = torch.zeros(
            num_reqs,
            3,
            device=device,
            dtype=torch.float32,
        )

        dp_mgr = self.detect_point_manager
        if pp.is_last_rank:
            for i, req_id in enumerate(req_ids):
                st = dp_mgr.get_state(req_id)
                if st is None or st.pending_embed_type is None:
                    continue
                if st.pending_embed_type == "coord":
                    sync[i, 0] = 1.0
                    sync[i, 1] = st.pending_embed_coord or 0.0
                elif st.pending_embed_type == "size":
                    sync[i, 0] = 2.0
                    sync[i, 1] = st.pending_embed_w or 0.0
                    sync[i, 2] = st.pending_embed_h or 0.0

        torch.distributed.broadcast(
            sync,
            src=pp.last_rank,
            group=pp.device_group,
        )

        if not pp.is_last_rank:
            for i, req_id in enumerate(req_ids):
                st = dp_mgr.get_state(req_id)
                if st is None:
                    continue
                state_type = int(sync[i, 0].item())
                if state_type == 0:
                    st.pending_embed_type = None
                elif state_type == 1:
                    st.pending_embed_type = "coord"
                    st.pending_embed_coord = sync[i, 1].item()
                elif state_type == 2:
                    st.pending_embed_type = "size"
                    st.pending_embed_w = sync[i, 1].item()
                    st.pending_embed_h = sync[i, 2].item()

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality == "image":
            return "<image>"
        return None

    def get_language_model(self) -> nn.Module:
        return self.text

    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        return num_image_tokens

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        return num_vision_tokens

    def _split_pixel_values(
        self,
        pixel_values: object,
    ) -> list[torch.Tensor]:
        if isinstance(pixel_values, torch.Tensor):
            if pixel_values.dim() == 5:
                return [pv.contiguous() for pv in pixel_values]
            if pixel_values.dim() == 4:
                return [pixel_values.contiguous()]
            if pixel_values.dim() == 3:
                return [pixel_values.unsqueeze(0).contiguous()]
            raise ValueError(
                f"Unsupported pixel_values shape {tuple(pixel_values.shape)}."
            )

        if isinstance(pixel_values, (list, tuple)):
            tensors: list[torch.Tensor] = []
            for value in pixel_values:
                tensor_value = (
                    value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
                )
                if tensor_value.dim() == 3:
                    tensor_value = tensor_value.unsqueeze(0)
                elif tensor_value.dim() != 4:
                    raise ValueError(
                        f"Unsupported pixel_values element shape "
                        f"{tuple(tensor_value.shape)}."
                    )
                tensors.append(tensor_value.contiguous())
            return tensors

        raise TypeError(
            "pixel_values must be a tensor or a sequence of tensors, "
            f"got {type(pixel_values)!r}."
        )

    def _split_tilings(
        self,
        tilings: object,
        expected: int,
    ) -> list[tuple[int, int] | None]:
        if tilings is None:
            return [None] * expected

        if isinstance(tilings, torch.Tensor):
            tiling_items = tilings.tolist()
        elif isinstance(tilings, (list, tuple)):
            tiling_items = list(tilings)
        else:
            raise TypeError(
                "tilings must be None, a tensor or a sequence of tuples, "
                f"got {type(tilings)!r}."
            )

        if len(tiling_items) != expected:
            raise ValueError(
                "Mismatch between the number of pixel_values entries "
                f"({expected}) and tilings ({len(tiling_items)})."
            )

        normalized: list[tuple[int, int] | None] = []
        for tiling in tiling_items:
            if tiling is None:
                normalized.append(None)
                continue
            if isinstance(tiling, torch.Tensor):
                tiling = tiling.tolist()
            if isinstance(tiling, (list, tuple)) and len(tiling) == 2:
                normalized.append((int(tiling[0]), int(tiling[1])))
            else:
                raise ValueError(
                    f"Each tiling entry must be a pair of integers, got {tiling!r}."
                )
        return normalized

    def _parse_image_inputs(self, **kwargs: object) -> list[Moondream3ImageInput]:
        pixel_values = kwargs.get("pixel_values")
        if pixel_values is None:
            return []

        pixel_values_list = self._split_pixel_values(pixel_values)
        tilings_list = self._split_tilings(
            kwargs.get("tilings"), len(pixel_values_list)
        )

        image_inputs: list[Moondream3ImageInput] = []
        for value, tiling in zip(pixel_values_list, tilings_list):
            if value.dim() != 4:
                raise ValueError(
                    f"Expected 4D tensor for crops, got {tuple(value.shape)}."
                )
            image_inputs.append(Moondream3ImageInput(pixel_values=value, tiling=tiling))
        return image_inputs

    def _encode_image_input(self, image_input: Moondream3ImageInput) -> torch.Tensor:
        pixel_values = image_input.pixel_values
        if pixel_values.dim() != 4:
            raise ValueError(
                f"Expected 4D tensor for crops, got {tuple(pixel_values.shape)}."
            )

        device = self.vision.patch_emb.weight.device
        dtype = self.vision.patch_emb.weight.dtype
        pixel_values = pixel_values.to(device=device, dtype=dtype)

        features = self.vision(pixel_values)

        # Grid size = crop_size / patch_size (e.g., 378 / 14 = 27)
        grid_size = self.config.vision.crop_size // self.config.vision.enc_patch_size
        enc_dim = self.config.vision.enc_dim
        global_features = features[0]

        if features.shape[0] > 1:
            if image_input.tiling is None:
                raise ValueError(
                    "Missing tiling metadata for multi-crop Moondream image."
                )
            local = features[1:].contiguous().view(-1, grid_size, grid_size, enc_dim)
            reconstructed = reconstruct_from_crops(
                local,
                image_input.tiling,
                overlap_margin=self.config.vision.overlap_margin,
                patch_size=1,
            )
        else:
            reconstructed = global_features.view(grid_size, grid_size, enc_dim)

        recon = reconstructed.permute(2, 0, 1).contiguous()
        # Mirror HF reference behavior: reconstructed local features are pooled
        # to enc_n_layers x enc_n_layers. For moondream3-preview this is 27x27.
        pooled_size = self.config.vision.enc_n_layers
        if pooled_size != grid_size:
            logger.warning_once(
                "Moondream3 pooled_size (%d) differs from crop grid (%d). "
                "Using enc_n_layers to match HF reference behavior.",
                pooled_size,
                grid_size,
            )
        recon = F.adaptive_avg_pool2d(recon, output_size=(pooled_size, pooled_size))
        recon = recon.permute(1, 2, 0).contiguous().view(-1, enc_dim)

        combined = torch.cat([global_features, recon], dim=-1).unsqueeze(0)
        projected = self.vision_proj(combined).squeeze(0)

        # Note: Vision embeddings are already synchronized across TP ranks
        # because the vision projection uses RowParallelLinear which performs
        # all-reduce internally, ensuring identical outputs on all ranks.

        return projected

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        """Generate 729 vision embeddings per image (27x27 patches)."""
        image_inputs = self._parse_image_inputs(**kwargs)
        if not image_inputs:
            return []

        return [self._encode_image_input(image_input) for image_input in image_inputs]

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor | IntermediateTensors:
        # Replace embeddings for detect/point decode positions with
        # Fourier-encoded coordinate/size embeddings from the region module.
        dp_embed_data = self._dp_embed_data
        if dp_embed_data and inputs_embeds is not None:
            for pos, info in dp_embed_data.items():
                if pos >= inputs_embeds.shape[0]:
                    continue
                if info["type"] == "coord":
                    coord_t = torch.tensor(
                        [[info["value"]]],
                        device=inputs_embeds.device,
                        dtype=inputs_embeds.dtype,
                    )
                    emb = self.region.encode_coordinate(coord_t)  # [1, dim]
                    inputs_embeds[pos] = emb.squeeze(0)
                elif info["type"] == "size":
                    emb = self.region.encode_size(
                        info["w"],
                        info["h"],
                        device=inputs_embeds.device,
                        dtype=inputs_embeds.dtype,
                    )  # [1, dim]
                    inputs_embeds[pos] = emb.squeeze(0)
            # Clear after use.
            self._dp_embed_data = None

        hidden_states = self.text(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        row_states = self._dp_row_states
        if row_states is None or not any(s is not None for s in row_states):
            # No detect/point requests — standard path.
            return self.logits_processor(self.lm_head, hidden_states)

        # Compute standard logits for ALL rows (needed for text requests
        # and for the continue/stop sparse check in decode_x_or_stop).
        logits = self.logits_processor(self.lm_head, hidden_states)
        if logits is None:
            return logits

        for i, state in enumerate(row_states):
            if state is None:
                continue
            h = hidden_states[i : i + 1]  # [1, dim]

            if state.step == "decode_x_or_stop":
                # Check continue/stop via sparse lm_head logits.
                coord_score = logits[i, self._coord_id].item()
                eos_score = logits[i, self._eos_id].item()

                if state.finished or eos_score > coord_score:
                    # Model wants to stop — force EOS.
                    logits[i] = float("-inf")
                    logits[i, self._eos_id] = 0.0
                else:
                    # Decode x coordinate from the *same* hidden state.
                    coord_logits = self.region.decode_coordinate(h)
                    x_bin = torch.argmax(coord_logits, dim=-1)
                    x_val = x_bin.float().item() / coord_logits.shape[-1]
                    state.current_x = x_val
                    state.pending_embed_type = "coord"
                    state.pending_embed_coord = x_val
                    # Force COORD_ID token.
                    logits[i] = float("-inf")
                    logits[i, self._coord_id] = 0.0

            elif state.step == "decode_y":
                coord_logits = self.region.decode_coordinate(h)
                y_bin = torch.argmax(coord_logits, dim=-1)
                y_val = y_bin.float().item() / coord_logits.shape[-1]
                state.current_y = y_val
                state.pending_embed_type = "coord"
                state.pending_embed_coord = y_val
                logits[i] = float("-inf")
                logits[i, self._coord_id] = 0.0

            elif state.step == "decode_size":
                w, h_val = self.region.decode_size(h)
                state.pending_embed_type = "size"
                state.pending_embed_w = w
                state.pending_embed_h = h_val
                logits[i] = float("-inf")
                logits[i, self._size_id] = 0.0

        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """Load weights with remapping from HuggingFace format."""

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        # Get expert intermediate size for fc1 splitting

        for name, loaded_weight in weights:
            # Map from HF naming to vLLM naming
            # model.vision.* -> vision.*
            # model.text.* -> text.*
            if name.startswith("model."):
                name = name[6:]  # Remove "model." prefix

            # Specific name mappings
            # Vision projection: vision.proj_mlp.fc1 -> vision_proj.fc1
            name = name.replace("vision.proj_mlp.", "vision_proj.")

            # Text embedding: text.wte (no suffix) -> text.wte.weight
            if name == "text.wte":
                name = "text.wte.weight"

            # LM head: text.lm_head -> lm_head
            name = name.replace("text.lm_head.", "lm_head.")

            # Attention mapping
            name = name.replace(".attn.qkv.", ".attn.qkv_proj.")
            name = name.replace(".attn.proj.", ".attn.out_proj.")

            # Tau attention scaling weights
            # HF format: .attn.tau.alpha -> .attn.tau_alpha
            name = name.replace(".attn.tau.alpha", ".attn.tau_alpha")
            name = name.replace(".attn.tau.wq", ".attn.tau_wq")
            name = name.replace(".attn.tau.wv", ".attn.tau_wv")

            # MoE router mapping: mlp.router -> mlp.gate
            name = name.replace(".mlp.router.", ".mlp.gate.")

            # Handle MoE expert weights for layers 4+ with expert parallelism
            # fc1.weight: [n_experts, expert_inner_dim * 2, hidden_size] (gate+up)
            # fc2.weight: [n_experts, hidden_size, expert_inner_dim] (down)
            # Each GPU stores n_experts/tp_size experts
            # Note: Only 3D weights are MoE, 2D weights are standard MLP
            if ".mlp.fc1.weight" in name and loaded_weight.dim() == 3:
                from vllm.distributed import get_tensor_model_parallel_rank

                tp_size = get_tensor_model_parallel_world_size()
                tp_rank = get_tensor_model_parallel_rank()
                num_experts = loaded_weight.shape[0]
                experts_per_rank = num_experts // tp_size
                expert_start = tp_rank * experts_per_rank
                expert_end = expert_start + experts_per_rank
                # Shard by expert dimension
                loaded_weight = loaded_weight[expert_start:expert_end].contiguous()
                # Map to our custom MoE format: mlp.fc1_weight
                name = name.replace(".mlp.fc1.weight", ".mlp.fc1_weight")

            if ".mlp.fc2.weight" in name and loaded_weight.dim() == 3:
                from vllm.distributed import get_tensor_model_parallel_rank

                tp_size = get_tensor_model_parallel_world_size()
                tp_rank = get_tensor_model_parallel_rank()
                num_experts = loaded_weight.shape[0]
                experts_per_rank = num_experts // tp_size
                expert_start = tp_rank * experts_per_rank
                expert_end = expert_start + experts_per_rank
                # Shard by expert dimension
                loaded_weight = loaded_weight[expert_start:expert_end].contiguous()
                # Map to our custom MoE format: mlp.fc2_weight
                name = name.replace(".mlp.fc2.weight", ".mlp.fc2_weight")

            # Handle tau weights with tensor parallelism
            # tau_alpha: [num_heads] -> [num_heads/tp]
            # tau_wq: [num_heads, qkv_dim] -> [num_heads/tp, qkv_dim/tp]
            # tau_wv: [num_heads, qkv_dim] -> [num_heads/tp, qkv_dim/tp]
            if ".tau_alpha" in name:
                from vllm.distributed import get_tensor_model_parallel_rank

                tp_size = get_tensor_model_parallel_world_size()
                tp_rank = get_tensor_model_parallel_rank()
                num_heads = loaded_weight.shape[0]
                heads_per_partition = num_heads // tp_size
                start = tp_rank * heads_per_partition
                end = start + heads_per_partition
                loaded_weight = loaded_weight[start:end].contiguous()

            if ".tau_wq" in name or ".tau_wv" in name:
                from vllm.distributed import get_tensor_model_parallel_rank

                tp_size = get_tensor_model_parallel_world_size()
                tp_rank = get_tensor_model_parallel_rank()
                num_heads, qkv_dim = loaded_weight.shape
                heads_per_partition = num_heads // tp_size
                # Only shard by head dimension, keep full qkv_dim for all-gather
                head_start = tp_rank * heads_per_partition
                head_end = head_start + heads_per_partition
                loaded_weight = loaded_weight[head_start:head_end, :].contiguous()

            if name in params_dict:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)

        return loaded_params

_sync_dp_pending_embeds

_sync_dp_pending_embeds(
    *, req_ids: list[str], device: device, pp: Any
) -> None

Broadcast detect/point pending embed data across PP ranks.

Source code in vllm/model_executor/models/moondream3.py
def _sync_dp_pending_embeds(
    self,
    *,
    req_ids: list[str],
    device: torch.device,
    pp: Any,
) -> None:
    """Broadcast detect/point pending embed data across PP ranks."""
    num_reqs = len(req_ids)
    sync = torch.zeros(
        num_reqs,
        3,
        device=device,
        dtype=torch.float32,
    )

    dp_mgr = self.detect_point_manager
    if pp.is_last_rank:
        for i, req_id in enumerate(req_ids):
            st = dp_mgr.get_state(req_id)
            if st is None or st.pending_embed_type is None:
                continue
            if st.pending_embed_type == "coord":
                sync[i, 0] = 1.0
                sync[i, 1] = st.pending_embed_coord or 0.0
            elif st.pending_embed_type == "size":
                sync[i, 0] = 2.0
                sync[i, 1] = st.pending_embed_w or 0.0
                sync[i, 2] = st.pending_embed_h or 0.0

    torch.distributed.broadcast(
        sync,
        src=pp.last_rank,
        group=pp.device_group,
    )

    if not pp.is_last_rank:
        for i, req_id in enumerate(req_ids):
            st = dp_mgr.get_state(req_id)
            if st is None:
                continue
            state_type = int(sync[i, 0].item())
            if state_type == 0:
                st.pending_embed_type = None
            elif state_type == 1:
                st.pending_embed_type = "coord"
                st.pending_embed_coord = sync[i, 1].item()
            elif state_type == 2:
                st.pending_embed_type = "size"
                st.pending_embed_w = sync[i, 1].item()
                st.pending_embed_h = sync[i, 2].item()

embed_multimodal

embed_multimodal(**kwargs: object) -> MultiModalEmbeddings

Generate 729 vision embeddings per image (27x27 patches).

Source code in vllm/model_executor/models/moondream3.py
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
    """Generate 729 vision embeddings per image (27x27 patches)."""
    image_inputs = self._parse_image_inputs(**kwargs)
    if not image_inputs:
        return []

    return [self._encode_image_input(image_input) for image_input in image_inputs]

get_per_request_extra_output

get_per_request_extra_output(
    *, req_ids: list[str]
) -> dict[str, dict[str, Any]] | None

Return per-request final text overrides for active requests.

Source code in vllm/model_executor/models/moondream3.py
def get_per_request_extra_output(
    self,
    *,
    req_ids: list[str],
) -> dict[str, dict[str, Any]] | None:
    """Return per-request final text overrides for active requests."""
    dp_mgr = self.detect_point_manager
    if not get_pp_group().is_last_rank or not dp_mgr.has_active_requests():
        return None

    per_request_extra: dict[str, dict[str, Any]] = {}
    for req_id in req_ids:
        state = dp_mgr.get_state(req_id)
        if state is None:
            continue
        json_str = dp_mgr.get_json_result(req_id)
        if json_str is not None:
            per_request_extra[req_id] = {"text_override": json_str}
    return per_request_extra or None

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights with remapping from HuggingFace format.

Source code in vllm/model_executor/models/moondream3.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    """Load weights with remapping from HuggingFace format."""

    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    # Get expert intermediate size for fc1 splitting

    for name, loaded_weight in weights:
        # Map from HF naming to vLLM naming
        # model.vision.* -> vision.*
        # model.text.* -> text.*
        if name.startswith("model."):
            name = name[6:]  # Remove "model." prefix

        # Specific name mappings
        # Vision projection: vision.proj_mlp.fc1 -> vision_proj.fc1
        name = name.replace("vision.proj_mlp.", "vision_proj.")

        # Text embedding: text.wte (no suffix) -> text.wte.weight
        if name == "text.wte":
            name = "text.wte.weight"

        # LM head: text.lm_head -> lm_head
        name = name.replace("text.lm_head.", "lm_head.")

        # Attention mapping
        name = name.replace(".attn.qkv.", ".attn.qkv_proj.")
        name = name.replace(".attn.proj.", ".attn.out_proj.")

        # Tau attention scaling weights
        # HF format: .attn.tau.alpha -> .attn.tau_alpha
        name = name.replace(".attn.tau.alpha", ".attn.tau_alpha")
        name = name.replace(".attn.tau.wq", ".attn.tau_wq")
        name = name.replace(".attn.tau.wv", ".attn.tau_wv")

        # MoE router mapping: mlp.router -> mlp.gate
        name = name.replace(".mlp.router.", ".mlp.gate.")

        # Handle MoE expert weights for layers 4+ with expert parallelism
        # fc1.weight: [n_experts, expert_inner_dim * 2, hidden_size] (gate+up)
        # fc2.weight: [n_experts, hidden_size, expert_inner_dim] (down)
        # Each GPU stores n_experts/tp_size experts
        # Note: Only 3D weights are MoE, 2D weights are standard MLP
        if ".mlp.fc1.weight" in name and loaded_weight.dim() == 3:
            from vllm.distributed import get_tensor_model_parallel_rank

            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
            num_experts = loaded_weight.shape[0]
            experts_per_rank = num_experts // tp_size
            expert_start = tp_rank * experts_per_rank
            expert_end = expert_start + experts_per_rank
            # Shard by expert dimension
            loaded_weight = loaded_weight[expert_start:expert_end].contiguous()
            # Map to our custom MoE format: mlp.fc1_weight
            name = name.replace(".mlp.fc1.weight", ".mlp.fc1_weight")

        if ".mlp.fc2.weight" in name and loaded_weight.dim() == 3:
            from vllm.distributed import get_tensor_model_parallel_rank

            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
            num_experts = loaded_weight.shape[0]
            experts_per_rank = num_experts // tp_size
            expert_start = tp_rank * experts_per_rank
            expert_end = expert_start + experts_per_rank
            # Shard by expert dimension
            loaded_weight = loaded_weight[expert_start:expert_end].contiguous()
            # Map to our custom MoE format: mlp.fc2_weight
            name = name.replace(".mlp.fc2.weight", ".mlp.fc2_weight")

        # Handle tau weights with tensor parallelism
        # tau_alpha: [num_heads] -> [num_heads/tp]
        # tau_wq: [num_heads, qkv_dim] -> [num_heads/tp, qkv_dim/tp]
        # tau_wv: [num_heads, qkv_dim] -> [num_heads/tp, qkv_dim/tp]
        if ".tau_alpha" in name:
            from vllm.distributed import get_tensor_model_parallel_rank

            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
            num_heads = loaded_weight.shape[0]
            heads_per_partition = num_heads // tp_size
            start = tp_rank * heads_per_partition
            end = start + heads_per_partition
            loaded_weight = loaded_weight[start:end].contiguous()

        if ".tau_wq" in name or ".tau_wv" in name:
            from vllm.distributed import get_tensor_model_parallel_rank

            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
            num_heads, qkv_dim = loaded_weight.shape
            heads_per_partition = num_heads // tp_size
            # Only shard by head dimension, keep full qkv_dim for all-gather
            head_start = tp_rank * heads_per_partition
            head_end = head_start + heads_per_partition
            loaded_weight = loaded_weight[head_start:head_end, :].contiguous()

        if name in params_dict:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)

    return loaded_params

on_before_model_forward

on_before_model_forward(
    *,
    req_ids: list[str],
    logits_indices: Tensor,
    device: device,
) -> None

Prepare pending coordinate/size embed replacements for forward.

Source code in vllm/model_executor/models/moondream3.py
def on_before_model_forward(
    self,
    *,
    req_ids: list[str],
    logits_indices: torch.Tensor,
    device: torch.device,
) -> None:
    """Prepare pending coordinate/size embed replacements for forward."""
    dp_mgr = self.detect_point_manager
    if not dp_mgr.has_active_requests():
        self._dp_embed_data = None
        return

    pp = get_pp_group()
    if pp.world_size > 1:
        self._sync_dp_pending_embeds(req_ids=req_ids, device=device, pp=pp)

    dp_embed_data: dict[int, dict[str, Any]] = {}
    for i, rid in enumerate(req_ids):
        dp_st = dp_mgr.get_state(rid)
        if dp_st is None or dp_st.pending_embed_type is None:
            continue
        pos = int(logits_indices[i].item())
        if dp_st.pending_embed_type == "coord":
            dp_embed_data[pos] = {
                "type": "coord",
                "value": dp_st.pending_embed_coord,
            }
        elif dp_st.pending_embed_type == "size":
            dp_embed_data[pos] = {
                "type": "size",
                "w": dp_st.pending_embed_w,
                "h": dp_st.pending_embed_h,
            }
    self._dp_embed_data = dp_embed_data or None

on_new_request

on_new_request(
    *, req_id: str, sampling_params: object | None
) -> None

Register detect/point requests from per-request extra args.

Source code in vllm/model_executor/models/moondream3.py
def on_new_request(
    self,
    *,
    req_id: str,
    sampling_params: object | None,
) -> None:
    """Register detect/point requests from per-request extra args."""
    if sampling_params is None:
        return
    extra = getattr(sampling_params, "extra_args", None) or {}
    dp_task = extra.get("moondream3_task")
    if dp_task == "detect":
        mode: Literal["detect", "point"] = "detect"
    elif dp_task == "point":
        mode = "point"
    else:
        return

    raw_max = extra.get("moondream3_max_objects", _DEFAULT_MAX_OBJECTS)
    try:
        max_obj = int(raw_max)
    except (TypeError, ValueError):
        raise ValueError(
            "moondream3_max_objects must be an integer, "
            f"got {type(raw_max).__name__}"
        ) from None
    if max_obj < 1:
        raise ValueError(f"moondream3_max_objects must be >= 1, got {max_obj}")
    self.detect_point_manager.register_request(req_id, mode, max_obj)

Moondream3ImageInput dataclass

Container holding per-image inputs for embedding.

Source code in vllm/model_executor/models/moondream3.py
@dataclass(frozen=True)
class Moondream3ImageInput:
    """Container holding per-image inputs for embedding."""

    pixel_values: torch.Tensor
    tiling: tuple[int, int] | None

Moondream3MultiModalProcessor

Bases: BaseMultiModalProcessor[Moondream3ProcessingInfo]

Multimodal processor for Moondream3.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3MultiModalProcessor(BaseMultiModalProcessor[Moondream3ProcessingInfo]):
    """Multimodal processor for Moondream3."""

    image_placeholder: str = "<image>"

    @cached_property
    def image_placeholder_tokens(self) -> list[int]:
        tokenizer = self.info.get_tokenizer()
        token_ids = tokenizer.encode(
            self.image_placeholder,
            add_special_tokens=False,
        )
        if not token_ids:
            raise ValueError(
                f"Tokenizer could not encode placeholder {self.image_placeholder!r}."
            )
        return token_ids

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return {
            "pixel_values": MultiModalFieldConfig.batched("image"),
            "tilings": MultiModalFieldConfig.batched("image", keep_on_cpu=True),
        }

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        # Moondream3 HF processor does NOT expand placeholder tokens.
        # vLLM should apply prompt updates to expand <image> to 729 tokens.
        return False

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> list[PromptUpdate]:
        image_size = self.info.get_image_size_with_most_features()
        num_image_tokens = self.info.get_num_image_tokens(
            image_width=image_size.width,
            image_height=image_size.height,
        )
        replacement_token = self.image_placeholder_tokens[0]
        return [
            PromptReplacement(
                modality="image",
                target=self.image_placeholder_tokens,
                replacement=[replacement_token] * num_image_tokens,
            ),
        ]

Moondream3ProcessingInfo

Bases: BaseProcessingInfo

Processing info for Moondream3.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3ProcessingInfo(BaseProcessingInfo):
    """Processing info for Moondream3."""

    def get_hf_config(self):
        return self.ctx.get_hf_config()

    def get_hf_processor(self, **kwargs: object):
        from vllm.transformers_utils.processors.moondream3 import Moondream3Processor

        return self.ctx.get_hf_processor(Moondream3Processor, **kwargs)

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        # Moondream3 always emits a fixed 27x27 vision grid (729 tokens)
        # after the projection path, regardless of input crop tiling.
        return 729

    def get_image_size_with_most_features(self) -> ImageSize:
        return ImageSize(width=378, height=378)

    def get_max_image_tokens(self) -> int:
        return 729

Moondream3RegionConfig dataclass

Configuration for Moondream3 region module (point/detect).

Source code in vllm/model_executor/models/moondream3.py
@dataclass
class Moondream3RegionConfig:
    """Configuration for Moondream3 region module (point/detect)."""

    dim: int = 2048
    coord_feat_dim: int = 256
    coord_out_dim: int = 1024
    size_feat_dim: int = 512
    size_out_dim: int = 2048

    @classmethod
    def from_dict(cls, d: dict) -> "Moondream3RegionConfig":
        region_cfg = d.get("region", d)
        return cls(
            dim=region_cfg.get("dim", 2048),
            coord_feat_dim=region_cfg.get("coord_feat_dim", 256),
            coord_out_dim=region_cfg.get("coord_out_dim", 1024),
            size_feat_dim=region_cfg.get("size_feat_dim", 512),
            size_out_dim=region_cfg.get("size_out_dim", 2048),
        )

Moondream3RegionModule

Bases: Module

Region module for coordinate encoding/decoding (point/detect).

This module handles Fourier feature encoding of coordinates and sizes for the point and detect capabilities. It is used by Moondream3's custom detect/point decode state machine, integrated into vLLM's decode loop via model runner hooks.

The module is small (~14M params) and uses plain nn.Linear layers (replicated on all TP ranks, no parallelization needed).

Source code in vllm/model_executor/models/moondream3.py
class Moondream3RegionModule(nn.Module):
    """Region module for coordinate encoding/decoding (point/detect).

    This module handles Fourier feature encoding of coordinates and sizes
    for the point and detect capabilities. It is used by Moondream3's
    custom detect/point decode state machine, integrated into vLLM's
    decode loop via model runner hooks.

    The module is small (~14M params) and uses plain nn.Linear layers
    (replicated on all TP ranks, no parallelization needed).
    """

    def __init__(self, config: Moondream3RegionConfig, prefix: str = ""):
        super().__init__()
        self._config = config
        # Fourier frequency matrices for coordinate/size encoding
        self.coord_features = nn.Parameter(torch.empty(1, config.coord_feat_dim // 2))
        self.size_features = nn.Parameter(torch.empty(2, config.size_feat_dim // 2))

        # Coordinate encoder/decoder
        self.coord_encoder = nn.Linear(config.coord_feat_dim, config.dim)
        self.coord_decoder = nn.Linear(config.dim, config.coord_out_dim)

        # Size encoder/decoder
        self.size_encoder = nn.Linear(config.size_feat_dim, config.dim)
        self.size_decoder = nn.Linear(config.dim, config.size_out_dim)

        # Layer norm
        self.ln = nn.LayerNorm(config.dim)

    def _fourier_features(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        """Fourier feature mapping: x @ w -> cat(cos, sin)."""
        f = 2 * math.pi * x @ w
        return torch.cat([f.cos(), f.sin()], dim=-1)

    def encode_coordinate(self, coord_value: torch.Tensor) -> torch.Tensor:
        """Encode a coordinate value into an embedding.

        Args:
            coord_value: Scalar or tensor of shape [..., 1] with float
                coordinate values in [0, 1].

        Returns:
            Embedding of shape [..., dim].
        """
        if coord_value.dim() == 0:
            coord_value = coord_value.reshape(1, 1)
        elif coord_value.dim() == 1:
            coord_value = coord_value.unsqueeze(-1)
        feat = self._fourier_features(coord_value, self.coord_features)
        return self.coord_encoder(feat)

    def decode_coordinate(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Decode coordinate logits from hidden states.

        Args:
            hidden_states: Tensor of shape [..., dim].

        Returns:
            Logits of shape [..., coord_out_dim] (1024 bins).
        """
        return self.coord_decoder(self.ln(hidden_states))

    def encode_size(
        self, w: float, h: float, device: torch.device, dtype: torch.dtype
    ) -> torch.Tensor:
        """Encode width and height into an embedding.

        Args:
            w: Width value (float).
            h: Height value (float).
            device: Target device.
            dtype: Target dtype.

        Returns:
            Embedding of shape [1, dim].
        """
        size = torch.tensor([w, h], device=device, dtype=dtype)
        feat = self._fourier_features(size, self.size_features)
        return self.size_encoder(feat.unsqueeze(0))

    def decode_size(self, hidden_states: torch.Tensor) -> tuple[float, float]:
        """Decode size (width, height) from hidden states.

        Applies ln + size_decoder, argmax on 2x1024 bins, then converts
        from log-scale bins to float values.

        Args:
            hidden_states: Tensor of shape [..., dim].

        Returns:
            Tuple (w_float, h_float).
        """
        logits = self.size_decoder(self.ln(hidden_states)).view(2, -1)
        w_bin = torch.argmax(logits[0], dim=-1)
        h_bin = torch.argmax(logits[1], dim=-1)
        w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
        h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
        return w.item(), h.item()

_fourier_features

_fourier_features(x: Tensor, w: Tensor) -> Tensor

Fourier feature mapping: x @ w -> cat(cos, sin).

Source code in vllm/model_executor/models/moondream3.py
def _fourier_features(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    """Fourier feature mapping: x @ w -> cat(cos, sin)."""
    f = 2 * math.pi * x @ w
    return torch.cat([f.cos(), f.sin()], dim=-1)

decode_coordinate

decode_coordinate(hidden_states: Tensor) -> Tensor

Decode coordinate logits from hidden states.

Parameters:

Name Type Description Default
hidden_states Tensor

Tensor of shape [..., dim].

required

Returns:

Type Description
Tensor

Logits of shape [..., coord_out_dim] (1024 bins).

Source code in vllm/model_executor/models/moondream3.py
def decode_coordinate(self, hidden_states: torch.Tensor) -> torch.Tensor:
    """Decode coordinate logits from hidden states.

    Args:
        hidden_states: Tensor of shape [..., dim].

    Returns:
        Logits of shape [..., coord_out_dim] (1024 bins).
    """
    return self.coord_decoder(self.ln(hidden_states))

decode_size

decode_size(hidden_states: Tensor) -> tuple[float, float]

Decode size (width, height) from hidden states.

Applies ln + size_decoder, argmax on 2x1024 bins, then converts from log-scale bins to float values.

Parameters:

Name Type Description Default
hidden_states Tensor

Tensor of shape [..., dim].

required

Returns:

Type Description
tuple[float, float]

Tuple (w_float, h_float).

Source code in vllm/model_executor/models/moondream3.py
def decode_size(self, hidden_states: torch.Tensor) -> tuple[float, float]:
    """Decode size (width, height) from hidden states.

    Applies ln + size_decoder, argmax on 2x1024 bins, then converts
    from log-scale bins to float values.

    Args:
        hidden_states: Tensor of shape [..., dim].

    Returns:
        Tuple (w_float, h_float).
    """
    logits = self.size_decoder(self.ln(hidden_states)).view(2, -1)
    w_bin = torch.argmax(logits[0], dim=-1)
    h_bin = torch.argmax(logits[1], dim=-1)
    w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
    h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
    return w.item(), h.item()

encode_coordinate

encode_coordinate(coord_value: Tensor) -> Tensor

Encode a coordinate value into an embedding.

Parameters:

Name Type Description Default
coord_value Tensor

Scalar or tensor of shape [..., 1] with float coordinate values in [0, 1].

required

Returns:

Type Description
Tensor

Embedding of shape [..., dim].

Source code in vllm/model_executor/models/moondream3.py
def encode_coordinate(self, coord_value: torch.Tensor) -> torch.Tensor:
    """Encode a coordinate value into an embedding.

    Args:
        coord_value: Scalar or tensor of shape [..., 1] with float
            coordinate values in [0, 1].

    Returns:
        Embedding of shape [..., dim].
    """
    if coord_value.dim() == 0:
        coord_value = coord_value.reshape(1, 1)
    elif coord_value.dim() == 1:
        coord_value = coord_value.unsqueeze(-1)
    feat = self._fourier_features(coord_value, self.coord_features)
    return self.coord_encoder(feat)

encode_size

encode_size(
    w: float, h: float, device: device, dtype: dtype
) -> Tensor

Encode width and height into an embedding.

Parameters:

Name Type Description Default
w float

Width value (float).

required
h float

Height value (float).

required
device device

Target device.

required
dtype dtype

Target dtype.

required

Returns:

Type Description
Tensor

Embedding of shape [1, dim].

Source code in vllm/model_executor/models/moondream3.py
def encode_size(
    self, w: float, h: float, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
    """Encode width and height into an embedding.

    Args:
        w: Width value (float).
        h: Height value (float).
        device: Target device.
        dtype: Target dtype.

    Returns:
        Embedding of shape [1, dim].
    """
    size = torch.tensor([w, h], device=device, dtype=dtype)
    feat = self._fourier_features(size, self.size_features)
    return self.size_encoder(feat.unsqueeze(0))

Moondream3TextConfig dataclass

Configuration for Moondream3 text decoder.

Source code in vllm/model_executor/models/moondream3.py
@dataclass
class Moondream3TextConfig:
    """Configuration for Moondream3 text decoder."""

    dim: int = 2048
    ff_dim: int = 8192
    n_layers: int = 24
    vocab_size: int = 51200
    max_context: int = 4096
    n_heads: int = 32
    n_kv_heads: int = 32
    prefix_attn: int = 730  # BOS + 729 vision tokens
    prefix_lm_left_padding: int = 1  # include BOS in prefix-lm span
    rope_theta: float = 1500000.0
    # MoE configuration
    moe_start_layer: int = 4
    moe_num_experts: int = 64
    moe_experts_per_token: int = 8
    moe_expert_inner_dim: int = 1024

    @classmethod
    def from_dict(cls, d: dict) -> "Moondream3TextConfig":
        text_cfg = d.get("text", d)
        return cls(
            dim=text_cfg.get("dim", 2048),
            ff_dim=text_cfg.get("ff_dim", 8192),
            n_layers=text_cfg.get("n_layers", 24),
            vocab_size=text_cfg.get("vocab_size", 51200),
            max_context=text_cfg.get("max_context", 4096),
            n_heads=text_cfg.get("n_heads", 32),
            n_kv_heads=text_cfg.get("n_kv_heads", 32),
            prefix_attn=text_cfg.get("prefix_attn", 730),
            prefix_lm_left_padding=text_cfg.get("prefix_lm_left_padding", 1),
            rope_theta=text_cfg.get("rope_theta", 1500000.0),
            moe_start_layer=text_cfg.get("moe", {}).get("start_layer", 4),
            moe_num_experts=text_cfg.get("moe", {}).get("n_experts", 64),
            moe_experts_per_token=text_cfg.get("moe", {}).get("n_experts_per_tok", 8),
            moe_expert_inner_dim=text_cfg.get("moe", {}).get("expert_inner_dim", 1024),
        )

Moondream3TextMLP

Bases: Module

Standard MLP for non-MoE layers (layers 0-3).

Source code in vllm/model_executor/models/moondream3.py
class Moondream3TextMLP(nn.Module):
    """Standard MLP for non-MoE layers (layers 0-3)."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.act = get_act_fn("gelu_pytorch_tanh")
        self.fc2 = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc1(x)
        x = self.act(x)
        x, _ = self.fc2(x)
        return x

Moondream3TextMoE

Bases: Module

Mixture of Experts layer for layers 4+ with expert parallelism.

Moondream3 uses a custom GeGLU activation: gelu(h) * (g + 1) where fc1 outputs [gate, up] and the activation is gelu(gate) * (up + 1).

Uses expert parallelism where each GPU stores num_experts/tp_size experts. Routing and communication handled via all-to-all or replicated computation.

Checkpoint format: - fc1.weight: [num_experts, expert_inner_dim * 2, hidden_size] (gate+up) - fc2.weight: [num_experts, hidden_size, expert_inner_dim] (down) - router.weight: [num_experts, hidden_size] - router.bias: [num_experts]

Source code in vllm/model_executor/models/moondream3.py
class Moondream3TextMoE(nn.Module):
    """Mixture of Experts layer for layers 4+ with expert parallelism.

    Moondream3 uses a custom GeGLU activation: gelu(h) * (g + 1)
    where fc1 outputs [gate, up] and the activation is gelu(gate) * (up + 1).

    Uses expert parallelism where each GPU stores num_experts/tp_size experts.
    Routing and communication handled via all-to-all or replicated computation.

    Checkpoint format:
    - fc1.weight: [num_experts, expert_inner_dim * 2, hidden_size] (gate+up)
    - fc2.weight: [num_experts, hidden_size, expert_inner_dim] (down)
    - router.weight: [num_experts, hidden_size]
    - router.bias: [num_experts]
    """

    def __init__(
        self,
        hidden_size: int,
        expert_inner_dim: int,
        num_experts: int,
        experts_per_token: int,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.expert_inner_dim = expert_inner_dim
        self.num_experts = num_experts
        self.experts_per_token = experts_per_token

        # Expert parallelism: each GPU stores a subset of experts
        self.tp_size = get_tensor_model_parallel_world_size()
        self.experts_per_rank = num_experts // self.tp_size
        self.num_local_experts = self.experts_per_rank

        # Router (gate) - use ReplicatedLinear for compatibility
        self.gate = ReplicatedLinear(
            hidden_size,
            num_experts,
            bias=True,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )

        # Local expert weights (only store experts_per_rank experts)
        # fc1: [experts_per_rank, expert_inner_dim * 2, hidden_size]
        # fc2: [experts_per_rank, hidden_size, expert_inner_dim]
        self.fc1_weight = nn.Parameter(
            torch.empty(self.num_local_experts, expert_inner_dim * 2, hidden_size)
        )
        self.fc2_weight = nn.Parameter(
            torch.empty(self.num_local_experts, hidden_size, expert_inner_dim)
        )
        self._use_fused_moe = True

        local_expert_start = get_tensor_model_parallel_rank() * self.experts_per_rank
        expert_map = torch.full((num_experts,), -1, dtype=torch.int32)
        expert_map[local_expert_start : local_expert_start + self.num_local_experts] = (
            torch.arange(self.num_local_experts, dtype=torch.int32)
        )
        self.register_buffer("_expert_map", expert_map, persistent=False)

        # Preserve Moondream3's exact GeGLU variant (gelu(h) * (g + 1)) by
        # adding +1 bias to the second half of the fused fc1 activations.
        fused_w1_bias = torch.zeros(
            self.num_local_experts,
            expert_inner_dim * 2,
            dtype=torch.float32,
        )
        fused_w1_bias[:, expert_inner_dim:] = 1.0
        self.register_buffer("_fused_w1_bias", fused_w1_bias, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with expert parallelism and custom GeGLU activation."""

        # Get router logits and compute top-k
        router_logits, _ = self.gate(x)  # [num_tokens, num_experts]
        topk_logits, topk_ids = torch.topk(
            router_logits, self.experts_per_token, dim=-1
        )
        # Softmax over selected experts
        topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)

        if self._use_fused_moe and x.is_cuda:
            try:
                out = fused_experts(
                    hidden_states=x.contiguous(),
                    w1=self.fc1_weight,
                    w2=self.fc2_weight,
                    topk_weights=topk_weights.contiguous(),
                    topk_ids=topk_ids.contiguous(),
                    activation=MoEActivation.GELU,
                    global_num_experts=self.num_experts,
                    expert_map=self._expert_map,
                    quant_config=biased_moe_quant_config(self._fused_w1_bias, None),
                )
                out = tensor_model_parallel_all_reduce(out)
                return out
            except (NotImplementedError, RuntimeError) as exc:
                self._use_fused_moe = False
                logger.warning_once(
                    "Disabling fused Moondream3 MoE path and falling back to "
                    "the Python expert loop: %s",
                    str(exc),
                )

        tp_rank = get_tensor_model_parallel_rank()
        # Compute local expert range
        local_expert_start = tp_rank * self.experts_per_rank

        # Fallback path for environments where fused kernels are unavailable.
        out = x.new_zeros(x.shape)

        for local_expert_idx in range(self.num_local_experts):
            global_expert_id = local_expert_start + local_expert_idx

            # Find tokens assigned to this expert
            token_pos, which_k = (topk_ids == global_expert_id).nonzero(as_tuple=True)
            if token_pos.numel() == 0:
                continue

            # Get tokens and their routing weights
            x_tok = x.index_select(0, token_pos)  # [n_tokens, hidden_size]
            gate_tok = topk_weights[token_pos, which_k]  # [n_tokens]

            # fc1: [expert_inner_dim * 2, hidden_size]
            # h_full: [n_tokens, expert_inner_dim * 2]
            h_full = F.linear(x_tok, self.fc1_weight[local_expert_idx])

            # GeGLU with (g + 1): h, g = split; output = gelu(h) * (g + 1)
            # HF MoE uses exact GELU (not tanh approximation).
            h, g = h_full.chunk(2, dim=-1)  # Each [n_tokens, expert_inner_dim]
            h = F.gelu(h) * (g + 1.0)

            # fc2: [hidden_size, expert_inner_dim]
            # y: [n_tokens, hidden_size]
            y = F.linear(h, self.fc2_weight[local_expert_idx])

            # Apply routing weight
            y = y * gate_tok.unsqueeze(-1)

            # Accumulate output
            out.index_add_(0, token_pos, y)

        # All-reduce to combine results from all experts across GPUs
        out = tensor_model_parallel_all_reduce(out)

        return out

forward

forward(x: Tensor) -> Tensor

Forward pass with expert parallelism and custom GeGLU activation.

Source code in vllm/model_executor/models/moondream3.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass with expert parallelism and custom GeGLU activation."""

    # Get router logits and compute top-k
    router_logits, _ = self.gate(x)  # [num_tokens, num_experts]
    topk_logits, topk_ids = torch.topk(
        router_logits, self.experts_per_token, dim=-1
    )
    # Softmax over selected experts
    topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)

    if self._use_fused_moe and x.is_cuda:
        try:
            out = fused_experts(
                hidden_states=x.contiguous(),
                w1=self.fc1_weight,
                w2=self.fc2_weight,
                topk_weights=topk_weights.contiguous(),
                topk_ids=topk_ids.contiguous(),
                activation=MoEActivation.GELU,
                global_num_experts=self.num_experts,
                expert_map=self._expert_map,
                quant_config=biased_moe_quant_config(self._fused_w1_bias, None),
            )
            out = tensor_model_parallel_all_reduce(out)
            return out
        except (NotImplementedError, RuntimeError) as exc:
            self._use_fused_moe = False
            logger.warning_once(
                "Disabling fused Moondream3 MoE path and falling back to "
                "the Python expert loop: %s",
                str(exc),
            )

    tp_rank = get_tensor_model_parallel_rank()
    # Compute local expert range
    local_expert_start = tp_rank * self.experts_per_rank

    # Fallback path for environments where fused kernels are unavailable.
    out = x.new_zeros(x.shape)

    for local_expert_idx in range(self.num_local_experts):
        global_expert_id = local_expert_start + local_expert_idx

        # Find tokens assigned to this expert
        token_pos, which_k = (topk_ids == global_expert_id).nonzero(as_tuple=True)
        if token_pos.numel() == 0:
            continue

        # Get tokens and their routing weights
        x_tok = x.index_select(0, token_pos)  # [n_tokens, hidden_size]
        gate_tok = topk_weights[token_pos, which_k]  # [n_tokens]

        # fc1: [expert_inner_dim * 2, hidden_size]
        # h_full: [n_tokens, expert_inner_dim * 2]
        h_full = F.linear(x_tok, self.fc1_weight[local_expert_idx])

        # GeGLU with (g + 1): h, g = split; output = gelu(h) * (g + 1)
        # HF MoE uses exact GELU (not tanh approximation).
        h, g = h_full.chunk(2, dim=-1)  # Each [n_tokens, expert_inner_dim]
        h = F.gelu(h) * (g + 1.0)

        # fc2: [hidden_size, expert_inner_dim]
        # y: [n_tokens, hidden_size]
        y = F.linear(h, self.fc2_weight[local_expert_idx])

        # Apply routing weight
        y = y * gate_tok.unsqueeze(-1)

        # Accumulate output
        out.index_add_(0, token_pos, y)

    # All-reduce to combine results from all experts across GPUs
    out = tensor_model_parallel_all_reduce(out)

    return out

Moondream3TextModel

Bases: Module

Text decoder model.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3TextModel(nn.Module):
    """Text decoder model."""

    def __init__(
        self,
        config: Moondream3TextConfig,
        cache_config=None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.dim,
            prefix=f"{prefix}.wte",
        )

        blocks_prefix = maybe_prefix(prefix, "blocks")
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
            lambda prefix: Moondream3DecoderLayer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
            prefix=blocks_prefix,
        )

        self.post_ln = nn.LayerNorm(config.dim, eps=1e-5, bias=True)
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.dim
        )

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        pp_group = get_pp_group()
        if pp_group.is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                assert input_ids is not None
                hidden_states = self.embed_input_ids(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i, layer in enumerate(
            islice(self.blocks, self.start_layer, self.end_layer)
        ):
            hidden_states = layer(positions, hidden_states)

        if not pp_group.is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        hidden_states = self.post_ln(hidden_states)
        return hidden_states

Moondream3VisionAttention

Bases: Module

Self-attention for vision encoder (bidirectional).

Uses native PyTorch scaled_dot_product_attention to avoid dependency on vLLM forward context during memory profiling.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3VisionAttention(nn.Module):
    """Self-attention for vision encoder (bidirectional).

    Uses native PyTorch scaled_dot_product_attention to avoid
    dependency on vLLM forward context during memory profiling.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=num_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=hidden_size,
            output_size=hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = num_heads // tp_size
        self.scale = self.head_dim**-0.5

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Forward pass using native PyTorch SDPA.

        Args:
            hidden_states: (batch, seq_len, hidden_size)

        Returns:
            output: (batch, seq_len, hidden_size)
        """
        batch_size, seq_len, _ = hidden_states.shape

        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape to (batch, num_heads, seq_len, head_dim)
        q = q.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)
        q = q.transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)
        k = k.transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)
        v = v.transpose(1, 2)

        # Use PyTorch's scaled_dot_product_attention (bidirectional, no mask)
        out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)

        # Reshape back to (batch, seq_len, hidden_size)
        out = out.transpose(1, 2).contiguous()
        out = out.view(batch_size, seq_len, -1)

        out, _ = self.out_proj(out)
        return out

forward

forward(hidden_states: Tensor) -> Tensor

Forward pass using native PyTorch SDPA.

Parameters:

Name Type Description Default
hidden_states Tensor

(batch, seq_len, hidden_size)

required

Returns:

Name Type Description
output Tensor

(batch, seq_len, hidden_size)

Source code in vllm/model_executor/models/moondream3.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    """Forward pass using native PyTorch SDPA.

    Args:
        hidden_states: (batch, seq_len, hidden_size)

    Returns:
        output: (batch, seq_len, hidden_size)
    """
    batch_size, seq_len, _ = hidden_states.shape

    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.chunk(3, dim=-1)

    # Reshape to (batch, num_heads, seq_len, head_dim)
    q = q.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)
    q = q.transpose(1, 2)
    k = k.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)
    k = k.transpose(1, 2)
    v = v.view(batch_size, seq_len, self.num_heads_per_partition, self.head_dim)
    v = v.transpose(1, 2)

    # Use PyTorch's scaled_dot_product_attention (bidirectional, no mask)
    out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)

    # Reshape back to (batch, seq_len, hidden_size)
    out = out.transpose(1, 2).contiguous()
    out = out.view(batch_size, seq_len, -1)

    out, _ = self.out_proj(out)
    return out

Moondream3VisionBlock

Bases: Module

Transformer block for vision encoder.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3VisionBlock(nn.Module):
    """Transformer block for vision encoder."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        num_heads: int,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size, eps=1e-5)
        self.attn = Moondream3VisionAttention(
            hidden_size=hidden_size,
            num_heads=num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        self.ln2 = nn.LayerNorm(hidden_size, eps=1e-5)
        self.mlp = Moondream3VisionMLP(
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

Moondream3VisionConfig dataclass

Configuration for Moondream3 vision encoder.

Source code in vllm/model_executor/models/moondream3.py
@dataclass
class Moondream3VisionConfig:
    """Configuration for Moondream3 vision encoder."""

    enc_dim: int = 1152
    enc_patch_size: int = 14
    enc_n_layers: int = 27
    enc_ff_dim: int = 4304
    enc_n_heads: int = 16
    proj_inner_dim: int = 8192
    crop_size: int = 378
    max_crops: int = 12
    overlap_margin: int = 4

    @classmethod
    def from_dict(cls, d: dict) -> "Moondream3VisionConfig":
        vision_cfg = d.get("vision", d)
        return cls(
            enc_dim=vision_cfg.get("enc_dim", 1152),
            enc_patch_size=vision_cfg.get("enc_patch_size", 14),
            enc_n_layers=vision_cfg.get("enc_n_layers", 27),
            enc_ff_dim=vision_cfg.get("enc_ff_dim", 4304),
            enc_n_heads=vision_cfg.get("enc_n_heads", 16),
            proj_inner_dim=vision_cfg.get("proj_inner_dim", 8192),
            crop_size=vision_cfg.get("crop_size", 378),
            max_crops=vision_cfg.get("max_crops", 12),
            overlap_margin=vision_cfg.get("overlap_margin", 4),
        )

Moondream3VisionEncoder

Bases: Module

Vision encoder (SigLIP-style ViT).

Source code in vllm/model_executor/models/moondream3.py
class Moondream3VisionEncoder(nn.Module):
    """Vision encoder (SigLIP-style ViT)."""

    def __init__(
        self,
        config: Moondream3VisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config

        # Patch embedding
        self.patch_emb = nn.Linear(
            config.enc_patch_size * config.enc_patch_size * 3,
            config.enc_dim,
            bias=True,
        )

        # Position embeddings (27x27 = 729 patches for 378x378 / 14)
        num_patches = (config.crop_size // config.enc_patch_size) ** 2
        self.pos_emb = nn.Parameter(torch.zeros(1, num_patches, config.enc_dim))

        # Transformer blocks
        self.blocks = nn.ModuleList(
            [
                Moondream3VisionBlock(
                    hidden_size=config.enc_dim,
                    intermediate_size=config.enc_ff_dim,
                    num_heads=config.enc_n_heads,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{i}",
                )
                for i in range(config.enc_n_layers)
            ]
        )

        self.post_ln = nn.LayerNorm(config.enc_dim, eps=1e-5)

    def create_patches(self, images: torch.Tensor) -> torch.Tensor:
        """Convert images to patch embeddings.

        Args:
            images: (batch, channels, height, width)

        Returns:
            patches: (batch, num_patches, patch_dim)
        """
        patch_size = self.config.enc_patch_size
        batch, channels, height, width = images.shape
        patches_h = height // patch_size
        patches_w = width // patch_size

        # Unfold into patches
        patches = images.unfold(2, patch_size, patch_size).unfold(
            3, patch_size, patch_size
        )
        # (batch, channels, patches_h, patches_w, patch_size, patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
        # (batch, patches_h, patches_w, channels, patch_size, patch_size)
        patches = patches.view(batch, patches_h * patches_w, -1)
        # (batch, num_patches, channels * patch_size * patch_size)

        return patches

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Encode images.

        Args:
            pixel_values: (batch, channels, height, width)

        Returns:
            features: (batch, num_patches, hidden_size)
        """
        # Create patches and embed
        patches = self.create_patches(pixel_values)
        x = self.patch_emb(patches)

        # Add position embeddings
        x = x + self.pos_emb

        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)

        # Final layer norm
        x = self.post_ln(x)

        return x

create_patches

create_patches(images: Tensor) -> Tensor

Convert images to patch embeddings.

Parameters:

Name Type Description Default
images Tensor

(batch, channels, height, width)

required

Returns:

Name Type Description
patches Tensor

(batch, num_patches, patch_dim)

Source code in vllm/model_executor/models/moondream3.py
def create_patches(self, images: torch.Tensor) -> torch.Tensor:
    """Convert images to patch embeddings.

    Args:
        images: (batch, channels, height, width)

    Returns:
        patches: (batch, num_patches, patch_dim)
    """
    patch_size = self.config.enc_patch_size
    batch, channels, height, width = images.shape
    patches_h = height // patch_size
    patches_w = width // patch_size

    # Unfold into patches
    patches = images.unfold(2, patch_size, patch_size).unfold(
        3, patch_size, patch_size
    )
    # (batch, channels, patches_h, patches_w, patch_size, patch_size)
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
    # (batch, patches_h, patches_w, channels, patch_size, patch_size)
    patches = patches.view(batch, patches_h * patches_w, -1)
    # (batch, num_patches, channels * patch_size * patch_size)

    return patches

forward

forward(pixel_values: Tensor) -> Tensor

Encode images.

Parameters:

Name Type Description Default
pixel_values Tensor

(batch, channels, height, width)

required

Returns:

Name Type Description
features Tensor

(batch, num_patches, hidden_size)

Source code in vllm/model_executor/models/moondream3.py
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
    """Encode images.

    Args:
        pixel_values: (batch, channels, height, width)

    Returns:
        features: (batch, num_patches, hidden_size)
    """
    # Create patches and embed
    patches = self.create_patches(pixel_values)
    x = self.patch_emb(patches)

    # Add position embeddings
    x = x + self.pos_emb

    # Apply transformer blocks
    for block in self.blocks:
        x = block(x)

    # Final layer norm
    x = self.post_ln(x)

    return x

Moondream3VisionMLP

Bases: Module

MLP for vision encoder blocks.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3VisionMLP(nn.Module):
    """MLP for vision encoder blocks."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.act = get_act_fn("gelu_pytorch_tanh")
        self.fc2 = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc1(x)
        x = self.act(x)
        x, _ = self.fc2(x)
        return x

Moondream3VisionProjection

Bases: Module

Projects vision features to text embedding dimension.

Source code in vllm/model_executor/models/moondream3.py
class Moondream3VisionProjection(nn.Module):
    """Projects vision features to text embedding dimension."""

    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        output_dim: int,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        # Input is concatenated global and local features (2 * input_dim)
        self.fc1 = ColumnParallelLinear(
            input_dim * 2,
            inner_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.act = get_act_fn("gelu_pytorch_tanh")
        self.fc2 = RowParallelLinear(
            inner_dim,
            output_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc1(x)
        x = self.act(x)
        x, _ = self.fc2(x)
        return x

reconstruct_from_crops

reconstruct_from_crops(
    crops: Tensor,
    tiling: tuple[int, int],
    overlap_margin: int,
    patch_size: int = 14,
) -> Tensor

Reconstruct features from overlapping crops.

Source code in vllm/model_executor/models/moondream3.py
def reconstruct_from_crops(
    crops: torch.Tensor,
    tiling: tuple[int, int],
    overlap_margin: int,
    patch_size: int = 14,
) -> torch.Tensor:
    """Reconstruct features from overlapping crops."""
    tiling_h, tiling_w = tiling
    crop_height, crop_width = crops[0].shape[:2]
    margin_pixels = overlap_margin * patch_size

    output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
    output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels

    reconstructed = torch.zeros(
        (output_h, output_w, crops[0].shape[2]),
        device=crops[0].device,
        dtype=crops[0].dtype,
    )

    for i, crop in enumerate(crops):
        tile_y = i // tiling_w
        tile_x = i % tiling_w

        x_start = 0 if tile_x == 0 else margin_pixels
        x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
        y_start = 0 if tile_y == 0 else margin_pixels
        y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels

        out_x = tile_x * (crop_width - 2 * margin_pixels)
        out_y = tile_y * (crop_height - 2 * margin_pixels)

        reconstructed[
            out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
        ] = crop[y_start:y_end, x_start:x_end]

    return reconstructed