MNIST Tutorial in googles jax

MNIST Tutorial in googles jax


 In This tutorial we will learn how to use the googles jax library to do MNIST classification. We will use the MNIST dataset to train a neural network   to  classify the digits.
 The MNIST dataset is a set of images of handwritten digits. The goal is to train a neural network to recognize the digits.
 Firstly we will import the necessary libraries.
 Secondly we will load the MNIST dataset.
 Thirdly we will define the neural network model using the JAX library.
 Finally we will train the model. In training we will use simple stochastic gradient descent. And we will use jax automatic differentiation to compute  the gradients.

 Import the necessary libraries.

from pathlib import Path
from IPython.core.debugger import set_trace

import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt
from fastai.vision import *
from fastai.data.all import *
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from PIL import Image
from numpy import asarray

Load the MNIST dataset.

MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl' #MNIST dataset URL

from PIL import Image
from numpy import asarray
def get_data(validation_per):
path = untar_data(URLs.MNIST);
creating List
x_train_ls=[]
y_train_ls = []
for i in range(0,10):
x_ls ,y_ls = [asarray(Image.open(f)) for f in (path/"training"/str(i)).ls()] ,[ i for f in (path/"training"/str(i)).ls()]
y_train_ls.extend(y_ls)
x_train_ls.extend(x_ls)

converting list to numpy
train_x ,train_y = np.array(x_train_ls) , np.array(y_train_ls)

randomly shuffling data in th numpy
seed = np.random.randint(0, 10000)
np.random.seed(seed)
np.random.shuffle(train_x)
np.random.seed(seed)
np.random.shuffle(train_y)

dividing validation and training
training_set= int(train_x.shape[0] * (1-(validation_per/100)))
training_x,training_y,validation_x,validation_y = train_x[:training_set],train_y[:training_set],train_x[training_set:],train_y[training_set:]

return map(jnp.array, (training_x,training_y,validation_x,validation_y))


 In The above code we will load the MNIST dataset. We will use the untar_data function from the fastai library to download the dataset. We will also use the map function to convert the numpy arrays to jax arrays.

x_train , y_train , x_valid,y_valid = get_data(20) #20% of the data will be used for validation

def normalize(x, m, s): return (x-m)/s

train_mean,train_std = x_train.mean(),x_train.std()
train_mean,train_std

x_train = normalize(x_train, train_mean, train_std)
 NB: Use training, not validation mean for validation set
x_valid = normalize(x_valid, train_mean, train_std)

 In The above code we will normalize the data. We will use the mean and standard deviation of the training data to normalize the data.

import numpy as np
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1]*x_train.shape[2]) #flatten rows to make it 784 last two columns are merged
x_valid = x_valid.reshape(x_valid.shape[0], x_valid.shape[1]*x_valid.shape[2])
x_train.shape, x_valid.shape

n,m = x_train.shape #n is the number of samples and m is the number of features
c = y_train.max()+1 #number of classes
n,m,c #n is the number of samples and m is the number of features

 In The above code we will reshape the data. We will use the numpy reshape function to reshape the data. We will also use the max function to find the maximum value in the y_train array. We will add 1 to the maximum value to find the number of classes.


nh = 50 #number of hidden units will be 50
seed=1
key = random.PRNGKey(seed)
w1 = random.normal(key=key,shape=(m,nh)) * math.sqrt(2/m)
b1 = jnp.zeros(nh)
w2 = random.normal(key=key,shape=(nh ,1)) * math.sqrt(2/nh)
b2 = jnp.ones(1)

 In The above code we will define the neural network model. We will use the random.normal function to initialize the weights and biases. We will use the jnp.zeros function to initialize the biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.

def lin(x,w,b): # linear function
return x@w + b


def relu(x): #ReLU function
return jnp.maximum(0,x)

def leaky_relu(x): #Leaky ReLU function
return jnp.maximum(0.01*x,x)


