Skip to content

Commit f35b34e

Browse files
add lambda_cosine, move normalization to compute_loss, adapt stats
1 parent f44d1d2 commit f35b34e

5 files changed

Lines changed: 88 additions & 22 deletions

File tree

onmt/models/model.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" Onmt NMT Model base class definition """
22
import torch.nn as nn
3+
import torch
34

45

56
class NMTModel(nn.Module):
@@ -17,7 +18,8 @@ def __init__(self, encoder, decoder):
1718
self.encoder = encoder
1819
self.decoder = decoder
1920

20-
def forward(self, src, tgt, lengths, bptt=False, with_align=False):
21+
def forward(self, src, tgt, lengths, bptt=False,
22+
with_align=False, encode_tgt=False):
2123
"""Forward propagate a `src` and `tgt` pair for training.
2224
Possible initialized with a beginning decoder state.
2325
@@ -44,12 +46,22 @@ def forward(self, src, tgt, lengths, bptt=False, with_align=False):
4446

4547
enc_state, memory_bank, lengths = self.encoder(src, lengths)
4648

49+
if encode_tgt:
50+
# tgt for zero shot alignment loss
51+
tgt_lengths = torch.Tensor(tgt.size(1))\
52+
.type_as(memory_bank) \
53+
.long() \
54+
.fill_(tgt.size(0))
55+
embs_tgt, memory_bank_tgt, ltgt = self.encoder(tgt, tgt_lengths)
56+
else:
57+
memory_bank_tgt = None
58+
4759
if bptt is False:
4860
self.decoder.init_state(src, memory_bank, enc_state)
4961
dec_out, attns = self.decoder(dec_in, memory_bank,
5062
memory_lengths=lengths,
5163
with_align=with_align)
52-
return dec_out, attns
64+
return dec_out, attns, memory_bank, memory_bank_tgt
5365

5466
def update_dropout(self, dropout):
5567
self.encoder.update_dropout(dropout)

onmt/opts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def model_opts(parser):
193193
help='Train a coverage attention layer.')
194194
group.add('--lambda_coverage', '-lambda_coverage', type=float, default=0.0,
195195
help='Lambda value for coverage loss of See et al (2017)')
196+
group.add('--lambda_cosine', '-lambda_cosine', type=float, default=0.0,
197+
help='Lambda value for cosine alignment loss #TODO cite')
196198
group.add('--loss_scale', '-loss_scale', type=float, default=0,
197199
help="For FP16 training, the static loss scale to use. If not "
198200
"set, the loss scale is dynamically computed.")

onmt/trainer.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
7070
model_dtype=opt.model_dtype,
7171
earlystopper=earlystopper,
7272
dropout=dropout,
73-
dropout_steps=dropout_steps)
73+
dropout_steps=dropout_steps,
74+
encode_tgt=True if opt.lambda_cosine > 0 else False)
7475
return trainer
7576

7677

@@ -107,7 +108,8 @@ def __init__(self, model, train_loss, valid_loss, optim,
107108
n_gpu=1, gpu_rank=1, gpu_verbose_level=0,
108109
report_manager=None, with_align=False, model_saver=None,
109110
average_decay=0, average_every=1, model_dtype='fp32',
110-
earlystopper=None, dropout=[0.3], dropout_steps=[0]):
111+
earlystopper=None, dropout=[0.3], dropout_steps=[0],
112+
encode_tgt=False):
111113
# Basic attributes.
112114
self.model = model
113115
self.train_loss = train_loss
@@ -132,6 +134,7 @@ def __init__(self, model, train_loss, valid_loss, optim,
132134
self.earlystopper = earlystopper
133135
self.dropout = dropout
134136
self.dropout_steps = dropout_steps
137+
self.encode_tgt = encode_tgt
135138

136139
for i in range(len(self.accum_count_l)):
137140
assert self.accum_count_l[i] > 0
@@ -314,11 +317,13 @@ def validate(self, valid_iter, moving_average=None):
314317
tgt = batch.tgt
315318

316319
# F-prop through the model.
317-
outputs, attns = valid_model(src, tgt, src_lengths,
318-
with_align=self.with_align)
320+
outputs, attns, enc_src, enc_tgt = valid_model(
321+
src, tgt, src_lengths,
322+
with_align=self.with_align)
319323

320324
# Compute loss.
321-
_, batch_stats = self.valid_loss(batch, outputs, attns)
325+
_, batch_stats = self.valid_loss(
326+
batch, outputs, attns, enc_src, enc_tgt)
322327

323328
# Update statistics.
324329
stats.update(batch_stats)
@@ -361,8 +366,9 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
361366
if self.accum_count == 1:
362367
self.optim.zero_grad()
363368

364-
outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt,
365-
with_align=self.with_align)
369+
outputs, attns, enc_src, enc_tgt = self.model(
370+
src, tgt, src_lengths, bptt=bptt,
371+
with_align=self.with_align, encode_tgt=self.encode_tgt)
366372
bptt = True
367373

368374
# 3. Compute loss.
@@ -371,6 +377,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
371377
batch,
372378
outputs,
373379
attns,
380+
enc_src,
381+
enc_tgt,
374382
normalization=normalization,
375383
shard_size=self.shard_size,
376384
trunc_start=j,

onmt/utils/loss.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def build_loss_compute(model, tgt_field, opt, train=True):
5858
else:
5959
compute = NMTLossCompute(
6060
criterion, loss_gen, lambda_coverage=opt.lambda_coverage,
61-
lambda_align=opt.lambda_align)
61+
lambda_align=opt.lambda_align, lambda_cosine=opt.lambda_cosine)
6262
compute.to(device)
6363

6464
return compute
@@ -123,6 +123,8 @@ def __call__(self,
123123
batch,
124124
output,
125125
attns,
126+
enc_src,
127+
enc_tgt,
126128
normalization=1.0,
127129
shard_size=0,
128130
trunc_start=0,
@@ -157,18 +159,19 @@ def __call__(self,
157159
if trunc_size is None:
158160
trunc_size = batch.tgt.size(0) - trunc_start
159161
trunc_range = (trunc_start, trunc_start + trunc_size)
160-
shard_state = self._make_shard_state(batch, output, trunc_range, attns)
162+
shard_state = self._make_shard_state(
163+
batch, output, enc_src, enc_tgt, trunc_range, attns)
161164
if shard_size == 0:
162-
loss, stats = self._compute_loss(batch, **shard_state)
163-
return loss / float(normalization), stats
165+
loss, stats = self._compute_loss(batch, normalization, **shard_state)
166+
return loss, stats
164167
batch_stats = onmt.utils.Statistics()
165168
for shard in shards(shard_state, shard_size):
166169
loss, stats = self._compute_loss(batch, **shard)
167-
loss.div(float(normalization)).backward()
170+
loss.backward()
168171
batch_stats.update(stats)
169172
return None, batch_stats
170173

171-
def _stats(self, loss, scores, target):
174+
def _stats(self, loss, cosine_loss, scores, target, num_ex):
172175
"""
173176
Args:
174177
loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
@@ -182,7 +185,9 @@ def _stats(self, loss, scores, target):
182185
non_padding = target.ne(self.padding_idx)
183186
num_correct = pred.eq(target).masked_select(non_padding).sum().item()
184187
num_non_padding = non_padding.sum().item()
185-
return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct)
188+
return onmt.utils.Statistics(
189+
loss.item(), cosine_loss.item() if cosine_loss is not None else 0,
190+
num_non_padding, num_correct, num_ex)
186191

