diff --git a/requirements.txt b/requirements.txt index 26cfeb5..90d7a0d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ click pytest +pytest-ordering http://www.opengrm.org/twiki/pub/GRM/PyniniDownload/pynini-2.0.0.tar.gz#egg=pynini diff --git a/tests/test_fsts.py b/tests/test_fsts.py new file mode 100644 index 0000000..916fe10 --- /dev/null +++ b/tests/test_fsts.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +import sys +import os +import pytest +import pynini + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../timur'))) + +from timur import fsts +from timur import helpers + +sample_symbols = ["", "", "", "", ""] + +def test_map_fst(): + syms = helpers.load_alphabet(sample_symbols) + assert(True) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 455fd4e..9e08d84 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -import sys, os, pytest - +import sys +import os +import pytest import pynini sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../timur'))) @@ -11,6 +12,7 @@ sample_symbols = ["", "", "", "", ""] sample_entries = ["Anüs"] +@pytest.mark.first def test_load_alphabet(): ''' Load the sample symbol set and check for membership. @@ -19,6 +21,7 @@ def test_load_alphabet(): assert(syms.member("")) assert(syms.member("ü")) +@pytest.mark.second def test_load_lexicon(): ''' Load the sample lexicon and check vor invariance. diff --git a/timur/data/syms.txt b/timur/data/syms.txt index 39ab4f5..f52a59c 100644 --- a/timur/data/syms.txt +++ b/timur/data/syms.txt @@ -18,12 +18,35 @@ + + + + + + + + + + + + + + + + + + + + + + # intermediate features + # stem type features diff --git a/timur/fsts/__init__.py b/timur/fsts/__init__.py index 35eab7c..1563752 100644 --- a/timur/fsts/__init__.py +++ b/timur/fsts/__init__.py @@ -1,2 +1,4 @@ from .num_fst import num_fst from .phon_fst import phon_fst +from .map_fst import map_fst_map1 +from .map_fst import map_fst_map2 diff --git a/timur/fsts/map_fst.py b/timur/fsts/map_fst.py new file mode 100644 index 0000000..6df5aac --- /dev/null +++ b/timur/fsts/map_fst.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +import pynini + +from timur.helpers import union +from timur.helpers import concat + +def map_fst_map1(symbol_table): + ''' + Modifications of lexical entries + ''' + + # lexical features to be omitted from the output + cat = pynini.string_map(["", "", "", "", "", "", "", "", "", "", "", ""], input_token_type=symbol_table, output_token_type=symbol_table) + + cat_ext = pynini.string_map(["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], input_token_type=symbol_table, output_token_type=symbol_table) + + return pynini.Fst() + +def map_fst_map2(symbol_table): + ''' + Modifications of lexical entries + ''' diff --git a/timur/helpers/helpers.py b/timur/helpers/helpers.py index 4ac36aa..1336fa2 100644 --- a/timur/helpers/helpers.py +++ b/timur/helpers/helpers.py @@ -4,7 +4,7 @@ def union(*args, token_type="utf8"): args_mod = [] for arg in args: - if type(args) == "str": + if isinstance(arg, str): args_mod.append(pynini.acceptor(arg, token_type=token_type)) else: args_mod.append(arg) @@ -13,8 +13,13 @@ def union(*args, token_type="utf8"): def concat(*args, token_type="utf8"): args_mod = [] conc = pynini.Fst() + conc.set_start(conc.add_state()) + conc.set_final(conc.start()) + if isinstance(token_type, pynini.SymbolTable): + conc.set_input_symbols(token_type) + conc.set_output_symbols(token_type) for arg in args: - if type(args) == "str": + if isinstance(arg, str): arg = pynini.acceptor(arg, token_type=token_type) conc = pynini.concat(conc, arg) return conc @@ -30,7 +35,8 @@ def load_alphabet(source, auto_singletons=True): if symbol.isprintable() and not symbol.isspace(): syms.add_symbol(symbol) for symbol in source: - symbol = str(symbol) + if isinstance(symbol, bytes): + symbol = symbol.decode("utf-8") if symbol.startswith('#'): continue syms.add_symbol(symbol.strip()) diff --git a/timur/scripts/timur.py b/timur/scripts/timur.py index b0513cb..2336f11 100644 --- a/timur/scripts/timur.py +++ b/timur/scripts/timur.py @@ -36,8 +36,25 @@ def build(lexicon): lex = helpers.load_lexicon(lexicon, syms) - #phon = phon_fst(syms) - #phon.draw("test.dot") - num_stems = fsts.num_fst(syms) + # add repetitive prefixes + # TODO: move to fst function + print(syms.member("")) + repeatable_prefs = helpers.concat( + "", + helpers.union( + "u r ", + "v o r ", + token_type=syms + ).closure(1), + " ", + token_type=syms + ) + lex = pynini.union(lex, repeatable_prefs) + + map1 = fsts.map_fst_map1(syms) + + lex = pynini.compose(map1, lex) + lex.draw("test.dot") - ANY = construct_any(syms) + #phon = phon_fst(syms) + #num_stems = fsts.num_fst(syms)