def model(xb,w1,w2,b1,b2) : #model function
l1=lin(xb,w1,b1)
l2=leaky_relu(l1)
l3=lin(l2,w2,b2)
return l3

 model function will be used to define the neural network model, which will be used to train the model. Which contains the following layers:
 1. Linear layer
 2. ReLU layer
 3. Linear layer

def mse(x,y): # mean squared error function for loss function
return ((jnp.squeeze(x) - y)**2).mean()

 In The above code we will define the mse function. We will use the jnp.squeeze function to remove the extra dimension from the output of the model. We will use the mean function to find the mean of the squared error.

 Now we will try to do above all through pytorch design method .

class Module():
def __call__(self, *args):
self.args = args
self.out = self.forward(*args)
return self.out

def forward(self): raise Exception('not implemented')

class Relu(Module):
def forward(self, inp): return jnp.maximum(0.,inp) - 0.5

class Lin(Module):
def __init__(self, w, b):
self.w,self.b = w,b

def forward(self, inp): return inp@self.w + self.b

class Mse(Module):
def forward (self, inp, targ): return ((jnp.squeeze(inp) - targ)**2).mean()

In the above code we will define the Relu, Lin and Mse classes. We will use the jnp.maximum function to find the maximum value in the input. We will use the jnp.squeeze function to remove the extra dimension from the output of the model. We will use the mean function to find the mean of the squared error.

def Model(x,targ,w1,b1,w2,b2):
layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]
loss = Mse()
for l in layers: x = l(x)
return loss(x, targ)

lr = 0.001

 In The above code we will define the learning rate. We will use the jnp.zeros function to initialize the weights and biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.

loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_valid,y_valid,w1,b1,w2,b2)

 In The above code we will define the loss function and the gradients. We will use the jax.value_and_grad function to find the loss and gradients.

class NegativeLogLikelyhood(Module):
def forward (self, inp, targ): return -inp[[i for i in range(inp.shape[0])],targ].mean()

 In The above code we will define the NegativeLogLikelyhood class. We will use the jnp.mean function to find the mean of the negative log likelyhood.

def log_softmax(x):
return jnp.log(jnp.exp(x) / jnp.sum(jnp.exp(x)))

In The above code we will define the log_softmax function. We will use the jnp.log function to find the log of the softmax. We will use the jnp.exp function to find the exponential of the input. We will use the jnp.sum function to find the sum of the exponential.

nh = 50 #number of hidden units will be 50
seed=1
bs=64
key = random.PRNGKey(seed)
w1 = random.normal(key=key,shape=(m,nh)) * math.sqrt(2/m)
b1 = jnp.zeros(nh)
w2 = random.normal(key=key,shape=(nh ,10)) * math.sqrt(2/nh)
b2 = jnp.ones(10)
lr=0.01
epochs = 5
loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_train[0:64],y_train[0:64],w1,b1,w2,b2)
print("Initial Loss:-"+loss)

for epoch in range(epochs):
for i in range(((totalNumerOfRows-1)//bs)+1):
start_index=i*bs
end_index=(i+1)*bs
loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_train[start_index:end_index],y_train[start_index:end_index],w1,b1,w2,b2)
w1 -= lr*grads[0]
b1 -= lr*grads[1]
w2 -= lr*grads[2]
b2 -= lr*grads[3]

print(loss)


 In The above code we will define the epochs. We will use the jnp.zeros function to initialize the weights and biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.
 We will use the jnp.mean function to find the mean of the negative log likelyhood.
 We will use the jnp.log function to find the log of the softmax. We will use the jnp.exp function to find the exponential of the input. We will use the jnp.sum function to find the sum of the exponential.
 In the above code we are training the model for the number of epochs.

 As you run the above program you will see that the loss is decreasing.

 You off the above code is available on this https://github.com/Saify-Technologies/MNIST-example-deep-learning-in-googles-jax.git.




Taher Ali Badnawarwala

Taher Ali, drives to create something special, He loves swimming ,family and AI from depth of his heart . He loves to write and make videos about AI and its usage


Leave a Comment


No Comments Yet

Leave a Reply

Your email address will not be published. Required fields are marked *