MNIST Tutorial in googles jax
In This tutorial we will learn how to use the googles jax library to do MNIST classification. We will use the MNIST dataset to train a neural network to classify the digits.
The MNIST dataset is a set of images of handwritten digits. The goal is to train a neural network to recognize the digits.
Firstly we will import the necessary libraries.
Secondly we will load the MNIST dataset.
Thirdly we will define the neural network model using the JAX library.
Finally we will train the model. In training we will use simple stochastic gradient descent. And we will use jax automatic differentiation to compute the gradients.
Import the necessary libraries.
from pathlib import Pathfrom IPython.core.debugger import set_trace
import pickle, gzip, math, torch, matplotlib as mplimport matplotlib.pyplot as pltfrom fastai.vision import * from fastai.data.all import *import jax.numpy as jnpfrom jax import grad, jit, vmapfrom jax import randomfrom PIL import Imagefrom numpy import asarray
Load the MNIST dataset.
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl' #MNIST dataset URL
from PIL import Imagefrom numpy import asarraydef get_data(validation_per): path = untar_data(URLs.MNIST);
creating List
x_train_ls=[] y_train_ls = [] for i in range(0,10): x_ls ,y_ls = [asarray(Image.open(f)) for f in (path/"training"/str(i)).ls()] ,[ i for f in (path/"training"/str(i)).ls()] y_train_ls.extend(y_ls) x_train_ls.extend(x_ls)
converting list to numpy
train_x ,train_y = np.array(x_train_ls) , np.array(y_train_ls)
randomly shuffling data in th numpy
seed = np.random.randint(0, 10000) np.random.seed(seed) np.random.shuffle(train_x) np.random.seed(seed) np.random.shuffle(train_y)
dividing validation and training
training_set= int(train_x.shape[0] * (1-(validation_per/100))) training_x,training_y,validation_x,validation_y = train_x[:training_set],train_y[:training_set],train_x[training_set:],train_y[training_set:]
return map(jnp.array, (training_x,training_y,validation_x,validation_y))
In The above code we will load the MNIST dataset. We will use the untar_data function from the fastai library to download the dataset. We will also use the map function to convert the numpy arrays to jax arrays.
x_train , y_train , x_valid,y_valid = get_data(20) #20% of the data will be used for validation
def normalize(x, m, s): return (x-m)/s
train_mean,train_std = x_train.mean(),x_train.std()train_mean,train_std
x_train = normalize(x_train, train_mean, train_std)
NB: Use training, not validation mean for validation setx_valid = normalize(x_valid, train_mean, train_std)
In The above code we will normalize the data. We will use the mean and standard deviation of the training data to normalize the data.
import numpy as npx_train = x_train.reshape(x_train.shape[0], x_train.shape[1]*x_train.shape[2]) #flatten rows to make it 784 last two columns are mergedx_valid = x_valid.reshape(x_valid.shape[0], x_valid.shape[1]*x_valid.shape[2])x_train.shape, x_valid.shape
n,m = x_train.shape #n is the number of samples and m is the number of featuresc = y_train.max()+1 #number of classesn,m,c #n is the number of samples and m is the number of features
In The above code we will reshape the data. We will use the numpy reshape function to reshape the data. We will also use the max function to find the maximum value in the y_train array. We will add 1 to the maximum value to find the number of classes.
nh = 50 #number of hidden units will be 50seed=1key = random.PRNGKey(seed)w1 = random.normal(key=key,shape=(m,nh)) * math.sqrt(2/m)b1 = jnp.zeros(nh)w2 = random.normal(key=key,shape=(nh ,1)) * math.sqrt(2/nh)b2 = jnp.ones(1)
In The above code we will define the neural network model. We will use the random.normal function to initialize the weights and biases. We will use the jnp.zeros function to initialize the biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.
def lin(x,w,b): # linear function return x@w + b
def relu(x): #ReLU function return jnp.maximum(0,x)
def leaky_relu(x): #Leaky ReLU function return jnp.maximum(0.01*x,x)
def model(xb,w1,w2,b1,b2) : #model function l1=lin(xb,w1,b1) l2=leaky_relu(l1) l3=lin(l2,w2,b2) return l3
model function will be used to define the neural network model, which will be used to train the model. Which contains the following layers:
1. Linear layer
2. ReLU layer
3. Linear layer
def mse(x,y): # mean squared error function for loss function return ((jnp.squeeze(x) - y)**2).mean()
In The above code we will define the mse function. We will use the jnp.squeeze function to remove the extra dimension from the output of the model. We will use the mean function to find the mean of the squared error.
Now we will try to do above all through pytorch design method .
class Module(): def __call__(self, *args): self.args = args self.out = self.forward(*args) return self.out
def forward(self): raise Exception('not implemented')
class Relu(Module): def forward(self, inp): return jnp.maximum(0.,inp) - 0.5
class Lin(Module): def __init__(self, w, b): self.w,self.b = w,b
def forward(self, inp): return inp@self.w + self.b
class Mse(Module): def forward (self, inp, targ): return ((jnp.squeeze(inp) - targ)**2).mean()
In the above code we will define the Relu, Lin and Mse classes. We will use the jnp.maximum function to find the maximum value in the input. We will use the jnp.squeeze function to remove the extra dimension from the output of the model. We will use the mean function to find the mean of the squared error.
def Model(x,targ,w1,b1,w2,b2): layers = [Lin(w1,b1), Relu(), Lin(w2,b2)] loss = Mse() for l in layers: x = l(x) return loss(x, targ)
lr = 0.001
In The above code we will define the learning rate. We will use the jnp.zeros function to initialize the weights and biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.
loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_valid,y_valid,w1,b1,w2,b2)
In The above code we will define the loss function and the gradients. We will use the jax.value_and_grad function to find the loss and gradients.
class NegativeLogLikelyhood(Module): def forward (self, inp, targ): return -inp[[i for i in range(inp.shape[0])],targ].mean()
In The above code we will define the NegativeLogLikelyhood class. We will use the jnp.mean function to find the mean of the negative log likelyhood.
def log_softmax(x): return jnp.log(jnp.exp(x) / jnp.sum(jnp.exp(x)))
In The above code we will define the log_softmax function. We will use the jnp.log function to find the log of the softmax. We will use the jnp.exp function to find the exponential of the input. We will use the jnp.sum function to find the sum of the exponential.
nh = 50 #number of hidden units will be 50seed=1bs=64key = random.PRNGKey(seed)w1 = random.normal(key=key,shape=(m,nh)) * math.sqrt(2/m)b1 = jnp.zeros(nh)w2 = random.normal(key=key,shape=(nh ,10)) * math.sqrt(2/nh)b2 = jnp.ones(10)lr=0.01epochs = 5loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_train[0:64],y_train[0:64],w1,b1,w2,b2)print("Initial Loss:-"+loss)
for epoch in range(epochs): for i in range(((totalNumerOfRows-1)//bs)+1): start_index=i*bs end_index=(i+1)*bs loss , grads = jax.value_and_grad(Model,argnums=[2,3,4,5])(x_train[start_index:end_index],y_train[start_index:end_index],w1,b1,w2,b2) w1 -= lr*grads[0] b1 -= lr*grads[1] w2 -= lr*grads[2] b2 -= lr*grads[3]
print(loss)
In The above code we will define the epochs. We will use the jnp.zeros function to initialize the weights and biases. We will use the jnp.ones function to initialize the biases. This is kaiming he initialization.
We will use the jnp.mean function to find the mean of the negative log likelyhood.
We will use the jnp.log function to find the log of the softmax. We will use the jnp.exp function to find the exponential of the input. We will use the jnp.sum function to find the sum of the exponential.
In the above code we are training the model for the number of epochs.
As you run the above program you will see that the loss is decreasing.
You off the above code is available on this https://github.com/Saify-Technologies/MNIST-example-deep-learning-in-googles-jax.git.
Leave a Comment
No Comments Yet