Skip to content

galilai-group/cjepa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

120 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

C-JEPA

Causal-JEPA: Learning World Models through Object-Level Latent Interventions

by Heejeong Nam, Quentin Le Lidec*,Lucas Maes*, Yann LeCun, Randall Balestriero.

architecture

Summary

World models require robust relational understanding to support prediction, reasoning, and control. While object-centric representations provide a useful abstraction, they are not sufficient to capture interaction-dependent dynamics. We therefore propose C-JEPA, a simple and flexible object-centric world model that extends masked joint embedding prediction from image patches to object-centric representations. By applying object-level masking that requires an object's state to be inferred from other objects, C-JEPA induces latent interventions with counterfactual-like effects and prevents shortcut solutions, making interaction reasoning essential. Empirically, C-JEPA leads to consistent gains in visual question answering, with an absolute improvement of about 20% in counterfactual reasoning compared to the same architecture without object-level masking. On agent control tasks, C-JEPA enables substantially more efficient planning by using only 1% of the total latent input features required by patch-based world models, while achieving comparable performance. Finally, we provide a formal analysis demonstrating that object-level masking induces a causal inductive bias via latent interventions.

Setup

Dataset Preparation

Please refer to DATASET.md for dataset preparation.

Object-Centric Model / Representations for C-JEPA

C-JEPA relies on object-centric encoders to extract object-centric representations. You can train the encoder by yourself, or download the model checkpoints from HuggingFace, or download the pre-extracted slot representations.

visualization

1. Train Object-Centric Encoder

  • Train VideoSAUR on CLEVRER

    PYTHONPATH=. python src/third_party/videosaur/videosaur/train.py \
        src/third_party/videosaur/configs/videosaur/clevrer_dinov2_hf.yml \
        dataset.train_shards="/to/data/path/clevrer-train-{000000..0000XX}.tar" \
        dataset.val_shards="/to/data/path/clevrer-val-{000000..0000XX}.tar" \
        globals.BATCH_SIZE_PER_GPU=32 \
    
  • Train SAVi on CLEVRER

    • Refer to slotformer repo to setup data for train savi. Then you have to modify data_root and gpus in src/third_party/slotformer/base_slots/configs/stosavi_clevrer_params.py to point to the data directory. After that, you can run the code below to train SAVi on CLEVRER.

    • Before running the code, replace whole src/third_party/nerv/nerv/utils/misc.py with src/custom_codes/misc.py. This is because the original code is based on pytorch-lightning==0.8.* while we are using pytorch-lightning==2.6.*.

    PYTHONPATH=. torchrun --nproc_per_node=3 src/aloe_train.py \
      --task base_slots \
      --params src/third_party/slotformer/base_slots/configs/stosavi_clevrer_params.py \
      --exp_name clevrer_savi_reproduce \
      --out_dir $OUTDIR \
      --fp16 --ddp --cudnn
    
  • Train VideoSAUR on PushT

    PYTHONPATH=. python src/third_party/videosaur/videosaur/train.py \
        src/third_party/videosaur/configs/videosaur/pusht_dinov2_hf.yml \
        dataset.train_shards="pusht_mixed/train/pusht-train-{000000..0000xx}.tar" \
        dataset.val_shards="pusht_mixed/validation/pusht-val-{000000..00000x}.tar"
    

2. Model Checkpoints

Dataset Encoder Hyperparams Checkpoint Link
CLEVRER VideoSAUR clevrer_dinov2_hf.yml Checkpoint
CLEVRER SAVi stosavi_clevrer_params.py Checkpoint
Push-T VideoSAUR pusht_dinov2_hf.yml Checkpoint

3. Pre-extracted Slot Representations

Dataset Encoder Config Checkpoint Link
CLEVRER VideoSAUR clevrer_dinov2_hf.yml Checkpoint
CLEVRER SAVi stosavi_clevrer_params.py Checkpoint
PUSHT VideoSAUR pusht_dinov2_hf.yml Checkpoint

Train / Download C-JEPA

Extract slot representations with the checkpoints

