How to build a KNN classification model from scratch and visualize it using Streamlit

Posted by

Although libraries like sklearn have made our lives easier, it is always a good practice to make a model from scratch. In this tutorial, we will be building a KNN Classification model from Scratch and build a web app using Streamlit to visualize it. Below is a demo of the final app.

KNN overview

KNN or K Nearest Neighbour is used for classification and regression. In this tutorial, we will be using it for classification. Since the target label is known, it is a Supervised algorithm. It essentially takes an input and finds the K nearest points to it. It then checks the labels of the nearest points and classifies the input as the label which occurred the most. Say we want to build a model to classify an animal as a dog or a cat based on the weight, height as input. If K = 3, we find the 3 nearest points to our input and check their label. If 2 of the 3 nearest points have a label ‘dog’, our model classifies the input as ‘dog’. If 2 of the 3 nearest points have a label ‘cat’, our model will classify the input as ‘cat’

Steps

  • Normalize the dataset and store it, i.e make sure all values are between 0 and 1.
  • Take an input data point and find the distance from all the records in our dataset. Store the distances in a list.
  • Sort the list containing the distances and check the labels for the first K records in the sorted list
  • Classify the input as the label which occurred the most in the first K records

First, we will create all the helper functions we will need. Then we will combine them and add some streamlit functions to build a web app.

For ease of understanding and visualization, we will be working with a dataset that has 2 features and has binary labels, i.e ‘0’ and ‘1’.

Helper Functions

Function to Normalize Data

To normalize a list of values, we iterate over each value and find the difference between the value and minimum value in the list. We then divide it by the difference of maximum and minimum values in the list.

The equation to Normalize Data
def min_max_normalize(lst):
    minimum = min(lst)
    maximum = max(lst)
    normalized = [(val - minimum)/(maximum - minimum) for val in 
    lst]                               
    return normalized

The function takes in a list of values and returns the normalized values

Function to Calculate Euclidean Distance

This function will be used to calculate the distance between two given points. We will use Euclid’s Formula to calculate the distance.

Formula to calculate Euclidean distance
def distance(element1 , element2):
    x_distance = (element1[0] - element2[0])**2
    y_distance = (element1[1] - element2[1])**2
    return (x_distance + y_distance)**0.5

The function takes in two 2D points and returns the Euclidean Distance between them. Since we are considering a dataset with only 2 features, we are only considering the x and y coordinates. As the number of features increases, this function will need to change to find the squared difference between all indices.

Function to find the distance between input point and all points in the dataset

We iterate over each value in the dataset and use our distance function above to calculate the distance between the two points. We then store the distances and sort it.

def find_nearest(x , y , input , k):
    distances = []
    for id,element in enumerate(x):
        distances.append([distance(input , element),id])
    distances = sorted(distances)
    predicted_label = get_label(distances[0:k] , y)
    return predicted_label, distances[0:k] , distances[k:]

The function takes the following parameters as input:

  • x: This is our dataset containing the two features
  • y: This contains the labels for each row in x. They are mapped respectively, i.e the label for x[i] is y[i]
  • input: This is a 2D array which contains the features of the points we want to classify
  • k: The number of nearest neighbour we want our model to consider

First, we create an empty array to store the distances. We need to store the distance and the index of the record in the dataset. The index can be used in the y array to find the label for that record.

Then we sort the distances. Next, we use the get_label function (It will be discussed below) to get the most occurring label.

Since the distances array is sorted, the first k elements,i.e distances[0:k] are the k nearest neighbours to our input. We return the predicted label of our input, the k nearest neighbours and the rest of the neighbours.

Function to find the most occurring Label

We essentially get the k nearest neighbours, check the label for each. In our case, we have only two labels ‘0’ and ‘1’. If the label of a neighbour is ‘0’ we increment count for occurrences of ‘0’ and do the same for ‘1’. We compare the counts for occurrences of both labels and return the label with the higher count.

def get_label(neighbours, y):
    zero_count , one_count = 0,0
    for element in neighbours:
      if y[element[1]] == 0:
         zero_count +=1
      elif y[element[1]] == 1:
         one_count +=1
    if zero_count == one_count:
         return y[neighbours[0][1]]
    return 1 if one_count > zero_count else 0

The function takes the k nearest neighbours as the input. Each record in neighbours contains the distance from the input point and its original id. We use the id and the y array to get the label of the record. We then check the label and return the predicted label.

We have created the required Helper functions. Now we will combine them along with some streamlit functions. Exciting Stuff! 😎

Required Libraries

We will use Plotly to plot our graphs since Plotly plots interactive graphs. For our dataset, we will import a dataset from sklearn.datasets. We will also use the pandas library to create Datframes.

import streamlit as st
import pandas as pd
from sklearn.datasets import make_blobs
import plotly.express as px
import plotly.graph_objects as go

