Skip to content

Commit d96caa2

Browse files
committed
test: category clustering statistics
1 parent 0dc3edb commit d96caa2

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/test_fr.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,33 @@ def test_category_crp(data):
428428
assert crp['possible'].iloc[0] == 1
429429

430430

431+
def test_category_clustering():
432+
"""Test category clustering statistics."""
433+
subject = [1] * 2
434+
435+
# category of study and list items (two cases from category
436+
# clustering tests)
437+
study_category = [list('abcd') * 4] * 2
438+
recall_str = ['aaabbbcccddd', 'aabbcdcd']
439+
recall_category = [list(s) for s in recall_str]
440+
441+
# unique item codes (needed for merging study and recall events;
442+
# not actually needed for the stats)
443+
study_item = [[i for i in range(len(c))] for c in study_category]
444+
recall_item = [[0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11], [0, 4, 1, 5, 2, 3, 6, 7]]
445+
446+
# create merged free recall data
447+
raw = fr.table_from_lists(
448+
subject, study_item, recall_item, category=(study_category, recall_category)
449+
)
450+
data = fr.merge_free_recall(raw, list_keys=['category'])
451+
452+
# test ARC and LBC stats
453+
stats = fr.category_clustering(data, 'category')
454+
np.testing.assert_allclose(stats.loc[1, 'arc'], 0.667, rtol=0.011)
455+
np.testing.assert_allclose(stats.loc[1, 'lbc'], 3.2, rtol=0.011)
456+
457+
431458
def test_lag_rank(data):
432459
"""Test lag rank analysis."""
433460
stat = fr.lag_rank(data)

0 commit comments

Comments
 (0)