Services

Simple Linear Regression on JAX , Pytorch and Tensorflow

male avatar

taher

Published On : 2022-07-11
blog header

Simple Linear Regression on JAX , Pytorch and Tensorflow

In this post, we will learn about linear regression. We will implement linear regression on JAX, Pytorch and Tensorflow. This is first posted by jeremy howard in fastai tutorials ( https://www.youtube.com/user/howardjeremyp )
which shows every thing from start in pytorch . We though of remaking it using JAX and tensorflow . And compare each one which is easy to develop for development

What is Linear Regression ?

Linear regression is a model which predicts the value of an unknown variable using the observed value of a number of known variables.

Let’s understand linear regression with an example.

Suppose, we are given a data set which consists of data of the form

{x, y}

where, x are input values and y is the target value.We will use x to predict y.The input variables are not limited to the number of variables, it could be more.
We need to find a formula to predict y.

The formula is,

y = a*x

In the above formula, a is the weights of the formula.

For experiment with all three libraries . we will be doing following steps
first we need to import all libraries.
secondly need to create a data set.
thirdly we need to create a model.
fourthly we need to train the model.
fifthly we need to evaluate model using graph.

Lets first Start With Pytorch as it is shown by jeremy howard in fastai tutorials.

pytorch
importing all libraries

from fastai.basics import *
from matplotlib import *
from fastai.torch_core import tensor

creating a data set

n= 100 #this is point numbers
x = torch.ones(n,2)
x[:,0].uniform_(-1.,1)
a = tensor(3.,2);
y=x@a + 0.25*torch.randn(n)

In the above code, we have created a data set with 100 points. In the x we have created a matrix with two columns. The first column is 1 and the second column is uniform random numbers between -1 and 1. The y is the target value. The target value is the output of the formula y = a*x + 0.25*randn(n).
Now we will create graph to show our data set in matplotlib .

plt.scatter(x[:,0], y);

def mse(ypred,y):
return ((ypred - y)**2).mean()

In the above code, we have created a function to calculate the mean square error. The function takes two arguments, ypred and y. The ypred is the predicted value and y is the actual value. It is loss function for our model.

plt.scatter(x[:,0], y);
plt.scatter(x[:,0], ypred,c="orange");
x[0,0]

In the above code, we have created a graph to show the data set and the predicted value. The predicted value is the output of the formula y = a*x + 0.25*randn(n).

a = nn.Parameter(a) ; a

In the above code, we have created a parameter a. The parameter a is the weights of the formula. The weights are the parameters which are used to predict the output.

def update():
  ypred = x@a #This is the predicted value
  loss = mse(ypred,y) #This is the loss function
  if t % 10 == 0 : print(loss) #This is to print the loss function
  loss.backward() #This is to backpropagate the loss function which calculates the gradient of the loss function
  with torch.no_grad():
  a.sub_(lr*a.grad) #This is to update the weights
  a.grad.zero_() #This is to set the gradient to zero after updating the weights

In the above code, we have created a function to update the weights. The function takes no arguments. The function updates the weights of the formula. The weights are updated using the gradient of the loss function.

lr = 1e-1 #This is the learning rate
epoch = 100 #This is the number of epochs
for t in range(100):
  update() #This is to update the weights

In the above code, we have created a loop to update the weights. The loop takes 100 iterations. The loop updates the weights of the formula. The weights are updated using the gradient of the loss function.

from matplotlib import animation, rc
rc('animation', html='jshtml')

a = nn.Parameter(tensor(-100.,100.))

fig = plt.figure()
plt.scatter(x[:,0], y, c='orange')
line, = plt.plot(x[:,0], x@a.detach())
plt.close()

def animate(i):
  update()
  line.set_ydata(x@a.detach())
  return line,

animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)

In the above code, we have created a function to animate the graph. The function takes one argument, i.e. the iteration number. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The graph is animated and graph is updated every 20 iterations.

Tensorflow
importing all libraries

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

creating a data set

n = 100 #total number of points
x = np.ones(shape=(n,2));x[:5]
x_ = np.linspace(-1, 1, num=100);
x[:,0] = x_
a = np.array([3,2]);
y = x@a + 0.75*np.random.rand(n);y.shape


In the above code, we have created a data set with 100 points. In the x we have created a matrix with two columns. The first column is 1 and the second column is uniform random numbers between -1 and 1. The y is the target value. The target value is the output of the formula y = a*x + 0.25*randn(n).

plt.scatter(x[:,0], y);

def mse(ypred,y):
return tf.reduce_mean((ypred-y)**2)

In the above code, we have created a function to calculate the mean square error. The function takes two arguments, ypred and y. The ypred is the predicted value and y is the actual value. It is loss function for our model. It is modified to use tensorflow.

plt.scatter(x[:,0], y);
plt.scatter(x[:,0], ypred,c="orange");
x[0,0]

In the above code, we have created a graph to show the data set and the predicted value. The predicted value is the output of the formula y = a*x + 0.25*randn(n).

a = tf.Variable(tf.reshape(np.array([-.5,.5]),shape=(2,1)));

In the above code, we have created a variable a. The variable a is the weights of the formula. The weights are the parameters which are used to predict the output. The variable a is modified to use tensorflow.

