Skip to content

Commit

Permalink
Update spm_train and test it
Browse files Browse the repository at this point in the history
  • Loading branch information
ShigekiKarita committed Aug 24, 2019
1 parent 9f51838 commit f7f09b5
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
4 changes: 2 additions & 2 deletions test/test_sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_spm_compatibility():
# test encode and decode
sp = spm.SentencePieceProcessor()
sp.Load(f"{bpemodel}.model")
txt = "test sentencepiece."
txt = "test sentencepiece.[noise]"
actual = sp.EncodeAsPieces(txt)
expect = "▁ te s t ▁ s en t en c e p ie c e .".split()
expect = "▁ te s t ▁ s en t en c e p ie c e . [noise]".split()
assert actual == expect
assert sp.DecodePieces(actual) == txt
33 changes: 33 additions & 0 deletions test_utils/test_spm.bats
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env bats
# -*- mode:sh -*-

setup() {
tmpdir=$(mktemp -d test_spm.XXXXXX)
echo $tmpdir
}

teardown() {
rm -rf $tmpdir
}

@test "spm_xxx" {
testfile=test/tedlium2.txt
nbpe=100
bpemode=unigram
bpemodel=$tmpdir/test_spm

utils/spm_train --user_defined_symbols --input=${testfile} --vocab_size=${nbpe} --model_type=${bpemode} \
--model_prefix=${bpemodel} --input_sentence_size=100000000 \
--character_coverage=1.0 --bos_id=-1 --eos_id=-1 \
--unk_id=0 --user_defined_symbols=[laughter],[noise],[vocalized-noise]

diff ${bpemodel}.vocab test/tedlium2.vocab

txt="test sentencepiece.[noise]"

enc=$(echo $txt | utils/spm_encode --model=${bpemodel}.model --output_format=piece)
[ "$enc" = "▁ te s t ▁ s en t en c e p ie c e . [noise]" ]

dec=$(echo $enc | utils/spm_decode --model=${bpemodel}.model --input_format=piece)
[ "$dec" = "$txt" ]
}
8 changes: 2 additions & 6 deletions utils/spm_train
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#!/usr/bin/env python
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# https://github.com/pytorch/fairseq/blob/master/LICENSE

from __future__ import absolute_import, division, print_function, unicode_literals

import shlex
import sys

import sentencepiece as spm


if __name__ == "__main__":
spm.SentencePieceTrainer.Train(" ".join(map(shlex.quote, sys.argv[1:])).replace("\'", ""))
spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))

0 comments on commit f7f09b5

Please sign in to comment.