@@ -162,7 +162,8 @@ def __call__(self,
162162 shard_state = self ._make_shard_state (
163163 batch , output , enc_src , enc_tgt , trunc_range , attns )
164164 if shard_size == 0 :
165- loss , stats = self ._compute_loss (batch , normalization , ** shard_state )
165+ loss , stats = self ._compute_loss (batch , normalization ,
166+ ** shard_state )
166167 return loss , stats
167168 batch_stats = onmt .utils .Statistics ()
168169 for shard in shards (shard_state , shard_size ):
@@ -238,7 +239,8 @@ def __init__(self, criterion, generator, normalization="sents",
238239 self .lambda_align = lambda_align
239240 self .lambda_cosine = lambda_cosine
240241
241- def _make_shard_state (self , batch , output , enc_src , enc_tgt , range_ , attns = None ):
242+ def _make_shard_state (self , batch , output , enc_src , enc_tgt ,
243+ range_ , attns = None ):
242244 shard_state = {
243245 "output" : output ,
244246 "target" : batch .tgt [range_ [0 ] + 1 : range_ [1 ], :, 0 ],
@@ -283,7 +285,8 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns=None)
283285 })
284286 return shard_state
285287
286- def _compute_loss (self , batch , normalization , output , target , enc_src , enc_tgt , std_attn = None ,
288+ def _compute_loss (self , batch , normalization , output , target ,
289+ enc_src , enc_tgt , std_attn = None ,
287290 coverage_attn = None , align_head = None , ref_align = None ):
288291
289292 bottled_output = self ._bottle (output )
@@ -312,7 +315,7 @@ def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt,
312315 max_src = enc_src .max (axis = 0 )[0 ]
313316 max_tgt = enc_tgt .max (axis = 0 )[0 ]
314317 cosine_loss = torch .nn .functional .cosine_similarity (
315- max_src .float (), max_tgt .float (), dim = 1 )
318+ max_src .float (), max_tgt .float (), dim = 1 )
316319 ones = torch .ones (cosine_loss .size ()).to (cosine_loss .device )
317320 cosine_loss = ones - cosine_loss
318321 num_ex = cosine_loss .size (0 )
@@ -322,7 +325,6 @@ def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt,
322325 cosine_loss = None
323326 num_ex = 0
324327
325-
326328 stats = self ._stats (loss .clone () * normalization ,
327329 cosine_loss .clone () if cosine_loss is not None
328330 else cosine_loss ,
0 commit comments