**Simple Linear Regression on JAX , Pytorch and Tensorflow**

In this post, we will learn about linear regression. We will implement linear regression on JAX, Pytorch and Tensorflow. This is first posted by jeremy howard in fastai tutorials ( https://www.youtube.com/user/howardjeremyp )

which shows every thing from start in pytorch . We though of remaking it using JAX and tensorflow . And compare each one which is easy to develop for development

**What is Linear Regression ?**

Linear regression is a model which predicts the value of an unknown variable using the observed value of a number of known variables.

Let’s understand linear regression with an example.

Suppose, we are given a data set which consists of data of the form

{x, y}

where, x are input values and y is the target value.We will use x to predict y.The input variables are not limited to the number of variables, it could be more.

We need to find a formula to predict y.

The formula is,

y = a*x

In the above formula, a is the weights of the formula.

For experiment with all three libraries . we will be doing following steps

first we need to import all libraries.

secondly need to create a data set.

thirdly we need to create a model.

fourthly we need to train the model.

fifthly we need to evaluate model using graph.

Lets first Start With Pytorch as it is shown by jeremy howard in fastai tutorials.

**pytorch**

importing all libraries

`from fastai.basics import *`

`from matplotlib import *`

`from fastai.torch_core import tensor`

`creating a data set`

`n= 100 #this is point numbers`

`x = torch.ones(n,2)`

`x[:,0].uniform_(-1.,1)`

`a = tensor(3.,2);`

`y=x@a + 0.25*torch.randn(n)`

In the above code, we have created a data set with 100 points. In the x we have created a matrix with two columns. The first column is 1 and the second column is uniform random numbers between -1 and 1. The y is the target value. The target value is the output of the formula y = a*x + 0.25*randn(n).

Now we will create graph to show our data set in matplotlib .

`plt.scatter(x[:,0], y);`

`def mse(ypred,y): `

` return ((ypred - y)**2).mean()`

In the above code, we have created a function to calculate the mean square error. The function takes two arguments, ypred and y. The ypred is the predicted value and y is the actual value. It is loss function for our model.

`plt.scatter(x[:,0], y);`

`plt.scatter(x[:,0], ypred,c="orange");`

`x[0,0]`

In the above code, we have created a graph to show the data set and the predicted value. The predicted value is the output of the formula y = a*x + 0.25*randn(n).

`a = nn.Parameter(a) ; a`

In the above code, we have created a parameter a. The parameter a is the weights of the formula. The weights are the parameters which are used to predict the output.

`def update():`

` ypred = x@a #This is the predicted value`

` loss = mse(ypred,y) #This is the loss function`

` if t % 10 == 0 : print(loss) #This is to print the loss function`

` loss.backward() #This is to backpropagate the loss function which calculates the gradient of the loss function`

` with torch.no_grad():`

` a.sub_(lr*a.grad) #This is to update the weights `

` a.grad.zero_() #This is to set the gradient to zero after updating the weights`

In the above code, we have created a function to update the weights. The function takes no arguments. The function updates the weights of the formula. The weights are updated using the gradient of the loss function.

`lr = 1e-1 #This is the learning rate`

`epoch = 100 #This is the number of epochs`

`for t in range(100):`

` update() #This is to update the weights`

In the above code, we have created a loop to update the weights. The loop takes 100 iterations. The loop updates the weights of the formula. The weights are updated using the gradient of the loss function.

`from matplotlib import animation, rc`

`rc('animation', html='jshtml')`

`a = nn.Parameter(tensor(-100.,100.))`

`fig = plt.figure()`

`plt.scatter(x[:,0], y, c='orange')`

`line, = plt.plot(x[:,0], x@a.detach())`

`plt.close()`

`def animate(i):`

` update()`

` line.set_ydata(x@a.detach())`

` return line,`

animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)

In the above code, we have created a function to animate the graph. The function takes one argument, i.e. the iteration number. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The graph is animated and graph is updated every 20 iterations.

**Tensorflow**

importing all libraries

import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

creating a data set

`n = 100 #total number of points`

`x = np.ones(shape=(n,2));x[:5]`

`x_ = np.linspace(-1, 1, num=100);`

`x[:,0] = x_`

`a = np.array([3,2]);`

`y = x@a + 0.75*np.random.rand(n);y.shape`

In the above code, we have created a data set with 100 points. In the x we have created a matrix with two columns. The first column is 1 and the second column is uniform random numbers between -1 and 1. The y is the target value. The target value is the output of the formula y = a*x + 0.25*randn(n).

`plt.scatter(x[:,0], y);`

`def mse(ypred,y):`

` return tf.reduce_mean((ypred-y)**2)`

In the above code, we have created a function to calculate the mean square error. The function takes two arguments, ypred and y. The ypred is the predicted value and y is the actual value. It is loss function for our model. It is modified to use tensorflow.

`plt.scatter(x[:,0], y);`

`plt.scatter(x[:,0], ypred,c="orange");`

`x[0,0]`

In the above code, we have created a graph to show the data set and the predicted value. The predicted value is the output of the formula y = a*x + 0.25*randn(n).

`a = tf.Variable(tf.reshape(np.array([-.5,.5]),shape=(2,1)));`

In the above code, we have created a variable a. The variable a is the weights of the formula. The weights are the parameters which are used to predict the output. The variable a is modified to use tensorflow.

`x_new = tf.constant(x)`

`ypred_tensor = x_new@a`

