I have a project at work where I need to find the nearest neighbor for a set of points on a map. We store these points in SQL Server, which has spatial query support.
It looks like scikit-learn got support for nearest neighbors using Haversine distance (what we need to use for measuring map distance) earlier in 2019 (https://github.com/scikit-learn/scikit-learn/pull/13289), but I just needed somethng quick – if I could keep all the work in SQL Server, that was ideal in this case.
Then I found this post, which does a lovely job of showing how we can perform some simple calculations to reduce the set of points that we need to measure distance between:
The post describes how to find all the nearest neighbors for a single point. In my case, I have a large set of points, and for each one, I want its nearest neighbors within the same set. So, how can we modify the query to handle this?
A cross join is an easy way of pairing every point with every other point in the set. But that introduces a new problem – now we can’t simply use a
LIMIT N as shown in the blog post (or
TOP N in my case, since we’re using SQL Server) clause, because we’ll be getting back the nearest neighbors for each point in the set. We need to get the
TOP N per point.
So to address that concern, we can use a windowing function to construct a ranking of the nearest neighbors for each point, and then limit the results to only include the top N ranks in the results. The resulting query ends up looking like this:
DECLARE @projectId INT = 123; -- Set this to the project ID of interest DECLARE @topRankLimit INT = 5; -- Show at most N nearest neighbors for each location DECLARE @distanceThresholdKM FLOAT = 16.0934 -- Only show nearest neighbors within 16.0934km (~= 10 miles) WITH nn AS ( SELECT loc.Id, loc.Location.Lat AS Lat, loc.Location.Long AS Long, nearest_neighbor_loc.Id AS NearestNeighborId, nearest_neighbor_loc.[Location].Lat AS NearestNeighborLat, nearest_neighbor_loc.[Location].Long AS NearestNeighborLong, loc.[Location].STDistance(nearest_neighbor_loc.[Location]) AS DistanceM, ROW_NUMBER() OVER (PARTITION BY loc.Id ORDER BY loc.Id, loc.[Location].STDistance(nearest_neighbor_loc.[Location])) AS DistanceRank FROM ProjectLocations AS loc CROSS JOIN ProjectLocations AS nearest_neighbor_loc WHERE loc.Id <> nearest_neighbor_loc.Id AND nearest_neighbor_loc.[Location].Lat BETWEEN (loc.[Location].Lat - (@distanceThresholdKM / 111.045)) AND (loc.[Location].Lat + (@distanceThresholdKM / 111.045)) AND nearest_neighbor_loc.[Location].Long BETWEEN (loc.[Location].Long - (@distanceThresholdKM / (111.045 * COS(RADIANS(loc.[Location].Lat))))) AND (loc.[Location].Long + (@distanceThresholdKM / (111.045 * COS(RADIANS(loc.[Location].Lat))))) AND loc.[Location].STDistance(nearest_neighbor_loc.[Location]) IS NOT NULL ) SELECT * FROM nn WHERE DistanceRank <= @topRankLimit AND (DistanceM / 1000.0) <= @distanceThresholdKM ORDER BY Id, DistanceM DESC ;
I’ll plan on updating this post soon with some additional discussion and explanation of the query.