What is the KNN Algorithm?
KNN, or K Nearest Neighbor, is a classification algorithm in supervised machine learning. It finds intense application in pattern recognition, data mining, and intrusion detection. It is nonparametric and a lazy learner developed in 1951 by Evelyn Fix and Joseph Hodges.
Applications
- Pattern recognition
- Data preprocessing
- Recommendation Engines
- Predictive analytics
How does it work?
Given test data, KNN can predict its correct class by calculating the distance between the test data and training points
If you had a large catalog of plants in your gallery and use it as training data, you could accurately identify any new plant photos that you capture in the future, assuming it is registered in the catalog.
When new data points are entered, they are compared with the K Nearest Neigbor and identified.
Based on a category's euclidean average distance in relation to the new data point, KNN can derive the classification. The closer the data point is to a given category, the more likely it belongs to it.
Code Snippets
def classifyTestData(data, test, k=3):
# Create distance array
distance_array = []
# For class in data:
for class_label in data:
# For dataPoint in data[class]:
for dataPoint in data[class_label]:
# Calculate euclidean distance of test from training points (i.e. datapoints)
distance = euclidean_distance(test, dataPoint) # Assuming euclidean_distance is defined
# Append (euclidean distance, class) to distance array
distance_array.append((distance, class_label))
# Sort list in ascending order
distance_array.sort()
# Select first k distances
nearest_neighbors = distance_array[:k]
# Create a frequency variable (int) for each class in data and set to 0
frequency = {class_label: 0 for class_label in data}
# For distance in distance array
for distance in nearest_neighbors:
# Increment the frequency of the corresponding class
frequency[distance[1]] += 1
# Determine the class with the highest frequency
max_class = max(frequency, key=frequency.get)
return max_class
# Add elif’s until all classifications are covered
Example Usage:
def main():
# Create a dictionary of classes with data points as tuples
data = {
0: [(1, 2), (2, 3)],
1: [(3, 4), (4, 5)],
2: [(5, 6), (6, 7)]
}
# Create a testing point variable, test
test = (3, 3)
# Create a variable k for the number of neighbors
k = 3
# Print result returned from classifyTestData function
result = classifyTestData(data, test, k)
print(f"The classified class for the test point {test} is: {result}")
main()
KDD Trees
- K dimensional tree: data in each node is a k dimensional point in space
- Constructed by recursively splitting data along one dimension
- Dimension used for splitting alternates with each sublevel of the tree, EX. root splits to two children from x to y axis, children split into further nodes from Y to Z axis
Ball Trees
- More useful than a kd tree for higher dimensional spaces
- One ball (root node) encompasses all data points, and is split into smaller balls (child nodes) within itself that are defined by its center and radius
Complexity
Time:
- Best case: O(log n)
- Average case: O(log n) Divides the search space as the algorithm runs through each level of a balanced KD tree
- Worst case: O(n*d): In high dimensional spaces the tree can be unbalanced and the algorithm will search every node rather than partitioning the space
Space:
- Best case: O(1)
- Average case O(nd): Storing n data points with d dimensions takes n * d space, memory grows linearly with number of dimensions
- Worst case: O(nd)
Advantages and Drawbacks
Advantages:
- Easy to implement and adaptable
- Very few hyper-parameters to incorporate
Drawbacks:
- Not scalable for large datasets
- Prone to overfitting
- Victim of the “Curse of Dimensionality”