Skip to content

Commit

Permalink
add the action transformer head for multiple actions behavioral cloni…
Browse files Browse the repository at this point in the history
…ng, and cite rq transformer paper from kakao brain
  • Loading branch information
lucidrains committed Dec 9, 2024
1 parent 677294b commit 8acf0dd
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 22 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,14 @@ assert generated_video.shape == (1, 768, 16 + 1, 2, 2)
url = {https://api.semanticscholar.org/CorpusID:270870613}
}
```

```bibtex
@article{Lee2022AutoregressiveIG,
title = {Autoregressive Image Generation using Residual Quantization},
author = {Doyup Lee and Chiheon Kim and Saehoon Kim and Minsu Cho and Wook-Shin Han},
journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022},
pages = {11513-11522},
url = {https://api.semanticscholar.org/CorpusID:247244535}
}
```
101 changes: 80 additions & 21 deletions genie2_pytorch/genie2.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,22 @@ def __init__(
ff_glu = True,
use_rmsnorm = True,
),
action_transformer_kwargs: dict = dict(
add_value_residual = True,
learned_value_residual_mix = True,
ff_glu = True,
use_rmsnorm = True,
depth = 2,
heads = 4,
attn_dim_head = 64
),
vq_codebook_size = 4096,
vq_kwargs: dict = dict(),
encoder: Module = nn.Identity(),
decoder: Module = nn.Identity(),
vq_commit_loss_weight = 1.,
allow_multiple_actions = False,
max_num_actions = 10,
action_autoregressive_loss_weight = 0.1,
is_video_enc_dec = False # by default will assume image encoder / decoder, but in the future, video diffusion models with temporal compression will likely perform even better, imo
):
Expand Down Expand Up @@ -186,12 +197,34 @@ def __init__(
**transformer_kwargs
)

# behavioral cloning loss weight
# action related

has_action_loss = action_autoregressive_loss_weight > 0.
self.to_action_pred = nn.Linear(dim, num_actions, bias = False) if has_action_loss else None
self.allow_multiple_actions = allow_multiple_actions
self.max_num_actions = max_num_actions # in the case multiple actions are allowed, maximum number of actions allowed

has_action_loss = action_autoregressive_loss_weight > 0.
self.has_action_loss = has_action_loss

self.to_action_pred = None

if has_action_loss:
if allow_multiple_actions:
dim_action_transformer = dim // 2

self.action_eos_id = num_actions
self.action_pos_embed = nn.Parameter(torch.zeros(max_num_actions, dim))

self.to_action_pred = nn.Sequential(
nn.Linear(dim, dim_action_transformer, bias = False),
Decoder(
dim = dim_action_transformer,
**action_transformer_kwargs
),
nn.Linear(dim_action_transformer, num_actions + 1, bias = False)
)
else:
self.to_action_pred = nn.Linear(dim, num_actions, bias = False)

self.action_autoregressive_loss_weight = action_autoregressive_loss_weight

self.register_buffer('zero', torch.tensor(0.), persistent = False)
Expand Down Expand Up @@ -276,6 +309,11 @@ def generate(
maybe_next_actions = [*map(int, maybe_next_action.split(','))]
maybe_next_actions = [*set(maybe_next_actions)]

if not self.allow_multiple_actions:
assert len(maybe_next_actions) == 1, f'you cannot interact with multiple actions if `allow_multiple_actions` is not set to `True`'
else:
assert len(maybe_next_actions) <= self.max_num_actions, f'maximum number of actions is set at {self.max_num_actions}'

next_action = tensor(maybe_next_actions, device = self.device)
next_action = rearrange(next_action, 'a -> 1 1 a')

Expand Down Expand Up @@ -524,23 +562,6 @@ def forward(
rotary_pos_emb = time_rotary_pos
)

# maybe action prediction

if return_loss and self.has_action_loss:
is_single_action = actions.ndim == 2 or actions.shape[-1] == 1
assert is_single_action

action_time_len = tokens_seq_len // spatial_repeat_factor
round_down_by_space_len = action_time_len * spatial_repeat_factor
action_embed = reduce(embed[:, :round_down_by_space_len], 'b (t s) d -> b t d', 'mean', t = action_time_len)

action_logits = self.to_action_pred(action_embed)

if actions.ndim == 3:
actions = rearrange(actions, '... 1 -> ...')

action_labels = actions[:, 1:]

# project out

tokens = self.model_to_latent(embed)
Expand All @@ -565,11 +586,49 @@ def forward(
commit_loss * self.vq_commit_loss_weight
)

# maybe behavioral cloning
# maybe action loss

action_loss = self.zero

if self.has_action_loss:
is_single_action = actions.ndim == 2 or actions.shape[-1] == 1

if not self.allow_multiple_actions:
assert is_single_action, 'you need to set `allow_multiple_actions = True` on init to learn and decode multiple actions'

action_time_len = tokens_seq_len // spatial_repeat_factor
round_down_by_space_len = action_time_len * spatial_repeat_factor
action_embed = reduce(embed[:, :round_down_by_space_len], 'b (t s) d -> b t d', 'mean', t = action_time_len)

if is_single_action:
action_logits = self.to_action_pred(action_embed)

if actions.ndim == 3:
actions = rearrange(actions, '... 1 -> ...')

action_labels = actions[:, 1:]

else:
actions, _ = pack_one(actions, 'b n *')
inp_num_actions = actions.shape[-1]
assert inp_num_actions <= self.max_num_actions, f'maximum number of actions is set at {self.max_num_actions}'

action_embed = rearrange(action_embed, 'b t d -> (b t) 1 d')
action_pos_embed = repeat(self.action_pos_embed[:inp_num_actions], 'a d -> bt a d', bt = action_embed.shape[0])

action_embed = torch.cat((action_embed, action_pos_embed), dim = -2)

action_logits = self.to_action_pred(action_embed)

# prepare the action labels, adding the action end token appropriately

action_labels = actions[:, 1:]
action_labels = F.pad(action_labels, (0, 1), value = -1)
num_actions_per_time = (action_labels >= 0).sum(dim = -1, keepdim = True)
action_labels = action_labels.scatter(-1, num_actions_per_time, self.action_eos_id)
action_labels = rearrange(action_labels, 'b t a -> (b t) a')

# cross entropy loss for predicted action on the action transformer head (hierarchical transformer)

action_loss = F.cross_entropy(
rearrange(action_logits, 'b n l -> b l n'),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "genie2-pytorch"
version = "0.0.12"
version = "0.0.14"
description = "Genie2"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 8acf0dd

Please sign in to comment.