Skip to content

Commit ec28446

Browse files
committed
[aggr-] add rank aggregator, cmds addcol-aggregate/sheetrank
1 parent b41afba commit ec28446

File tree

5 files changed

+200
-8
lines changed

5 files changed

+200
-8
lines changed

tests/aggregators-cols.vdj

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!vd -p
2+
{"sheet": "global", "col": null, "row": "disp_date_fmt", "longname": "set-option", "input": "%b %d, %Y", "keystrokes": "", "comment": null}
3+
{"longname": "open-file", "input": "sample_data/test.jsonl", "keystrokes": "o"}
4+
{"sheet": "test", "col": "key2", "row": "", "longname": "key-col", "input": "", "keystrokes": "!", "comment": "toggle current column as a key column"}
5+
{"sheet": "test", "col": "key2", "row": "", "longname": "addcol-aggregate", "input": "count", "comment": "add column(s) with aggregator of rows grouped by key columns"}
6+
{"sheet": "test", "col": "qty", "row": "", "longname": "type-float", "input": "", "keystrokes": "%", "comment": "set type of current column to float"}
7+
{"sheet": "test", "col": "qty", "row": "", "longname": "addcol-aggregate", "input": "rank sum", "comment": "add column(s) with aggregator of rows grouped by key columns"}
8+
{"sheet": "test", "col": "qty_sum", "row": "", "longname": "addcol-sheetrank", "input": "", "comment": "add column with the rank of each row based on its key columns"}

tests/golden/aggregators-cols.tsv

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
key2 key2_count key1 qty qty_rank qty_sum test_sheetrank amt
2+
foo 2 2016-01-01 11:00:00 1.00 1 31.00 5
3+
0 2016-01-01 1:00 2.00 1 66.00 2 3
4+
baz 3 4.00 1 292.00 4 43.2
5+
#ERR 0 #ERR #ERR 1 0.00 1 #ERR #ERR
6+
bar 2 2017-12-25 8:44 16.00 2 16.00 3 .3
7+
baz 3 32.00 2 292.00 4 3.3
8+
0 2018-07-27 4:44 64.00 2 66.00 2 9.1
9+
bar 2 2018-07-27 16:44 1 16.00 3
10+
baz 3 2018-07-27 18:44 256.00 3 292.00 4 .01
11+
foo 2 2018-10-20 18:44 30.00 2 31.00 5 .01

visidata/aggregators.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import functools
44
import collections
55
import statistics
6+
import itertools
67

7-
from visidata import Progress, Sheet, Column, ColumnsSheet, VisiData
8+
from visidata import Progress, Sheet, Column, ColumnsSheet, VisiData, SettableColumn
89
from visidata import vd, anytype, vlen, asyncthread, wrapply, AttrDict, date, INPROGRESS, stacktrace, TypedExceptionWrapper
910

