Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MaverickRen committed Feb 10, 2025
1 parent 0ae333b commit 3c850cb
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 32 deletions.
2 changes: 1 addition & 1 deletion LDM/tools/calvin_ldm_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env bash
CONFIG="./configs/calvin_ldm_debug.py"
CONFIG="./configs/calvin_ldm.py"
GPUS=1
NNODES=1
NODE_RANK=0
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Download CALVIN dataset follow the official instructions and organize it as foll
│ │ └── magvit_init.pth
│ └──
```
Use the script ./LDM/tools/calvin_ldm_train.sh to initiate LDM training. Training requires loading the [Magvit weights](https://huggingface.co/maverickrzw/VideoWorld_CALVIN/tree/main) we pre-trained on natural image reconstruction as initialization. Upon completion, the latent codes on the training set will be automatically saved to ./LDM/work_dirs/calvin_ldm_results.pth, and the UMAP visualization of the latent codes will also be generated.
Use the script ./LDM/tools/calvin_ldm_train.sh to initiate LDM training. Training requires loading the [Magvit weights](https://huggingface.co/maverickrzw/VideoWorld-GoBattle/tree/main) we pre-trained on natural image reconstruction as initialization. Upon completion, the latent codes on the training set will be automatically saved to ./LDM/work_dirs/calvin_ldm_results.pth, and the UMAP visualization of the latent codes will also be generated.
```
cd LDM
bash ./tools/calvin_ldm_train.sh
Expand Down
6 changes: 3 additions & 3 deletions VideoWorld/configs/calvin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
type='VQGANEncoder',
init_cfg=dict(
type='Pretrained',
checkpoint='work_dirs/configs/vqgan_fsq_imagenet1k_style-2_256x256_ep60_calvin_16code_hand/iter_460000_new.pth'),
checkpoint='./work_dirs/calvin_fsq.pth'),
width_mults=(1,1,1,2,2,4,4),
),
neck=dict(
type='InternLMGenModel',
pretrain_path='work_dirs/init/Intern_300m',
pretrain_path='./work_dirs/Intern_300m',
vq_num=64000,
sepcial_token_num=2+64000,
use_text=False
Expand Down Expand Up @@ -128,7 +128,7 @@
# collate_fn=dict(type='default_collate'),
dataset=dict(
type='CALVINEnvValDataset',
data_root = "/mnt/bn/panxuran/calvin/task_ABCD_D/",
data_root = "./data/calvin/task_ABCD_D/",
pipeline=test_pipeline,
num_sequences=20),
)
Expand Down
6 changes: 3 additions & 3 deletions VideoWorld/configs/calvin_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
type='VQGANEncoder',
init_cfg=dict(
type='Pretrained',
checkpoint='work_dirs/configs/vqgan_fsq_imagenet1k_style-2_256x256_ep60_calvin_16code_hand/iter_460000_new.pth'),
checkpoint='./work_dirs/calvin_fsq.pth'),
width_mults=(1,1,1,2,2,4,4),
),
neck=dict(
type='InternLMGenModel',
pretrain_path='work_dirs/init/Intern_300m',
pretrain_path='./work_dirs/Intern_300m',
vq_num=64000,
sepcial_token_num=2+64000,
use_text=False
Expand Down Expand Up @@ -128,7 +128,7 @@
# collate_fn=dict(type='default_collate'),
dataset=dict(
type='CALVINEnvValDataset',
data_root = "/mnt/bn/panxuran/calvin/task_ABCD_D/",
data_root = "./data/calvin/task_ABCD_D/",
pipeline=test_pipeline,
num_sequences=20),
)
Expand Down
4 changes: 2 additions & 2 deletions VideoWorld/configs/go_battle_vs_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
type='VQGANEncoder',
init_cfg=dict(
type='Pretrained',
checkpoint='work_dirs/init/fsq/vqgan/16code_10000.pth'),
checkpoint='./work_dirs/go_fsq.pth'),
width_mults=(1,1,1,2,2,4,4),
),

neck=dict(
type='InternLMGenModel',
pretrain_path='work_dirs/init/Intern_300m',
pretrain_path='./work_dirs/Intern_300m',
vq_num=64000,
sepcial_token_num=3+64000+2,
use_text=True
Expand Down
23 changes: 2 additions & 21 deletions VideoWorld/falcon/models/algorithms/calvin_GR1_wostate_vq_idm.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,32 +466,13 @@ def forward_train(self, img, input_ids, pred_label=None, attention_mask=None, **
logits, loss, hidden_state = self.neck(inputs_embeds=stacked_inputs, attention_mask=stacked_attention_mask, labels=stacked_label, return_dict=True)
x = hidden_state.reshape(batch_size, sequence_length, n_tokens, self.hidden_size)

# import pdb;pdb.set_trace()
if self.act_pred:
action_embedding = x[:, :, act_query_token_start_i:act_query_token_start_i+self.la_act_scope] if not self.fix_act_pred else x[:, :, act_query_token_start_i-1:act_query_token_start_i+self.la_act_scope]
for pred_act_mlp in self.pred_act_mlps:
action_embedding = pred_act_mlp(action_embedding)
# action_embedding = action_embedding.mean(dim=2)
action_embedding = self.act_embed_fuse(action_embedding.permute(0, 1, 3, 2)).squeeze(-1)
arm_action_preds = self.pred_arm_act(action_embedding) # (b, l, act_dim - 1)
gripper_action_preds = self.pred_gripper_act(action_embedding) # (b, l, 1)
loss_state = self.state_loss(arm_action_preds, arm_action)
loss_gripper = self.gripper_loss(gripper_action_preds, gripper_action)
losses = { "loss_v": loss, "loss_state": loss_state, "loss_gripper": loss_gripper}
elif self.la_act_pred:
losses = { "loss_v": loss}

# if self.sup_actions:

# loss_rgb = self.rgb_loss(obs_preds[:, :-1], obs_targets[:, 1:])
# loss_hand_rgb = self.rgb_loss(obs_hand_preds[:, :-1], obs_hand_targets[:, 1:])
# print("--------", input_ids[0], input_ids[1], "--------")
losses = { "loss_v": loss}

return losses


def rollout_pred_rgb(self, img, seq_input_ids, pred_label=None, seq_attention_mask=None, index=None, **kwargs):
# import pdb;pdb.set_trace()

import cv2
scene = kwargs.pop('scene')
robot_obs = kwargs.pop('state')
Expand Down
2 changes: 1 addition & 1 deletion VideoWorld/tools/calvin_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env bash
CONFIG="./configs/calvin_train_debug.py"
CONFIG="./configs/calvin_train.py"
GPUS=1
NNODES=1
NODE_RANK=0
Expand Down

0 comments on commit 3c850cb

Please sign in to comment.