@@ -186,22 +186,23 @@ def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
186186 self .tgt_vocab = tgt_vocab
187187 self .normalize_by_length = normalize_by_length
188188
189- def _make_shard_state (self , batch , output , range_ , attns ):
189+ def _make_shard_state (self , batch , output , enc_src , enc_tgt , range_ , attns ):
190190 """See base class for args description."""
191191 if getattr (batch , "alignment" , None ) is None :
192192 raise AssertionError ("using -copy_attn you need to pass in "
193193 "-dynamic_dict during preprocess stage." )
194194
195195 shard_state = super (CopyGeneratorLossCompute , self )._make_shard_state (
196- batch , output , range_ , attns )
196+ batch , output , enc_src , enc_tgt , range_ , attns )
197197
198198 shard_state .update ({
199199 "copy_attn" : attns .get ("copy" ),
200200 "align" : batch .alignment [range_ [0 ] + 1 : range_ [1 ]]
201201 })
202202 return shard_state
203203
204- def _compute_loss (self , batch , output , target , copy_attn , align ,
204+ def _compute_loss (self , batch , normalization , output , target ,
205+ copy_attn , align , enc_src = None , enc_tgt = None ,
205206 std_attn = None , coverage_attn = None ):
206207 """Compute the loss.
207208
@@ -244,8 +245,18 @@ def _compute_loss(self, batch, output, target, copy_attn, align,
244245 offset_align = align [correct_mask ] + len (self .tgt_vocab )
245246 target_data [correct_mask ] += offset_align
246247
248+ if self .lambda_cosine != 0.0 :
249+ cosine_loss , num_ex = self ._compute_cosine_loss (enc_src , enc_tgt )
250+ loss += self .lambda_cosine * (cosine_loss / num_ex )
251+ else :
252+ cosine_loss = None
253+ num_ex = 0
254+
247255 # Compute sum of perplexities for stats
248- stats = self ._stats (loss .sum ().clone (), scores_data , target_data )
256+ stats = self ._stats (loss .sum ().clone (),
257+ cosine_loss .clone () if cosine_loss is not None
258+ else cosine_loss ,
259+ scores_data , target_data , num_ex )
249260
250261 # this part looks like it belongs in CopyGeneratorLoss
251262 if self .normalize_by_length :
0 commit comments