Use this only if you want to extract slots by yourself. If you are using pre-extracted slots, you can skip this step and directly download the pre-extracted slots from HuggingFace.

  • CLEVRER SAVi slots

    • Refer to slotformer repo to setup data for train savi. Then you have to modify data_root in src/third_party/slotformer/base_slots/configs/stosavi_clevrer_params.py to point to the data directory. After that, you can run the code below to train SAVi on CLEVRER.
    PYTHONPATH=. python src/third_party/slotformer/base_slots/extract_slots.py --params src/third_party/slotformer/base_slots/configs/stosavi_clevrer_params.py  --weight $WEIGHT  --save_path clevrer_savi_slots.pkl
    
  • CLEVRER VideoSAUR

    PYTHONPATH=. python src/third_party/slotformer/base_slots/extract_videosaur.py --weight $WEIGHT --data_root="~/.stable_worldmodel"   --save_path=$SAVE_DIR --dataset="clevrer"  --videosaur_config="src/third_party/videosaur/configs/videosaur/clevrer_dinov2_hf.yml" 
    
  • PushT VideoSAUR Slots

    PYTHONPATH=. python src/third_party/slotformer/base_slots/extract_videosaur.py --weight $WEIGHT --data_root="~/.stable_worldmodel"   --save_path=$SAVE_DIR  --dataset="pusht_expert"  --videosaur_config="src/third_party/videosaur/configs/videosaur/pusht_dinov2_hf.yml"   --params="src/third_party/slotformer/aloe_pusht_params.py"
    
  • Extracted pkl will look like:

    {
        'train': {'0_pixels.mp4': slots, '1_pixels.mp4': slots, ...},  # slots: [T, N, 128] each
        'val': {...},
        'test': {...}
    }
    

2. Run C-JEPA with pre-extracted slot representations

Use scripts below. You can change the config file and the checkpoint path to run with different settings. For example, you can change num_masked_slots to control how many slots are masked.

sh script/clevrer/train_cjepa_from_slot.sh
sh script/pusht/train_cjepa_from_slot.sh 

3. Download C-JEPA checkpoints

  • We release checkpoints for C-JEPA trained with different object-centric backbones and different number of masked slots.
  • (*) means it performs best. (See the paper for details.)
Dataset OC-Backbone Masked Slots / Total Slots Checkpoint Link
CLEVRER VideoSAUR 0/7 checkpoint
CLEVRER VideoSAUR 1/7 checkpoint
CLEVRER VideoSAUR 2/7 checkpoint
CLEVRER VideoSAUR 3/7 checkpoint
CLEVRER VideoSAUR 4/7 (*) checkpoint
CLEVRER SAVi 0/7 checkpoint
CLEVRER SAVi 1/7 checkpoint
CLEVRER SAVi 2/7 (*) checkpoint
CLEVRER SAVi 3/7 checkpoint
CLEVRER SAVi 4/7 checkpoint
PUSHT VideoSAUR 0/4 checkpoint
PUSHT VideoSAUR 1/4 (*) checkpoint
PUSHT VideoSAUR 2/4 checkpoint

Evaluation

Evaluate Control on Push-T

sh scripts/pusht/test_planning.sh

Evaluate Visual Reasoning on CLEVRER

  • We will first rollout slots (from 128 frame to 160 frame) with C-JEPA checkpoint and pre-extracted slots.
# change CKPTPATH and SLOTPATH, `predictor_lr`, `num_masked_slots` before. SLOTPATH should be matched with the checkpoint that you are using (dataset, oc-encoder) and it should point to the pre-extracted slot representations. For example, if you are using `clevrer_savi_2.ckpt`, you should use `clevrer_savi_slots.pkl` and set `num_masked_slots` to 2.
# `predictor_lr` (default=5e-4), `num_masked_slots` should be matched with config is CKPT that you are using (they go to the output name)
# output name will be like : rollout_{SLOTPATH name}_{CKPT lr, CKPT masked slot}
sh scripts/clevrer/rollout_from_slot.sh
  • This will save output file under the same directory as the pre-extracted slot representations.
# rollout_clevrer_slots_{configuration}.pkl :
{
    'train': {'0_pixels.mp4': slots, '1_pixels.mp4': slots, ...},  # slots: [160, N, 128] each
    'val': {...},
    'test': {...}
}
  • Then you can run the code below to evaluate C-JEPA on CLEVRER visual question answering.
  • Again, before running the code, replace whole src/third_party/nerv/nerv/utils/misc.py with src/custom_codes/misc.py. This is because the thrid-party code is based on pytorch-lightning==0.8.* while we are using pytorch-lightning==2.6.*.
  • You should change params manually in src/third_party/slotformer/clevrer_vqa/configs/aloe_clevrer_params-rollout.py. For example,
    • gpu (it should exactly match the number of the visible devices)
    • lr
  • Then run script below to train and test aloe.
sh scripts/clevrer/train_aloe.sh
sh scripts/clevrer/test_aloe.sh

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors