Skip to content

Commit 845c989

Browse files
fix some flake
1 parent f35b34e commit 845c989

2 files changed

Lines changed: 9 additions & 6 deletions

File tree

onmt/utils/loss.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def __call__(self,
162162
shard_state = self._make_shard_state(
163163
batch, output, enc_src, enc_tgt, trunc_range, attns)
164164
if shard_size == 0:
165-
loss, stats = self._compute_loss(batch, normalization, **shard_state)
165+
loss, stats = self._compute_loss(batch, normalization,
166+
**shard_state)
166167
return loss, stats
167168
batch_stats = onmt.utils.Statistics()
168169
for shard in shards(shard_state, shard_size):
@@ -238,7 +239,8 @@ def __init__(self, criterion, generator, normalization="sents",
238239
self.lambda_align = lambda_align
239240
self.lambda_cosine = lambda_cosine
240241

241-
def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns=None):
242+
def _make_shard_state(self, batch, output, enc_src, enc_tgt,
243+
range_, attns=None):
242244
shard_state = {
243245
"output": output,
244246
"target": batch.tgt[range_[0] + 1: range_[1], :, 0],
@@ -283,7 +285,8 @@ def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns=None)
283285
})
284286
return shard_state
285287

286-
def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt, std_attn=None,
288+
def _compute_loss(self, batch, normalization, output, target,
289+
enc_src, enc_tgt, std_attn=None,
287290
coverage_attn=None, align_head=None, ref_align=None):
288291

289292
bottled_output = self._bottle(output)
@@ -312,7 +315,7 @@ def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt,
312315
max_src = enc_src.max(axis=0)[0]
313316
max_tgt = enc_tgt.max(axis=0)[0]
314317
cosine_loss = torch.nn.functional.cosine_similarity(
315-
max_src.float(), max_tgt.float(), dim=1)
318+
max_src.float(), max_tgt.float(), dim=1)
316319
ones = torch.ones(cosine_loss.size()).to(cosine_loss.device)
317320
cosine_loss = ones - cosine_loss
318321
num_ex = cosine_loss.size(0)
@@ -322,7 +325,6 @@ def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt,
322325
cosine_loss = None
323326
num_ex = 0
324327

325-
326328
stats = self._stats(loss.clone() * normalization,
327329
cosine_loss.clone() if cosine_loss is not None
328330
else cosine_loss,

onmt/utils/statistics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class Statistics(object):
1717
* elapsed time
1818
"""
1919

20-
def __init__(self, loss=0, cosine_loss=0, n_words=0, n_correct=0, num_ex=0):
20+
def __init__(self, loss=0, cosine_loss=0, n_words=0,
21+
n_correct=0, num_ex=0):
2122
self.loss = loss
2223
self.n_words = n_words
2324
self.n_correct = n_correct

0 commit comments

Comments
 (0)