diff --git a/app.py b/app.py index 29b4452..b1d0665 100644 --- a/app.py +++ b/app.py @@ -37,6 +37,16 @@ def gradio_inference( height=1024 ): """Wrapper function for Gradio interface""" + # Check if mask has been drawn + if image_data is None or "layers" not in image_data or not image_data["layers"]: + raise gr.Error("Please draw a mask over the clothing area before generating!") + + # Check if mask is empty (all black) + mask = image_data["layers"][0] + mask_array = np.array(mask) + if np.all(mask_array < 10): + raise gr.Error("The mask is empty! Please draw over the clothing area you want to replace.") + # Use temporary directory with tempfile.TemporaryDirectory() as tmp_dir: # Save inputs to temp directory @@ -86,6 +96,14 @@ def gradio_inference( # gr.Video("example/github.mp4", label="Demo Video: How to use the tool") with gr.Column(): + gr.Markdown(""" + ### ⚠️ Important: + 1. Click 'Edit' on the Model Image + 2. Draw a mask over the clothing area you want to replace + 3. Click 'Save' when done + 4. Then click 'Generate Try-On' + """) + with gr.Row(): with gr.Column(): image_input = gr.ImageMask(