Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE-REQUEST] Suggested implementation for scatter plot of large data #2391

Open
stianke opened this issue Sep 1, 2023 · 0 comments
Open

Comments

@stianke
Copy link

stianke commented Sep 1, 2023

Hi! I had an issue similar to #653, in that I needed to plot a large 1D data set over time. Something like df.viz.scatter(), but able to handle large data sets. I want to see potential single outliers in the data, so straight downsampling was not an option. After some experimentation, I found a simple way to do this using the capabilities provided by vaex.
I am asking you to know if you would be interested in adding native support for something like this.

The idea is to provide matplotlib with a rasterized image of the scatter plot instead of all the data points. All data points that fall into the same screen pixel can be treated as one. I suppose this is a kind of downsampling, but in a way that doesn't affect the scatter plot. In practice, I achieve this by using vaex.dataframe.DataFrame.count() to generate a heatmap of the counts, and then limit the value of all nonzero bins to 1. This effectively generates a black/white scatter plot where each data point is one pixel large. We can then apply a marker to the scatter with scipy.ndimage.grey_dilation(), and set the marker color. I have provided a code example below.

import vaex
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy

df = vaex.open('data_8M_samples.hdf5') # Open Data

fig = plt.figure(1) # Create figure
ax = fig.subplots()
plt.grid('on') # Enable gridlines

# Fetch size of plot area in pixels
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
plot_width_pixels = round(bbox.width * fig.dpi)
plot_height_pixels = round(bbox.height * fig.dpi)
xlims = df['time'].minmax()
ylims = df['data_to_plot'].minmax()

# Make a heatmap of the data counts with bins equal to the screen resolution of the figure
counts = df.count(None,
                  binby=[df['time'], df['data_1']],
                  shape=[plot_width_pixels, plot_height_pixels],
                  limits=[xlims, ylims])
monochrome_scatter_plot = np.minimum(counts, 1)
monochrome_scatter_plot = np.rot90(monochrome_scatter_plot)

# Create the marker that represents a data point
marker_radius = 5
marker_color = [0, 0.4470, 0.7410]
xx, yy = np.mgrid[-marker_radius:marker_radius+1, -marker_radius:marker_radius+1]
footprint = np.logical_or(xx-yy == 0, xx+yy == 0) # This particular marker is for a cross

# Apply the marker to the scatter plot
monochrome_scatter_plot_with_markers = scipy.ndimage.grey_dilation(monochrome_scatter_plot, footprint=footprint)

# Convert monochrome to RGB with alpha channel
color_scatter_plot = np.stack([monochrome_scatter_plot_with_markers * marker_color[0], # Red channel
                               monochrome_scatter_plot_with_markers * marker_color[1], # Green channel
                               monochrome_scatter_plot_with_markers * marker_color[2], # Blue channel
                               monochrome_scatter_plot_with_markers], axis=2) # Alpha channel
plt.imshow(color_scatter_plot, extent=[xlims[0], xlims[1], ylims[0], ylims[1]], aspect='auto')
plt.show(block=True)

This example is non-interactive, but some modifications can add full interactivity similar to the interactive widgets in Jupyter notebook. I don't use notebooks, so I made an interactive version in a traditional python script to test the concept by recalculating the counts whenever zooming/panning/resizing the window (source code below). This also supports showing multiple data series in the same scatter plot due to the alpha channel of the rasterized image.
Legends could be added manually, but I have not done so here. Lines between the data points is not possible, though, as there could be data points outside the plot area that we have not accounted for.

scatter_demo

Source code for the interactive version:

import vaex
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy


def main():
    df = vaex.open('data_8M_samples.hdf5')
    df.my_viz.my_scatter(df['time'], df['data_1'], PlotMarker(shape='filled-circle', radius=5, color=[0, 0.4470, 0.7410]))
    df.my_viz.my_scatter(df['time'], df['data_2'], PlotMarker(shape='cross', radius=5, color=[0.8500, 0.3250, 0.0980]))
    plt.grid('on')
    plt.title('No downsampling!')
    plt.show(block=True)

# Marker used to represent a data point in the scatter plot
class PlotMarker:
    def __init__(self, shape='filled-circle', radius=5, color=None):
        if color is None:
            color = [0, 0.4470, 0.7410]
        self.shape = shape
        self.radius = radius
        self.color = color

