@@ -167,7 +167,7 @@ def __call__(self,
167167 return loss , stats
168168 batch_stats = onmt .utils .Statistics ()
169169 for shard in shards (shard_state , shard_size ):
170- loss , stats = self ._compute_loss (batch , ** shard )
170+ loss , stats = self ._compute_loss (batch , normalization , ** shard )
171171 loss .backward ()
172172 batch_stats .update (stats )
173173 return None , batch_stats
@@ -243,9 +243,7 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt,
243243 range_ , attns = None ):
244244 shard_state = {
245245 "output" : output ,
246- "target" : batch .tgt [range_ [0 ] + 1 : range_ [1 ], :, 0 ],
247- "enc_src" : enc_src ,
248- "enc_tgt" : enc_tgt
246+ "target" : batch .tgt [range_ [0 ] + 1 : range_ [1 ], :, 0 ]
249247 }
250248 if self .lambda_coverage != 0.0 :
251249 coverage = attns .get ("coverage" , None )
@@ -283,10 +281,15 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt,
283281 "align_head" : attn_align ,
284282 "ref_align" : ref_align [:, range_ [0 ] + 1 : range_ [1 ], :]
285283 })
284+ if self .lambda_cosine != 0.0 :
285+ shard_state .update ({
286+ "enc_src" : enc_src ,
287+ "enc_tgt" : enc_tgt
288+ })
286289 return shard_state
287290
288291 def _compute_loss (self , batch , normalization , output , target ,
289- enc_src , enc_tgt , std_attn = None ,
292+ enc_src = None , enc_tgt = None , std_attn = None ,
290293 coverage_attn = None , align_head = None , ref_align = None ):
291294
292295 bottled_output = self ._bottle (output )
@@ -400,7 +403,7 @@ def shards(state, shard_size, eval_only=False):
400403 # over the shards, not over the keys: therefore, the values need
401404 # to be re-zipped by shard and then each shard can be paired
402405 # with the keys.
403- for shard_tensors in zip (* values ):
406+ for i , shard_tensors in enumerate ( zip (* values ) ):
404407 yield dict (zip (keys , shard_tensors ))
405408
406409 # Assumed backprop'd
0 commit comments