-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcanvas.py
94 lines (85 loc) · 2.63 KB
/
canvas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import streamlit as st
from streamlit_drawable_canvas import st_canvas
# local
from widgets import Timer
from images import avg_rgb
from segmentation import thresholding
def show_canvas(tog_auto, tog_edit):
cur_i = st.session_state.cur_i
img_pil = st.session_state.pil_imgs[cur_i]
state = st.session_state.json_data[cur_i]
filename = st.session_state.file_imgs[cur_i].name
# with Timer("Rendering Time"):
canvas = st_canvas(
initial_drawing=state,
fill_color="rgba(255, 90, 0, 0.3)",
stroke_width=1,
stroke_color="rgb(255, 0, 0, 1)",
background_color="#fff",
background_image=img_pil,
update_streamlit=tog_auto,
height=img_pil.height,
width=img_pil.width,
drawing_mode="transform" if tog_edit else "rect",
# point_display_radius=0, # not available in 0.8.0
key="canvas%d" % cur_i,
)
st.write(filename)
return canvas
def canvas_to_states(canvas):
cur_i = st.session_state.cur_i
img_pil = st.session_state.pil_imgs[cur_i]
filename = st.session_state.file_imgs[cur_i].name
strength = st.session_state.seg_binary
try:
json_objs, cropped_img = extract_canvas(
canvas.json_data["objects"],
img_pil,
strength,
)
json_out = dict(
{
"filename": filename,
"objects": json_objs,
}
)
st.session_state.json_data_tmp = canvas.json_data
st.session_state.json_out[cur_i] = json_out
st.session_state.cropped_imgs[cur_i] = cropped_img
except Exception as e:
st.spinner("Buffering...")
def extract_canvas(objects, img_pil, seg_binary=0):
"""
objects: list
canvas.json_data["objects"]
"""
n_objs = len(objects)
cropped_imgs = [None] * n_objs
json_objs = [None] * n_objs
# iterate through each object
for i, obj in enumerate(objects):
l, t, w, h = (
obj["left"],
obj["top"],
obj["width"] * obj["scaleX"],
obj["height"] * obj["scaleY"],
)
# cropped image
cropped_image = img_pil.crop((l, t, l + w, t + h))
segged_image = thresholding(cropped_image, seg_binary)
cropped_imgs[i] = segged_image
# calculate average channel values
r, g, b = avg_rgb(segged_image)
# json
json_objs[i] = {
"left": l,
"top": t,
"width": w,
"height": h,
"red": r,
"green": g,
"blue": b,
}
# return
return json_objs, cropped_imgs