# Custom interactive scatter plot
@vaex.register_dataframe_accessor('my_viz', override=True)
class ScatterPlot(object):
    def __init__(self, df):
        self.df = df

    def my_scatter(self, x, y, marker=PlotMarker()):
        # Get data limits
        x_lims = x.minmax()
        y_lims = y.minmax()

        # get axis limits
        ax = plt.gca()
        fig = ax.figure
        if len(ax.get_images()) == 0:
            # Zoom slightly out on x-axis, to ensure all data is easily visible
            ylim_range = (y_lims[1] - y_lims[0])
            y_lims = [y_lims[0] - 0.05 * ylim_range, y_lims[1] + 0.1 * ylim_range]
        else:
            # If another scatter is already plotted, zoom out if neccesary, but don't zoom in
            ylim_range = max(y_lims[1], ax.get_ylim()[1]) - min(y_lims[0], ax.get_ylim()[0])
            y_lims[0] = min(y_lims[0] - 0.05 * ylim_range, ax.get_ylim()[0])
            y_lims[1] = max(y_lims[1] + 0.05 * ylim_range, ax.get_ylim()[1])

        ax.set_xlim(x_lims)
        ax.set_ylim(y_lims)

        im, panning = None, None
        def update_plot(_=None):
            nonlocal im, panning

            if panning: # Panning will result in a constant stream of callbacks. Updating each time is laggy.
                return

            # Fetch size of plot area in pixels
            bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
            [width_pixels, height_pixels] = [round(bbox.width * fig.dpi), round(bbox.height * fig.dpi)]

            # Get axis limits to calculate scatter plot with
            [xlims, ylims] = [ax.get_xlim(), ax.get_ylim()]

            # Make a heatmap of the data counts with bins equal to the screen resolution of the figure
            counts = self.df.count(None, binby=[x, y], shape=[width_pixels, height_pixels], limits=[xlims, ylims])
            color_plot = _make_scatter_plot_image(counts, marker)

            # Show image
            if im is None:
                im = plt.imshow(color_plot, extent=[xlims[0], xlims[1], ylims[0], ylims[1]], aspect='auto')
            else:
                # When refreshing the plot, update the old image
                im.set(data=color_plot, extent=[xlims[0], xlims[1], ylims[0], ylims[1]])
            ax.figure.canvas.draw()

        update_plot()

        ax.callbacks.connect('xlim_changed', update_plot)
        ax.callbacks.connect('ylim_changed', update_plot)
        fig.canvas.mpl_connect('resize_event', update_plot)

        # When panning the view, the constant callbacks to update_plot cause lots of lag. Oly update when panning is finished
        def panning_started(_=None):
            nonlocal panning
            panning = True

        def panning_stopped(_=None):
            nonlocal panning
            panning = False
            update_plot()
        fig.canvas.mpl_connect('button_press_event', panning_started)
        fig.canvas.mpl_connect('button_release_event', panning_stopped)


def _make_scatter_plot_image(counts, marker):
    xx, yy = np.mgrid[-marker.radius:marker.radius+1, -marker.radius:marker.radius+1]
    if marker.shape == 'filled-circle': # Circle (filled in)
        footprint = xx**2 + yy**2 < (marker.radius+0.5) ** 2
    elif marker.shape == 'hollow-circle': # Circle (not filled in)
        footprint = np.logical_and((marker.radius-0.5) ** 2 < xx**2 + yy**2, xx**2 + yy**2 < (marker.radius+0.5) ** 2)
    elif marker.shape == 'filled-square': # Square (filled in)
        footprint = np.ones(shape=xx.shape)
    elif marker.shape == 'hollow-square': # Square (not filled in)
        footprint = np.logical_or(np.logical_or(xx == marker.radius, xx == -marker.radius), np.logical_or(yy == marker.radius, yy == -marker.radius))
    elif marker.shape == 'cross': # Cross
        footprint = np.logical_or(xx-yy == 0, xx+yy == 0)
    else:
        raise Exception(f'Marker {marker.shape} in make_plot_image not recognized')

    monochrome_scatter_plot = np.minimum(counts, 1)
    monochrome_scatter_plot = np.rot90(monochrome_scatter_plot)
    monochrome_scatter_plot_with_markers = scipy.ndimage.grey_dilation(monochrome_scatter_plot, footprint=footprint)
    color_plot = np.stack([monochrome_scatter_plot_with_markers * marker.color[0],
                           monochrome_scatter_plot_with_markers * marker.color[1],
                           monochrome_scatter_plot_with_markers * marker.color[2],
                           monochrome_scatter_plot_with_markers], axis=2)
    return color_plot


if __name__ == '__main__':
    main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant