Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to SentencePieceTokenizer #11

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions tkseem/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
# from tkseem.tokenizers import (
# WordTokenizer,
# CharacterTokenizer,
# AutoTokenizer,
# CharacterTokenizer,
# DisjointLetterTokenizer,
# RandomTokenizer,
# SentencePieceTokenizer,
# )

from tkseem.character_tokenizer import CharacterTokenizer
from tkseem.disjoint_letters_tokenizer import DisjointLetterTokenizer
from tkseem.morphological_tokenizer import MorphologicalTokenizer
Expand Down
65 changes: 47 additions & 18 deletions tkseem/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,67 @@


class SentencePieceTokenizer(BaseTokenizer):
""" Sentencepiece based tokenization.
"""
"""Sentencepiece based tokenization."""

def train(self, file_path, model_type="bpe"):
""" Train using sentence piece
def train(self, file_path, **kwargs):
"""Train using sentence piece

Args:
file_path (str): file to train
model_type (str, optional): train using sp. Defaults to "bpe".
file_path (str): file to train
kwargs: additional arguments to pass to the SentencePieceTrainer. See https://github.com/google/sentencepiece/blob/master/doc/options.md
"""
print("Training SentencePiece ...")
self.model = io.BytesIO()

if kwargs.get("vocab_size"):
print(
f"WARNING: Vocab size is being overwritten to {kwargs.get('vocab_size')}"
)
self.vocab_size = kwargs.get("vocab_size")
kwargs.pop("vocab_size")

if kwargs.get("special_tokens"):
print(
f"WARNING: Special tokens are being overwritten to {kwargs.get('special_tokens')}"
)
self.special_tokens = kwargs.get("special_tokens")
kwargs.pop("special_tokens")

# Preserve default values from previous versions
model_type = kwargs.get("model_type", "bpe")
kwargs.pop("model_type")
character_coverage = kwargs.get("character_coverage", 1.0)
kwargs.pop("character_coverage")
unk_id = kwargs.get("unk_id", 0)
kwargs.pop("unk_id")
pad_id = kwargs.get("pad_id", 1)
kwargs.pop("pad_id")
bos_id = kwargs.get("bos_id", -1)
kwargs.pop("bos_id")
eos_id = kwargs.get("eos_id", -1)
kwargs.pop("eos_id")
normalization_rule_name = kwargs.get("normalization_rule_name", "identity")
kwargs.pop("normalization_rule_name")

spm.SentencePieceTrainer.train(
input=file_path,
model_writer=self.model,
vocab_size=self.vocab_size,
model_type=model_type,
character_coverage=1.0,
unk_id=0,
pad_id=1,
bos_id=-1,
eos_id=-1,
character_coverage=character_coverage,
unk_id=unk_id,
pad_id=pad_id,
bos_id=bos_id,
eos_id=eos_id,
user_defined_symbols=self.special_tokens,
normalization_rule_name="identity",
normalization_rule_name=normalization_rule_name,
**kwargs,
)
self.save_model("m.model")
self.sp = spm.SentencePieceProcessor(model_file="m.model")
self.sp = spm.SentencePieceProcessor(model_proto=self.model.getvalue())
self.vocab_size = self.sp.vocab_size()

def tokenize(self, text):
"""Tokenize using the frequency dictionary
"""Tokenize using the frequency dictionary

Args:
text (str): input string
Expand Down Expand Up @@ -72,7 +101,7 @@ def token_to_id(self, token):
return self.sp.piece_to_id(token)

def encode(self, text):
""" Convert string to a list of ids
"""Convert string to a list of ids

Args:
text (str): input string
Expand All @@ -83,7 +112,7 @@ def encode(self, text):
return self.sp.encode(text, out_type=int)

def decode(self, encoded):
""" Decode ids
"""Decode ids

Args:
encoded (list): list of ids to decode
Expand All @@ -94,7 +123,7 @@ def decode(self, encoded):
return self.sp.id_to_piece(encoded)

def detokenize(self, tokens):
""" Convert tokens to a string
"""Convert tokens to a string

Args:
tokens (list): list of tokens
Expand Down
Loading