Skip to content

vllm.transformers_utils.processors.moondream3

Custom processor for Moondream3 model.

Moondream3Processor

Bases: ProcessorMixin

Constructs a Moondream3 processor which handles image preprocessing and tokenization for the Moondream3 multimodal model.

Parameters:

Name Type Description Default
tokenizer PreTrainedTokenizerBase | None

The tokenizer to use for text processing.

None
chat_template str | None

Optional chat template string.

None
crop_size int

Size of each image crop.

378
max_crops int

Maximum number of crops per image.

12
overlap_margin int

Margin for overlapping crops in patches.

4
patch_size int

Size of each patch.

14
Source code in vllm/transformers_utils/processors/moondream3.py
class Moondream3Processor(ProcessorMixin):
    """
    Constructs a Moondream3 processor which handles image preprocessing
    and tokenization for the Moondream3 multimodal model.

    Args:
        tokenizer: The tokenizer to use for text processing.
        chat_template: Optional chat template string.
        crop_size: Size of each image crop.
        max_crops: Maximum number of crops per image.
        overlap_margin: Margin for overlapping crops in patches.
        patch_size: Size of each patch.
    """

    attributes = ["tokenizer"]
    valid_kwargs = [
        "chat_template",
        "crop_size",
        "max_crops",
        "overlap_margin",
        "patch_size",
    ]

    tokenizer_class = "AutoTokenizer"
    # Use separate tokenizer repo
    _tokenizer_repo = "moondream/starmie-v1"

    # Default chat template for Moondream3
    # Moondream uses special tokens for prompting:
    # - Token 0 (<|endoftext|>): BOS token (ALWAYS present at position 0)
    # - Token 1 (<|md_reserved_0|>): Start of instruction
    # - Token 2 (<|md_reserved_1|>): Separator before question
    # - Token 3 (<|md_reserved_2|>): End of question / start of answer
    #
    # Task routing based on text prefix:
    #   "detect <obj>"  → detect<|md_reserved_1|> <obj><|md_reserved_2|>
    #   "point <obj>"   → point<|md_reserved_1|> <obj><|md_reserved_2|>
    #   otherwise        → query<|md_reserved_1|><text><|md_reserved_2|>
    #
    # Format with image:
    #   <|endoftext|><image><|md_reserved_0|>{task}<|md_reserved_1|>{q}<|md_reserved_2|>
    # Format without image:
    #   <|endoftext|><|md_reserved_0|>{task}<|md_reserved_1|>{q}<|md_reserved_2|>
    _default_chat_template = (
        "{% for message in messages %}"
        "{% if message['role'] == 'user' %}"
        "{% if message['content'] is string %}"
        # Simple string content (with image assumed) - route by prefix
        "<|endoftext|><image><|md_reserved_0|>"
        "{% if message['content'].startswith('detect ') %}"
        "detect<|md_reserved_1|> {{ message['content'][7:] }}<|md_reserved_2|>"
        "{% elif message['content'].startswith('point ') %}"
        "point<|md_reserved_1|> {{ message['content'][6:] }}<|md_reserved_2|>"
        "{% else %}"
        "query<|md_reserved_1|>{{ message['content'] }}<|md_reserved_2|>"
        "{% endif %}"
        "{% else %}"
        # List content - always start with BOS
        "<|endoftext|>"
        "{% for content in message['content'] %}"
        "{% if content['type'] == 'image' or content['type'] == 'image_url' %}"
        "<image>"
        "{% elif content['type'] == 'text' %}"
        "<|md_reserved_0|>"
        "{% if content['text'].startswith('detect ') %}"
        "detect<|md_reserved_1|> {{ content['text'][7:] }}<|md_reserved_2|>"
        "{% elif content['text'].startswith('point ') %}"
        "point<|md_reserved_1|> {{ content['text'][6:] }}<|md_reserved_2|>"
        "{% else %}"
        "query<|md_reserved_1|>{{ content['text'] }}<|md_reserved_2|>"
        "{% endif %}"
        "{% endif %}"
        "{% endfor %}"
        "{% endif %}"
        "{% elif message['role'] == 'assistant' %}"
        "{{ message['content'] }}"
        "{% endif %}"
        "{% endfor %}"
    )

    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase | None = None,
        chat_template: str | None = None,
        crop_size: int = 378,
        max_crops: int = 12,
        overlap_margin: int = 4,
        patch_size: int = 14,
        **kwargs,
    ):
        self.image_token = "<image>"
        self.crop_size = crop_size
        self.max_crops = max_crops
        self.overlap_margin = overlap_margin
        self.patch_size = patch_size

        # Number of patches per crop (27x27 = 729 for 378/14)
        self.patches_per_crop = (crop_size // patch_size) ** 2

        # Use default chat template if none provided
        if chat_template is None:
            chat_template = self._default_chat_template

        super().__init__(tokenizer, chat_template=chat_template)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path,
        **kwargs,
    ):
        """
        Load the processor, using a separate tokenizer repo.

        The moondream3 model uses a custom tokenizer from 'moondream/starmie-v1'
        instead of having tokenizer files in the model repo.
        """
        from transformers import AutoTokenizer

        tokenizer = kwargs.pop("tokenizer", None)

        tokenizer_kwargs = {
            "trust_remote_code": kwargs.get("trust_remote_code", False),
        }
        for key in (
            "cache_dir",
            "force_download",
            "local_files_only",
            "revision",
            "subfolder",
            "token",
            "use_fast",
        ):
            if key in kwargs:
                tokenizer_kwargs[key] = kwargs[key]

        if isinstance(tokenizer, str):
            tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)

        if tokenizer is None:
            # Prefer model-local tokenizer files first. If unavailable, fall
            # back to moondream's dedicated tokenizer repository.
            try:
                tokenizer = AutoTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **tokenizer_kwargs
                )
            except Exception:
                tokenizer = AutoTokenizer.from_pretrained(
                    cls._tokenizer_repo, **tokenizer_kwargs
                )

        # Configure special tokens for Moondream3
        # BOS and EOS are both token 0 (<|endoftext|>), matching the native
        # config (TokenizerConfig.bos_id=0, eos_id=0). This is standard for
        # GPT-2 style models where <|endoftext|> signals both start and end.
        # Token 1 (<|md_reserved_0|>) is a template delimiter, NOT the EOS.
        tokenizer.bos_token = "<|endoftext|>"
        tokenizer.bos_token_id = 0
        tokenizer.eos_token = "<|endoftext|>"
        tokenizer.eos_token_id = 0

        # Extract processor-specific kwargs
        crop_size = kwargs.pop("crop_size", 378)
        max_crops = kwargs.pop("max_crops", 12)
        overlap_margin = kwargs.pop("overlap_margin", 4)
        patch_size = kwargs.pop("patch_size", 14)
        chat_template = kwargs.pop("chat_template", None)

        # Set default chat template on tokenizer if not already set
        if chat_template is None:
            chat_template = cls._default_chat_template
        if tokenizer.chat_template is None:
            tokenizer.chat_template = chat_template

        return cls(
            tokenizer=tokenizer,
            chat_template=chat_template,
            crop_size=crop_size,
            max_crops=max_crops,
            overlap_margin=overlap_margin,
            patch_size=patch_size,
        )

    def __call__(
        self,
        images: ImageInput = None,
        text: TextInput
        | PreTokenizedInput
        | list[TextInput]
        | list[PreTokenizedInput] = None,
        **kwargs: Unpack[Moondream3ProcessorKwargs],
    ) -> BatchFeature:
        """
        Process images and text for Moondream3 model.

        Args:
            images: Input images (PIL Image, numpy array, or list thereof).
            text: Input text or list of texts.
            **kwargs: Additional processing arguments.

        Returns:
            BatchFeature with processed inputs.
        """
        output_kwargs = self._merge_kwargs(
            Moondream3ProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )

        # Process images
        image_features = {}
        if images is not None:
            processed_images = []
            tilings = []

            images_list = images if isinstance(images, list) else [images]
            for image in images_list:
                pixel_values, tiling = self.preprocess_image(
                    image, **output_kwargs["images_kwargs"]
                )
                processed_images.append(pixel_values)
                tilings.append(tiling)

            if processed_images:
                image_features["pixel_values"] = processed_images
                image_features["tilings"] = tilings

        # Process text
        if text is not None:
            if not isinstance(text, list):
                text = [text]

            # Get text kwargs, remove keys we set ourselves
            text_kwargs = output_kwargs.get("text_kwargs", {}).copy()
            text_kwargs.pop("return_tensors", None)
            text_kwargs.pop("add_special_tokens", None)

            # Tokenize text
            tokenized = self.tokenizer(
                text,
                add_special_tokens=True,
                return_tensors="pt",
                **text_kwargs,
            )

            output = BatchFeature(data=dict(tokenized))

            # Add image features
            if image_features:
                output["pixel_values"] = image_features["pixel_values"]
                output["tilings"] = image_features["tilings"]

            return output

        # If only images were provided
        return BatchFeature(data=image_features)

    def preprocess_image(
        self,
        image: Image.Image,
        max_crops: int = 12,
        overlap_margin: int = 4,
        crop_size: int = 378,
        patch_size: int = 14,
        convert_to_rgb: bool = True,
        return_tensors: str = "pt",
    ) -> tuple[torch.Tensor, tuple[int, int]]:
        """
        Preprocess an image using overlap-and-resize cropping strategy.

        Args:
            image: Input PIL Image.
            max_crops: Maximum number of crops.
            overlap_margin: Margin for overlapping in patches.
            crop_size: Size of each crop.
            patch_size: Size of each patch.
            convert_to_rgb: Whether to convert to RGB.
            return_tensors: Return type ("pt" for PyTorch).

        Returns:
            Tuple of (pixel_values tensor, tiling tuple).
        """
        if convert_to_rgb:
            image = convert_image_mode(image, "RGB")

        # Convert to numpy array
        image_array = np.array(image)
        original_h, original_w = image_array.shape[:2]

        margin_pixels = patch_size * overlap_margin
        total_margin_pixels = margin_pixels * 2

        crop_patches = crop_size // patch_size
        crop_window_patches = crop_patches - (2 * overlap_margin)
        crop_window_size = crop_window_patches * patch_size

        tiling = select_tiling(
            original_h - total_margin_pixels,
            original_w - total_margin_pixels,
            crop_window_size,
            max_crops,
        )

        n_crops = tiling[0] * tiling[1] + 1
        crops = np.zeros((n_crops, crop_size, crop_size, 3), dtype=np.uint8)

        target_size = (
            tiling[0] * crop_window_size + total_margin_pixels,
            tiling[1] * crop_window_size + total_margin_pixels,
        )

        # Resize image
        pil_img = Image.fromarray(image_array)
        resized = pil_img.resize(
            (int(target_size[1]), int(target_size[0])),
            resample=Image.Resampling.LANCZOS,
        )
        resized_array = np.asarray(resized)

        # Create global crop
        global_pil = pil_img.resize(
            (crop_size, crop_size), resample=Image.Resampling.LANCZOS
        )
        crops[0] = np.asarray(global_pil)

        # Create local crops
        for i in range(tiling[0]):
            for j in range(tiling[1]):
                y0 = i * crop_window_size
                x0 = j * crop_window_size
                y_end = min(y0 + crop_size, resized_array.shape[0])
                x_end = min(x0 + crop_size, resized_array.shape[1])

                crop_region = resized_array[y0:y_end, x0:x_end]
                crop_idx = 1 + i * tiling[1] + j
                h_slice = slice(None, crop_region.shape[0])
                w_slice = slice(None, crop_region.shape[1])
                crops[crop_idx, h_slice, w_slice] = crop_region

        # Normalize: (x - 0.5) / 0.5 = 2*x - 1
        # Convert to float and normalize to [-1, 1]
        pixel_values = crops.astype(np.float32) / 255.0
        pixel_values = (pixel_values - 0.5) / 0.5

        # Convert to tensor: (n_crops, H, W, C) -> (n_crops, C, H, W)
        pixel_values = np.transpose(pixel_values, (0, 3, 1, 2))

        if return_tensors == "pt":
            pixel_values = torch.from_numpy(pixel_values)

        return pixel_values, tiling

    def get_num_image_tokens(self) -> int:
        """Return the number of image tokens (729 = 27x27 patches)."""
        return self.patches_per_crop

    def batch_decode(self, *args, **kwargs):
        """Forward to tokenizer's batch_decode."""
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """Forward to tokenizer's decode."""
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        return tokenizer_input_names + ["pixel_values", "tilings"]

