Skip to content

Commit

Permalink
refactor(utils/data_model_utils/tabular): get_cell_ngrams and get_nei…
Browse files Browse the repository at this point in the history
…ghbor_cell_ngrams yield nothing if mention is not tabular (#504)

Fixes #471.
  • Loading branch information
Hiromu Hota authored Sep 11, 2020
1 parent b5765ab commit 9113fb3
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ Changed
* `@HiromuHota`_: Log a stack trace on parsing error for better debug experience.
(`#478 <https://github.com/HazyResearch/fonduer/issues/478>`_)
(`#479 <https://github.com/HazyResearch/fonduer/pull/479>`_)
* `@HiromuHota`_: :func:`get_cell_ngrams` and :func:`get_neighbor_cell_ngrams` yield
nothing when the mention is not tabular.
(`#471 <https://github.com/HazyResearch/fonduer/issues/471>`_)
(`#504 <https://github.com/HazyResearch/fonduer/pull/504>`_)

Deprecated
^^^^^^^^^^
Expand Down
14 changes: 12 additions & 2 deletions src/fonduer/utils/data_model_utils/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def get_cell_ngrams(
"""Get the ngrams that are in the Cell of the given mention, not including itself.
Note that if a candidate is passed in, all of its Mentions will be searched.
Also note that if the mention is not tabular, nothing will be yielded.
:param mention: The Mention whose Cell is being searched
:param attrib: The token attribute type (e.g. words, lemmas, poses)
Expand All @@ -235,11 +236,13 @@ def get_cell_ngrams(
"""
spans = _to_spans(mention)
for span in spans:
if not span.sentence.is_tabular():
continue

for ngram in get_sentence_ngrams(
span, attrib=attrib, n_min=n_min, n_max=n_max, lower=lower
):
yield ngram
if span.sentence.is_tabular():
for ngram in chain.from_iterable(
[
tokens_to_ngrams(
Expand Down Expand Up @@ -271,6 +274,7 @@ def get_neighbor_cell_ngrams(
Note that if a candidate is passed in, all of its Mentions will be
searched. If `directions=True``, each ngram will be returned with a
direction in {'UP', 'DOWN', 'LEFT', 'RIGHT'}.
Also note that if the mention is not tabular, nothing will be yielded.
:param mention: The Mention whose neighbor Cells are being searched
:param dist: The Cell distance within which a neighbor Cell must be to be
Expand All @@ -286,11 +290,13 @@ def get_neighbor_cell_ngrams(
# TODO: Fix this to be more efficient (optimize with SQL query)
spans = _to_spans(mention)
for span in spans:
if not span.sentence.is_tabular():
continue

for ngram in get_sentence_ngrams(
span, attrib=attrib, n_min=n_min, n_max=n_max, lower=lower
):
yield ngram
if span.sentence.is_tabular():
root_cell = span.sentence.cell
for sentence in chain.from_iterable(
[
Expand Down Expand Up @@ -337,6 +343,7 @@ def get_row_ngrams(
"""Get the ngrams from all Cells that are in the same row as the given Mention.
Note that if a candidate is passed in, all of its Mentions will be searched.
Also note that if the mention is not tabular, nothing will be yielded.
:param mention: The Mention whose row Cells are being searched
:param attrib: The token attribute type (e.g. words, lemmas, poses)
Expand Down Expand Up @@ -370,6 +377,7 @@ def get_col_ngrams(
"""Get the ngrams from all Cells that are in the same column as the given Mention.
Note that if a candidate is passed in, all of its Mentions will be searched.
Also note that if the mention is not tabular, nothing will be yielded.
:param mention: The Mention whose column Cells are being searched
:param attrib: The token attribute type (e.g. words, lemmas, poses)
Expand Down Expand Up @@ -404,6 +412,7 @@ def get_aligned_ngrams(
Note that if a candidate is passed in, all of its Mentions will be
searched.
Also note that if the mention is not tabular, nothing will be yielded.
:param mention: The Mention whose row and column Cells are being searched
:param attrib: The token attribute type (e.g. words, lemmas, poses)
Expand Down Expand Up @@ -439,6 +448,7 @@ def get_head_ngrams(
ngrams in the topmost cell in the column, depending on the axis parameter.
Note that if a candidate is passed in, all of its Mentions will be searched.
Also note that if the mention is not tabular, nothing will be yielded.
:param mention: The Mention whose head Cells are being returned
:param axis: Which axis {'row', 'col'} to search. If None, then both row
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_e2e(database_session):
featurizer.apply(split=1, train=True, parallelism=PARALLEL)
assert session.query(Feature).count() == 214
num_feature_keys = session.query(FeatureKey).count()
assert num_feature_keys == 1281
assert num_feature_keys == 1278

# Test Dropping FeatureKey
# Should force a row deletion
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_e2e(database_session):
num_features = session.query(Feature).count()
assert num_features == len(train_cands[0]) + len(train_cands[1])
num_feature_keys = session.query(FeatureKey).count()
assert num_feature_keys == 4577
assert num_feature_keys == 4555
F_train = featurizer.get_feature_matrices(train_cands)
assert F_train[0].shape == (len(train_cands[0]), num_feature_keys)
assert F_train[1].shape == (len(train_cands[1]), num_feature_keys)
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_incremental(database_session):
featurizer.update(new_docs, parallelism=PARALLEL)
assert session.query(Feature).count() == len(train_cands[0])
num_feature_keys = session.query(FeatureKey).count()
assert num_feature_keys == 2526
assert num_feature_keys == 2514
F_train = featurizer.get_feature_matrices(train_cands)
assert F_train[0].shape == (len(train_cands[0]), num_feature_keys)
assert len(featurizer.get_keys()) == num_feature_keys
Expand Down
12 changes: 10 additions & 2 deletions tests/utils/data_model_utils/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_get_cell_ngrams(mention_setup):

# when a mention is not tabular
assert mentions[0].get_span() == "Sample"
assert list(get_cell_ngrams(mentions[0])) == ["markdown"]
assert list(get_cell_ngrams(mentions[0])) == []


def test_get_neighbor_cell_ngrams(mention_setup):
Expand All @@ -216,7 +216,7 @@ def test_get_neighbor_cell_ngrams(mention_setup):

# when a mention is not tabular
assert mentions[0].get_span() == "Sample"
assert list(get_neighbor_cell_ngrams(mentions[0])) == ["markdown"]
assert list(get_neighbor_cell_ngrams(mentions[0])) == []


def test_get_row_ngrams(mention_setup):
Expand All @@ -229,6 +229,10 @@ def test_get_row_ngrams(mention_setup):
"11",
]

# when a mention is not tabular
assert mentions[0].get_span() == "Sample"
assert list(get_row_ngrams(mentions[0])) == []


def test_get_col_ngrams(mention_setup):
"""Test the get_col_ngrams function."""
Expand Down Expand Up @@ -257,6 +261,10 @@ def test_get_aligned_ngrams(mention_setup):
"madras",
]

# when a mention is not tabular
assert mentions[0].get_span() == "Sample"
assert list(get_aligned_ngrams(mentions[0])) == []


def test_get_head_ngrams(mention_setup):
"""Test the get_head_ngrams function."""
Expand Down

0 comments on commit 9113fb3

Please sign in to comment.