187192
def _bottle(self, _v):
188193
return _v.view(-1, _v.size(2))
@@ -227,15 +232,18 @@ class NMTLossCompute(LossComputeBase):
227232
"""
228233

229234
def __init__(self, criterion, generator, normalization="sents",
230-
lambda_coverage=0.0, lambda_align=0.0):
235+
lambda_coverage=0.0, lambda_align=0.0, lambda_cosine=0.0):
231236
super(NMTLossCompute, self).__init__(criterion, generator)
232237
self.lambda_coverage = lambda_coverage
233238
self.lambda_align = lambda_align
239+
self.lambda_cosine = lambda_cosine
234240

235-
def _make_shard_state(self, batch, output, range_, attns=None):
241+
def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns=None):
236242
shard_state = {
237243
"output": output,
238244
"target": batch.tgt[range_[0] + 1: range_[1], :, 0],
245+
"enc_src": enc_src,
246+
"enc_tgt": enc_tgt
239247
}
240248
if self.lambda_coverage != 0.0:
241249
coverage = attns.get("coverage", None)
@@ -275,7 +283,7 @@ def _make_shard_state(self, batch, output, range_, attns=None):
275283
})
276284
return shard_state
277285

278-
def _compute_loss(self, batch, output, target, std_attn=None,
286+
def _compute_loss(self, batch, normalization, output, target, enc_src, enc_tgt, std_attn=None,
279287
coverage_attn=None, align_head=None, ref_align=None):
280288

281289
bottled_output = self._bottle(output)
@@ -284,6 +292,7 @@ def _compute_loss(self, batch, output, target, std_attn=None,
284292
gtruth = target.view(-1)
285293

286294
loss = self.criterion(scores, gtruth)
295+
287296
if self.lambda_coverage != 0.0:
288297
coverage_loss = self._compute_coverage_loss(
289298
std_attn=std_attn, coverage_attn=coverage_attn)
@@ -296,7 +305,28 @@ def _compute_loss(self, batch, output, target, std_attn=None,
296305
align_loss = self._compute_alignement_loss(
297306
align_head=align_head, ref_align=ref_align)
298307
loss += align_loss
299-
stats = self._stats(loss.clone(), scores, gtruth)
308+
309+
loss = loss/float(normalization)
310+
311+
if self.lambda_cosine != 0.0:
312+
max_src = enc_src.max(axis=0)[0]
313+
max_tgt = enc_tgt.max(axis=0)[0]
314+
cosine_loss = torch.nn.functional.cosine_similarity(
315+
max_src.float(), max_tgt.float(), dim=1)
316+
ones = torch.ones(cosine_loss.size()).to(cosine_loss.device)
317+
cosine_loss = ones - cosine_loss
318+
num_ex = cosine_loss.size(0)
319+
cosine_loss = cosine_loss.sum()
320+
loss += self.lambda_cosine * (cosine_loss / num_ex)
321+
else:
322+
cosine_loss = None
323+
num_ex = 0
324+
325+
326+
stats = self._stats(loss.clone() * normalization,
327+
cosine_loss.clone() if cosine_loss is not None
328+
else cosine_loss,
329+
scores, gtruth, num_ex)
300330

301331
return loss, stats
302332

onmt/utils/statistics.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ class Statistics(object):
1717
* elapsed time
1818
"""
1919

