Skip to content

Commit 86f2c52

Browse files
Merge branch 'master' of https://github.com/OpenNMT/OpenNMT-py into cosine_loss
2 parents 845c989 + 46c0456 commit 86f2c52

12 files changed

Lines changed: 144 additions & 80 deletions

File tree

.travis.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
dist: xenial
22
language: python
33
python:
4-
- "3.5"
4+
- "3.6"
55
git:
66
depth: false
77
addons:
@@ -13,7 +13,8 @@ addons:
1313
- sox
1414
before_install:
1515
# Install CPU version of PyTorch.
16-
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install torch==1.2.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html; fi
16+
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install torch==1.4.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html; fi
17+
- pip install --upgrade setuptools
1718
- pip install -r requirements.opt.txt
1819
- python setup.py install
1920
env:

onmt/bin/release_model.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python
2+
import argparse
3+
import torch
4+
5+
6+
def get_ctranslate2_model_spec(opt):
7+
"""Creates a CTranslate2 model specification from the model options."""
8+
is_vanilla_transformer = (
9+
opt.encoder_type == "transformer"
10+
and opt.decoder_type == "transformer"
11+
and opt.position_encoding
12+
and opt.enc_layers == opt.dec_layers
13+
and getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot"
14+
and getattr(opt, "max_relative_positions", 0) == 0)
15+
if not is_vanilla_transformer:
16+
return None
17+
import ctranslate2
18+
num_heads = getattr(opt, "heads", 8)
19+
return ctranslate2.specs.TransformerSpec(opt.layers, num_heads)
20+
21+
22+
def main():
23+
parser = argparse.ArgumentParser(
24+
description="Release an OpenNMT-py model for inference")
25+
parser.add_argument("--model", "-m",
26+
help="The model path", required=True)
27+
parser.add_argument("--output", "-o",
28+
help="The output path", required=True)
29+
parser.add_argument("--format",
30+
choices=["pytorch", "ctranslate2"],
31+
default="pytorch",
32+
help="The format of the released model")
33+
opt = parser.parse_args()
34+
35+
model = torch.load(opt.model)
36+
if opt.format == "pytorch":
37+
model["optim"] = None
38+
torch.save(model, opt.output)
39+
elif opt.format == "ctranslate2":
40+
model_spec = get_ctranslate2_model_spec(model["opt"])
41+
if model_spec is None:
42+
raise ValueError("This model is not supported by CTranslate2. Go "
43+
"to https://github.com/OpenNMT/CTranslate2 for "
44+
"more information on supported models.")
45+
import ctranslate2
46+
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
47+
converter.convert(opt.output, model_spec, force=True)
48+
49+
50+
if __name__ == "__main__":
51+
main()

onmt/bin/server.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import configargparse
33

44
from flask import Flask, jsonify, request
5+
from waitress import serve
56
from onmt.translate import TranslationServer, ServerModelError
7+
import logging
8+
from logging.handlers import RotatingFileHandler
69

710
STATUS_OK = "ok"
811
STATUS_ERROR = "error"
@@ -12,12 +15,22 @@ def start(config_file,
1215
url_root="./translator",
1316
host="0.0.0.0",
1417
port=5000,
15-
debug=True):
18+
debug=False):
1619
def prefix_route(route_function, prefix='', mask='{0}{1}'):
1720
def newroute(route, *args, **kwargs):
1821
return route_function(mask.format(prefix, route), *args, **kwargs)
1922
return newroute
2023

24+
if debug:
25+
logger = logging.getLogger("main")
26+
log_format = logging.Formatter(
27+
"[%(asctime)s %(levelname)s] %(message)s")
28+
file_handler = RotatingFileHandler(
29+
"debug_requests.log",
30+
maxBytes=1000000, backupCount=10)
31+
file_handler.setFormatter(log_format)
32+
logger.addHandler(file_handler)
33+
2134
app = Flask(__name__)
2235
app.route = prefix_route(app.route, url_root)
2336
translation_server = TranslationServer()
@@ -73,6 +86,8 @@ def unload_model(model_id):
7386
@app.route('/translate', methods=['POST'])
7487
def translate():
7588
inputs = request.get_json(force=True)
89+
if debug:
90+
logger.info(inputs)
7691
out = {}
7792
try:
7893
trans, scores, n_best, _, aligns = translation_server.run(inputs)
@@ -90,7 +105,8 @@ def translate():
90105
except ServerModelError as e:
91106
out['error'] = str(e)
92107
out['status'] = STATUS_ERROR
93-
108+
if debug:
109+
logger.info(out)
94110
return jsonify(out)
95111

