From 818bc5e9b30caae6bb5d9b5cb00e829bbfe08833 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 10 Jan 2025 16:17:48 -0800 Subject: [PATCH] also handle special tokens that are already in vocab Signed-off-by: Alexandros Koumparoulis --- .../tokenizers/sentencepiece_tokenizer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index 038e21f3d60a..ac22dbbbd7bc 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -70,6 +70,7 @@ def __init__( self.id_to_special_token = {} self.trim_spm_separator_after_special_token = trim_spm_separator_after_special_token self.spm_separator_id = self.tokenizer.piece_to_id(spm_separator) + self.spm_separator = spm_separator if special_tokens: if not self.legacy: @@ -104,7 +105,19 @@ def text_to_tokens(self, text): next_token = min(indices, key=indices.get) next_idx = idx + indices[next_token] - tokens.extend(self.tokenizer.encode_as_pieces(text[idx:next_idx])) + tok = self.tokenizer.encode_as_pieces(text[idx:next_idx]) + # Chat-templates insert a space between a special token and first word (e.g. + # "[INST] who") which is tokenized as instead of + # . + if ( + self.trim_spm_separator_after_special_token + and len(tokens) > 0 + and tokens[-1] in self.special_token_to_id + and len(tok) > 0 + and tok[0] == self.spm_separator + ): + tok.pop(0) + tokens.extend(tok) tokens.append(next_token) idx = next_idx + len(next_token) @@ -268,6 +281,9 @@ def add_special_tokens(self, special_tokens): self.special_token_to_id[token] = self.vocab_size self.id_to_special_token[self.vocab_size] = token self.vocab_size += 1 + elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id(): + self.special_token_to_id[token] = self.tokenizer.piece_to_id(token) + self.id_to_special_token[self.special_token_to_id[token]] = token else: raise ValueError("Expected special_tokens to be a list or a dict " + str(type(special_tokens)))