20-
def __init__(self, loss=0, n_words=0, n_correct=0):
20+
def __init__(self, loss=0, cosine_loss=0, n_words=0, n_correct=0, num_ex=0):
2121
self.loss = loss
2222
self.n_words = n_words
2323
self.n_correct = n_correct
2424
self.n_src_words = 0
2525
self.start_time = time.time()
26+
self.cosine_loss = cosine_loss
27+
self.num_ex = num_ex
2628

2729
@staticmethod
2830
def all_gather_stats(stat, max_size=4096):
@@ -81,6 +83,10 @@ def update(self, stat, update_n_src_words=False):
8183
self.loss += stat.loss
8284
self.n_words += stat.n_words
8385
self.n_correct += stat.n_correct
86+
# print("LOSS update", stat.loss)
87+
# print("ZS_LOSS update", stat.zs_loss)
88+
self.cosine_loss += stat.cosine_loss
89+
self.num_ex += stat.num_ex
8490

8591
if update_n_src_words:
8692
self.n_src_words += stat.n_src_words
@@ -97,6 +103,10 @@ def ppl(self):
97103
""" compute perplexity """
98104
return math.exp(min(self.loss / self.n_words, 100))
99105

106+
def cos(self):
107+
# print("ZS LOSS", self.zs_loss)
108+
return self.cosine_loss / self.num_ex
109+
100110
def elapsed_time(self):
101111
""" compute elapsed time """
102112
return time.time() - self.start_time
@@ -113,8 +123,12 @@ def output(self, step, num_steps, learning_rate, start):
113123
step_fmt = "%2d" % step
114124
if num_steps > 0:
115125
step_fmt = "%s/%5d" % (step_fmt, num_steps)
126+
if self.cosine_loss != 0:
127+
cos_log = "cos: %4.2f; " % (self.cos())
128+
else:
129+
cos_log = ""
116130
logger.info(
117-
("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
131+
("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + cos_log +
118132
"lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec")
119133
% (step_fmt,
120134
self.accuracy(),

0 commit comments

Comments
 (0)