96112
@app.route('/to_cpu/<int:model_id>', methods=['GET'])
@@ -109,8 +125,7 @@ def to_gpu(model_id):
109125
out['status'] = STATUS_OK
110126
return jsonify(out)
111127

112-
app.run(debug=debug, host=host, port=port, use_reloader=False,
113-
threaded=True)
128+
serve(app, host=host, port=port)
114129

115130

116131
def _get_parser():

onmt/decoders/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
5656

5757
if self_attn_type == "scaled-dot":
5858
self.self_attn = MultiHeadedAttention(
59-
heads, d_model, dropout=dropout,
59+
heads, d_model, dropout=attention_dropout,
6060
max_relative_positions=max_relative_positions)
6161
elif self_attn_type == "average":
6262
self.self_attn = AverageAttention(d_model,

onmt/inputters/inputter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -822,11 +822,12 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False):
822822
to iterate over. We implement simple ordered iterator strategy here,
823823
but more sophisticated strategy like curriculum learning is ok too.
824824
"""
825+
dataset_glob = opt.data + '.' + corpus_type + '.[0-9]*.pt'
825826
dataset_paths = list(sorted(
826-
glob.glob(opt.data + '.' + corpus_type + '.[0-9]*.pt')))
827+
glob.glob(dataset_glob)))
827828
if not dataset_paths:
828829
if is_train:
829-
raise ValueError('Training data %s not found' % opt.data)
830+
raise ValueError('Training data %s not found' % dataset_glob)
830831
else:
831832
return None
832833
if multi:

onmt/model_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
190190
vocab_size = len(tgt_base_field.vocab)
191191
pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
192192
generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx)
193+
if model_opt.share_decoder_embeddings:
194+
generator.linear.weight = decoder.embeddings.word_lut.weight
193195

194196
# Load the model states from checkpoint or initialize them.
195197
if checkpoint is not None:

onmt/models/model_saver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,5 @@ def _save(self, step, model):
131131
return checkpoint, checkpoint_path
132132

133133
def _rm_checkpoint(self, name):
134-
os.remove(name)
134+
if os.path.exists(name):
135+
os.remove(name)

onmt/opts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def train_opts(parser):
377377
help="IP of master for torch.distributed training.")
378378
group.add('--master_port', '-master_port', default=10000, type=int,
379379
help="Port of master for torch.distributed training.")
380-
group.add('--queue_size', '-queue_size', default=400, type=int,
380+
group.add('--queue_size', '-queue_size', default=40, type=int,
381381
help="Size of queue for each process in producer/consumer")
382382

383383
group.add('--seed', '-seed', type=int, default=-1,

onmt/translate/translation_server.py

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,60 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
233233

234234
set_random_seed(self.opt.seed, self.opt.cuda)
235235

236+
if self.preprocess_opt is not None:
237+
self.logger.info("Loading preprocessor")
238+
self.preprocessor = []
239+
240+
for function_path in self.preprocess_opt:
241+
function = get_function_by_path(function_path)
242+
self.preprocessor.append(function)
243+
244+
if self.tokenizer_opt is not None:
245+
self.logger.info("Loading tokenizer")
246+
247+
if "type" not in self.tokenizer_opt:
248+
raise ValueError(
249+
"Missing mandatory tokenizer option 'type'")
250+
251+
if self.tokenizer_opt['type'] == 'sentencepiece':
252+
if "model" not in self.tokenizer_opt:
253+
raise ValueError(
254+
"Missing mandatory tokenizer option 'model'")
255+
import sentencepiece as spm
256+
sp = spm.SentencePieceProcessor()
257+
model_path = os.path.join(self.model_root,
258+
self.tokenizer_opt['model'])
259+
sp.Load(model_path)
260+
self.tokenizer = sp
261+
elif self.tokenizer_opt['type'] == 'pyonmttok':
262+
if "params" not in self.tokenizer_opt:
263+
raise ValueError(
264+
"Missing mandatory tokenizer option 'params'")
265+
import pyonmttok
266+
if self.tokenizer_opt["mode"] is not None:
267+
mode = self.tokenizer_opt["mode"]
268+
else:
269+
mode = None
270+
# load can be called multiple times: modify copy
271+
tokenizer_params = dict(self.tokenizer_opt["params"])
272+
for key, value in self.tokenizer_opt["params"].items():
273+
if key.endswith("path"):
274+
tokenizer_params[key] = os.path.join(
275+
self.model_root, value)
276+
tokenizer = pyonmttok.Tokenizer(mode,
277+
**tokenizer_params)
278+
self.tokenizer = tokenizer
279+
else:
280+
raise ValueError("Invalid value for tokenizer type")
281+
282+
if self.postprocess_opt is not None:
283+
self.logger.info("Loading postprocessor")
284+
self.postprocessor = []
285+
286+
for function_path in self.postprocess_opt:
287+
function = get_function_by_path(function_path)
288+
self.postprocessor.append(function)
289+
236290
if load:
237291
self.load()
238292

@@ -294,60 +348,6 @@ def load(self):
294348
raise ServerModelError("Runtime Error: %s" % str(e))
295349

296350
timer.tick("model_loading")
297-
if self.preprocess_opt is not None:
298-
self.logger.info("Loading preprocessor")
299-
self.preprocessor = []
300-
301-
for function_path in self.preprocess_opt:
302-
function = get_function_by_path(function_path)
303-
self.preprocessor.append(function)
304-
305-
if self.tokenizer_opt is not None:
306-
self.logger.info("Loading tokenizer")
307-
308-
if "type" not in self.tokenizer_opt:
309-
raise ValueError(
310-
"Missing mandatory tokenizer option 'type'")
311-
312-
if self.tokenizer_opt['type'] == 'sentencepiece':
313-
if "model" not in self.tokenizer_opt:
314-
raise ValueError(
315-
"Missing mandatory tokenizer option 'model'")
316-
import sentencepiece as spm
317-
sp = spm.SentencePieceProcessor()
318-
model_path = os.path.join(self.model_root,
319-
self.tokenizer_opt['model'])
320-
sp.Load(model_path)
321-
self.tokenizer = sp
322-
elif self.tokenizer_opt['type'] == 'pyonmttok':
323-
if "params" not in self.tokenizer_opt:
324-
raise ValueError(
325-
"Missing mandatory tokenizer option 'params'")
326-
import pyonmttok
327-
if self.tokenizer_opt["mode"] is not None:
328-
mode = self.tokenizer_opt["mode"]
329-
else:
330-
mode = None
331-
# load can be called multiple times: modify copy
332-
tokenizer_params = dict(self.tokenizer_opt["params"])
333-
for key, value in self.tokenizer_opt["params"].items():
334-
if key.endswith("path"):
335-
tokenizer_params[key] = os.path.join(
336-
self.model_root, value)
337-
tokenizer = pyonmttok.Tokenizer(mode,
338-
**tokenizer_params)
339-
self.tokenizer = tokenizer
340-
else:
341-
raise ValueError("Invalid value for tokenizer type")
342-
343-
if self.postprocess_opt is not None:
344-
self.logger.info("Loading postprocessor")
345-
self.postprocessor = []
346-
347-
for function_path in self.postprocess_opt:
348-
function = get_function_by_path(function_path)
349-
self.postprocessor.append(function)
350-
351351
self.load_time = timer.tick()
352352
self.reset_unload_timer()
353353
self.loading_lock.set()
@@ -491,6 +491,7 @@ def unload(self):
491491
del self.translator
492492
if self.opt.cuda:
493493
torch.cuda.empty_cache()
494+
self.stop_unload_timer()
494495
self.unload_timer = None
495496

496497
def stop_unload_timer(self):

onmt/utils/logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def init_logger(log_file=None, log_file_level=logging.NOTSET):
1717

1818
if log_file and log_file != '':
1919
file_handler = RotatingFileHandler(
20-
log_file, maxBytes=1000, backupCount=10)
20+
log_file, maxBytes=1000000, backupCount=10)
2121
file_handler.setLevel(log_file_level)
2222
file_handler.setFormatter(log_format)
2323
logger.addHandler(file_handler)

0 commit comments

Comments
 (0)