77from copy import copy
88from math import sqrt
99
10- from onmt .utils .misc import fn_args
11-
1210
1311def build_torch_optimizer (model , opt ):
1412 """Builds the PyTorch optimizer.
@@ -87,17 +85,14 @@ def build_torch_optimizer(model, opt):
8785
8886 if opt .model_dtype == 'fp16' :
8987 import apex
90- static_loss_scale = opt .loss_scale
91- dynamic_loss_scale = opt .loss_scale == 0
92- # TODO: clean this up when APEX unify its optimizer API.
93- if opt .optim .startswith ('fused' ):
94- namespace = apex .optimizers # Faster wrapper.
95- else :
96- namespace = apex .fp16_utils
97- optimizer = namespace .FP16_Optimizer (
88+ loss_scale = "dynamic" if opt .loss_scale == 0 else opt .loss_scale
89+ model , optimizer = apex .amp .initialize (
90+ [model , model .generator ],
9891 optimizer ,
99- static_loss_scale = static_loss_scale ,
100- dynamic_loss_scale = dynamic_loss_scale )
92+ opt_level = opt .apex_opt_level ,
93+ loss_scale = loss_scale ,
94+ keep_batchnorm_fp32 = False if opt .optim == "fusedadam" else None )
95+
10196 return optimizer
10297
10398
@@ -317,10 +312,9 @@ def backward(self, loss):
317312 """Wrapper for backward pass. Some optimizer requires ownership of the
318313 backward pass."""
319314 if self ._with_fp16_wrapper :
320- kwargs = {}
321- if "update_master_grads" in fn_args (self ._optimizer .backward ):
322- kwargs ["update_master_grads" ] = True
323- self ._optimizer .backward (loss , ** kwargs )
315+ import apex
316+ with apex .amp .scale_loss (loss , self ._optimizer ) as scaled_loss :
317+ scaled_loss .backward ()
324318 else :
325319 loss .backward ()
326320
@@ -336,7 +330,9 @@ def step(self):
336330 self ._optimizer .update_master_grads ()
337331 if hasattr (self ._optimizer , "clip_master_grads" ) and \
338332 self ._max_grad_norm > 0 :
339- self ._optimizer .clip_master_grads (self ._max_grad_norm )
333+ import apex
334+ torch .nn .utils .glip_grad_norm_ (
335+ apex .amp .master_params (self ), self ._max_grad_norm )
340336 for group in self ._optimizer .param_groups :
341337 group ['lr' ] = learning_rate
342338 if not self ._with_fp16_wrapper and self ._max_grad_norm > 0 :
0 commit comments