diff --git a/amplify/utils/libero_utils/env_utils.py b/amplify/utils/libero_utils/env_utils.py index 02efa8e..e45db6d 100644 --- a/amplify/utils/libero_utils/env_utils.py +++ b/amplify/utils/libero_utils/env_utils.py @@ -52,7 +52,16 @@ def get_task_emb(task_suite, task_name, dataset_path=None): task_file = f"{task_name}_demo.hdf5" task_file_path = os.path.join(task_suite_path, task_file) with h5py.File(task_file_path, 'r') as f: - task_emb = torch.tensor(f['text_emb'][()]) + if 'text_emb' in f: + task_emb = torch.tensor(f['text_emb'][()]) + else: + # Fallback to the preprocessed text embeddings if not in the dataset + root_dir = get_root_dir() + fallback_path = os.path.join(root_dir, 'preprocessed_data', task_suite, 'text', f'{task_name}.hdf5') + if not os.path.exists(fallback_path): + raise ValueError(f"text_emb not found in {task_file_path} and fallback text embedding missing at {fallback_path}. Please run: python -m preprocessing.preprocess_libero mode=text suite='{task_suite}'") + with h5py.File(fallback_path, 'r') as fallback_f: + task_emb = torch.tensor(fallback_f['text_emb'][()]) return task_emb