Skip to content

Commit

Permalink
Merge pull request scikit-optimize#898 from kartikayyer/master
Browse files Browse the repository at this point in the history
[MRG] Add options to plot_objective
  • Loading branch information
holgern authored May 18, 2020
2 parents 3d1610c + 0d9c4f7 commit bf6e244
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions skopt/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ def _format_scatter_plot_axes(ax, space, ylabel, plot_dims,

else: # diagonal plots
ax_.set_ylim(*diagonal_ylim)
if not iscat[i]:
low, high = dim_i.bounds
ax_.set_xlim(low, high)
ax_.yaxis.tick_right()
ax_.yaxis.set_label_position('right')
ax_.yaxis.set_ticks_position('both')
Expand Down Expand Up @@ -537,7 +540,8 @@ def partial_dependence(space, model, i, j=None, sample_points=None,

def plot_objective(result, levels=10, n_points=40, n_samples=250, size=2,
zscale='linear', dimensions=None, sample_source='random',
minimum='result', n_minimum_search=None, plot_dims=None):
minimum='result', n_minimum_search=None, plot_dims=None,
show_points=True, cmap='viridis_r'):
"""Plot a 2-d matrix with so-called Partial Dependence plots
of the objective function. This shows the influence of each
search-space dimension on the objective function.
Expand Down Expand Up @@ -586,7 +590,7 @@ def plot_objective(result, levels=10, n_points=40, n_samples=250, size=2,
levels : int, default=10
Number of levels to draw on the contour plot, passed directly
to `plt.contour()`.
to `plt.contourf()`.
n_points : int, default=40
Number of points at which to evaluate the partial dependence
Expand Down Expand Up @@ -654,6 +658,14 @@ def plot_objective(result, levels=10, n_points=40, n_samples=250, size=2,
`sample_source` and/or `minimum` is set to
'expected_minimum' or 'expected_minimum_random'.
show_points: bool, default = True
Choose whether to show evaluated points in the
contour plots.
cmap: str or Colormap, default = 'viridis_r'
Color map for contour plots. Passed directly to
`plt.contourf()`
Returns
-------
ax : `Matplotlib.Axes`
Expand Down Expand Up @@ -717,7 +729,7 @@ def plot_objective(result, levels=10, n_points=40, n_samples=250, size=2,
else:
ax_ = ax
ax_.plot(xi, yi)
ax_.axvline(minimum[i], linestyle="--", color="r", lw=1)
ax_.axvline(minimum[index], linestyle="--", color="r", lw=1)

# lower triangle
elif i > j:
Expand All @@ -728,9 +740,10 @@ def plot_objective(result, levels=10, n_points=40, n_samples=250, size=2,
index1, index2,
samples, n_points)
ax_.contourf(xi, yi, zi, levels,
locator=locator, cmap='viridis_r')
ax_.scatter(x_samples[:, index2], x_samples[:, index1],
c='k', s=10, lw=0.)
locator=locator, cmap=cmap)
if show_points:
ax_.scatter(x_samples[:, index2], x_samples[:, index1],
c='k', s=10, lw=0.)
ax_.scatter(minimum[index2], minimum[index1],
c=['r'], s=100, lw=0., marker='*')
ylabel = "Partial dependence"
Expand Down

0 comments on commit bf6e244

Please sign in to comment.