Skip to content

Commit 3658fa3

Browse files
Jay Mahadeokarfacebook-github-bot
Jay Mahadeokar
authored andcommitted
aligned training task and CE related changes
Summary: This diff adds: 1. Aligned training task specifically for doing cross entropy criterion training using prod data and prod like models 2. Few changes to correctly register the task and criterions. 3. Changes to trainer code for propogating accuracy metrics which we care about for training. Couple of things are hacky right now: - The reporting is not modular (this needs to be thought about in general for fairseq). - The get dummy batch could be specific to task instead of specific for dataset. Reviewed By: myleott Differential Revision: D14670482 fbshipit-source-id: dc077247b2ae9d26a8e842a386ec5faa5771e836
1 parent 3a64ace commit 3658fa3

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

fairseq/trainer.py

+8
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ def train_step(self, samples, dummy_batch=False):
271271
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
272272
)
273273
self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
274+
if 'train_acc' in self.meters:
275+
self.meters['train_acc'].update(
276+
logging_output.get('acc', 0), sample_size)
277+
274278
if 'nll_loss' in logging_output:
275279
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
276280
except OverflowError as e:
@@ -340,6 +344,10 @@ def valid_step(self, sample, raise_oom=False):
340344
# update meters for validation
341345
ntokens = logging_output.get('ntokens', 0)
342346
self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
347+
if 'valid_acc' in self.meters:
348+
self.meters['valid_acc'].update(
349+
logging_output.get('acc', 0), sample_size)
350+
343351
if 'nll_loss' in logging_output:
344352
self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
345353

0 commit comments

Comments
 (0)