Skip to content

Commit

Permalink
fix bug for issue #173
Browse files Browse the repository at this point in the history
  • Loading branch information
pengzju committed Feb 12, 2024
1 parent 090c1f1 commit 06dca99
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
25 changes: 17 additions & 8 deletions easyeditor/models/ft/ft_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,16 @@ def execute_ft(
target_ids = tok(tgt, return_tensors="pt", padding=True)["input_ids"].to(
device
)
last_token_inds = inputs["attention_mask"].sum(dim=1) - 1
loss_mask = target_ids != tok.unk_token_id
inputs_targets = [txt_ + tgt_ for txt_, tgt_ in zip(txt, tgt)]
inputs_targets = tok(inputs_targets, return_tensors="pt", padding=True).to(device)
# last_token_inds = inputs["attention_mask"].sum(dim=1) - 1
# loss_mask = inputs != tok.unk_token_id
# loss_mask = [:, ]
num_prompt_toks = [int((i != tok.pad_token_id).sum()) for i in inputs['input_ids'].cpu()]
num_pad_toks = [int((i == tok.pad_token_id).sum()) for i in inputs_targets['input_ids'].cpu()]
prompt_len = [x + y for x, y in zip(num_pad_toks, num_prompt_toks)]
prompt_target_len = inputs_targets['input_ids'].size(1)
label_mask = torch.tensor([[False] * length + [True] * (prompt_target_len - length) for length in prompt_len]).to(device)
opt.zero_grad()
bs = inputs["input_ids"].shape[0]
if 't5' in hparams.model_name.lower():
Expand Down Expand Up @@ -178,12 +186,13 @@ def execute_ft(
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss.to(lm_logits.dtype)
else:
probs = torch.nn.functional.log_softmax(
model(**inputs).logits[torch.arange(bs), last_token_inds], dim=-1
)
loss = -(torch.gather(probs, 1, target_ids) * loss_mask).sum(
1
) / loss_mask.sum(1)
logits = model(**inputs_targets).logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = inputs_targets['input_ids'][..., 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction='none')
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss.view(bs, -1)
loss = (loss * label_mask[:,1:]).sum(1) / label_mask[:,1:].sum(1)
loss = loss.mean()
print(f"Batch loss {loss.item()}")
loss_meter.update(loss.item(), n=bs)
Expand Down
2 changes: 1 addition & 1 deletion hparams/FT/gpt-j-6B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ max_length: 40
lr: 5e-4
weight_decay: 0
kl_factor: 0
norm_constraint: 5e-5
norm_constraint: false
rewrite_module_tmp: "transformer.h.{}.mlp.fc_out"
layer_module_tmp: "transformer.h.{}"
mlp_module_tmp: "transformer.h.{}.mlp"
Expand Down

0 comments on commit 06dca99

Please sign in to comment.