Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally deep clone run outputs #1142

Open
dmadisetti opened this issue Apr 16, 2024 · 14 comments
Open

Optionally deep clone run outputs #1142

dmadisetti opened this issue Apr 16, 2024 · 14 comments
Labels
bug Something isn't working

Comments

@dmadisetti
Copy link
Contributor

dmadisetti commented Apr 16, 2024

Describe the bug

I noticed when I did the qmd export for quarto that plots don't properly work:

This doesn't affect the reactive notebook since the plot object is rendered to HTML, essentially creating a snapshot, and then the next cell is executed. However, in script mode the output is just a pointer to the current global state

Suggestion: I think deep copies can be expensive, so maybe an optional deep copy?

Environment

3.12-dev (I haven't pulled in a few days but I don't think you've made any changed that would affect this)

Code to reproduce

Relevant export test:

 
import matplotlib.pyplot as plt
 
if __name__ == "__main__":
    objs, _ = app.run()
    for i, o in enumerate(objs):
        # Unpack the Line2d cases
        if isinstance(o, list):
          o, = o
        # Quick hack for testing
        if "matplotlib" in str(type(o)):
          o.get_figure().savefig(f"{i}.png")
for completion the [plots.py](http://plots.py) tutorial
# Copyright 2024 Marimo. All rights reserved.
import marimo

__generated_with = "0.1.69"
app = marimo.App()


@app.cell(hide_code=True)
def __(mo):
    mo.md("# Plotting")
    return


@app.cell(hide_code=True)
def __(check_dependencies):
    check_dependencies()
    return


@app.cell(hide_code=True)
def __(mo):
    mo.md(
        """
        marimo supports several popular plotting libraries, including matplotlib,
        plotly, seaborn, and altair. 

        This tutorial gives examples using matplotlib; other libraries are
        used similarly.
        """
    )
    return


@app.cell(hide_code=True)
def __(mo):
    mo.md("## Matplotlib")
    return


@app.cell(hide_code=True)
def __(mo):
    mo.md(
        """
        To show a plot, include it in the last expression of a cell (just
        like any other output).

        ```python3
        # create the plot in the last line of the cell
        import matplotlib.pyplot as plt
        plt.plot([1, 2])
        ```
        """
    )
    return
 
 
@app.cell
def __(plt):
    plt.plot([1, 2])
    return
 
 
@app.cell(hide_code=True)
def __(mo):
    mo.md(
        """
        ```python3
        # create a plot
        plt.plot([1, 2])
        # ... do some work ...
        # make plt.gca() the last line of the cell
        plt.gca()
        ```
        """
    )
    return


@app.cell
def __(plt):
    plt.plot([1, 2])
    # ... do some work ...
    # make plt.gca() the last line of the cell
    plt.gca()
    return


@app.cell(hide_code=True)
def __(mo, plt_show_explainer):
    mo.accordion(plt_show_explainer)
    return


@app.cell(hide_code=True)
def __(mo):
    mo.md(
        """
        **A new figure every cell.** Every cell starts with an empty figure for 
        the imperative `pyplot` API.
        """
    )
    return


@app.cell
def __(np):
    x = np.linspace(start=-4, stop=4, num=100, dtype=float)
    return x,


@app.cell
def __(plt, x):
    plt.plot(x, x)
    plt.plot(x, x**2)
    plt.gca()
    return


@app.cell
def __(plt, x):
    plt.plot(x, x**3)
    return


@app.cell(hide_code=True)
def __(mo):
    mo.md(
        """
        To build a figure over multiple cells, use the object-oriented API and
        create your own axis:
        """
    )
    return


@app.cell
def __(plt, x):
    _, axis = plt.subplots()
    axis.plot(x, x)
    axis.plot(x, x**2)
    axis
    return axis,


@app.cell
def __(axis, x):
    axis.plot(x, x**3)
    axis
    return


@app.cell(hide_code=True)
def __(mo):
    mo.md(
        """
        ### Draw plots interactively

        Draw plots interactively by parametrizing them with UI elements.
        """
    )
    return


@app.cell
def __(mo):
    exponent = mo.ui.slider(1, 5, value=1, step=1, label='exponent')

    mo.md(
        f"""
        **Visualizing powers.**

        {exponent}
        """
    )
    return exponent,


@app.cell
def __(exponent, mo, plt, x):
    import functools


    @functools.cache
    def _plot(exponent):
        plt.plot(x, x**exponent)
        return plt.gca()


    _tex = (
        f"$$f(x) = x^{exponent.value}$$" if exponent.value > 1 else "$$f(x) = x$$"
    )

    mo.md(
        f"""

        {_tex}

        {mo.as_html(_plot(exponent.value))}
        """
    )
    return functools,


@app.cell(hide_code=True)
def __(mo):
    mo.md("## Other libraries")
    return


@app.cell(hide_code=True)
def __(mo):
    mo.md(
        """
        marimo also supports these other plotting libraries:

        - Plotly
        - Seaborn
        - Altair

        Just output their figure objects as the last expression of a cell,
        or embed them in markdown with `mo.as_html`.

        If you would like another library to be integrated into marimo, please
        get in touch.
        """
    )
    return


@app.cell(hide_code=True)
def __(missing_packages, mo):
    module_not_found_explainer = mo.md(
        """
        ## Oops!

        It looks like you're missing a package that this tutorial 
        requires.

        Close marimo, install **`numpy`** and **`matplotlib`**, then 
        open this tutorial once more.

        If you use `pip`, run

        ```
        pip install numpy matplotlib
        ```

        at your command line.
        """
    ).callout(kind='warn')

    def check_dependencies():
        if missing_packages:
            return module_not_found_explainer
    return check_dependencies, module_not_found_explainer


@app.cell(hide_code=True)
def __():
    plt_show_explainer = {
        "Using `plt.show()`": """
        You can use `plt.show()` or `figure.show()` to display
        plots in the console area of a cell. Keep in mind that console
        outputs are not shown in the app view.
        """
    }
    return plt_show_explainer,


@app.cell
def __():
    try:
        import matplotlib
        import matplotlib.pyplot as plt
        import numpy as np
        missing_packages = False
    except ModuleNotFoundError:
        missing_packages = True

    if not missing_packages:
        matplotlib.rcParams['figure.figsize'] = (6, 2.4)
    return matplotlib, missing_packages, np, plt


@app.cell
def __():
    import marimo as mo
    return mo,

import matplotlib.pyplot as plt

if __name__ == "__main__":
    objs, _ = app.run()
    for i, o in enumerate(objs):
        # Unpack the Line2d cases
        if isinstance(o, list):
          o, = o
        # Quick hack for testing
        if "matplotlib" in str(type(o)):
          o.get_figure().savefig(f"{i}.png")

produces: 11.png 12.png 14.png 15.png 5.png 7.png

where 5 and 7 are the same, and the rest are the same
This isn't the same behavior as in the export, but similar, and still incorrect

5

7

11

12

14

15

@dmadisetti dmadisetti added the bug Something isn't working label Apr 16, 2024
@mscolnick
Copy link
Contributor

mscolnick commented Apr 24, 2024

We do some fancy-ish things in marimo/_output/mpl.py

which is called twice:

I wonder if this is only run in edit mode (and not as a script) that you are seeing this issue. and possible we need to clear the plots in run mode. @akshayka any thoughts?

@dmadisetti
Copy link
Contributor Author

This does not happen with the way islands generates cells, but I just checked behavior with the script and I now have this error?

Traceback (most recent call last):
  File "/home/dylan/test.py", line 293, in <module>
    objs, _ = app.run()
              ^^^^^^^^^
  File "/home/dylan/src/marimo/marimo/_ast/app.py", line 335, in run
    return self._run_sync()
           ^^^^^^^^^^^^^^^^
  File "/home/dylan/src/marimo/marimo/_ast/app.py", line 267, in _run_sync
    outputs[cid] = execute_cell(cell._cell, glbls)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dylan/src/marimo/marimo/_ast/cell.py", line 444, in execute_cell
    exec(cell.body, glbls)
  File "/tmp/marimo_3443916/__marimo__cell_TqIu_.py", line 11, in <module>
  File "/home/dylan/src/marimo/marimo/_plugins/ui/_core/ui_element.py", line 267, in value
    == ctx.ui_element_registry.get_cell(self._id)
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dylan/src/marimo/marimo/_plugins/ui/_core/registry.py", line 200, in get_cell
    raise NotImplementedError
NotImplementedError

@mscolnick
Copy link
Contributor

@dmadisetti - this is likely from #1150. do you have a small snippet of the code to reproduce?

@dmadisetti
Copy link
Contributor Author

Not exactly the same error:

Traceback (most recent call last):                                                                                                                                                                                                
  File "/home/dylan/1086.py", line 21, in <module>                                                                                                                                                                                
    app.run()                                                                                                                                                                                                                     
  File "/home/dylan/src/marimo/marimo/_ast/app.py", line 335, in run                                                                                                                                                              
    return self._run_sync()                                                                                                                                                                                                       
           ^^^^^^^^^^^^^^^^                                                                                                                                                                                                       
  File "/home/dylan/src/marimo/marimo/_ast/app.py", line 267, in _run_sync                                                                                                                                                        
    outputs[cid] = execute_cell(cell._cell, glbls)                                                                                                                                                                                
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                
  File "/home/dylan/src/marimo/marimo/_ast/cell.py", line 445, in execute_cell                                                                                                                                                    
    return eval(cell.last_expr, glbls)                                                                                                                                                                                            
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                            
  File "/tmp/marimo_3531073/__marimo__cell_Hbol__output.py", line 7, in <module>                                                                                                                                                  
  File "/home/dylan/downloads/manim/lib/python3.11/site-packages/matplotlib/pyplot.py", line 527, in show                                                                                                                         
    return _get_backend_mod().show(*args, **kwargs)                                                                                                                                                                               
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                               
  File "/home/dylan/src/marimo/marimo/_output/mpl.py", line 64, in show                                                                                                                                                           
    _internal_show(manager.canvas)                                                                                                                                                                                                
  File "/home/dylan/src/marimo/marimo/_output/mpl.py", line 47, in _internal_show                                                                                                                                                 
    CellOp.broadcast_console_output(                                                                                                                                                                                              
  File "/home/dylan/src/marimo/marimo/_messaging/ops.py", line 179, in broadcast_console_output                                                                                                                                   
    assert cell_id is not None                                                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                    
AssertionError   

might be worth adding this as a unit test: https://github.com/marimo-team/marimo/blob/main/marimo/_smoke_tests/bugs/1086.py

I'll see if I can boil it down to something smaller, but the first stack track is just plots.py from the tutorial with the full source and initial output in the original summary

@mscolnick
Copy link
Contributor

1086.py should be good enough to work from - and we can add a unit test for this, thanks

@mscolnick
Copy link
Contributor

@akshayka just squashed 3 related bugs. can you give it another shot with your own code? (we have it running again our tutorials now)

@dmadisetti
Copy link
Contributor Author

Yes! This fixes the crash- this also fixes most of the redundant plots in tutorial. However, the shared axis object still has a global response:

14

@app.cell  # cell 14                                                                                                                                                                                                                           
def __(plt, x):                                                                                                                                                                                                                     
    _, axis = plt.subplots()                                                                                                                                                                                                        
    axis.plot(x, x)                                                                                                                                                                                                                 
    axis.plot(x, x**2)                                                                                                                                                                                                              
    axis                                                                                                                                                                                                                            
    return axis,                                                                                                                                                                                                                    

15

@app.cell # cell 15
def __(axis, x):                                                                                                                                                                                                                    
    axis.plot(x, x**3)                                                                                                                                                                                                              
    axis                                                                                                                                                                                                                            
    return   

@dmadisetti
Copy link
Contributor Author

This is actually a bug beyond just scripts now:

image

run order will change the output of this. Correct behavior should be

plot 1: x, x^2 
plot 2: x, x^2, x^3
+ plot 3: x, x^2, x^4
- plot 3: x, x^2, x^3, x^4  

@dmadisetti
Copy link
Contributor Author

OR, have the script response as being correct

- plot 1: x, x^2 
- plot 2: x, x^2, x^3
+ plot 1: x, x^2, x^3, x^4
+ plot 2: x, x^2, x^3, x^4
plot 3: x, x^2, x^3, x^4

code:

import marimo                                                                                                                                                                                                                       
                                                                                                                                                                                                                                    
app = marimo.App()                                                                                                                                                                                                                  
                                                                                                                                                                                                                                    
@app.cell                                                                                                                                                                                                                           
def __(np):                                                                                                                                                                                                                         
    x = np.linspace(start=-4, stop=4, num=100, dtype=float)                                                                                                                                                                         
    return x,    

def __(plt, x):
    _, axis = plt.subplots()
    axis.plot(x, x)
    axis.plot(x, x**2)
    axis
    return axis,


@app.cell
def __(axis, x):
    axis.plot(x, x**3)
    axis
    return


@app.cell
def __(axis, x):
    axis.plot(x, x**4)
    axis
    return


@app.cell
def __():
    import marimo as mo
    try:
        import matplotlib
        import matplotlib.pyplot as plt
        import numpy as np
        missing_packages = False
    except ModuleNotFoundError:
        missing_packages = True

    if not missing_packages:
        matplotlib.rcParams['figure.figsize'] = (6, 2.4)
    return mo, matplotlib, missing_packages, np, plt

import matplotlib.pyplot as plt

if __name__ == "__main__":
    objs, v = app.run()
    print(v)
    for i, o in enumerate(objs):
        # Unpack the Line2d cases
        if isinstance(o, list):
          o, = o
        # Quick hack for testing
        if "matplotlib" in str(type(o)):
          o.get_figure().savefig(f"{i}.png")

@akshayka
Copy link
Contributor

run order will change the output of this. Correct behavior should be ...
plot 1: x, x^2
plot 2: x, x^2, x^3

  • plot 3: x, x^2, x^4
  • plot 3: x, x^2, x^3, x^4

Don't they all share the same axis object though? So it makes sense that plot 3 has x^3.

OR, have the script response as being correct

Okay, I now see how providing a fresh figure isn't sufficient. Thanks for the example. This I suppose isn't specific to plots but really to any data structure that is mutated over and output in multiple cells.

  1. For the particular use case of integrating with Quarto or exporting to HTML — I think islands / marimo export html will now do the right thing, right?
  2. We could still consider adding an optional deep clone to app.run() — but I wonder if there's a better solution. Perhaps we could register a _mime_ method on each output that returned the snapshotted HTML?

@dmadisetti
Copy link
Contributor Author

Correct is relative to design, I think both cases are valid- but ideally you'd want consistency

  1. Yes, export and islands run first and then access output, so Case 2 is what happens—but the result is still cell order dependent particular to matplotlib z-order will change.

  2. I wonder what the performance hit would actually be. Maybe it's fine—premature optimization and all that. Alternatively, what about creating a check to see if a known data structure is in defs and just clones those specific object types?

global hidden state is really hard to get around- the data flow tutorial makes it pretty clear that objects are hard to handle, so this could just be covered under that disclaimer

Just an aside, dataflow expectations for parallelism are definitely Case 1

@akshayka
Copy link
Contributor

Correct is relative to design, I think both cases are valid- but ideally you'd want consistency

Yeah, consistency is key.

Export snapshots while running, so I think you get the same semantics as when running as an notebook (at least that's what happened when I ran it just now). But like you say the more important thing is to define the semantics and ensure consistency across different ways of running.

Just an aside, dataflow expectations for parallelism are definitely Case 1

To double check -- what is Case 1? :)

@dmadisetti
Copy link
Contributor Author

Oh woops. To be explicit:

Case 1

graph TD;
    cell1("Cell 1\nShows x, x^2")-->cell2("Cell 2\nShows x, x^2, x^3");
    cell1-->cell3("Cell 3\nShows x, x^2, x^4");

Which can be made parallel because the tasks are independent

Case 2

graph TD;
    cell1("Cell 1\nShows x, x^2, x^3, x^4")-->cell2("Cell 2\nShows x, x^2, x^3, x^4");
    cell3("Cell 3\nShows x, x^2, x^3, x^4");
    cell2 --"order takes precedence in mutating axis"-->cell3;

Case 2 has a hidden dependency since the global object is built up.

@akshayka
Copy link
Contributor

Okay great, yes — that makes total sense. On the same page.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants