 # MNIST Tutorial in googles jax dimple

##### Published On : 2022-07-15 In This tutorial we will learn how to use the googles jax library to do MNIST classification. We will use the MNIST dataset to train a neural network   to  classify the digits.
The MNIST dataset is a set of images of handwritten digits. The goal is to train a neural network to recognize the digits.
Firstly we will import the necessary libraries.
Secondly we will load the MNIST dataset.
Thirdly we will define the neural network model using the JAX library.
Finally we will train the model. In training we will use simple stochastic gradient descent. And we will use jax automatic differentiation to compute  the gradients.

Import the necessary libraries.

`from pathlib import Path`
`from IPython.core.debugger import set_trace`

`import pickle, gzip, math, torch, matplotlib as mpl`
`import matplotlib.pyplot as plt`
`from fastai.vision import * `
`from fastai.data.all import *`
`import jax.numpy as jnp`
`from jax import grad, jit, vmap`
`from jax import random`
`from PIL import Image`
`from numpy import asarray`

`MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl' #MNIST dataset URL`

`from PIL import Image`
`from numpy import asarray`
`def get_data(validation_per):`
` path = untar_data(URLs.MNIST);`
creating List
`x_train_ls=[]`
` y_train_ls = []`
` for i in range(0,10):`
` x_ls ,y_ls = [asarray(Image.open(f)) for f in (path/"training"/str(i)).ls()] ,[ i for f in (path/"training"/str(i)).ls()]`
` y_train_ls.extend(y_ls)`
` x_train_ls.extend(x_ls)`

converting list to numpy
`train_x ,train_y = np.array(x_train_ls) , np.array(y_train_ls)`

randomly shuffling data in th numpy
`seed = np.random.randint(0, 10000)`
` np.random.seed(seed)`
` np.random.shuffle(train_x)`
` np.random.seed(seed)`
` np.random.shuffle(train_y)`

dividing validation and training
`training_set= int(train_x.shape * (1-(validation_per/100)))`
` training_x,training_y,validation_x,validation_y = train_x[:training_set],train_y[:training_set],train_x[training_set:],train_y[training_set:]`

`return map(jnp.array, (training_x,training_y,validation_x,validation_y))`

In The above code we will load the MNIST dataset. We will use the untar_data function from the fastai library to download the dataset. We will also use the map function to convert the numpy arrays to jax arrays.

`x_train , y_train , x_valid,y_valid = get_data(20) #20% of the data will be used for validation`

`def normalize(x, m, s): return (x-m)/s`

`train_mean,train_std = x_train.mean(),x_train.std()`
`train_mean,train_std`

`x_train = normalize(x_train, train_mean, train_std)`
NB: Use training, not validation mean for validation set
`x_valid = normalize(x_valid, train_mean, train_std)`

In The above code we will normalize the data. We will use the mean and standard deviation of the training data to normalize the data.

`import numpy as np`
`x_train = x_train.reshape(x_train.shape, x_train.shape*x_train.shape) #flatten rows to make it 784 last two columns are merged`
`x_valid = x_valid.reshape(x_valid.shape, x_valid.shape*x_valid.shape)`
`x_train.shape, x_valid.shape`

`n,m = x_train.shape #n is the number of samples and m is the number of features`
`c = y_train.max()+1 #number of classes`
`n,m,c #n is the number of samples and m is the number of features`

In The above code we will reshape the data. We will use the numpy reshape function to reshape the data. We will also use the max function to find the maximum value in the y_train array. We will add 1 to the maximum value to find the number of classes.

`nh = 50 #number of hidden units will be 50`
`seed=1`
`key = random.PRNGKey(seed)`
`w1 = random.normal(key=key,shape=(m,nh)) * math.sqrt(2/m)`
`b1 = jnp.zeros(nh)`
`w2 = random.normal(key=key,shape=(nh ,1)) * math.sqrt(2/nh)`
`b2 = jnp.ones(1)`

In The above code we will define the neural network model. We will use the random.normal function to initialize the weights and biases. We will use the jnp.zeros function to initialize the biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.

`def lin(x,w,b): # linear function`
` return x@w + b`

`def relu(x): #ReLU function`
` return jnp.maximum(0,x)`