1011
vd.help_aggregators = '''# Choose Aggregators
@@ -76,7 +77,7 @@ def aggregators_set(col, aggs):
7677

7778

7879
class Aggregator:
79-
def __init__(self, name, type, funcValues=None, helpstr='foo'):
80+
def __init__(self, name, type, funcValues=None, helpstr=''):
8081
'Define aggregator `name` that calls funcValues(values)'
8182
self.type = type
8283
self.funcValues = funcValues # funcValues(values)
@@ -92,13 +93,48 @@ def aggregate(self, col, rows): # wrap builtins so they can have a .type
9293
return None
9394
raise e
9495

96+
class ListAggregator(Aggregator):
97+
'''A list aggregator is an aggregator that returns a list of values, generally
98+
one value per input row, unlike ordinary aggregators that operate on rows
99+
and return only a single value.
100+
To implement a new list aggregator, subclass ListAggregator,
101+
and override aggregate() and aggregate_list().'''
102+
def __init__(self, name, type, helpstr='', listtype=None):
103+
'''*listtype* determines the type of the column created by addcol_aggregate()
104+
for list aggrs. If it is None, then the new column will match the type of the input column'''
105+
super().__init__(name, type, helpstr=helpstr)
106+
self.listtype = listtype
107+
108+
def aggregate(self, col, rows) -> list:
109+
'''Return a list, which can be shorter than *rows*, because it filters out nulls and errors.
110+
Override in subclass.'''
111+
vals = self.aggregate_list(col, rows)
112+
# filter out nulls and errors
113+
vals = [ v for v in vals if not col.sheet.isNullFunc()(v) ]
114+
return vals
115+
116+
def aggregate_list(self, col, row_group) -> list:
117+
'''Return a list of results, which will be one result per input row.
118+
*row_group* is an iterable that holds a "group" of rows to run the aggregator on.
119+
rows in *row_group* are not necessarily in the same order they are in the sheet.
120+
Override in subclass.'''
121+
vals = [ col.getTypedValue(r) for r in row_group ]
122+
return vals
95123

96124
@VisiData.api
97125
def aggregator(vd, name, funcValues, helpstr='', *, type=None):
98126
'''Define simple aggregator *name* that calls ``funcValues(values)`` to aggregate *values*.
99127
Use *type* to force type of aggregated column (default to use type of source column).'''
100128
vd.aggregators[name] = Aggregator(name, type, funcValues=funcValues, helpstr=helpstr)
101129

130+
@VisiData.api
131+
def aggregator_list(vd, name, helpstr='', type=anytype, listtype=anytype):
132+
'''Define simple aggregator *name* that calls ``funcValues(values)`` to aggregate *values*.
133+
Use *type* to force type of aggregated column (default to use type of source column).
134+
Use *listtype* to force the type of the new column created by addcol-aggregate.
135+
If *listtype* is None, it will match the type of the source column.'''
136+
vd.aggregators[name] = ListAggregator(name, type, helpstr=helpstr, listtype=listtype)
137+
102138
## specific aggregator implementations
103139

104140
def mean(vals):
@@ -147,10 +183,49 @@ def __init__(self, pct, helpstr=''):
147183
def aggregate(self, col, rows):
148184
return _percentile(sorted(col.getValues(rows)), self.pct/100, key=float)
149185

150-
151186
def quantiles(q, helpstr):
152187
return [PercentileAggregator(round(100*i/q), helpstr) for i in range(1, q)]
153188

189+
def aggregate_groups(sheet, col, rows, aggr) -> list:
190+
'''Returns a list, containing the result of the aggregator applied to each row.
191+
*col* is a column whose values determine each row's rank within a group.
192+
*rows* is a list of visidata rows.
193+
*aggr* is an Aggregator object.
194+
Rows are grouped by their key columns. Null key column cells are considered equal,
195+
so nulls are grouped together. Cells with exceptions do not group together.
196+
Each exception cell is grouped by itself, with only one row in the group.
197+
'''
198+
def _key_progress(prog):
199+
def identity(val):
200+
prog.addProgress(1)
201+
return val
202+
return identity
203+
204+
with Progress(gerund='ranking', total=4*sheet.nRows) as prog:
205+
p = _key_progress(prog) # increment progress every time p() is called
206+
# compile row data, for each row a list of tuples: (group_key, rank_key, rownum)
207+
rowdata = [(sheet.rowkey(r), col.getTypedValue(r), p(rownum)) for rownum, r in enumerate(rows)]
208+
# sort by row key and column value to prepare for grouping
209+
try:
210+
rowdata.sort(key=p)
211+
except TypeError as e:
212+
vd.fail(f'elements in a ranking column must be comparable: {e.args[0]}')
213+
rowvals = []
214+
#group by row key
215+
for _, group in itertools.groupby(rowdata, key=lambda v: v[0]):
216+
# within a group, the rows have already been sorted by col_val
217+
group = list(group)
218+
if isinstance(aggr, ListAggregator): # for list aggregators, each row gets its own value
219+
aggr_vals = aggr.aggregate_list(col, [rows[rownum] for _, _, rownum in group])
220+
rowvals += [(rownum, v) for (_, _, rownum), v in zip(group, aggr_vals)]
221+
else: # for normal aggregators, each row in the group gets the same value
222+
aggr_val = aggr.aggregate(col, [rows[rownum] for _, _, rownum in group])
223+
rowvals += [(rownum, aggr_val) for _, _, rownum in group]
224+
prog.addProgress(len(group))
225+
# sort by unique rownum, to make rank results match the original row order
226+
rowvals.sort(key=p)
227+
rowvals = [ v for rownum, v in rowvals ]
228+
return rowvals
154229

155230
vd.aggregator('min', min, 'minimum value')
156231
vd.aggregator('max', max, 'maximum value')
@@ -161,8 +236,8 @@ def quantiles(q, helpstr):
161236
vd.aggregator('sum', vsum, 'sum of values')
162237
vd.aggregator('distinct', set, 'distinct values', type=vlen)
163238
vd.aggregator('count', lambda values: sum(1 for v in values), 'number of values', type=int)
164-
vd.aggregator('list', list, 'list of values', type=anytype)
165-
vd.aggregator('stdev', stdev, 'standard deviation of values', type=float)
239+
vd.aggregator_list('list', 'list of values', type=anytype, listtype=None)
240+
vd.aggregator('stdev', statistics.stdev, 'standard deviation of values', type=float)
166241

167242
vd.aggregators['q3'] = quantiles(3, 'tertiles (33/66th pctile)')
168243
vd.aggregators['q4'] = quantiles(4, 'quartiles (25/50/75th pctile)')
@@ -267,9 +342,8 @@ def aggregator_choices(vd):
267342

268343

269344
@VisiData.api
270-
def chooseAggregators(vd):
345+
def chooseAggregators(vd, prompt = 'choose aggregators: '):
271346
'''Return a list of aggregator name strings chosen or entered by the user. User-entered names may be invalid.'''
272-
prompt = 'choose aggregators: '
273347
def _fmt_aggr_summary(match, row, trigger_key):
274348
formatted_aggrname = match.formatted.get('key', row.key) if match else row.key
275349
r = ' '*(len(prompt)-3)
@@ -296,10 +370,34 @@ def _fmt_aggr_summary(match, row, trigger_key):
296370
vd.warning(f'aggregator does not exist: {aggr}')
297371
return aggrs
298372

299-
Sheet.addCommand('+', 'aggregate-col', 'addAggregators([cursorCol], chooseAggregators())', 'add aggregator to current column')
373+
@Sheet.api
374+
@asyncthread
375+
def addcol_aggregate(sheet, col, aggrnames):
376+
for aggrname in aggrnames:
377+
aggrs = vd.aggregators.get(aggrname)
378+
aggrs = aggrs if isinstance(aggrs, list) else [aggrs]
379+
if not aggrs: continue
380+
for aggr in aggrs:
381+
rows = aggregate_groups(sheet, col, sheet.rows, aggr)
382+
if isinstance(aggr, ListAggregator):
383+
t = aggr.listtype or col.type
384+
else:
385+
t = aggr.type or col.type
386+
c = SettableColumn(name=f'{col.name}_{aggr.name}', type=t)
387+
sheet.addColumnAtCursor(c)
388+
c.setValues(sheet.rows, *rows)
389+
390+
Sheet.addCommand('+', 'aggregate-col', 'addAggregators([cursorCol], chooseAggregators())', 'Add aggregator to current column')
300391
Sheet.addCommand('z+', 'memo-aggregate', 'cursorCol.memo_aggregate(chooseAggregators(), selectedRows or rows)', 'memo result of aggregator over values in selected rows for current column')
301392
ColumnsSheet.addCommand('g+', 'aggregate-cols', 'addAggregators(selectedRows or source[0].nonKeyVisibleCols, chooseAggregators())', 'add aggregators to selected source columns')
393+
Sheet.addCommand('', 'addcol-aggregate', 'addcol_aggregate(cursorCol, chooseAggregators(prompt="aggregator for groups: "))', 'add column(s) with aggregator of rows grouped by key columns')
394+
395+
vd.addGlobals(
396+
ListAggregator=ListAggregator
397+
)
302398

303399
vd.addMenuItems('''
304400
Column > Add aggregator > aggregate-col
401+
Column > Add column > aggregate > addcol-aggregate
305402
''')
403+

visidata/features/rank.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import itertools
2+
3+
from visidata import Sheet, ListAggregator, SettableColumn
4+
from visidata import vd, anytype, asyncthread
5+
6+
class RankAggregator(ListAggregator):
7+
'''
8+
Ranks start at 1, and each group's rank is 1 higher than the previous group.
9+
When elements are tied in ranking, each of them gets the same rank.
10+
'''
11+
def aggregate(self, col, rows) -> [int]:
12+
return self.aggregate_list(col, rows)
13+
14+
def aggregate_list(self, col, rows) -> [int]:
15+
if not col.sheet.keyCols:
16+
vd.error('ranking requires one or more key columns')
17+
return None
18+
return self.rank(col, rows)
19+
20+
def rank(self, col, rows):
21+
# compile row data, for each row a list of tuples: (group_key, rank_key, rownum)
22+
rowdata = [(col.sheet.rowkey(r), col.getTypedValue(r), rownum) for rownum, r in enumerate(rows)]
23+
# sort by row key and column value to prepare for grouping
24+
try:
25+
rowdata.sort()
26+
except TypeError as e:
27+
vd.fail(f'elements in a ranking column must be comparable: {e.args[0]}')
28+
rowvals = []
29+
#group by row key
30+
for _, group in itertools.groupby(rowdata, key=lambda v: v[0]):
31+
# within a group, the rows have already been sorted by col_val
32+
group = list(group)
33+
# rank each group individually
34+
group_ranks = rank_sorted_iterable([col_val for _, col_val, rownum in group])
35+
rowvals += [(rownum, rank) for (_, _, rownum), rank in zip(group, group_ranks)]
36+
# sort by unique rownum, to make rank results match the original row order
37+
rowvals.sort()
38+
rowvals = [ rank for rownum, rank in rowvals ]
39+
return rowvals
40+
41+
vd.aggregators['rank'] = RankAggregator('rank', anytype, helpstr='list of ranks, when grouping by key columns', listtype=int)
42+
43+
def rank_sorted_iterable(vals_sorted) -> [int]:
44+
'''*vals_sorted* is an iterable whose elements form one group.
45+
The iterable must already be sorted.'''
46+
47+
ranks = []
48+
val_groups = itertools.groupby(vals_sorted)
49+
for rank, (_, val_group) in enumerate(val_groups, 1):
50+
for _ in val_group:
51+
ranks.append(rank)
52+
return ranks
53+
54+
@Sheet.api
55+
@asyncthread
56+
def addcol_sheetrank(sheet, rows):
57+
'''
58+
Each row is ranked within its sheet. Rows are ordered by the
59+
value of their key columns.
60+
'''
61+
colname = f'{sheet.name}_sheetrank'
62+
c = SettableColumn(name=colname, type=int)
63+
sheet.addColumnAtCursor(c)
64+
if not sheet.keyCols:
65+
vd.error('ranking requires one or more key columns')
66+
return None
67+
rowkeys = [(sheet.rowkey(r), rownum) for rownum, r in enumerate(rows)]
68+
rowkeys.sort()
69+
ranks = rank_sorted_iterable([rowkey for rowkey, rownum in rowkeys])
70+
row_ranks = sorted(zip((rownum for _, rownum in rowkeys), ranks))
71+
row_ranks = [rank for rownum, rank in row_ranks]
72+
c.setValues(sheet.rows, *row_ranks)
73+
74+
Sheet.addCommand('', 'addcol-sheetrank', 'sheet.addcol_sheetrank(rows)', 'add column with the rank of each row based on its key columns')

visidata/tests/test_commands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def isTestableCommand(longname, cmdlist):
116116
'sheet': '',
117117
'col': 'Units',
118118
'row': '5',
119+
'addcol-aggregate': 'max',
119120
}
120121

121122
@pytest.mark.usefixtures('curses_setup')

0 commit comments

Comments
 (0)