3. Visualization

3.1 Visualization of the training set

In [15]:
plt.figure(figsize=(12,7))
plt.scatter(X_train[:,0], X_train[:,1], c=y_train[0,:] ,cmap=cm.coolwarm)
plt.title('Training set')
plt.axis('equal');

3.2 Visualization of the model predictions on our test set

In [16]:
plt.figure(figsize =(12,7))
plt.scatter(X_test[:,0], X_test[:,1], c=prediction_values[:,0], cmap=cm.coolwarm)
plt.title('Model predictions on our Test set')
plt.axis('equal');

We can now see how our model separates the classes.

In [17]:
xx = np.linspace(-2, 2, 40)
yy = np.linspace(-1.5, 1.5, 40)
gx, gy = np.meshgrid(xx, yy)
Z = model.predict(np.c_[gx.ravel(), gy.ravel()])
Z = Z.reshape(gx.shape)
plt.contourf(gx, gy, Z, cmap=plt.cm.coolwarm, alpha=0.8)

axes = plt.gca()
axes.set_xlim([-2, 2])
axes.set_ylim([-1.5, 1.5])
plt.grid('off')
plt.axis('off')

plt.scatter(X_test[:,0], X_test[:,1], c=prediction_values[:,0], cmap=cm.coolwarm)
plt.title('Model predictions on our Test set')
Out[17]:
Text(0.5, 1.0, 'Model predictions on our Test set')

3.3 Display the weights and biases of our model

In [18]:
# Input layer
weights0 = model.layers[0].get_weights()[0]
biases0 = model.layers[0].get_weights()[1]
print("Input layer weights",weights0.shape,":\n",weights0)
print("Input layer biases",biases0.shape,":\n",biases0)


# Output layer
weights1 = model.layers[1].get_weights()[0]
biases1 = model.layers[1].get_weights()[1]
print("\nOutput layer weights",weights1.shape,":\n",weights1)
print("Output layer biases",biases1.shape,":\n",biases1)
Input layer weights (2, 8) :
 [[ 1.7777659   0.10601337  1.8901554  -2.580172   -1.6185337   2.404422
   0.09642663 -1.4703599 ]
 [-2.3700235  -0.6269098   2.8601322  -1.181832    2.387554   -0.8151907
  -0.5713202  -2.3834863 ]]
Input layer biases (8,) :
 [-0.3669517   1.6662995  -0.4583802  -0.27512434 -0.36231026 -0.21786891
  1.5253848  -0.37401372]

Output layer weights (8, 1) :
 [[-3.6675723]
 [ 2.4735324]
 [-3.0899823]
 [-3.3958952]
 [-3.737522 ]
 [-3.9193368]
 [ 2.7542424]
 [-3.6022592]]
Output layer biases (1,) :
 [1.0367551]