Supervised Learning is one of the techniques of Machine Learning in which a machine infers a mathematical function from a labelled dataset. With labelled set, I mean values which exist in pairs. Take it as an input value which has certain output. For example a land of 100 sq. feet has a price of some 233676 INR!
There could be two types of problems:
Before diving into the math of Machine Learning lets have an overview. So let's consider an example.There could be two types of problems:
- Regression problem in which the output for an input is a real number. For example, the price of land could take any value between an upper and lower limit.
- Classification problem in which the output value can take up only a fixed number of values. For example, A tumour could be either Benign (non-cancerous) or malignant (cancerous).
All the codes used in this post can be found on my GitHub profile :
https://github.com/majeedk526/Univariate-Linear-Regression
Consider a company which produces chocolates. Over the period the company has produced chocolates with varying weights with certain costs. Consider the following dataset which the company has maintained.
MatLab :
x = [5; 10; 15; 20; 25]; % weights
Y = [10; 25; 23; 28; 40]; %costs
plot the data for better visualisation
MatLab:
plot(x,Y,'ko', 'MarkerFaceColor', 'b'); % plot for visualisation | |
title('Linear regression'); | |
xlabel('Weights (g)'); | |
ylabel('Cost (INR)'); |
Clearly, we can see the trend in the data, as the weights of chocolates increase its cost increases. And Supervised learning is all about making the machine learn this trend.
There are many algorithms which can be used to find the mathematical function which represents this trend. Some of these algorithms include
- Linear Regression
- Logistic Regression
- Neural Networks
- Support Vector Machines
- and the list could go on...
Let's train our model using Linear Regression to find best possible straight line. More specifically, it is called Linear regression with one variable or Univariate linear regression.
Since we are doing linear regression let's use a linear hypothesis function. A hypothesis function is a sort of rough approximation function which our linear regression algorithm improves to make it best fit the data.
hθ(x) = θ0 + θ1(x)
The equation is a simple line equation with gradient θ1 and y-intercept θ0.
θ0 , θ1a are more precisely referred as parameters.
MatLab:
theta = zeros(2,1); % 2x1 matrix with inital elements zero
So here is what we are going to do. We will be substituting a value of x(weight) in our hypothesis function which will give a 'cost' value. This might be different from the actual corresponding 'cost' value given in the table. We will call this difference 'error'. And this error would be defined for every set of points in our dataset. Since error could be positive or negative which would sum our total error to be zero, so instead we would take the squared-mean average of the errors.
The squared-mean error function is defined as
MatLab:
theta = zeros(2,1); % 2x1 matrix with inital elements zero
So here is what we are going to do. We will be substituting a value of x(weight) in our hypothesis function which will give a 'cost' value. This might be different from the actual corresponding 'cost' value given in the table. We will call this difference 'error'. And this error would be defined for every set of points in our dataset. Since error could be positive or negative which would sum our total error to be zero, so instead we would take the squared-mean average of the errors.
The squared-mean error function is defined as
m - length of dataset
h - hypothesis function
y - output (Cost)
During our calculations we have to try different values of θ0 , θ1 which could minimise the error. To find that we will use another algorithm called gradient descent.
The gradient descent algorithm :
alpha term is called learning rate and it controls how fast values of θ0 , θ1 changes. This is a crucial part, if alpha is too small it would take forever to find a best-fit line if it is too high it may overshoot the optimum values of θ0 , θ1 which can give us minimum error or our best-fit line.
MatLab:
alpha = 0.001; % learning rate
function [theta, theta_history, j_history] =
gradient_descent(X,Y,theta, alpha, itr,m)
end
MatLab:
alpha = 0.001; % learning rate
function [theta, theta_history, j_history] =
gradient_descent(X,Y,theta, alpha, itr,m)
for i=1:itr | |
j_history(i) = cost(X,Y,m, theta); | |
tmp0 = theta(1) - alpha * (1/m)*sum(X*theta-Y); | |
tmp1 = theta(2) - alpha * (1/m)*sum((X*theta-Y).*X(:,2)); | |
theta(1) = tmp0; | |
theta(2) = tmp1; | |
theta_history(i,1) = tmp0; | |
theta_history(i,2) = tmp1; |
end
here are the derivative terms. These are derived by taking partial derivatives of error function.
MatLab:
m = length(x);
function J = cost(X, Y, m, theta)
J = (1/(2*m)) * sum((X*theta - Y).^2);
end
Since x, y are defined as matrix. To make use of matrix multiplication properties we would add a column of 1's to x.
MatLab:
X = [ones(m,1), x]; % add 1s in the first column
itr = 10; % number of iterations
Now lets call our gradient descent function
[theta, theta_history, j_history] = gradient_descent(X, Y, theta, alpha, itr,m); |
If we plot our j_history data which is actually the cost which has been calculated after each iteraion. You can see our cost decreases exponentially.
And here is the plot of hypotheisi function after each iteration. You can see how our hypothesis becomes more and more accuarte after each iteration.
And here is the best-fit line
After the machine has learnt from the data set it plots the following approx linear function. This function could be used to predict cost for weights which were not present in the training dataset. And that's all a supervised learning is!
Note that in practice the relations between cost and weight might not be linear and may include an infinite number of features. With feature what I mean is some characteristic which uniquely distinguishes samples.
Consider the following examples
As you can see that the relation could take any shape!
If you see any error please feel free to correct me! :)
No comments:
Post a Comment