x_new = tf.constant(x)
ypred_tensor = x_new@a
y_tensor = tf.constant(y)

Above code converts the numpy array to tensorflow tensor.

def update():
  with tf.GradientTape() as tape: #This is to calculate the gradient of the loss function using tensorflow gradient tape and tensorflow 
  ypred = x_new@a #This is the predicted value
  loss = mse(tf.reshape(ypred,shape=(100,)),y) #This is the loss function
  if t%10 == 0 : print(loss) #This is to print the loss function
  tape.watch(a) #This is to watch the weights
  grad = tape.gradient(loss, a) #This is to calculate the gradient of the loss function
  a.assign_sub((lr * grad)) #This is to update the weights

In the above code, we have created a function to update the weights. The function takes no arguments. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The function is modified to use tensorflow. It uses the gradient tape to calculate the gradient of the loss function. and saves the gradient in the variable grad.

lr = 1e-1 #This is the learning rate
for t in range(100):
update() #This is to update the weights this function defines in the previous code

In the above code, we have created a loop to update the weights. The loop takes 100 iterations. The loop updates the weights of the formula. The weights are updated using the gradient of the loss function. The loop is modified to use tensorflow.

from matplotlib import animation, rc
rc('animation', html='jshtml')

a = tf.Variable(tf.reshape(np.array([-100.,.1]),shape=(2,1)))

fig = plt.figure()
plt.scatter(x[:,0], y, c='orange')
line, = plt.plot(x[:,0], tf.reshape(x_new@a,shape=(100,)))
plt.close()

def animate(i):
  update()
  line.set_ydata(tf.reshape(x_new@a,shape=(100,)))
  return line,

animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)

In the above code, we have created a function to animate the graph. The function takes one argument, i.e. the iteration number. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The graph is animated and graph is updated every 20 iterations. The graph is modified to use tensorflow.

Now we will use googles JAX
#importing all libraries
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from PIL import Image
import jax

creating a data set
n = 100 total number of points
x = jnp.ones(shape=(n,2));x[:5]
x_ = jnp.linspace(-1, 1, num=100);
x = x.at[:,0].set(x_)
a = jnp.array([3,2]);
seed=1
key = jax.random.PRNGKey(seed)
y = x@a + 0.75*np.random.rand(n)

In the above code, we have created a data set with 100 points. In the x we have created a matrix with two columns. The first column is 1 and the second column is uniform random numbers between -1 and 1. The y is the target value. The target value is the output of the formula y = a*x + 0.75*randn(n).


plt.scatter(x[:,0], y);

def mse(ypred,y):
  return jnp.mean((ypred-y)**2)


In the above code, we have created a function to calculate the mean square error. The function takes two arguments, ypred and y. The ypred is the predicted value and y is the actual value. It is loss function for our model. It is modified to use JAX.

a = jnp.array([-.5,.5])
ypred = x@a
def model(x,a):
  return x@a

In the above code, we have created model function. The function takes two arguments, x and a. The function returns the output .

def grad_func(y,x,a):
   return mse(model(x,a),y)

In the above code , we have created a function to calculate the gradient of the loss function. The function takes three arguments, y, x and a. The function will return the gradient of the loss function.

lr = 1e-1 #This is the learning rate
def update():
  global a #This is to make the variable a global variable
  loss,grad = jax.value_and_grad(grad_func,argnums=(2))(y,x,a) #This is to calculate the gradient and value of the loss function in JAX
  if t%10 == 0 : print(loss) #This is to print the loss function
  a -= lr*grad #This is to update the weights

In the above code, we have created a function to update the weights. The function takes no arguments. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The function is modified to use JAX. It uses the value and gradient tape to calculate the gradient of the loss function and update the weights.

def loop():
for t in range(100):
  update()


In the above code, we have created a loop to update the weights. The loop takes 100 iterations. The loop updates the weights of the formula. The weights are updated using the gradient of the loss function.

from matplotlib import animation, rc
rc('animation', html='jshtml')

a = jnp.array([-100.,.1]) #This is the weights of the formula

fig = plt.figure()
plt.scatter(x[:,0], y, c='orange') #This is to plot the data set
line, = plt.plot(x[:,0], x@a) #This is to animate the graph
plt.close()

def animate(i): #This is to animate the graph
  update()
  line.set_ydata(x@a) #This is to update the graph
  return line, #This is to return the graph


animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)


In the above code, we have created a function to animate the graph. The function takes one argument, i.e. the iteration number. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The graph is animated and graph is updated every 20 iterations. The graph is modified to use JAX.

 

You can view all three google colab notebooks in this repo

In Our Conclusion between tensorflow , Pytorch and JAX, we have found that JAX is the best library to use for this example as it is the fastest and most efficient library. And it is easy to use. Secondly what we find that pytorch best suits for developers

In the coming post we will try to reimplement full fastai library using JAX. As we have already seen that JAX is the fastest and most efficient library.
So stay tuned for the next post.




Leave a comment

Comments

No Comments Yet

About Us

We are Saify technologies a software development company located in india . Who develops custom Web , Mobile applications. We are specialized in Artificial intelligence technology. We have completed many projects on ecommerce with many different technology stacks which includes JAVA, Flutter , PHP, Python , C#,Swift and many more