|
|
""" |
|
|
materialize.py |
|
|
|
|
|
Dataset and collator materialization for VLA training. |
|
|
""" |
|
|
|
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Tuple |
|
|
|
|
|
from vitra.datasets.dataset import MultipleWeightedDataset, MultipleDatasetWeightedDistributedBatchSampler |
|
|
from vitra.utils.data_utils import PaddedCollatorForHandPrediction |
|
|
|
|
|
def get_vla_dataset_and_collator( |
|
|
data_root_dir: Path, |
|
|
data_mix: str, |
|
|
padding_side: str = "right", |
|
|
augmentation: bool = False, |
|
|
shard_num: int = 1, |
|
|
shard_index: int = 0, |
|
|
seed: int = 42, |
|
|
future_action_window_size: int = 0, |
|
|
batch_size: int = 32, |
|
|
processor=None, |
|
|
normalization: bool = True, |
|
|
flip_augmentation: float = 1.0, |
|
|
set_none_ratio: float = 0.0, |
|
|
statistics_type: str = 'ori', |
|
|
action_type: str = "angle", |
|
|
use_rel: bool = False, |
|
|
rel_mode: str = "step", |
|
|
clip_len: int = None, |
|
|
state_mask_prob: float = 0.1, |
|
|
) -> Tuple[MultipleWeightedDataset, PaddedCollatorForHandPrediction, MultipleDatasetWeightedDistributedBatchSampler]: |
|
|
""" |
|
|
Create VLA dataset, batch sampler, and collator for training. |
|
|
|
|
|
Args: |
|
|
data_root_dir: Root directory containing datasets |
|
|
data_mix: Comma-separated list of dataset names to mix |
|
|
padding_side: Token padding side ("right" or "left") |
|
|
augmentation: Whether to apply data augmentation |
|
|
shard_num: Number of distributed shards |
|
|
shard_index: Index of current shard |
|
|
seed: Random seed for sampling |
|
|
future_action_window_size: Number of future actions to predict |
|
|
batch_size: Batch size per device |
|
|
processor: Text processor with tokenizer |
|
|
normalization: Whether to normalize actions |
|
|
flip_augmentation: Probability of flip augmentation |
|
|
set_none_ratio: Ratio of setting actions to None |
|
|
statistics_type: Type of statistics to use |
|
|
action_type: Type of action representation |
|
|
use_rel: Whether to use relative actions |
|
|
rel_mode: Relative action mode |
|
|
|
|
|
Returns: |
|
|
Tuple of (dataset, collator, batch_sampler) |
|
|
""" |
|
|
assert 0 <= shard_index < shard_num, "Shard index must be in [0, shard_num)." |
|
|
|
|
|
multi_dataset = MultipleWeightedDataset.load_datasets( |
|
|
data_root_dir, |
|
|
data_mix, |
|
|
action_past_window_size=0, |
|
|
action_future_window_size=future_action_window_size, |
|
|
image_past_window_size=0, |
|
|
augmentation=augmentation, |
|
|
normalization = normalization, |
|
|
processor = processor, |
|
|
flip_augmentation = flip_augmentation, |
|
|
set_none_ratio = set_none_ratio, |
|
|
action_type = action_type, |
|
|
use_rel = use_rel, |
|
|
rel_mode = rel_mode, |
|
|
clip_len=clip_len, |
|
|
state_mask_prob=state_mask_prob |
|
|
) |
|
|
|
|
|
batch_sampler = MultipleDatasetWeightedDistributedBatchSampler( |
|
|
multi_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
drop_last=True, |
|
|
seed=seed, |
|
|
num_replicas=shard_num, |
|
|
rank=shard_index, |
|
|
) |
|
|
|
|
|
|
|
|
collator = PaddedCollatorForHandPrediction( |
|
|
processor.tokenizer.model_max_length, |
|
|
processor.tokenizer.pad_token_id, |
|
|
padding_side=padding_side |
|
|
) |
|
|
|
|
|
return multi_dataset, collator, batch_sampler |