__call__

__call__(
    images: ImageInput = None,
    text: TextInput
    | PreTokenizedInput
    | list[TextInput]
    | list[PreTokenizedInput] = None,
    **kwargs: Unpack[Moondream3ProcessorKwargs],
) -> BatchFeature

Process images and text for Moondream3 model.

Parameters:

Name Type Description Default
images ImageInput

Input images (PIL Image, numpy array, or list thereof).

None
text TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput]

Input text or list of texts.

None
**kwargs Unpack[Moondream3ProcessorKwargs]

Additional processing arguments.

{}

Returns:

Type Description
BatchFeature

BatchFeature with processed inputs.

Source code in vllm/transformers_utils/processors/moondream3.py
def __call__(
    self,
    images: ImageInput = None,
    text: TextInput
    | PreTokenizedInput
    | list[TextInput]
    | list[PreTokenizedInput] = None,
    **kwargs: Unpack[Moondream3ProcessorKwargs],
) -> BatchFeature:
    """
    Process images and text for Moondream3 model.

    Args:
        images: Input images (PIL Image, numpy array, or list thereof).
        text: Input text or list of texts.
        **kwargs: Additional processing arguments.

    Returns:
        BatchFeature with processed inputs.
    """
    output_kwargs = self._merge_kwargs(
        Moondream3ProcessorKwargs,
        tokenizer_init_kwargs=self.tokenizer.init_kwargs,
        **kwargs,
    )

    # Process images
    image_features = {}
    if images is not None:
        processed_images = []
        tilings = []

        images_list = images if isinstance(images, list) else [images]
        for image in images_list:
            pixel_values, tiling = self.preprocess_image(
                image, **output_kwargs["images_kwargs"]
            )
            processed_images.append(pixel_values)
            tilings.append(tiling)

        if processed_images:
            image_features["pixel_values"] = processed_images
            image_features["tilings"] = tilings

    # Process text
    if text is not None:
        if not isinstance(text, list):
            text = [text]

        # Get text kwargs, remove keys we set ourselves
        text_kwargs = output_kwargs.get("text_kwargs", {}).copy()
        text_kwargs.pop("return_tensors", None)
        text_kwargs.pop("add_special_tokens", None)

        # Tokenize text
        tokenized = self.tokenizer(
            text,
            add_special_tokens=True,
            return_tensors="pt",
            **text_kwargs,
        )

        output = BatchFeature(data=dict(tokenized))

        # Add image features
        if image_features:
            output["pixel_values"] = image_features["pixel_values"]
            output["tilings"] = image_features["tilings"]

        return output

    # If only images were provided
    return BatchFeature(data=image_features)

