Skip to content

Commit e1eefbb

Browse files
feat: make scatter plot legends clickable (#690)
* feat(wip): make scatter plot legends clickable wip: click changes the colors of points in an unexpected way * fix: build * fix: make legend ui more intuitive to the user * fix: address PR feedback + rectify logic to filter out values on legend click * chore: add comments * basic principle * fix: selectedpoints computation * chore: improve tooltips * chore: cleanup --------- Co-authored-by: Moritz Heckmann <[email protected]>
1 parent 51c005b commit e1eefbb

File tree

4 files changed

+114
-80
lines changed

4 files changed

+114
-80
lines changed

src/utils/indicesOf.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
export function indicesOf<T>(array: T[], predicate: (value: T, index: number) => boolean): number[] {
2+
const indices = new Array<number>(array.length);
3+
let count = 0;
4+
5+
for (let i = 0; i < array.length; i++) {
6+
if (predicate(array[i]!, i)) {
7+
indices[count++] = i;
8+
}
9+
}
10+
11+
indices.length = count;
12+
13+
return indices;
14+
}

src/vis/scatter/ScatterVis.tsx

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,17 @@ import { VIS_NEUTRAL_COLOR } from '../general/constants';
2626
import { EColumnTypes, ENumericalColorScaleType, EScatterSelectSettings, ICommonVisProps } from '../interfaces';
2727
import { BrushOptionButtons } from '../sidebar/BrushOptionButtons';
2828

29-
function Legend({ categories, colorMap, onClick }: { categories: string[]; colorMap: (v: number | string) => string; onClick: (string) => void }) {
29+
function Legend({
30+
categories,
31+
hiddenCategoriesSet,
32+
colorMap,
33+
onClick,
34+
}: {
35+
categories: string[];
36+
hiddenCategoriesSet?: Set<string>;
37+
colorMap: (v: number | string) => string;
38+
onClick: (category: string) => void;
39+
}) {
3040
return (
3141
<ScrollArea
3242
data-testid="PlotLegend"
@@ -40,9 +50,9 @@ function Legend({ categories, colorMap, onClick }: { categories: string[]; color
4050
`}
4151
>
4252
<Stack gap={0}>
43-
{categories.map((c) => {
44-
return <LegendItem key={c} color={colorMap(c)} label={c} onClick={() => onClick(c)} filtered={false} />;
45-
})}
53+
{categories.map((c) => (
54+
<LegendItem key={c} color={colorMap(c)} label={c} onClick={() => onClick(c)} filtered={hiddenCategoriesSet?.has(c) ?? false} />
55+
))}
4656
</Stack>
4757
</ScrollArea>
4858
);
@@ -79,6 +89,8 @@ export function ScatterVis({
7989
const [shiftPressed, setShiftPressed] = React.useState(false);
8090
const [showLegend, setShowLegend] = React.useState(false);
8191

92+
const [hiddenCategoriesSet, setHiddenCategoriesSet] = React.useState<Set<string>>(new Set<string>());
93+
8294
// const [ref, { width, height }] = useResizeObserver();
8395
const { ref, width, height } = useElementSize();
8496

@@ -130,10 +142,11 @@ export function ScatterVis({
130142
}
131143

132144
const { subplots, scatter, splom, facet, shapeScale } = useDataPreparation({
133-
value,
145+
hiddenCategoriesSet,
146+
numColorScaleType: config.numColorScaleType,
134147
status,
135148
uniqueSymbols,
136-
numColorScaleType: config.numColorScaleType,
149+
value,
137150
});
138151

139152
const regressions = React.useMemo<{
@@ -295,49 +308,6 @@ export function ScatterVis({
295308
return undefined;
296309
}
297310

298-
/* const legendPlots: PlotlyTypes.Data[] = [];
299-
300-
if (value.shapeColumn) {
301-
legendPlots.push({
302-
x: [null],
303-
y: [null],
304-
type: 'scatter',
305-
mode: 'markers',
306-
showlegend: true,
307-
legendgroup: 'shape',
308-
hoverinfo: 'all',
309-
310-
hoverlabel: {
311-
namelength: 10,
312-
bgcolor: 'black',
313-
align: 'left',
314-
bordercolor: 'black',
315-
},
316-
// @ts-ignore
317-
legendgrouptitle: {
318-
text: truncateText(value.shapeColumn.info.name, true, 20),
319-
},
320-
marker: {
321-
line: {
322-
width: 0,
323-
},
324-
symbol: value.shapeColumn ? value.shapeColumn.resolvedValues.map((v) => shapeScale(v.val as string)) : 'circle',
325-
color: VIS_NEUTRAL_COLOR,
326-
},
327-
transforms: [
328-
{
329-
type: 'groupby',
330-
groups: value.shapeColumn.resolvedValues.map((v) => getLabelOrUnknown(v.val)),
331-
styles: [
332-
...[...new Set<string>(value.shapeColumn.resolvedValues.map((v) => getLabelOrUnknown(v.val)))].map((c) => {
333-
return { target: c, value: { name: c } };
334-
}),
335-
],
336-
},
337-
],
338-
});
339-
}
340-
*/
341311
if (value.colorColumn && value.colorColumn.type === EColumnTypes.CATEGORICAL) {
342312
// Get distinct values
343313
const colorValues = uniq(value.colorColumn.resolvedValues.map((v) => v.val ?? 'Unknown') as string[]);
@@ -529,7 +499,7 @@ export function ScatterVis({
529499
}
530500

531501
if (scatter) {
532-
const ids = event.points.map((point) => scatter.ids[point.pointIndex]) as string[];
502+
const ids = event.points.map((point) => scatter.ids[scatter.filter[point.pointIndex]!]) as string[];
533503
mergeIntoSelection(ids);
534504
}
535505

@@ -570,7 +540,22 @@ export function ScatterVis({
570540

571541
{status === 'success' && layout && legendData?.color.mapping && showLegend ? (
572542
<div style={{ gridArea: 'legend', overflow: 'hidden' }}>
573-
<Legend categories={legendData.color.categories} colorMap={legendData.color.mappingFunction} onClick={() => {}} />
543+
<Legend
544+
categories={legendData.color.categories}
545+
colorMap={legendData.color.mappingFunction}
546+
hiddenCategoriesSet={hiddenCategoriesSet}
547+
onClick={(category: string) => {
548+
setHiddenCategoriesSet((prevSet) => {
549+
const newSet = new Set(prevSet);
550+
if (newSet.has(category)) {
551+
newSet.delete(category);
552+
} else {
553+
newSet.add(category);
554+
}
555+
return newSet;
556+
});
557+
}}
558+
/>
574559
</div>
575560
) : null}
576561
</div>

src/vis/scatter/useData.tsx

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,35 +119,54 @@ export function useData({
119119
{
120120
...BASE_DATA,
121121
type: 'scattergl',
122-
x: scatter.plotlyData.x,
123-
y: scatter.plotlyData.y,
122+
x: scatter.filter.map((index) => scatter.plotlyData.x[index]),
123+
y: scatter.filter.map((index) => scatter.plotlyData.y[index]),
124124
// text: scatter.plotlyData.text,
125-
textposition: scatter.plotlyData.text.map((_, i) => textPositionOptions[i % textPositionOptions.length]),
126-
...(isEmpty(selectedSet) ? {} : { selectedpoints: selectedList.map((idx) => scatter.idToIndex.get(idx)) }),
125+
textposition: scatter.filter.map((index) => textPositionOptions[index % textPositionOptions.length]),
126+
...(isEmpty(selectedSet)
127+
? {}
128+
: { selectedpoints: selectedList.map((idx) => scatter.idToIndex.get(idx)).filter((v) => v !== undefined && v !== null) }),
127129
mode: config.showLabels === ELabelingOptions.NEVER || config.xAxisScale === 'log' || config.yAxisScale === 'log' ? 'markers' : 'text+markers',
128130
...(config.showLabels === ELabelingOptions.NEVER
129131
? {}
130132
: config.showLabels === ELabelingOptions.ALWAYS
131133
? {
132-
text: scatter.plotlyData.text.map((t) => truncateText(value.idToLabelMapper(t), true, 10)),
133-
// textposition: 'top center',
134+
text: scatter.filter.map((index) => truncateText(value.idToLabelMapper(scatter.plotlyData.text[index]!), true, 10)),
134135
}
135136
: {
136-
text: scatter.plotlyData.text.map((t, i) => (visibleLabelsSet.has(scatter.ids[i]!) ? truncateText(value.idToLabelMapper(t), true, 10) : '')),
137-
// textposition: 'top center',
137+
text: scatter.filter.map((index, i) =>
138+
visibleLabelsSet.has(scatter.ids[index]!)
139+
? truncateText(value.idToLabelMapper(value.idToLabelMapper(scatter.plotlyData.text[index]!)), true, 10)
140+
: '',
141+
),
138142
}),
139-
hovertext: value.validColumns[0].resolvedValues.map((v, i) =>
140-
`${value.idToLabelMapper(v.id)}
141-
${(value.resolvedLabelColumns ?? []).map((l) => `<br />${columnNameWithDescription(l.info)}: ${getLabelOrUnknown(l.resolvedValues[i]?.val)}`)}
142-
${value.colorColumn ? `<br />${columnNameWithDescription(value.colorColumn.info)}: ${getLabelOrUnknown(value.colorColumn.resolvedValues[i]?.val)}` : ''}
143-
${value.shapeColumn && value.shapeColumn.info.id !== value.colorColumn?.info.id ? `<br />${columnNameWithDescription(value.shapeColumn.info)}: ${getLabelOrUnknown(value.shapeColumn.resolvedValues[i]?.val)}` : ''}`.trim(),
144-
),
143+
hovertext: scatter.filter.map((i) => {
144+
const resolvedLabelString =
145+
value.resolvedLabelColumns?.length > 0
146+
? value.resolvedLabelColumns.map((l) => `<b>${columnNameWithDescription(l.info)}</b>: ${getLabelOrUnknown(l.resolvedValues[i]?.val)}<br />`)
147+
: '';
148+
const idString = `<b>${value.idToLabelMapper(scatter.plotlyData.text[i]!)}</b><br />`;
149+
const xString = `<b>${columnNameWithDescription(value.validColumns[0]!.info)}</b>: ${getLabelOrUnknown(value.validColumns[0]!.resolvedValues[i]?.val)}<br />`;
150+
const yString = `<b>${columnNameWithDescription(value.validColumns[1]!.info)}</b>: ${getLabelOrUnknown(value.validColumns[1]!.resolvedValues[i]?.val)}<br />`;
151+
const colorColumnString = value.colorColumn
152+
? `<b>${columnNameWithDescription(value.colorColumn.info)}</b>: ${getLabelOrUnknown(value.colorColumn.resolvedValues[i]?.val)}<br />`
153+
: '';
154+
const shapeColumnString =
155+
value.shapeColumn && value.shapeColumn.info.id !== value.colorColumn?.info.id
156+
? `<b>${columnNameWithDescription(value.shapeColumn.info)}</b>: ${getLabelOrUnknown(value.shapeColumn.resolvedValues[i]?.val)}<br />`
157+
: '';
158+
159+
return `${idString}${xString}${yString}${resolvedLabelString}${colorColumnString}${shapeColumnString}`;
160+
}),
145161
marker: {
146162
textfont: {
147163
color: VIS_NEUTRAL_COLOR,
148164
},
149165
size: 8,
150-
color: value.colorColumn && mappingFunction ? value.colorColumn.resolvedValues.map((v) => mappingFunction(v.val)) : VIS_NEUTRAL_COLOR,
166+
color:
167+
value.colorColumn && mappingFunction
168+
? scatter.filter.map((index) => mappingFunction(value.colorColumn.resolvedValues[index]!.val as string))
169+
: VIS_NEUTRAL_COLOR,
151170
symbol: value.shapeColumn ? value.shapeColumn.resolvedValues.map((v) => shapeScale(v.val as string)) : 'circle',
152171
opacity: fullOpacityOrAlpha,
153172
},

src/vis/scatter/useDataPreparation.ts

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ import * as React from 'react';
33
import * as d3v7 from 'd3v7';
44
import groupBy from 'lodash/groupBy';
55
import isFinite from 'lodash/isFinite';
6+
import range from 'lodash/range';
67
import sortBy from 'lodash/sortBy';
78

8-
import { FetchColumnDataResult } from './utils';
99
import { PlotlyTypes } from '../../plotly';
10+
import { NAN_REPLACEMENT } from '../general';
1011
import { columnNameWithDescription } from '../general/layoutUtils';
1112
import { ENumericalColorScaleType } from '../interfaces';
13+
import { FetchColumnDataResult } from './utils';
14+
import { indicesOf } from '../../utils/indicesOf';
1215

1316
function getStretchedDomains(x: number[], y: number[]) {
1417
let xDomain = d3v7.extent(x);
@@ -26,15 +29,17 @@ function getStretchedDomains(x: number[], y: number[]) {
2629
}
2730

2831
export function useDataPreparation({
32+
hiddenCategoriesSet,
33+
numColorScaleType,
2934
status,
30-
value,
3135
uniqueSymbols,
32-
numColorScaleType,
36+
value,
3337
}: {
38+
hiddenCategoriesSet?: Set<string>;
39+
numColorScaleType: ENumericalColorScaleType;
3440
status: string;
35-
value: FetchColumnDataResult;
3641
uniqueSymbols: string[];
37-
numColorScaleType: ENumericalColorScaleType;
42+
value: FetchColumnDataResult | null;
3843
}) {
3944
const subplots = React.useMemo(() => {
4045
if (!(status === 'success' && value.subplots && value.subplots.length > 0 && value.subplots[0])) {
@@ -78,37 +83,48 @@ export function useDataPreparation({
7883
return undefined;
7984
}
8085

86+
const filter =
87+
value.colorColumn && hiddenCategoriesSet
88+
? indicesOf(
89+
value.colorColumn.resolvedValues,
90+
(e, index) =>
91+
!hiddenCategoriesSet.has((e.val ?? NAN_REPLACEMENT) as string) &&
92+
isFinite(value.validColumns[0]!.resolvedValues[index]!.val) &&
93+
isFinite(value.validColumns[1]!.resolvedValues[index]!.val),
94+
)
95+
: range(value.validColumns[0].resolvedValues.length);
96+
const ids = value.validColumns[0].resolvedValues.map((v) => v.id);
97+
8198
// Get shared range for all plots
8299
const { xDomain, yDomain } = getStretchedDomains(
83100
value.validColumns[0].resolvedValues.map((v) => v.val as number),
84101
value.validColumns[1].resolvedValues.map((v) => v.val as number),
85102
);
86103

87-
const ids = value.validColumns[0].resolvedValues.map((v) => v.id);
104+
const x = value.validColumns[0].resolvedValues.map((v) => v.val as number);
105+
const y = value.validColumns[1].resolvedValues.map((v) => v.val as number);
88106

89107
const idToIndex = new Map<string, number>();
90-
ids.forEach((v, i) => {
91-
idToIndex.set(v, i);
108+
filter.forEach((v, i) => {
109+
idToIndex.set(ids[v]!, i);
92110
});
93111

94-
const x = value.validColumns[0].resolvedValues.map((v) => v.val as number);
95-
const y = value.validColumns[1].resolvedValues.map((v) => v.val as number);
96-
97112
return {
98113
plotlyData: {
99-
validIndices: x.map((_, i) => (isFinite(x[i]) && isFinite(y[i]) ? i : null)).filter((i) => i !== null) as number[],
114+
validIndices: filter,
100115
x,
101116
y,
102-
text: value.validColumns[0].resolvedValues.map((v) => v.id),
117+
text: ids,
103118
},
119+
filter,
104120
ids,
105121
xDomain,
106122
yDomain,
107123
xLabel: columnNameWithDescription(value.validColumns[0].info),
108124
yLabel: columnNameWithDescription(value.validColumns[1].info),
109125
idToIndex,
110126
};
111-
}, [status, value]);
127+
}, [status, value, hiddenCategoriesSet]);
112128

113129
// Case when we have a scatterplot matrix
114130
const splom = React.useMemo(() => {

0 commit comments

Comments
 (0)