|
|
import bisect |
|
|
import copy |
|
|
import json |
|
|
import math |
|
|
import os |
|
|
import random |
|
|
import time |
|
|
from functools import lru_cache |
|
|
from typing import Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from PIL import Image |
|
|
from scipy.spatial.transform import Rotation as R |
|
|
from torch.utils.data import Dataset |
|
|
from tqdm import tqdm |
|
|
from transformers import PreTrainedTokenizerBase |
|
|
|
|
|
from vitra.datasets.augment_utils import ( |
|
|
augmentation_func, |
|
|
center_crop_short_side, |
|
|
project_to_image_space, |
|
|
) |
|
|
|
|
|
from vitra.datasets.interp_utils import interp_mano_state |
|
|
from vitra.datasets.video_utils import load_video_decord |
|
|
from vitra.datasets.dataset_utils import ( |
|
|
compute_new_intrinsics_crop, |
|
|
compute_new_intrinsics_resize, |
|
|
calculate_fov, |
|
|
ActionFeature, |
|
|
StateFeature, |
|
|
) |
|
|
from vitra.utils.data_utils import ( |
|
|
read_dataset_statistics, |
|
|
GaussianNormalizer, |
|
|
) |
|
|
|
|
|
class EpisodicDatasetCore(object): |
|
|
"""Core dataset class for episodic hand manipulation data. |
|
|
|
|
|
Handles loading and processing of video frames, MANO hand parameters, |
|
|
and action sequences for hand-centric manipulation tasks. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
video_root, |
|
|
annotation_file, |
|
|
label_folder, |
|
|
training_path=None, |
|
|
statistics_path=None, |
|
|
augmentation=True, |
|
|
flip_augmentation=True, |
|
|
set_none_ratio=0.0, |
|
|
action_type="angle", |
|
|
use_rel=False, |
|
|
upsample_factor=1.0, |
|
|
target_image_width=224, |
|
|
clip_len=2000, |
|
|
state_mask_prob=0.1, |
|
|
action_past_window_size=0, |
|
|
action_future_window_size=15, |
|
|
image_past_window_size=0, |
|
|
image_future_window_size=0, |
|
|
rel_mode="step", |
|
|
load_images=True, |
|
|
): |
|
|
self.video_root = video_root |
|
|
annotation_dict = np.load(annotation_file, allow_pickle=True) |
|
|
self.label_folder = label_folder |
|
|
self.index_frame_pair = annotation_dict['index_frame_pair'].copy() |
|
|
self.index_to_episode_id = annotation_dict['index_to_episode_id'].copy() |
|
|
|
|
|
if training_path is not None: |
|
|
self.training_idx = np.load(training_path, allow_pickle=True) |
|
|
self.num_valid_frames = len(self.training_idx) |
|
|
else: |
|
|
self.training_idx = None |
|
|
self.num_valid_frames = len(self.index_frame_pair) |
|
|
|
|
|
if statistics_path is not None: |
|
|
self.data_statistics = read_dataset_statistics(statistics_path) |
|
|
|
|
|
self.global_data_statistics = None |
|
|
self.clip_len = clip_len |
|
|
self.augmentation = augmentation |
|
|
self.target_image_width = target_image_width |
|
|
self.flip_augmentation = flip_augmentation |
|
|
self.set_none_ratio = set_none_ratio |
|
|
self.action_type = action_type |
|
|
self.use_rel = use_rel |
|
|
assert upsample_factor >= 1.0, "only support upsample_factor >= 1.0" |
|
|
self.upsample_factor = upsample_factor |
|
|
self.state_mask_prob = state_mask_prob |
|
|
|
|
|
self.action_past_window_size=action_past_window_size |
|
|
self.action_future_window_size=action_future_window_size |
|
|
self.image_past_window_size=image_past_window_size |
|
|
self.image_future_window_size=image_future_window_size |
|
|
self.rel_mode=rel_mode |
|
|
self.load_images=load_images |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_valid_frames |
|
|
|
|
|
@staticmethod |
|
|
@lru_cache(maxsize=256) |
|
|
def _load_episode_npy(episode_path: str): |
|
|
"""Load episode data from .npy file with caching. |
|
|
|
|
|
Uses LRU cache to keep up to 256 episodes in memory (~256 MB worst case). |
|
|
The cache automatically purges old entries when full. |
|
|
|
|
|
Args: |
|
|
episode_path: Path to the .npy file containing episode data |
|
|
|
|
|
Returns: |
|
|
Dictionary containing episode information |
|
|
""" |
|
|
return np.load(episode_path, allow_pickle=True).item() |
|
|
|
|
|
def _load_or_cache_episode(self, episode_id): |
|
|
""" |
|
|
Returns episode_result (raw dict) and the pre-extracted camera |
|
|
extrinsics (R_w2c, t_w2c). No camera-space MANO tensors are cached. |
|
|
""" |
|
|
root = self.label_folder |
|
|
epi_path = os.path.join(root, episode_id + '.npy') |
|
|
epi = self._load_episode_npy(epi_path) |
|
|
|
|
|
extr = epi['extrinsics'] |
|
|
R_w2c, t_w2c = extr[:, :3, :3], extr[:, :3, 3] |
|
|
|
|
|
return epi, R_w2c, t_w2c |
|
|
|
|
|
def _mat2euler(self, R_batch: np.ndarray) -> np.ndarray: |
|
|
"""Batched XYZ-Euler conversion using SciPy.""" |
|
|
flat = R_batch.reshape(-1, 3, 3) |
|
|
eul = R.from_matrix(flat).as_euler('xyz', degrees=False) |
|
|
return eul |
|
|
|
|
|
def _prepare_side_window(self, side_dict, |
|
|
R_w2c, t_w2c, |
|
|
idx_window, idx_anchor, |
|
|
*, anchor_frame=True, oob=None, start=None, end=None, upsample_factor=1.0): |
|
|
|
|
|
T, W = len(side_dict['global_orient_worldspace']), len(idx_window) |
|
|
idx_window_extend = np.append(idx_window, np.clip(idx_window[-1] + 1, start, end)) |
|
|
|
|
|
kept_extend = side_dict['kept_frames'][idx_window_extend].astype(bool) |
|
|
|
|
|
R_mano_extend = side_dict['global_orient_worldspace'][idx_window_extend] |
|
|
t_mano_extend = side_dict['transl_worldspace'][idx_window_extend] |
|
|
hand_P_extend = side_dict['hand_pose'][idx_window_extend] |
|
|
joints_worldspace_extend = side_dict['joints_worldspace'][idx_window_extend] |
|
|
|
|
|
oob_indices = np.where(oob)[0] |
|
|
if len(oob_indices) > 0: |
|
|
|
|
|
kept_extend[oob_indices] = False |
|
|
|
|
|
|
|
|
if idx_window[-1] + 1 > end: |
|
|
kept_extend[-1] = False |
|
|
|
|
|
if not np.all(kept_extend): |
|
|
identity = np.eye(3, dtype=hand_P_extend.dtype) |
|
|
identity_block = np.broadcast_to(identity, (hand_P_extend.shape[1], 3, 3)) |
|
|
hand_P_extend[~kept_extend] = identity_block |
|
|
R_mano_extend[~kept_extend] = identity |
|
|
|
|
|
|
|
|
R_cam_extend = R_w2c[idx_anchor] @ R_mano_extend |
|
|
t_cam_extend = (R_w2c[idx_anchor] @ t_mano_extend[..., None])[..., 0] + t_w2c[idx_anchor] |
|
|
|
|
|
|
|
|
|
|
|
pose_euler_extend = R.from_matrix(hand_P_extend.reshape(-1,3,3)) \ |
|
|
.as_euler('xyz', degrees=False).reshape(-1,45) |
|
|
|
|
|
|
|
|
joints_manospace_extend = (R_mano_extend.transpose(0, 2, 1) @ (joints_worldspace_extend.transpose(0, 2, 1) - t_mano_extend[..., None])).transpose(0,2,1) |
|
|
|
|
|
if upsample_factor > 1: |
|
|
R_cam_extend, t_cam_extend, hand_P_extend, joints_manospace_extend, kept_extend = \ |
|
|
interp_mano_state(R_cam_extend, t_cam_extend, hand_P_extend, joints_manospace_extend, kept_extend, upsample_factor, method="pchip") |
|
|
|
|
|
pose_euler_extend = R.from_matrix(hand_P_extend.reshape(-1,3,3)) \ |
|
|
.as_euler('xyz', degrees=False).reshape(-1,45) |
|
|
|
|
|
|
|
|
R_cam_extend = R_cam_extend[:W+1] |
|
|
t_cam_extend = t_cam_extend[:W+1] |
|
|
hand_P_extend = hand_P_extend[:W+1] |
|
|
pose_euler_extend = pose_euler_extend[:W+1] |
|
|
joints_manospace_extend = joints_manospace_extend[:W+1] |
|
|
kept_extend = kept_extend[:W+1] |
|
|
|
|
|
R_cam = R_cam_extend[:-1] |
|
|
t_cam = t_cam_extend[:-1] |
|
|
pose_euler = pose_euler_extend[:-1] |
|
|
hand_P = hand_P_extend[:-1] |
|
|
joints_manospace = joints_manospace_extend[:-1] |
|
|
kept = kept_extend[:-1] |
|
|
|
|
|
R_cam_next = R_cam_extend[1:] |
|
|
t_cam_next = t_cam_extend[1:] |
|
|
pose_euler_next = pose_euler_extend[1:] |
|
|
hand_P_next = hand_P_extend[1:] |
|
|
joints_manospace_next = joints_manospace_extend[1:] |
|
|
kept_next = kept_extend[1:] |
|
|
|
|
|
return dict( |
|
|
|
|
|
R_cam=R_cam.astype(np.float32), |
|
|
t_cam=t_cam.astype(np.float32), |
|
|
pose_euler=pose_euler.astype(np.float32), |
|
|
hand_P=hand_P.astype(np.float32), |
|
|
joints_manospace = joints_manospace.astype(np.float32), |
|
|
kept=kept, |
|
|
|
|
|
|
|
|
R_cam_next = R_cam_next.astype(np.float32), |
|
|
t_cam_next = t_cam_next.astype(np.float32), |
|
|
pose_euler_next = pose_euler_next.astype(np.float32), |
|
|
hand_P_next = hand_P_next.astype(np.float32), |
|
|
joints_manospace_next = joints_manospace_next.astype(np.float32), |
|
|
kept_next = kept_next, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _make_action_window_vec(self, win, anchor_idx: int, *, rel_mode="step", action_type="angle"): |
|
|
""" |
|
|
anchor_idx : the position of t0 inside the window (usually = past) |
|
|
rel_mode : "step" → Δ(t→t+1) |
|
|
"anchor" → Δ(t→t0) |
|
|
action_type: "angle" → Euler angles (xyz) |
|
|
"keypoints " → root keypoints (21x3=63) |
|
|
""" |
|
|
|
|
|
R_cur, t_cur = win['R_cam'], win['t_cam'] |
|
|
R_nxt, t_nxt = win['R_cam_next'], win['t_cam_next'] |
|
|
P_cur, P_nxt = win['hand_P'], win['hand_P_next'] |
|
|
pose_next = win['pose_euler_next'] |
|
|
kpoints_root_next = win['joints_manospace_next'] |
|
|
kept, kept_n = win['kept'], win['kept_next'] |
|
|
W = len(t_cur) |
|
|
|
|
|
|
|
|
if action_type == "keypoints": |
|
|
abs_next = kpoints_root_next.reshape(W, -1) |
|
|
elif action_type == "angle": |
|
|
abs_next = pose_next |
|
|
action_abs = np.concatenate( |
|
|
[t_nxt, |
|
|
self._mat2euler(R_nxt), |
|
|
abs_next], |
|
|
axis=-1).astype(np.float32) |
|
|
action_abs = action_abs.reshape(W, -1) |
|
|
|
|
|
|
|
|
if rel_mode == "step": |
|
|
t_rel = t_nxt - t_cur |
|
|
R_rel = R_nxt @ R_cur.transpose(0,2,1) |
|
|
P_rel = np.matmul(P_nxt, P_cur.transpose(0,1,3,2)) |
|
|
valid = kept & kept_n |
|
|
|
|
|
elif rel_mode == "anchor": |
|
|
t_anchor = t_cur[anchor_idx] |
|
|
R_anchor = R_cur[anchor_idx] |
|
|
P_anchor = P_cur[anchor_idx] |
|
|
|
|
|
|
|
|
t_rel = t_nxt - t_anchor |
|
|
R_rel = R_nxt @ R_anchor.T |
|
|
P_rel = np.matmul(P_nxt, P_anchor.transpose(0,2,1)) |
|
|
valid = kept_n & kept[anchor_idx] |
|
|
|
|
|
else: |
|
|
raise ValueError('rel_mode must be "step" or "anchor"') |
|
|
|
|
|
pose_rel = R.from_matrix(P_rel.reshape(-1,3,3)) \ |
|
|
.as_euler('xyz',False).reshape(W,45) |
|
|
|
|
|
action_rel = np.concatenate( |
|
|
[t_rel, |
|
|
self._mat2euler(R_rel), |
|
|
pose_rel], |
|
|
axis=-1).astype(np.float32) |
|
|
|
|
|
action_abs[~valid] = 0.0 |
|
|
action_rel[~valid] = 0.0 |
|
|
return action_abs, action_rel, valid |
|
|
|
|
|
def _window_indices(self, frame_id, past, future, start, end): |
|
|
""" |
|
|
Returns: |
|
|
idx_clip : (W,) indices clipped to [0, T-1] |
|
|
oob : (W,) bool — slots that were originally OOB |
|
|
""" |
|
|
win = np.arange(-past, future + 1) + frame_id |
|
|
oob = (win < start) | (win > end) |
|
|
return win.clip(start, end), oob |
|
|
|
|
|
def _resolve_video_path(self, dataset_name: str = None, video_name: str = None, part_index: int = None) -> str: |
|
|
if dataset_name=='Ego4D': |
|
|
if self.clip_len is not None: |
|
|
video_path = os.path.join(self.video_root, video_name + '_part' + str(part_index+1) +'.mp4') |
|
|
else: |
|
|
video_path = os.path.join(self.video_root, video_name +'.mp4') |
|
|
return video_path |
|
|
elif dataset_name=='EgoExo4D': |
|
|
if self.clip_len is not None: |
|
|
video_path = os.path.join(self.video_root, video_name +'_part' + str(part_index+1) +'.mp4') |
|
|
else: |
|
|
video_path = os.path.join(self.video_root, video_name +'.mp4') |
|
|
return video_path |
|
|
elif dataset_name == 'epic': |
|
|
video_id = video_name.split('_')[0] |
|
|
if self.clip_len is not None: |
|
|
video_path = os.path.join(self.video_root, video_name+ '_part' + str(part_index+1) + '.mp4') |
|
|
else: |
|
|
video_path = os.path.join(self.video_root, video_name+ '.MP4') |
|
|
return video_path |
|
|
elif dataset_name == 'somethingsomethingv2': |
|
|
if self.clip_len is not None: |
|
|
video_path = os.path.join(self.video_root, video_name+ '_part' + str(part_index+1) + '.mp4') |
|
|
else: |
|
|
video_path = os.path.join(self.video_root, video_name+'.webm') |
|
|
return video_path |
|
|
else: |
|
|
raise ValueError(f'Unknown dataset prefix {dataset_name}') |
|
|
|
|
|
def _mano_forward(self, betas, pose_m): |
|
|
"""Runs MANO once and returns (vertices, joints) on CPU NumPy.""" |
|
|
beta_t = torch.tensor(betas).unsqueeze(0).float().cuda() |
|
|
pose_t = torch.tensor(pose_m).unsqueeze(0).float().cuda() |
|
|
out = mano(betas=beta_t, hand_pose=pose_t) |
|
|
return out.vertices[0].cpu().numpy(), out.joints[0].cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pack_state(self, R_cam, t_cam, pose_euler, idx): |
|
|
return np.concatenate([t_cam[idx], |
|
|
self._mat2euler(R_cam[idx][None,...])[0], |
|
|
pose_euler[idx]]) |
|
|
|
|
|
def _grab_window_images(self, |
|
|
episode_id: str, |
|
|
epi: dict, |
|
|
frame_id: int, |
|
|
past: int, |
|
|
future: int |
|
|
): |
|
|
""" |
|
|
Returns |
|
|
------- |
|
|
images : (L, H, W, 3) uint8 – raw RGB frames |
|
|
mask : (L,) bool True where real frame, False where pad |
|
|
""" |
|
|
dataset_name = episode_id.split('_')[0] |
|
|
|
|
|
decode_table = epi['video_decode_frame'] |
|
|
T = len(decode_table) |
|
|
frame_in_video = epi['video_decode_frame'][frame_id] |
|
|
|
|
|
|
|
|
|
|
|
idx_win, oob = self._window_indices(frame_id, past, future, 0, T-1) |
|
|
|
|
|
if self.clip_len is not None: |
|
|
part_idx = frame_in_video // self.clip_len |
|
|
frame_in_part = frame_in_video % self.clip_len |
|
|
video_path = self._resolve_video_path(dataset_name, epi['video_name'], part_idx) |
|
|
decode_ids = [frame_in_part] |
|
|
else: |
|
|
video_path = self._resolve_video_path(dataset_name, epi['video_name']) |
|
|
decode_ids = decode_table[idx_win] |
|
|
|
|
|
|
|
|
|
|
|
for attempt in range(3): |
|
|
try: |
|
|
imgs, _ = load_video_decord(video_path, frame_index=decode_ids, rotation=False) |
|
|
break |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
print(f"Warning: failed to load video frames from {video_path} (attempt {attempt+1}/3): {e}") |
|
|
time.sleep(0.1) |
|
|
|
|
|
images = np.stack(imgs, axis=0) |
|
|
mask = ~oob |
|
|
|
|
|
return images, mask |
|
|
|
|
|
def _find_matching_texts(self, text_list, frame_id): |
|
|
"""Find text annotations that overlap with the given frame. |
|
|
|
|
|
Args: |
|
|
text_list: List of tuples (text, (start_frame, end_frame)) |
|
|
frame_id: Current frame ID to check |
|
|
|
|
|
Returns: |
|
|
matching_texts: List of matching text annotations |
|
|
matching_ranges: List of corresponding time ranges (start_frame, end_frame) |
|
|
|
|
|
Note: |
|
|
Uses half-open interval [start_frame, end_frame) |
|
|
""" |
|
|
matching_texts = [] |
|
|
matching_ranges = [] |
|
|
|
|
|
for text, (start_frame, end_frame) in text_list: |
|
|
|
|
|
if start_frame <= frame_id < end_frame: |
|
|
matching_texts.append(text) |
|
|
matching_ranges.append((start_frame, end_frame)) |
|
|
|
|
|
return matching_texts, matching_ranges |
|
|
|
|
|
def _random_select_text( |
|
|
self, |
|
|
text, |
|
|
text_rephrase, |
|
|
hand_type, |
|
|
clip_idx, |
|
|
): |
|
|
|
|
|
text_list = [text] |
|
|
if text_rephrase and isinstance(text_rephrase[hand_type][clip_idx][0], list): |
|
|
text_list.extend(text_rephrase[hand_type][clip_idx][0]) |
|
|
|
|
|
text_selected = random.choice(text_list).strip() |
|
|
if not text_selected.endswith('.'): |
|
|
text_selected += '.' |
|
|
return text_selected |
|
|
|
|
|
def _build_instruction( |
|
|
self, |
|
|
main_type, |
|
|
text_clip, |
|
|
text_rephrase, |
|
|
idx_win, |
|
|
oob, |
|
|
epi_len, |
|
|
frame_id, |
|
|
action_past_window_size, |
|
|
action_future_window_size, |
|
|
): |
|
|
|
|
|
sub_type = 'right' if main_type == 'left' else 'left' |
|
|
|
|
|
|
|
|
|
|
|
main_text_selected = self._random_select_text( |
|
|
text_clip[main_type][0][0], |
|
|
text_rephrase, |
|
|
main_type, |
|
|
clip_idx=0, |
|
|
) |
|
|
|
|
|
|
|
|
sub_text_list = text_clip[sub_type] |
|
|
has_sub_text = len(sub_text_list) > 0 |
|
|
|
|
|
sub_oob, sub_idx_win = oob, idx_win |
|
|
sub_text_selected = "None." |
|
|
sub_win = (0, epi_len) |
|
|
|
|
|
if has_sub_text: |
|
|
sub_matching_texts, sub_matching_ranges = self._find_matching_texts(sub_text_list, frame_id) |
|
|
if len(sub_matching_texts) > 0: |
|
|
selected_idx = random.randrange(len(sub_matching_texts)) |
|
|
sub_win = sub_matching_ranges[selected_idx] |
|
|
sub_idx_win, sub_oob = self._window_indices( |
|
|
frame_id, |
|
|
action_past_window_size, |
|
|
action_future_window_size, sub_win[0], sub_win[1]-1 |
|
|
) |
|
|
|
|
|
sub_text_selected = self._random_select_text( |
|
|
sub_matching_texts[selected_idx].strip(), |
|
|
text_rephrase, |
|
|
sub_type, |
|
|
clip_idx=selected_idx, |
|
|
) |
|
|
|
|
|
|
|
|
is_main_left = (main_type == 'left') |
|
|
|
|
|
idx_win_left = idx_win if is_main_left else sub_idx_win |
|
|
idx_win_right = sub_idx_win if is_main_left else idx_win |
|
|
oob_left = oob if is_main_left else sub_oob |
|
|
oob_right = sub_oob if is_main_left else oob |
|
|
|
|
|
text_left = main_text_selected if is_main_left else sub_text_selected |
|
|
text_right = sub_text_selected if is_main_left else main_text_selected |
|
|
|
|
|
start_left = 0 if is_main_left else (sub_win[0] if has_sub_text else 0) |
|
|
start_right = (sub_win[0] if has_sub_text else 0) if is_main_left else 0 |
|
|
end_left = epi_len - 1 if is_main_left else (sub_win[1] - 1 if has_sub_text else epi_len - 1) |
|
|
end_right = (sub_win[1] - 1 if has_sub_text else epi_len - 1) if is_main_left else epi_len - 1 |
|
|
|
|
|
instruction = f"Left hand: {text_left} Right hand: {text_right}" |
|
|
|
|
|
return instruction, idx_win_left, oob_left, idx_win_right, oob_right, start_left, end_left, start_right, end_right |
|
|
|
|
|
def _get_2d_traj_cur_to_end(self, idx_frame, epi, intrinsics, hand_type, image_size): |
|
|
"""Get the 2D trajectory of the hand palm from current frame to episode end. |
|
|
|
|
|
Args: |
|
|
idx_frame: Current frame index |
|
|
epi: Episode data dictionary |
|
|
intrinsics: Camera intrinsic matrix |
|
|
hand_type: 'left' or 'right' hand |
|
|
image_size: (H, W) tuple of image dimensions |
|
|
|
|
|
Returns: |
|
|
Normalized 2D palm trajectory in image space [0, 1] |
|
|
""" |
|
|
H, W = image_size |
|
|
|
|
|
intrinsics = intrinsics.copy() |
|
|
|
|
|
intrinsics[0] /= intrinsics[0,2]*2 |
|
|
intrinsics[1] /= intrinsics[1,2]*2 |
|
|
|
|
|
hand_joints_cur_to_end = epi[hand_type]['joints_worldspace'][idx_frame:] |
|
|
hand_palm_cur_to_end = np.mean(hand_joints_cur_to_end[:, [0,2,5,9,13,17], :], axis=1, keepdims=True) |
|
|
|
|
|
extrinsics = epi['extrinsics'].copy() |
|
|
extrinsics_cur = extrinsics[idx_frame] |
|
|
R_world_to_cam = extrinsics_cur[None, :3, :3].repeat(len(hand_palm_cur_to_end), axis=0) |
|
|
t_world_to_cam = extrinsics_cur[None, :3, 3:].repeat(len(hand_palm_cur_to_end), axis=0) |
|
|
|
|
|
hand_palm_cur_to_end_cam = (R_world_to_cam @ hand_palm_cur_to_end.transpose(0, 2, 1) + t_world_to_cam).transpose(0, 2, 1) |
|
|
|
|
|
uv_palm_cur_to_end = project_to_image_space(hand_palm_cur_to_end_cam, intrinsics, (H, W)) |
|
|
uv_palm_cur_to_end[..., 0] = np.clip(uv_palm_cur_to_end[..., 0], 0, W) |
|
|
uv_palm_cur_to_end[..., 1] = np.clip(uv_palm_cur_to_end[..., 1], 0, H) |
|
|
|
|
|
uv_palm_cur_to_end = uv_palm_cur_to_end.reshape(-1, 2) |
|
|
uv_palm_cur_to_end = uv_palm_cur_to_end.astype(np.float32) |
|
|
uv_palm_cur_to_end[:,0] /= W |
|
|
uv_palm_cur_to_end[:,1] /= H |
|
|
|
|
|
return uv_palm_cur_to_end |
|
|
|
|
|
def get_item_frame( |
|
|
self, episode_id, frame_id, |
|
|
action_past_window_size=0, |
|
|
action_future_window_size=0, |
|
|
image_past_window_size=0, |
|
|
image_future_window_size=0, |
|
|
rel_mode: str = "step", |
|
|
load_images: bool = True, |
|
|
): |
|
|
""" |
|
|
Vectorised dataloader. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
epi, R_w2c, t_w2c = self._load_or_cache_episode(episode_id) |
|
|
T = len(epi['extrinsics']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idx_win, oob = self._window_indices(frame_id, |
|
|
action_past_window_size, |
|
|
action_future_window_size, 0, T-1) |
|
|
W = len(idx_win) |
|
|
main_type = epi['anno_type'] |
|
|
sub_type = 'right' if main_type == 'left' else 'left' |
|
|
|
|
|
|
|
|
|
|
|
instruction, idx_win_left, oob_left, idx_win_right, oob_right, \ |
|
|
start_left, end_left, start_right, end_right = self._build_instruction( |
|
|
main_type = main_type, |
|
|
text_clip = epi['text'], |
|
|
text_rephrase = epi.get('text_rephrase'), |
|
|
idx_win = idx_win, |
|
|
oob = oob, |
|
|
epi_len = T, |
|
|
frame_id = frame_id, |
|
|
action_past_window_size = action_past_window_size, |
|
|
action_future_window_size = action_future_window_size, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
win_left = self._prepare_side_window( |
|
|
epi['left'], R_w2c, t_w2c, idx_win_left, frame_id, anchor_frame=True, |
|
|
oob=oob_left, start=start_left, end=end_left, upsample_factor=self.upsample_factor |
|
|
) |
|
|
win_right = self._prepare_side_window( |
|
|
epi['right'], R_w2c, t_w2c, idx_win_right, frame_id, anchor_frame=True, |
|
|
oob=oob_right, start=start_right, end=end_right, upsample_factor=self.upsample_factor |
|
|
) |
|
|
idx_center = action_past_window_size |
|
|
|
|
|
|
|
|
|
|
|
abs_L, rel_L, msk_L = self._make_action_window_vec( |
|
|
win_left, anchor_idx=idx_center, rel_mode=rel_mode, action_type=self.action_type |
|
|
) |
|
|
|
|
|
abs_R, rel_R, msk_R = self._make_action_window_vec( |
|
|
win_right, anchor_idx=idx_center, rel_mode=rel_mode, action_type=self.action_type |
|
|
) |
|
|
|
|
|
action_abs = np.concatenate([abs_L, abs_R], axis=1) |
|
|
action_rel = np.concatenate([rel_L, rel_R], axis=1) |
|
|
action_mask = np.stack([msk_L, msk_R], axis=1) |
|
|
|
|
|
cur_L = self._pack_state(win_left['R_cam'], |
|
|
win_left['t_cam'], |
|
|
win_left['pose_euler'] if self.action_type=='angle' else win_left['joints_manospace'].reshape(W, -1), |
|
|
idx_center) |
|
|
|
|
|
cur_R = self._pack_state(win_right['R_cam'], |
|
|
win_right['t_cam'], |
|
|
win_right['pose_euler'] if self.action_type=='angle' else win_right['joints_manospace'].reshape(W, -1), |
|
|
idx_center) |
|
|
|
|
|
betas_L = epi['left']['beta'] |
|
|
betas_R = epi['right']['beta'] |
|
|
|
|
|
current_state = np.concatenate([cur_L, betas_L, cur_R, betas_R]) |
|
|
|
|
|
|
|
|
current_state_mask = np.array([win_left['kept'][idx_center], win_right['kept'][idx_center]]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if load_images: |
|
|
image_list, image_mask = self._grab_window_images( |
|
|
episode_id, epi, |
|
|
frame_id, |
|
|
image_past_window_size, |
|
|
image_future_window_size |
|
|
) |
|
|
H = image_list[0].shape[0] |
|
|
W = image_list[0].shape[1] |
|
|
else: |
|
|
image_list = None |
|
|
image_mask = None |
|
|
H, W = epi['intrinsics'][1,2]*2, epi['intrinsics'][0,2]*2 |
|
|
|
|
|
|
|
|
|
|
|
dataset_name = episode_id.split('_')[0] |
|
|
intrinsics = epi['intrinsics'] |
|
|
|
|
|
if dataset_name == 'EgoExo4D': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_intrinsics = compute_new_intrinsics_crop(intrinsics, 1408, 256/448*1408, H) |
|
|
|
|
|
else: |
|
|
new_intrinsics = compute_new_intrinsics_resize(intrinsics, (H, W)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.augmentation: |
|
|
try: |
|
|
|
|
|
aspect_ratio = np.exp(random.uniform(np.log(1.0), np.log(2.0))) |
|
|
target_size = (int(self.target_image_width * aspect_ratio), self.target_image_width) |
|
|
augment_params = { |
|
|
'tgt_aspect': aspect_ratio, |
|
|
'flip_augmentation': self.flip_augmentation, |
|
|
'set_none_ratio': self.set_none_ratio, |
|
|
} |
|
|
|
|
|
uv_traj = self._get_2d_traj_cur_to_end(frame_id, epi, new_intrinsics, main_type, (H, W)) |
|
|
image_list, new_intrinsics, (action_abs, action_rel, action_mask), \ |
|
|
(current_state, current_state_mask), instruction = \ |
|
|
augmentation_func( |
|
|
image = image_list, |
|
|
intrinsics = new_intrinsics, |
|
|
actions = (action_abs, action_rel, action_mask), |
|
|
states = (current_state, current_state_mask), |
|
|
captions = instruction, |
|
|
uv_traj = uv_traj, |
|
|
target_size = target_size, |
|
|
augment_params = augment_params, |
|
|
sub_type = sub_type, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Augmentation failed for episode {episode_id}, frame {frame_id}: {e}. Do center crop only") |
|
|
import traceback |
|
|
print(f"Warning: Augmentation failed for episode {episode_id}, frame {frame_id}") |
|
|
print(f"Exception: {type(e).__name__}: {e}") |
|
|
print(f"Traceback:\n{traceback.format_exc()}") |
|
|
image_list = center_crop_short_side(image_list[0])[None, ...] |
|
|
new_intrinsics[0][2] = 0.5 * image_list[0].shape[1] |
|
|
new_intrinsics[1][2] = 0.5 * image_list[0].shape[0] |
|
|
|
|
|
if random.random() < self.state_mask_prob: |
|
|
current_state_mask = np.array([False, False]) |
|
|
current_state[:] = 0.0 |
|
|
|
|
|
fov = calculate_fov( 2 * new_intrinsics[1][2], 2 * new_intrinsics[0][2], new_intrinsics) |
|
|
|
|
|
if self.use_rel: |
|
|
action_list = action_rel |
|
|
else: |
|
|
|
|
|
rel_L = action_rel[:, :action_rel.shape[1]//2] |
|
|
rel_R = action_rel[:, action_rel.shape[1]//2:] |
|
|
abs_L = action_abs[:, :action_abs.shape[1]//2] |
|
|
abs_R = action_abs[:, action_abs.shape[1]//2:] |
|
|
|
|
|
action_list = np.concatenate([rel_L[:, :6], abs_L[:, 6:], rel_R[:, :6], abs_R[:, 6:]], axis=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result_dict = dict( |
|
|
instruction = instruction, |
|
|
action_list = action_list, |
|
|
action_mask = action_mask, |
|
|
current_state = current_state, |
|
|
current_state_mask = current_state_mask, |
|
|
fov = fov, |
|
|
intrinsics = new_intrinsics, |
|
|
) |
|
|
|
|
|
if image_list is not None: |
|
|
result_dict['image_list'] = image_list |
|
|
if image_mask is not None: |
|
|
result_dict['image_mask'] = image_mask |
|
|
|
|
|
return result_dict |
|
|
|
|
|
def set_global_data_statistics(self, global_data_statistics): |
|
|
self.global_data_statistics = global_data_statistics |
|
|
if not hasattr(self, 'gaussian_normalizer'): |
|
|
self.gaussian_normalizer = GaussianNormalizer(self.global_data_statistics) |
|
|
|
|
|
def transform_trajectory( |
|
|
self, |
|
|
sample_dict: dict = None, |
|
|
normalization: bool = True, |
|
|
): |
|
|
"""Pad action and state dimensions to match a unified size.""" |
|
|
action_np = sample_dict["action_list"] |
|
|
state_np = sample_dict["current_state"] |
|
|
|
|
|
action_dim = action_np.shape[1] |
|
|
state_dim = state_np.shape[0] |
|
|
if normalization: |
|
|
|
|
|
action_np = self.gaussian_normalizer.normalize_action(action_np) |
|
|
state_np = self.gaussian_normalizer.normalize_state(state_np) |
|
|
|
|
|
|
|
|
unified_action_dim = ActionFeature.ALL_FEATURES[1] |
|
|
unified_state_dim = StateFeature.ALL_FEATURES[1] |
|
|
unified_state, unified_state_mask = pad_state_human( |
|
|
state_np, |
|
|
sample_dict["current_state_mask"], |
|
|
action_dim, |
|
|
state_dim, |
|
|
unified_state_dim |
|
|
) |
|
|
unified_action, unified_action_mask = pad_action( |
|
|
action_np, |
|
|
sample_dict["action_mask"], |
|
|
action_dim, |
|
|
unified_action_dim |
|
|
) |
|
|
|
|
|
sample_dict["action_list"] = unified_action |
|
|
sample_dict["action_mask"] = unified_action_mask |
|
|
sample_dict["current_state"] = unified_state |
|
|
sample_dict["current_state_mask"] = unified_state_mask |
|
|
return sample_dict |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if self.training_idx is not None: |
|
|
data_id = self.training_idx[idx] |
|
|
else: |
|
|
data_id = idx |
|
|
corr = self.index_frame_pair[data_id] |
|
|
episode_id = self.index_to_episode_id[corr[0]] |
|
|
sample = self.get_item_frame( |
|
|
episode_id, int(corr[1]), |
|
|
action_past_window_size=self.action_past_window_size, |
|
|
action_future_window_size=self.action_future_window_size, |
|
|
image_past_window_size=self.image_past_window_size, |
|
|
image_future_window_size=self.image_future_window_size, |
|
|
rel_mode=self.rel_mode, |
|
|
load_images=self.load_images |
|
|
) |
|
|
return sample |
|
|
|
|
|
def pad_state_human( |
|
|
state: torch.Tensor, |
|
|
state_mask: torch.Tensor, |
|
|
action_dim: int, |
|
|
state_dim: int, |
|
|
unified_state_dim: int |
|
|
): |
|
|
""" |
|
|
Expand state mask, mask invalid state dims, and pad current_state to a standard size. |
|
|
|
|
|
Args: |
|
|
current_state (Tensor): original state tensor, shape [state_dim] |
|
|
current_state_mask (Tensor): per-hand state mask, shape [state_dim//2] or [state_dim] |
|
|
action_dim (int): original action dimension |
|
|
state_dim (int): original state dimension |
|
|
unified_state_dim (int): target padded state dimension |
|
|
|
|
|
Returns: |
|
|
Tuple[Tensor, Tensor]: |
|
|
padded current_state [unified_state_dim], |
|
|
padded current_state_mask [unified_state_dim] |
|
|
""" |
|
|
|
|
|
current_state = torch.tensor(state, dtype=torch.float32) |
|
|
current_state_mask = torch.tensor(state_mask, dtype=torch.bool) |
|
|
|
|
|
|
|
|
expanded_state_mask = current_state_mask.repeat_interleave(state_dim // 2) |
|
|
|
|
|
|
|
|
current_state_masked = current_state * expanded_state_mask.to(current_state.dtype) |
|
|
|
|
|
|
|
|
padded_state = torch.zeros(unified_state_dim, dtype=current_state.dtype) |
|
|
padded_mask = torch.zeros(unified_state_dim, dtype=torch.bool) |
|
|
|
|
|
|
|
|
padded_state[:action_dim//2] = current_state_masked[:action_dim//2].clone() |
|
|
padded_mask[:action_dim//2] = expanded_state_mask[:action_dim//2].clone() |
|
|
|
|
|
|
|
|
padded_state[action_dim//2:action_dim] = current_state_masked[state_dim//2:state_dim//2+action_dim//2].clone() |
|
|
padded_mask[action_dim//2:action_dim] = expanded_state_mask[state_dim//2:state_dim//2+action_dim//2].clone() |
|
|
|
|
|
return padded_state, padded_mask |
|
|
|
|
|
def pad_action( |
|
|
actions: torch.Tensor = None, |
|
|
action_mask: torch.Tensor = None, |
|
|
action_dim: int = None, |
|
|
unified_action_dim: int = None |
|
|
): |
|
|
""" |
|
|
Expand action mask per dimension, mask invalid actions, and pad actions to a unified size. |
|
|
|
|
|
Args: |
|
|
actions (Tensor or None): original actions tensor, shape [T, action_dim] or None. |
|
|
action_mask (Tensor): per-hand action mask, shape [T, 2]. |
|
|
action_dim (int): original action dimension. |
|
|
unified_action_dim (int): target padded actions dimension. |
|
|
|
|
|
Returns: |
|
|
Tuple[Optional[Tensor], Tensor]: |
|
|
padded actions [T, unified_action_dim] or None, |
|
|
padded action mask [T, unified_action_dim] |
|
|
""" |
|
|
|
|
|
action_mask = torch.tensor(action_mask, dtype=torch.bool) |
|
|
|
|
|
|
|
|
mask_left = action_mask[:, 0].unsqueeze(1).expand(-1, action_dim // 2) |
|
|
mask_right = action_mask[:, 1].unsqueeze(1).expand(-1, action_dim // 2) |
|
|
expanded_action_mask = torch.cat((mask_left, mask_right), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if actions is None: |
|
|
padding_mask = torch.zeros( |
|
|
(action_mask.shape[0], unified_action_dim - action_dim), |
|
|
dtype=torch.bool |
|
|
) |
|
|
action_mask_padded = torch.cat((expanded_action_mask, padding_mask), dim=1) |
|
|
return None, action_mask_padded |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actions = torch.tensor(actions, dtype=torch.float32) |
|
|
|
|
|
actions_masked = actions * expanded_action_mask.to(actions.dtype) |
|
|
|
|
|
|
|
|
padding = torch.zeros( |
|
|
(actions.shape[0], unified_action_dim - action_dim), |
|
|
dtype=actions.dtype |
|
|
) |
|
|
|
|
|
actions_padded = torch.cat((actions_masked, padding), dim=1) |
|
|
action_mask_padded = torch.cat((expanded_action_mask, padding.bool()), dim=1) |
|
|
|
|
|
return actions_padded, action_mask_padded |