batch_decode

batch_decode(*args, **kwargs)

Forward to tokenizer's batch_decode.

Source code in vllm/transformers_utils/processors/moondream3.py
def batch_decode(self, *args, **kwargs):
    """Forward to tokenizer's batch_decode."""
    return self.tokenizer.batch_decode(*args, **kwargs)

decode

decode(*args, **kwargs)

Forward to tokenizer's decode.

Source code in vllm/transformers_utils/processors/moondream3.py
def decode(self, *args, **kwargs):
    """Forward to tokenizer's decode."""
    return self.tokenizer.decode(*args, **kwargs)

from_pretrained classmethod

from_pretrained(pretrained_model_name_or_path, **kwargs)

Load the processor, using a separate tokenizer repo.

The moondream3 model uses a custom tokenizer from 'moondream/starmie-v1' instead of having tokenizer files in the model repo.

Source code in vllm/transformers_utils/processors/moondream3.py
@classmethod
def from_pretrained(
    cls,
    pretrained_model_name_or_path,
    **kwargs,
):
    """
    Load the processor, using a separate tokenizer repo.

    The moondream3 model uses a custom tokenizer from 'moondream/starmie-v1'
    instead of having tokenizer files in the model repo.
    """
    from transformers import AutoTokenizer

    tokenizer = kwargs.pop("tokenizer", None)

    tokenizer_kwargs = {
        "trust_remote_code": kwargs.get("trust_remote_code", False),
    }
    for key in (
        "cache_dir",
        "force_download",
        "local_files_only",
        "revision",
        "subfolder",
        "token",
        "use_fast",
    ):
        if key in kwargs:
            tokenizer_kwargs[key] = kwargs[key]

    if isinstance(tokenizer, str):
        tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)

    if tokenizer is None:
        # Prefer model-local tokenizer files first. If unavailable, fall
        # back to moondream's dedicated tokenizer repository.
        try:
            tokenizer = AutoTokenizer.from_pretrained(
                pretrained_model_name_or_path, **tokenizer_kwargs
            )
        except Exception:
            tokenizer = AutoTokenizer.from_pretrained(
                cls._tokenizer_repo, **tokenizer_kwargs
            )

    # Configure special tokens for Moondream3
    # BOS and EOS are both token 0 (<|endoftext|>), matching the native
    # config (TokenizerConfig.bos_id=0, eos_id=0). This is standard for
    # GPT-2 style models where <|endoftext|> signals both start and end.
    # Token 1 (<|md_reserved_0|>) is a template delimiter, NOT the EOS.
    tokenizer.bos_token = "<|endoftext|>"
    tokenizer.bos_token_id = 0
    tokenizer.eos_token = "<|endoftext|>"
    tokenizer.eos_token_id = 0

    # Extract processor-specific kwargs
    crop_size = kwargs.pop("crop_size", 378)
    max_crops = kwargs.pop("max_crops", 12)
    overlap_margin = kwargs.pop("overlap_margin", 4)
    patch_size = kwargs.pop("patch_size", 14)
    chat_template = kwargs.pop("chat_template", None)

    # Set default chat template on tokenizer if not already set
    if chat_template is None:
        chat_template = cls._default_chat_template
    if tokenizer.chat_template is None:
        tokenizer.chat_template = chat_template

    return cls(
        tokenizer=tokenizer,
        chat_template=chat_template,
        crop_size=crop_size,
        max_crops=max_crops,
        overlap_margin=overlap_margin,
        patch_size=patch_size,
    )

