Skip to content

Commit aaa220b

Browse files
francoishernandezvince62s
authored andcommitted
New apex amp API (#1465)
* use new apex amp API * make apex opt_level as option
1 parent e156cce commit aaa220b

5 files changed

Lines changed: 18 additions & 25 deletions

File tree

onmt/model_builder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ def fix_key(s):
214214

215215
model.generator = generator
216216
model.to(device)
217-
if model_opt.model_dtype == 'fp16':
218-
model.half()
219217

220218
return model
221219

onmt/opts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def model_opts(parser):
179179
group.add('--loss_scale', '-loss_scale', type=float, default=0,
180180
help="For FP16 training, the static loss scale to use. If not "
181181
"set, the loss scale is dynamically computed.")
182+
group.add('--apex_opt_level', '-apex_opt_level', type=str, default="O2",
183+
choices=["O0", "O1", "O2", "O3"],
184+
help="For FP16 training, the opt_level to use."
185+
"See https://nvidia.github.io/apex/amp.html#opt-levels.")
182186

183187

184188
def preprocess_opts(parser):

onmt/trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,7 @@ def validate(self, valid_iter, moving_average=None):
294294
valid_model = deepcopy(self.model)
295295
for avg, param in zip(self.moving_average,
296296
valid_model.parameters()):
297-
param.data = avg.data.half() if self.model_dtype == "fp16" \
298-
else avg.data
297+
param.data = avg.data
299298
else:
300299
valid_model = self.model
301300

onmt/utils/optimizers.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from copy import copy
88
from math import sqrt
99

10-
from onmt.utils.misc import fn_args
11-
1210

1311
def build_torch_optimizer(model, opt):
1412
"""Builds the PyTorch optimizer.
@@ -87,17 +85,14 @@ def build_torch_optimizer(model, opt):
8785

8886
if opt.model_dtype == 'fp16':
8987
import apex
90-
static_loss_scale = opt.loss_scale
91-
dynamic_loss_scale = opt.loss_scale == 0
92-
# TODO: clean this up when APEX unify its optimizer API.
93-
if opt.optim.startswith('fused'):
94-
namespace = apex.optimizers # Faster wrapper.
95-
else:
96-
namespace = apex.fp16_utils
97-
optimizer = namespace.FP16_Optimizer(
88+
loss_scale = "dynamic" if opt.loss_scale == 0 else opt.loss_scale
89+
model, optimizer = apex.amp.initialize(
90+
[model, model.generator],
9891
optimizer,
99-
static_loss_scale=static_loss_scale,
100-
dynamic_loss_scale=dynamic_loss_scale)
92+
opt_level=opt.apex_opt_level,
93+
loss_scale=loss_scale,
94+
keep_batchnorm_fp32=False if opt.optim == "fusedadam" else None)
95+
10196
return optimizer
10297

10398

@@ -317,10 +312,9 @@ def backward(self, loss):
317312
"""Wrapper for backward pass. Some optimizer requires ownership of the
318313
backward pass."""
319314
if self._with_fp16_wrapper:
320-
kwargs = {}
321-
if "update_master_grads" in fn_args(self._optimizer.backward):
322-
kwargs["update_master_grads"] = True
323-
self._optimizer.backward(loss, **kwargs)
315+
import apex
316+
with apex.amp.scale_loss(loss, self._optimizer) as scaled_loss:
317+
scaled_loss.backward()
324318
else:
325319
loss.backward()
326320

@@ -336,7 +330,9 @@ def step(self):
336330
self._optimizer.update_master_grads()
337331
if hasattr(self._optimizer, "clip_master_grads") and \
338332
self._max_grad_norm > 0:
339-
self._optimizer.clip_master_grads(self._max_grad_norm)
333+
import apex
334+
torch.nn.utils.glip_grad_norm_(
335+
apex.amp.master_params(self), self._max_grad_norm)
340336
for group in self._optimizer.param_groups:
341337
group['lr'] = learning_rate
342338
if not self._with_fp16_wrapper and self._max_grad_norm > 0:

onmt/utils/parse.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ def validate_model_opts(cls, model_opt):
6363
if model_opt.model_type != "text":
6464
raise AssertionError(
6565
"--share_embeddings requires --model_type text.")
66-
if model_opt.model_dtype == "fp16":
67-
logger.warning(
68-
"FP16 is experimental, the generated checkpoints may "
69-
"be incompatible with a future version")
7066

7167
@classmethod
7268
def ckpt_model_opts(cls, ckpt_opt):

0 commit comments

Comments
 (0)