class ChunkedPolicy(PreTrainedPolicy, nn.Module):
def __init__(self, config):
super().__init__(config)
self.chunk_size = 16 # Predict 16 future actions
self.action_dim = config.output_features['action'].shape[0]
self.network = nn.Sequential(
nn.Linear(obs_dim, 512),
nn.ReLU(),
nn.Linear(512, self.action_dim * self.chunk_size)
)
def forward(self, batch):
# Predict action chunk
obs = batch['observation.state']
action_chunk = self.network(obs)
action_chunk = action_chunk.reshape(-1, self.chunk_size, self.action_dim)
# Supervise with ground truth chunk
true_actions = batch['action'] # Shape: [batch, chunk_size, action_dim]
loss = nn.functional.mse_loss(action_chunk, true_actions)
return loss, {}
def select_action(self, obs):
"""Returns next action from chunk."""
with torch.no_grad():
action_chunk = self.network(obs['observation.state'])
action_chunk = action_chunk.reshape(-1, self.chunk_size, self.action_dim)
# Return first action in chunk
return action_chunk[:, 0, :]