Another (arguably the best one, in general) way to make things more performant is to make use of the right data structures and algorithms—in other words, we need to design our code better and use the right tools for the job in the first place. In our case, any spatial query, especially for a large dataset, will gain from the use of a spatial index. Essentially, this creates a hierarchical index, based on the spatial distribution itself. It allows it to measure the distances within a small subset of records. Let's try to make use of it in our model:
from scipy.spatial import cKDTree
class kdNearestNeighbor:
_kd = None
y = None
def __init__(self, N=3):
self.N=N
def fit(self, X, y):
self._kd = cKDTree(X, leafsize=2*self.N)
self.y = y
def predict(self, X):
d, closest = self._kd.query(X, k=self.N)
return np.mean(np.take(ytrain.values, closest), axis=1)
As you can see, now, the code is even simpler—cKDTree takes care of most of the actual logic, behind the scenes. Note that it also has a fair amount of parameters, which we could tune for additional performance gain on a specific dataset. But how does it perform? Let's take a look at the following code:
>>> kdKNN = kdNearestNeighbor(N=5)
>>> kdKNN.fit(Xtrain.values, ytrain.values)
>>> %%timeit
>>> _ = kdKNN.predict(Xtv)
11.3 ms ± 237 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
11.3 ms is less than one percent of our initial performance! Of course, there is a small trick to this: cKDTree creates an index during the fit. Due to this, the fit method will be considerably longer to run, but most of the time, this is a trade-off we're happy to make.
Here are a couple of resources on spatial indexes and other algorithms and data structures in Python:
- Spatial Range Queries Using Python In-Memory Indices, by Alexander Müller: https://www.youtube.com/watch?v=_95bSEqMzUA
- Python Data Structures and Algorithms, by Benjamin Baka: https://www.packtpub.com/application-development/python-data-structures-and-algorithms