@@ -58,7 +58,7 @@ def build_loss_compute(model, tgt_field, opt, train=True):
5858 else :
5959 compute = NMTLossCompute (
6060 criterion , loss_gen , lambda_coverage = opt .lambda_coverage ,
61- lambda_align = opt .lambda_align )
61+ lambda_align = opt .lambda_align , lambda_cosine = opt . lambda_cosine )
6262 compute .to (device )
6363
6464 return compute
@@ -123,6 +123,8 @@ def __call__(self,
123123 batch ,
124124 output ,
125125 attns ,
126+ enc_src ,
127+ enc_tgt ,
126128 normalization = 1.0 ,
127129 shard_size = 0 ,
128130 trunc_start = 0 ,
@@ -157,18 +159,19 @@ def __call__(self,
157159 if trunc_size is None :
158160 trunc_size = batch .tgt .size (0 ) - trunc_start
159161 trunc_range = (trunc_start , trunc_start + trunc_size )
160- shard_state = self ._make_shard_state (batch , output , trunc_range , attns )
162+ shard_state = self ._make_shard_state (
163+ batch , output , enc_src , enc_tgt , trunc_range , attns )
161164 if shard_size == 0 :
162- loss , stats = self ._compute_loss (batch , ** shard_state )
163- return loss / float ( normalization ) , stats
165+ loss , stats = self ._compute_loss (batch , normalization , ** shard_state )
166+ return loss , stats
164167 batch_stats = onmt .utils .Statistics ()
165168 for shard in shards (shard_state , shard_size ):
166169 loss , stats = self ._compute_loss (batch , ** shard )
167- loss .div ( float ( normalization )). backward ()
170+ loss .backward ()
168171 batch_stats .update (stats )
169172 return None , batch_stats
170173
171- def _stats (self , loss , scores , target ):
174+ def _stats (self , loss , cosine_loss , scores , target , num_ex ):
172175 """
173176 Args:
174177 loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
@@ -182,7 +185,9 @@ def _stats(self, loss, scores, target):
182185 non_padding = target .ne (self .padding_idx )
183186 num_correct = pred .eq (target ).masked_select (non_padding ).sum ().item ()
184187 num_non_padding = non_padding .sum ().item ()
185- return onmt .utils .Statistics (loss .item (), num_non_padding , num_correct )
188+ return onmt .utils .Statistics (
189+ loss .item (), cosine_loss .item () if cosine_loss is not None else 0 ,
190+ num_non_padding , num_correct , num_ex )
186191
187192 def _bottle (self , _v ):
188193 return _v .view (- 1 , _v .size (2 ))
@@ -227,15 +232,18 @@ class NMTLossCompute(LossComputeBase):
227232 """
228233
229234 def __init__ (self , criterion , generator , normalization = "sents" ,
230- lambda_coverage = 0.0 , lambda_align = 0.0 ):
235+ lambda_coverage = 0.0 , lambda_align = 0.0 , lambda_cosine = 0.0 ):
231236 super (NMTLossCompute , self ).__init__ (criterion , generator )
232237 self .lambda_coverage = lambda_coverage
233238 self .lambda_align = lambda_align
239+ self .lambda_cosine = lambda_cosine
234240
235- def _make_shard_state (self , batch , output , range_ , attns = None ):
241+ def _make_shard_state (self , batch , output , enc_src , enc_tgt , range_ , attns = None ):
236242 shard_state = {
237243 "output" : output ,
238244 "target" : batch .tgt [range_ [0 ] + 1 : range_ [1 ], :, 0 ],
245+ "enc_src" : enc_src ,
246+ "enc_tgt" : enc_tgt
239247 }
240248 if self .lambda_coverage != 0.0 :
241249 coverage = attns .get ("coverage" , None )
@@ -275,7 +283,7 @@ def _make_shard_state(self, batch, output, range_, attns=None):
275283 })
276284 return shard_state
277285
278- def _compute_loss (self , batch , output , target , std_attn = None ,
286+ def _compute_loss (self , batch , normalization , output , target , enc_src , enc_tgt , std_attn = None ,
279287 coverage_attn = None , align_head = None , ref_align = None ):
280288
281289 bottled_output = self ._bottle (output )
@@ -284,6 +292,7 @@ def _compute_loss(self, batch, output, target, std_attn=None,
284292 gtruth = target .view (- 1 )
285293
286294 loss = self .criterion (scores , gtruth )
295+
287296 if self .lambda_coverage != 0.0 :
288297 coverage_loss = self ._compute_coverage_loss (
289298 std_attn = std_attn , coverage_attn = coverage_attn )
@@ -296,7 +305,28 @@ def _compute_loss(self, batch, output, target, std_attn=None,
296305 align_loss = self ._compute_alignement_loss (
297306 align_head = align_head , ref_align = ref_align )
298307 loss += align_loss
299- stats = self ._stats (loss .clone (), scores , gtruth )
308+
309+ loss = loss / float (normalization )
310+
311+ if self .lambda_cosine != 0.0 :
312+ max_src = enc_src .max (axis = 0 )[0 ]
313+ max_tgt = enc_tgt .max (axis = 0 )[0 ]
314+ cosine_loss = torch .nn .functional .cosine_similarity (
315+ max_src .float (), max_tgt .float (), dim = 1 )
316+ ones = torch .ones (cosine_loss .size ()).to (cosine_loss .device )
317+ cosine_loss = ones - cosine_loss
318+ num_ex = cosine_loss .size (0 )
319+ cosine_loss = cosine_loss .sum ()
320+ loss += self .lambda_cosine * (cosine_loss / num_ex )
321+ else :
322+ cosine_loss = None
323+ num_ex = 0
324+
325+
326+ stats = self ._stats (loss .clone () * normalization ,
327+ cosine_loss .clone () if cosine_loss is not None
328+ else cosine_loss ,
329+ scores , gtruth , num_ex )
300330
301331 return loss , stats
302332
0 commit comments