I will break the rest of the tutorial into three parts

  1. Input from User
  2. Importing Dataset and Visualizing it
  3. Visualizing the Prediction

User Input

Screenshot of App

We will use streamlit’s title method to display a title and slider method to create a numerical slider to get input from the user. Since our data is normalized, we expect the input to be normalized as well. Therefore we restrict the user to enter values between 0 and 1. We can also take input from the user and use minimum, maximum values in our dataset to normalize the input.

st.title("KNN Visualize")
x_input = st.slider("Choose X input", min_value=0.0, max_value=1.0,key='x')
y_input = st.slider("Choose Y input", min_value=0.0, max_value=1.0,key='y')
k = st.slider("Choose value of K", min_value=1, max_value=10,key='k')
input = (x_input,y_input)

Every time the slider value is changed the entire python script is re-run and the variables will contain the new values according to the slider.

Importing Dataset and visualizing it

Screenshot of App
x , y = make_blobs(n_samples = 100 , n_features = 2 , centers = 2, random_state= 2)

The make_blobs functions create a dataset for us which looks similar to the distribution in the graph above. In the real world, the dataset won’t be so co-operative but this dataset will suffice for now. I suggest you plot a scatter plot using matplotlib to see the distribution of the data.

x contains the features and y contains the respective labels

# Normalizing Data
x[:,0] = min_max_normalize(x[:,0])
x[:,1] = min_max_normalize(x[:,1])
# Dataframe
df = pd.DataFrame(x , columns = ['Feature1' , 'Feature2'] )
df['Label'] = y
st.dataframe(df)

First, we use the normalizing helper function we created before to normalize our data. Then we combine the x and y array to create a Dataframe. We use streamlit’s dataframe method to view the dataframe.

# Initial Data Plot
fig = px.scatter(df, x = 'Feature1' , y='Feature2', symbol='Label',symbol_map={'0':'square-dot' , '1':'circle'})
fig.add_trace(
    go.Scatter(x= [input[0]], y=[input[1]], name = "Point to  Classify", )
)
st.plotly_chart(fig)

You can read Plotly’s documentation for a better understanding of the above code. We create a figure with the scatter plot of the dataframe we just created. We also add our input point to better understand where it is located with respect to other points in the dataset. Streamlit’s plotly_chart method takes a Plotly figure as a parameter and plots an interactive graph on our app.

Prediction and Visualizing it

Screenshot of App
#Finding Nearest Neighbours
predicted_label , nearest_neighbours, far_neighbours = find_nearest(x ,y , input ,k)
st.title('Prediction')
st.subheader('Predicted Label : {}'.format(predicted_label))

We use the find_nearest function we created earlier to get the predicted label and ids, distances of the k nearest neighbours and the far neighbours.

We display the predicted label us streamlit’s subheader method

nearest_neighbours = [[neighbour[1],x[neighbour[1],0],x[neighbour[1],1],neighbour[0],y[neighbour[1]]] for neighbour in nearest_neighbours]
nearest_neighbours = pd.DataFrame(nearest_neighbours , columns = ['id','Feature1','Feature2','Distance','Label'])
st.dataframe(nearest_neighbours)

The above code basically uses the id of the nearest neighbours and combines the id, distance with the feature1, feature2 values of the record and its label. We use the combined list to create a dataframe containing the information of the nearest neighbours. We then use streamlit’s dataframe method to display it. This dataframe will help us understand the graph below it.

far_neighbours = [[neighbour[1],x[neighbour[1],0],x[neighbour[1],1],neighbour[0],y[neighbour[1]]] for neighbour in far_neighbours]
far_neighbours = pd.DataFrame(far_neighbours , columns = ['id','Feature1','Feature2','Distance','Label'])
fig2 = px.scatter(far_neighbours,x='Feature1',y='Feature2',symbol='Label',symbol_map={'0':'square-dot' , '1':'circle'})

We create a similar dataframe for the far neighbours. We use Plotly to plot a scatter plot. We will now add the input and lines connecting the input to its k nearest neighbours.

for index,neighbour in nearest_neighbours.iterrows():
    fig2.add_trace(
       go.Scatter( x=[input[0], neighbour['Feature1']], y=[input[1],
       neighbour['Feature2']],mode='lines+markers' , name = 'id
       {}'.format(int(neighbour['id'])) )
    )
st.plotly_chart(fig2)

We iterate over each of the neighbours and add a line between the neighbour and our input point to the figure we created. Finally, we use the plotly_chart method to plot the figure.

And That’s It 👏 We have a created a KNN classifier from scratch and created a Streamlit App to visualize it

If you are interested in deploying your streamlit app, check out my tutorial.

You can find the GitHub repo here.

I am still fairly new to the world of Machine Learning, if you find any errors/mistakes or any piece of code which can be optimized, please let me know! I am always open to feedback 😃