Source code for l2l.optimizees.functions.tools
[docs]def plot(fn, random_state):
"""
Implements plotting of 2D functions generated by FunctionGenerator
:param fn: Instance of FunctionGenerator
"""
import numpy as np
from l2l.matplotlib_ import plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
fig = plt.figure()
ax = fig.gca(projection=Axes3D.name)
# Make data.
X = np.arange(fn.bound[0], fn.bound[1], 0.05)
Y = np.arange(fn.bound[0], fn.bound[1], 0.05)
XX, YY = np.meshgrid(X, Y)
Z = [fn.cost_function([x, y], random_state=random_state) for x, y in zip(XX.ravel(), YY.ravel())]
Z = np.array(Z).reshape(XX.shape)
# Plot the surface.
surf = ax.plot_surface(XX, YY, Z, cmap=cm.coolwarm, linewidth=0, antialiased=False)
# Customize the z axis.
# ax.set_zlim(-1.01, 1.01)
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
W = np.where(Z == np.min(Z))
ax.set(title='Min value is %.2f at (%.2f, %.2f)' % (np.min(Z), X[W[0]], Y[W[1]]))
# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.savefig('function.png')
plt.show()