`def leaky_relu(x): #Leaky ReLU function`
` return jnp.maximum(0.01*x,x)`

`def model(xb,w1,w2,b1,b2) : #model function`
` l1=lin(xb,w1,b1)`
` l2=leaky_relu(l1)`
` l3=lin(l2,w2,b2)`
` return l3`

model function will be used to define the neural network model, which will be used to train the model. Which contains the following layers:
1. Linear layer
2. ReLU layer
3. Linear layer

`def mse(x,y): # mean squared error function for loss function`
` return ((jnp.squeeze(x) - y)**2).mean()`

In The above code we will define the mse function. We will use the jnp.squeeze function to remove the extra dimension from the output of the model. We will use the mean function to find the mean of the squared error.

Now we will try to do above all through pytorch design method .

`class Module():`
` def __call__(self, *args):`
` self.args = args`
` self.out = self.forward(*args)`
` return self.out`

` def forward(self): raise Exception('not implemented')`

`class Relu(Module):`
` def forward(self, inp): return jnp.maximum(0.,inp) - 0.5`

`class Lin(Module):`
` def __init__(self, w, b): `
` self.w,self.b = w,b`

` def forward(self, inp): return inp@self.w + self.b`

`class Mse(Module):`
` def forward (self, inp, targ): return ((jnp.squeeze(inp) - targ)**2).mean()`

In the above code we will define the Relu, Lin and Mse classes. We will use the jnp.maximum function to find the maximum value in the input. We will use the jnp.squeeze function to remove the extra dimension from the output of the model. We will use the mean function to find the mean of the squared error.

`def Model(x,targ,w1,b1,w2,b2):`
` layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]`
` loss = Mse()`
` for l in layers: x = l(x)`
` return loss(x, targ)`

`lr = 0.001`

In The above code we will define the learning rate. We will use the jnp.zeros function to initialize the weights and biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.

`loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_valid,y_valid,w1,b1,w2,b2)`

In The above code we will define the loss function and the gradients. We will use the jax.value_and_grad function to find the loss and gradients.

`class NegativeLogLikelyhood(Module):`
` def forward (self, inp, targ): return -inp[[i for i in range(inp.shape)],targ].mean()`

In The above code we will define the NegativeLogLikelyhood class. We will use the jnp.mean function to find the mean of the negative log likelyhood.

`def log_softmax(x):`
` return jnp.log(jnp.exp(x) / jnp.sum(jnp.exp(x)))`

In The above code we will define the log_softmax function. We will use the jnp.log function to find the log of the softmax. We will use the jnp.exp function to find the exponential of the input. We will use the jnp.sum function to find the sum of the exponential.

`nh = 50 #number of hidden units will be 50`
`seed=1`
`bs=64`
`key = random.PRNGKey(seed)`
`w1 = random.normal(key=key,shape=(m,nh)) * math.sqrt(2/m)`
`b1 = jnp.zeros(nh)`
`w2 = random.normal(key=key,shape=(nh ,10)) * math.sqrt(2/nh)`
`b2 = jnp.ones(10)`
`lr=0.01`
`epochs = 5`
`loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_train[0:64],y_train[0:64],w1,b1,w2,b2)`
`print("Initial Loss:-"+loss)`

`for epoch in range(epochs):`
` for i in range(((totalNumerOfRows-1)//bs)+1):`
` start_index=i*bs`
` end_index=(i+1)*bs`
` loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_train[start_index:end_index],y_train[start_index:end_index],w1,b1,w2,b2)`
` w1 -= lr*grads`
` b1 -= lr*grads`
` w2 -= lr*grads`
` b2 -= lr*grads`

` print(loss)`

In The above code we will define the epochs. We will use the jnp.zeros function to initialize the weights and biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.
We will use the jnp.mean function to find the mean of the negative log likelyhood.
We will use the jnp.log function to find the log of the softmax. We will use the jnp.exp function to find the exponential of the input. We will use the jnp.sum function to find the sum of the exponential.
In the above code we are training the model for the number of epochs.

As you run the above program you will see that the loss is decreasing.

You off the above code is available on this https://github.com/Saify-Technologies/MNIST-example-deep-learning-in-googles-jax.git.

## Related posts ### NLLB 200 inference in google colab for language translation ### CRUD operation using VUEJS 