• AI & Data Science
February 3, 2022

ML isn’t all so mysterious – implement your own KNN classifier

Many people see machine learning (ML) systems as black boxes. Some believe these black boxes contain pure magic and have the power to solve any problem at hand. Others do not trust them at all, for what cannot be explained cannot be counted on. ML is not all so mysterious, though. Just like any other fields of engineering, ML relies on applied mathematics and applied science, and it uses scientific principles to design, build and test a system so that it becomes reliable and works efficiently in real time. Implementing a simple KNN classifier can be your first step to demystify ML.

What is a classifier and how a KNN classifier works

Classification refers to the practice where you use knowledge from some labeled objects to classify unlabeled objects into different categories. For example, given some facial features of a pool of objects labeled as a cat or a dog, you can build a system to tell whether an object is a cat or a dog. The system that you build on labeled data is called a classifier, who is able to put objects into discrete classes.

The k-nearest-neighbors (KNN) classifier is one simple ML classifier. It does not even require a model to be estimated. Though simple, KNN has been successful in a large number of classification problems, including recognizing hand written digits.

How does a KNN classifier work then? It is quite intuitive. For a human, if an object looks more similar to a cat than a dog, then we will see it as a cat. As for a KNN classifier, given an unlabeled object, if the k labeled objects with the most similar facial features to the object in question are cats, then it will classify the unlabeled object as a cat, otherwise a dog. What if the numbers of cats and dogs are equal among the k most similar objects? Such a case might be hard to tell by a human anyway, and a KNN classifier can simply randomly label it as a cat or a dog.

So where is the “nearest neighbors” component? KNN evaluates similarity among objects by the distances between them. The shorter the distance, the more similar two objects are. In essence, the k nearest labeled neighbors are the k objects most similar to the object in question, and their majority vote decides the unknown label. Similar objects belong to the same class and thus should get the same label. Simple like that.

Implement a KNN classifier from scratch in python

Scikit-learn (sklearn) is a great inspiration on how to structure your code. We follow the design of sklearn KNeighborsClassifier in our implementation.

The sklearn KNN classifier class has quite a few parameters. We only implement two of them for simplicity. One is the number of neighbors, the only parameter required by KNN to work. The other is the metric parameter, defining which distance metric to use. The most common metric is the standard Euclidean distance, which is used in this example implementation. By including the metric parameter, you can easily replace it with other metrics of interest later by yourself. Thus, we initialize our KNN classifier as below.

As for class methods, sklearn has the following ones implemented:

Once again, for simplicity, we only implement the crucial ones, including the fit, kneighbors, predict and score methods. For KNN, there is no model to estimate, and the fit method only needs to memorize the observed features and their corresponding labels. Besides keeping copies of the observed inputs and outputs, we also find the number of distinct classes in observed data for later use, and implement the fit method as such:

Now we come to the core function of KNN, the predict method. It finds the k nearest neighbors of some unlabeled objects and then the majority votes as the predicted labels, respectively. The predict method first calls the kneighbors method to compute the pairwise distances between an unlabeled object and all of the labeled objects saved by the fit method for each unlabeled object. After this step, for each unlabeled object, we know its distances to all the labeled objects. Then the kneighbors method can order the labeled objects by their distances to an unlabeled object from shortest to longest, and return the indices in the observed data of the top k labeled objects. Obviously, this is also done for each unlabeled object. With the indices, the predict method can now retrieve the labels (since the features and label of the same object has the same index in X and y) of the list of k nearest neighbors, and get the majority votes for all unlabeled objects. The codes for the predict method:

Lastly, we have a score method to evaluate the performance of our classifier. It simply computes the ratio of the number of correctly predicted labels to all predicted labels, as below:

Try out your KNN classifier on the Iris dataset

Let’s now test our KNN classifier on the Iris flower data set. It has three classes, i.e., the three species of Iris, including setosa, virginica and versicolor. It comes with four features. For simplicity of visualizations, we use three of them, namely the sepal length, sepal width and petal length as inputs to classify observations. We split the data into training (80%) and test (20%) sets, so that our classifier is trained using only the training set but tested on the test set.

Setting the number of neighbors k to an arbitrary number of 8, our trained KNN classifier obtain an accuracy score of 86.7%. Since the classes are equally sized, a baseline classifier that always predicts one of the labels will give a 33.3% accuracy. KNN is not bad as such a simple method!

As a reference, using the same k=8, the sklearn KNeighborsClassifier gives the same score.

We can also visualize our KNN classifier’s predictions in a 3D space. The three axes correspond to the three input features, and the colors denote the predicted class labels. Objects of the same class do cluster together.

You have now implemented your own KNN classifier. Have fun tuning and optimizing it!

You can find the complete code here.


‘k-Nearest-Neighbor Classifiers’, in Hastie, T., Tibshirani, R. and Friedman, J. (2016) The elements of statistical learning: data mining, inference, and prediction. 2nd ed. New York: Springer.


Jin Guo

How can we help?

Read about how Combine works with Machine Learning