@@ -233,6 +233,60 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
233233
234234 set_random_seed (self .opt .seed , self .opt .cuda )
235235
236+ if self .preprocess_opt is not None :
237+ self .logger .info ("Loading preprocessor" )
238+ self .preprocessor = []
239+
240+ for function_path in self .preprocess_opt :
241+ function = get_function_by_path (function_path )
242+ self .preprocessor .append (function )
243+
244+ if self .tokenizer_opt is not None :
245+ self .logger .info ("Loading tokenizer" )
246+
247+ if "type" not in self .tokenizer_opt :
248+ raise ValueError (
249+ "Missing mandatory tokenizer option 'type'" )
250+
251+ if self .tokenizer_opt ['type' ] == 'sentencepiece' :
252+ if "model" not in self .tokenizer_opt :
253+ raise ValueError (
254+ "Missing mandatory tokenizer option 'model'" )
255+ import sentencepiece as spm
256+ sp = spm .SentencePieceProcessor ()
257+ model_path = os .path .join (self .model_root ,
258+ self .tokenizer_opt ['model' ])
259+ sp .Load (model_path )
260+ self .tokenizer = sp
261+ elif self .tokenizer_opt ['type' ] == 'pyonmttok' :
262+ if "params" not in self .tokenizer_opt :
263+ raise ValueError (
264+ "Missing mandatory tokenizer option 'params'" )
265+ import pyonmttok
266+ if self .tokenizer_opt ["mode" ] is not None :
267+ mode = self .tokenizer_opt ["mode" ]
268+ else :
269+ mode = None
270+ # load can be called multiple times: modify copy
271+ tokenizer_params = dict (self .tokenizer_opt ["params" ])
272+ for key , value in self .tokenizer_opt ["params" ].items ():
273+ if key .endswith ("path" ):
274+ tokenizer_params [key ] = os .path .join (
275+ self .model_root , value )
276+ tokenizer = pyonmttok .Tokenizer (mode ,
277+ ** tokenizer_params )
278+ self .tokenizer = tokenizer
279+ else :
280+ raise ValueError ("Invalid value for tokenizer type" )
281+
282+ if self .postprocess_opt is not None :
283+ self .logger .info ("Loading postprocessor" )
284+ self .postprocessor = []
285+
286+ for function_path in self .postprocess_opt :
287+ function = get_function_by_path (function_path )
288+ self .postprocessor .append (function )
289+
236290 if load :
237291 self .load ()
238292
@@ -294,60 +348,6 @@ def load(self):
294348 raise ServerModelError ("Runtime Error: %s" % str (e ))
295349
296350 timer .tick ("model_loading" )
297- if self .preprocess_opt is not None :
298- self .logger .info ("Loading preprocessor" )
299- self .preprocessor = []
300-
301- for function_path in self .preprocess_opt :
302- function = get_function_by_path (function_path )
303- self .preprocessor .append (function )
304-
305- if self .tokenizer_opt is not None :
306- self .logger .info ("Loading tokenizer" )
307-
308- if "type" not in self .tokenizer_opt :
309- raise ValueError (
310- "Missing mandatory tokenizer option 'type'" )
311-
312- if self .tokenizer_opt ['type' ] == 'sentencepiece' :
313- if "model" not in self .tokenizer_opt :
314- raise ValueError (
315- "Missing mandatory tokenizer option 'model'" )
316- import sentencepiece as spm
317- sp = spm .SentencePieceProcessor ()
318- model_path = os .path .join (self .model_root ,
319- self .tokenizer_opt ['model' ])
320- sp .Load (model_path )
321- self .tokenizer = sp
322- elif self .tokenizer_opt ['type' ] == 'pyonmttok' :
323- if "params" not in self .tokenizer_opt :
324- raise ValueError (
325- "Missing mandatory tokenizer option 'params'" )
326- import pyonmttok
327- if self .tokenizer_opt ["mode" ] is not None :
328- mode = self .tokenizer_opt ["mode" ]
329- else :
330- mode = None
331- # load can be called multiple times: modify copy
332- tokenizer_params = dict (self .tokenizer_opt ["params" ])
333- for key , value in self .tokenizer_opt ["params" ].items ():
334- if key .endswith ("path" ):
335- tokenizer_params [key ] = os .path .join (
336- self .model_root , value )
337- tokenizer = pyonmttok .Tokenizer (mode ,
338- ** tokenizer_params )
339- self .tokenizer = tokenizer
340- else :
341- raise ValueError ("Invalid value for tokenizer type" )
342-
343- if self .postprocess_opt is not None :
344- self .logger .info ("Loading postprocessor" )
345- self .postprocessor = []
346-
347- for function_path in self .postprocess_opt :
348- function = get_function_by_path (function_path )
349- self .postprocessor .append (function )
350-
351351 self .load_time = timer .tick ()
352352 self .reset_unload_timer ()
353353 self .loading_lock .set ()
@@ -491,6 +491,7 @@ def unload(self):
491491 del self .translator
492492 if self .opt .cuda :
493493 torch .cuda .empty_cache ()
494+ self .stop_unload_timer ()
494495 self .unload_timer = None
495496
496497 def stop_unload_timer (self ):
0 commit comments