`y_tensor = tf.constant(y)`

Above code converts the numpy array to tensorflow tensor.

`def update():`

` with tf.GradientTape() as tape: #This is to calculate the gradient of the loss function using tensorflow gradient tape and tensorflow `

` ypred = x_new@a #This is the predicted value`

` loss = mse(tf.reshape(ypred,shape=(100,)),y) #This is the loss function`

` if t%10 == 0 : print(loss) #This is to print the loss function`

` tape.watch(a) #This is to watch the weights`

` grad = tape.gradient(loss, a) #This is to calculate the gradient of the loss function`

` a.assign_sub((lr * grad)) #This is to update the weights`

In the above code, we have created a function to update the weights. The function takes no arguments. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The function is modified to use tensorflow. It uses the gradient tape to calculate the gradient of the loss function. and saves the gradient in the variable grad.

`lr = 1e-1 #This is the learning rate`

`for t in range(100):`

` update() #This is to update the weights this function defines in the previous code`

In the above code, we have created a loop to update the weights. The loop takes 100 iterations. The loop updates the weights of the formula. The weights are updated using the gradient of the loss function. The loop is modified to use tensorflow.

`from matplotlib import animation, rc`

`rc('animation', html='jshtml')`

`a = tf.Variable(tf.reshape(np.array([-100.,.1]),shape=(2,1)))`

`fig = plt.figure()`

`plt.scatter(x[:,0], y, c='orange')`

`line, = plt.plot(x[:,0], tf.reshape(x_new@a,shape=(100,)))`

`plt.close()`

`def animate(i):`

` update()`

` line.set_ydata(tf.reshape(x_new@a,shape=(100,)))`

` return line,`

`animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20`

)

In the above code, we have created a function to animate the graph. The function takes one argument, i.e. the iteration number. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The graph is animated and graph is updated every 20 iterations. The graph is modified to use tensorflow.

Now we will use googles JAX

#importing all libraries`import numpy as np`

`import jax.numpy as jnp`

`from jax import grad, jit, vmap`

`from jax import random`

`from PIL import Image`

`import jax`

`creating a data set`

`n = 100 total number of points`

`x = jnp.ones(shape=(n,2));x[:5]`

`x_ = jnp.linspace(-1, 1, num=100);`

`x = x.at[:,0].set(x_) `

`a = jnp.array([3,2]);`

`seed=1`

`key = jax.random.PRNGKey(seed)`

`y = x@a + 0.75*np.random.rand(n)`

In the above code, we have created a data set with 100 points. In the x we have created a matrix with two columns. The first column is 1 and the second column is uniform random numbers between -1 and 1. The y is the target value. The target value is the output of the formula y = a*x + 0.75*randn(n).

`plt.scatter(x[:,0], y);`

`def mse(ypred,y):`

` return jnp.mean((ypred-y)**2)`

In the above code, we have created a function to calculate the mean square error. The function takes two arguments, ypred and y. The ypred is the predicted value and y is the actual value. It is loss function for our model. It is modified to use JAX.

`a = jnp.array([-.5,.5])`

`ypred = x@a`

`def model(x,a):`

` return x@a`

In the above code, we have created model function. The function takes two arguments, x and a. The function returns the output .

`def grad_func(y,x,a):`

` return mse(model(x,a),y)`

In the above code , we have created a function to calculate the gradient of the loss function. The function takes three arguments, y, x and a. The function will return the gradient of the loss function.

`lr = 1e-1 #This is the learning rate`

`def update():`

` global a #This is to make the variable a global variable`

` loss,grad = jax.value_and_grad(grad_func,argnums=(2))(y,x,a) #This is to calculate the gradient and value of the loss function in JAX `

` if t%10 == 0 : print(loss) #This is to print the loss function`

` a -= lr*grad #This is to update the weights`

In the above code, we have created a function to update the weights. The function takes no arguments. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The function is modified to use JAX. It uses the value and gradient tape to calculate the gradient of the loss function and update the weights.

`def loop():`

` for t in range(100):`

` update()`

In the above code, we have created a loop to update the weights. The loop takes 100 iterations. The loop updates the weights of the formula. The weights are updated using the gradient of the loss function.

`from matplotlib import animation, rc`

`rc('animation', html='jshtml')`

`a = jnp.array([-100.,.1]) #This is the weights of the formula`

`fig = plt.figure()`

`plt.scatter(x[:,0], y, c='orange') #This is to plot the data set`

`line, = plt.plot(x[:,0], x@a) #This is to animate the graph`

`plt.close()`

`def animate(i): #This is to animate the graph`

` update()`

` line.set_ydata(x@a) #This is to update the graph`

` return line, #This is to return the graph`

`animation.FuncAnimation(fig, animate, np.arange(0, 100), interval=20)`

In the above code, we have created a function to animate the graph. The function takes one argument, i.e. the iteration number. The function updates the weights of the formula. The weights are updated using the gradient of the loss function. The graph is animated and graph is updated every 20 iterations. The graph is modified to use JAX.

You can view all three google colab notebooks in this repo

In Our Conclusion between tensorflow , Pytorch and JAX, we have found that JAX is the best library to use for this example as it is the fastest and most efficient library. And it is easy to use. Secondly what we find that pytorch best suits for developers

In the coming post we will try to reimplement full fastai library using JAX. As we have already seen that JAX is the fastest and most efficient library.

So stay tuned for the next post.

## Leave a Comment

No Comments Yet