We will discuss:

  • Data Collection: How to gather and preprocess multimodal data.
  • Alignment: Techniques to align different modalities temporally and spatially.
  • Training Objectives: Common loss functions and training objectives used in multimodal learning.
  • Evaluation Metrics: How to evaluate the performance of multimodal models.

Data collection

  • The main focus of the data is first, scale, and second, alignment.

  • You can rely on big scale Internet data with naturally emergent aligned data

    • image-text caption pairs (LAION, …)
    • video and audio (youtube, …)
  • Sometimes, modalities can be added by post-processing

    • e.g., by using individual expert networks to pseudo-label (e.g., works well for RGB), you get explicit alignment this way.
    • e.g, extract transcript from video audio, this way, you get explicit temporal alignment between video, audio, and text.
    • Even more synthetic, you can artificially augment your data to generate aligned data
      • e.g., in Unified-IO, they synthetically add shapes in an image, and then ask the model to list the bounding boxes of synthetically added shapes.
    • You may also require LLMs to “augment”/“diversify” your data by taking simple text captions, giving more textual context w.r.t the current modality (e.g. give all the segmented objects and their bounding boxes as context) to the LLM, and create a more detailed caption.
  • Part of the pretraining data might also be text-only, as it is usually the main way users (and APIs) interact with the model.

  • Good SFT data is also crucial, but doesn’t require as much scale, but coverage is quite important.

Training objectives

  • if we’re using early fusion, with everything in token-space, we can just use the cross-entropy loss for every modality

    • Even with early fusion, you must be careful of balancing the loss between different modalities, as some modalities have much higher token counts than others
  • In the usual case (for GPU-poor folks), when we’re adding an adaptor to an LLM, we usually have an aligned modality-text dataset and we minimize the -conditioned text generation loss (usually cross-entropy) i.e. where are the projected tokens or “prompt” from and is the text prompt.

  • In the case, where we have an adaptor to an LLM, but we also want to generate the modality, we might generate special “signal tokens” i.e.

    • In that case, for example when using an external text-to-modality generator , we may have an Output Projector that maps the signal token representations from the LLM Backbone into features , , understandable to the generator
    • Given some -text dataset , is first fed into LLM to generate the corresponding , then mapped into
      • To facilitate alignment of the mapped features , the goal is to minimize the distance between and the conditional text representations of :
      • where is the textual condition encoder in
  • In the completely general case, we might have to use a combination of different loss functions to address the multiple tasks.

Evaluation metrics

  • Multimodal Alignment Score: measuring how well different modalities are aligned in the feature space, even if there was no explicit contrastive objective during training.
  • Modality-specific Performance: Evaluating each modality’s performance independently to ensure that integrating multiple modalities does not degrade individual performance
  • Multimodal benchmarks: Some benchmarks are explicitly multi-modal, e.g. MMMU (Massive Multi-discipline Multimodal Understanding and Reasoning) explicitly tests for interleaved text-image understanding, or MULTIBENCH which tests a wider range of combination of modalities ({video, audio, optical flow} or {image, sensor}) in different domains (healthcare, robotics, …)
  • Human evaluation: The golden standard for generation evaluation. It usually consists of a human deciding which output of different models it prefers the most, according to some guidelines.

Challenges to training multi-modal models

Data collection / Alignment

  • Well aligned large multimodal dataset are difficult and expensive to collect.
  • I partly discuss mitigations/techniques in the previous answer on how to train multi-modal models.
  • Mitigations:
    • if possible, pseudo-labelling using individual expert networks
    • If not possible (i.e. not possible to infer the label),
      • use large scale datasets and rely on implicit/emergent alignment
        • e.g. ImageBind
    • Do the pretraining with not perfectly aligned dataset, but spend the time to create a small but very highly aligned with lots of modality coverage, “SFT” dataset

Lack of robustness to missing modalities

  • Modality-specific missingness is a common real-world problem and is especially problematic when the missingness of a modality is predictive of the label
  • e.g. can happen in healthcare settings. Data acquisition is dependent on the healthcare professional
    • For example, for a fixed disease, a patient with a milder form may be monitored less and won’t require advanced tests. In that case, the lack of advanced tests may serve as a predictive feature for mildness of the condition.
  • Mitigations:
    • at the design stage
      • sequential update to a shared state, can skip modalities instead of padding/imputing.
    • at the training stage
      • masked modelling i.e. random token dropping
      • being careful with your data mixture

Training instabilities

  • Training instabilities have been observed when training a large multimodal “monolithic” model
    • In Chameleon, they report that the cause of the divergence is due to the softmax operation being problematic when training with multiple modalities of significantly varying entropy, due to the translation invariant property of softmax (i.e., softmax(z) = softmax(z + c)).
    • Because they share all weights of the model across modalities, each modality will try to “compete” with the other by increasing its norms slightly
  • Mitigations i.e. controlling norm growth:
    • Paying close attention to unusually large values in your forward pass
    • paying attention to the placement of layer norms.
      • In Chameleon, they use layer-norm after the self-attention & MLP, instead of before
      • It bounds the norm of what’s added to the residual stream.
    • QK-Norm (controls the norm growth of the attention logits)
      • layer-norm the query and keys before doing the dot product
    • z-loss (controls for the divergence in the output logits)
      • Let denote the model’s output logits, which are used to compute class probabilities via a softmax where .
      • They add an auxiliary loss , referred to as z-loss, with a coefficient of , to encourage to stay close to 1.

Input

  • Different modalities can have very different format (continous vs. discrete), and sizes (images vs. binary feature).
  • Possible Mitigations:
    • Use tokenizers + transformer architecure
      • shared feature space = token embedding space
      • can use VQ-VAE to map continuous data into discrete tokens
      • transformer architecture allows for elegant handling of varying sizes (i.e. sequence length in token space).
      • allows for generation too
    • Each modality has a specific encoder that maps to a shared (continuous) feature space.