get_num_image_tokens

get_num_image_tokens() -> int

Return the number of image tokens (729 = 27x27 patches).

Source code in vllm/transformers_utils/processors/moondream3.py
def get_num_image_tokens(self) -> int:
    """Return the number of image tokens (729 = 27x27 patches)."""
    return self.patches_per_crop

preprocess_image

preprocess_image(
    image: Image,
    max_crops: int = 12,
    overlap_margin: int = 4,
    crop_size: int = 378,
    patch_size: int = 14,
    convert_to_rgb: bool = True,
    return_tensors: str = "pt",
) -> tuple[Tensor, tuple[int, int]]

Preprocess an image using overlap-and-resize cropping strategy.

Parameters:

Name Type Description Default
image Image

Input PIL Image.

required
max_crops int

Maximum number of crops.

12
overlap_margin int

Margin for overlapping in patches.

4
crop_size int

Size of each crop.

378
patch_size int

Size of each patch.

14
convert_to_rgb bool

Whether to convert to RGB.

True
return_tensors str

Return type ("pt" for PyTorch).

'pt'

Returns:

Type Description
tuple[Tensor, tuple[int, int]]

Tuple of (pixel_values tensor, tiling tuple).

Source code in vllm/transformers_utils/processors/moondream3.py
def preprocess_image(
    self,
    image: Image.Image,
    max_crops: int = 12,
    overlap_margin: int = 4,
    crop_size: int = 378,
    patch_size: int = 14,
    convert_to_rgb: bool = True,
    return_tensors: str = "pt",
) -> tuple[torch.Tensor, tuple[int, int]]:
    """
    Preprocess an image using overlap-and-resize cropping strategy.

    Args:
        image: Input PIL Image.
        max_crops: Maximum number of crops.
        overlap_margin: Margin for overlapping in patches.
        crop_size: Size of each crop.
        patch_size: Size of each patch.
        convert_to_rgb: Whether to convert to RGB.
        return_tensors: Return type ("pt" for PyTorch).

    Returns:
        Tuple of (pixel_values tensor, tiling tuple).
    """
    if convert_to_rgb:
        image = convert_image_mode(image, "RGB")

    # Convert to numpy array
    image_array = np.array(image)
    original_h, original_w = image_array.shape[:2]

    margin_pixels = patch_size * overlap_margin
    total_margin_pixels = margin_pixels * 2

    crop_patches = crop_size // patch_size
    crop_window_patches = crop_patches - (2 * overlap_margin)
    crop_window_size = crop_window_patches * patch_size

    tiling = select_tiling(
        original_h - total_margin_pixels,
        original_w - total_margin_pixels,
        crop_window_size,
        max_crops,
    )

    n_crops = tiling[0] * tiling[1] + 1
    crops = np.zeros((n_crops, crop_size, crop_size, 3), dtype=np.uint8)

    target_size = (
        tiling[0] * crop_window_size + total_margin_pixels,
        tiling[1] * crop_window_size + total_margin_pixels,
    )

    # Resize image
    pil_img = Image.fromarray(image_array)
    resized = pil_img.resize(
        (int(target_size[1]), int(target_size[0])),
        resample=Image.Resampling.LANCZOS,
    )
    resized_array = np.asarray(resized)

    # Create global crop
    global_pil = pil_img.resize(
        (crop_size, crop_size), resample=Image.Resampling.LANCZOS
    )
    crops[0] = np.asarray(global_pil)

    # Create local crops
    for i in range(tiling[0]):
        for j in range(tiling[1]):
            y0 = i * crop_window_size
            x0 = j * crop_window_size
            y_end = min(y0 + crop_size, resized_array.shape[0])
            x_end = min(x0 + crop_size, resized_array.shape[1])

            crop_region = resized_array[y0:y_end, x0:x_end]
            crop_idx = 1 + i * tiling[1] + j
            h_slice = slice(None, crop_region.shape[0])
            w_slice = slice(None, crop_region.shape[1])
            crops[crop_idx, h_slice, w_slice] = crop_region

    # Normalize: (x - 0.5) / 0.5 = 2*x - 1
    # Convert to float and normalize to [-1, 1]
    pixel_values = crops.astype(np.float32) / 255.0
    pixel_values = (pixel_values - 0.5) / 0.5

    # Convert to tensor: (n_crops, H, W, C) -> (n_crops, C, H, W)
    pixel_values = np.transpose(pixel_values, (0, 3, 1, 2))

    if return_tensors == "pt":
        pixel_values = torch.from_numpy(pixel_values)

    return pixel_values, tiling

select_tiling

select_tiling(
    height: int, width: int, crop_size: int, max_crops: int
) -> tuple[int, int]

Determine the optimal number of tiles to cover an image.

Source code in vllm/transformers_utils/processors/moondream3.py
def select_tiling(
    height: int, width: int, crop_size: int, max_crops: int
) -> tuple[int, int]:
    """Determine the optimal number of tiles to cover an image."""
    if height <= crop_size or width <= crop_size:
        return (1, 1)

    min_h = math.ceil(height / crop_size)
    min_w = math.ceil(width / crop_size)

    if min_h * min_w > max_crops:
        ratio = math.sqrt(max_crops / (min_h * min_w))
        return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))

    h_tiles = math.floor(math.sqrt(max_crops * height / width))
    w_tiles = math.floor(math.sqrt(max_crops * width / height))

    h_tiles = max(h_tiles, min_h)
    w_tiles = max(w_tiles, min_w)

    if h_tiles * w_tiles > max_crops:
        if w_tiles > h_tiles:
            w_tiles = math.floor(max_crops / h_tiles)
        else:
            h_tiles = math.floor(max_crops / w_tiles)

    return (max(1, h_tiles), max(1, w_tiles))