Skip to content

Commit 1a6e737

Browse files
disable sharded loss if lambda_cosine
1 parent 84e472f commit 1a6e737

2 files changed

Lines changed: 13 additions & 6 deletions

File tree

onmt/utils/loss.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __call__(self,
167167
return loss, stats
168168
batch_stats = onmt.utils.Statistics()
169169
for shard in shards(shard_state, shard_size):
170-
loss, stats = self._compute_loss(batch, **shard)
170+
loss, stats = self._compute_loss(batch, normalization, **shard)
171171
loss.backward()
172172
batch_stats.update(stats)
173173
return None, batch_stats
@@ -243,9 +243,7 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt,
243243
range_, attns=None):
244244
shard_state = {
245245
"output": output,
246-
"target": batch.tgt[range_[0] + 1: range_[1], :, 0],
247-
"enc_src": enc_src,
248-
"enc_tgt": enc_tgt
246+
"target": batch.tgt[range_[0] + 1: range_[1], :, 0]
249247
}
250248
if self.lambda_coverage != 0.0:
251249
coverage = attns.get("coverage", None)
@@ -283,10 +281,15 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt,
283281
"align_head": attn_align,
284282
"ref_align": ref_align[:, range_[0] + 1: range_[1], :]
285283
})
284+
if self.lambda_cosine != 0.0:
285+
shard_state.update({
286+
"enc_src": enc_src,
287+
"enc_tgt": enc_tgt
288+
})
286289
return shard_state
287290

288291
def _compute_loss(self, batch, normalization, output, target,
289-
enc_src, enc_tgt, std_attn=None,
292+
enc_src=None, enc_tgt=None, std_attn=None,
290293
coverage_attn=None, align_head=None, ref_align=None):
291294

292295
bottled_output = self._bottle(output)
@@ -400,7 +403,7 @@ def shards(state, shard_size, eval_only=False):
400403
# over the shards, not over the keys: therefore, the values need
401404
# to be re-zipped by shard and then each shard can be paired
402405
# with the keys.
403-
for shard_tensors in zip(*values):
406+
for i, shard_tensors in enumerate(zip(*values)):
404407
yield dict(zip(keys, shard_tensors))
405408

406409
# Assumed backprop'd

onmt/utils/parse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def validate_train_opts(cls, opt):
120120
assert len(opt.attention_dropout) == len(opt.dropout_steps), \
121121
"Number of attention_dropout values must match accum_steps values"
122122

123+
assert not(opt.max_generator_batches > 0 and opt.lambda_cosine != 0), \
124+
"-lambda_cosine loss is not implemented for max_generator_batches > 0."
125+
126+
123127
@classmethod
124128
def validate_translate_opts(cls, opt):
125129
if opt.beam_size != 1 and opt.random_sampling_topk != 1:

0 commit comments

Comments
 (0)