diff --git a/feat/plotting.py b/feat/plotting.py index fa7c736c..53e7773b 100644 --- a/feat/plotting.py +++ b/feat/plotting.py @@ -329,7 +329,7 @@ def draw_vectorfield( return ax -def draw_muscles(currx, curry, au=None, ax=None, *args, **kwargs): +def draw_muscles(currx, curry, au=None, ax=None, cmap="Blues", *args, **kwargs): """Draw Muscles Args: @@ -756,7 +756,7 @@ def draw_muscles(currx, curry, au=None, ax=None, *args, **kwargs): del kwargs[muscle] for muscle in todraw.keys(): if todraw[muscle] == "heatmap": - muscles[muscle].set_color(get_heat(muscle, au, facet)) + muscles[muscle].set_color(get_heat(muscle, au, facet, cmap)) else: muscles[muscle].set_color(todraw[muscle]) ax.add_patch(muscles[muscle], *args, **kwargs) @@ -805,7 +805,7 @@ def draw_muscles(currx, curry, au=None, ax=None, *args, **kwargs): return ax -def get_heat(muscle, au, log): +def get_heat(muscle, au, log, cmap="Blues"): """Function to create heatmap from au vector Args: @@ -817,7 +817,7 @@ def get_heat(muscle, au, log): Returns: color of muscle according to its au value """ - q = sns.color_palette("Blues", 151) + q = sns.color_palette(cmap, 151) unit = 0 aus = { "masseter_l": 15, @@ -882,6 +882,7 @@ def plot_face( muscles=None, ax=None, feature_range=False, + cmap="Blues", color="k", linewidth=1, linestyle="-", @@ -937,7 +938,7 @@ def plot_face( au = minmax_scale(au, feature_range=(0, 100 * muscle_scaler)) else: au = muscle_scaler.transform(np.array(au).reshape(-1, 1)).squeeze() - ax = draw_muscles(currx, curry, ax=ax, au=au, **muscles) + ax = draw_muscles(currx, curry, cmap, ax=ax, au=au, **muscles) if gaze is not None and len((gaze)) != 4: warnings.warn(