Spectral clustering is a clustering technique that can be used to segment images. The scikit-learn spectral_clustering()
function implements the normalized graph cuts spectral clustering algorithm. This algorithm represents an image as a graph of units. "Graph" here is the same mathematical concept as in Chapter 8, Text Mining and Social Network Analysis. The algorithm tries to partition the image, while minimizing segment size and the ratio of intensity gradient along cuts.
import numpy as np import matplotlib.pyplot as plt from sklearn.feature_extraction.image import img_to_graph from sklearn.cluster import spectral_clustering from sklearn.datasets import load_digits
digits = load_digits() img = digits.images[0].astype(float) mask = img.astype(bool)
graph = img_to_graph(img, mask=mask) graph.data = np.exp(-graph.data/graph.data.std())
labels = spectral_clustering(graph, n_clusters=3) label_im = -np.ones(mask.shape) label_im[mask] = labels
plt.matshow(img, False) plt.gca().axis('off') plt.title('Original')
plt.figure() plt.matshow(label_im, False) plt.gca().axis('off') plt.title('Clustered')
Refer to the following screenshot for the end result:
The code is in the clustering_spectral.ipynb
file in this book's code bundle.
spectral_clustering()
function documented at http://scikit-learn.org/stable/modules/generated/sklearn.cluster.spectral_clustering.html (retrieved December 2015)