Transfer Learning with Diabetic Retinopathy Diagnosis

Dev Patel
The Startup
Published in
5 min readDec 28, 2020

--

Feature Maps of Image1 from dataset that was used in Pt. 1 Using VGG16 Model

We’re back with CNNs and developing accurate classification models. I recommend you read the last part (and clap it up too) so that you’re familiar with the content, the approach, and dataset that we’re using.

To give you a refresher, we designed a CNN to diagnose Diabetic Retinopathy from a dataset with 3662 images of patients with either Mild, Moderate, Severe, Proliferate DR, or No DR cases. After training and evaluating the model however, the accuracy only reached 23% as the loss did not go down and stayed around 1.6 [loss.item()].

After some reconstruction and learning_rate graphs, I found that it was not a problem with the architecture, the pre-processing, or the optimizer for that matter. Rather, it was more of a data issue where the model did not have enough data to train its filters on.

This is a key motivation behind the idea of transfer learning where most models nowadays take advantage of optimized, pre-established networks and weights while modifying them for their own dataset. It’s a great way to boost the performance of your model while reaping the benefits of these high-complexity architectures.

Well, what is transfer learning anyways?

Simply put, it literally means transferring the learnings from a given task and using these learnings for a similar task. It’s more-so a general system for initializing a model’s hyper-parameters and then applying your own layers for your specified dataset.

There are countless transfer learning models that have been trained on millions of images with well-trained filters to extract high-dimensional feature maps from a variety of images.

Credit: Sivaramakrishnan Rajaraman, Visualizing abnormalities in chest radiographs through salient network activations in Deep Learning Examples of feature maps extracted from a multitude of radiographs

The data these models are trained on is vast and powers their generalization across well-formatted data. Because of the time it takes to train these models, developers can use checkpoints that only require parts of the model for saves while fine-tuning the parameters based on the training data.

With that being said, there are a few constraints that transfer learning models have including the inability to take out convolutional layers, alter and experiment with the architecture, and the time it takes to train small datasets.

For PyTorch, its torchvision.models library has several pre-trained computer vision models including AlexNet, VGG, ResNet, SqueezeNet, DenseNet, Inception v3, GoogLeNet, ShuffleNet v2, MobileNet v2, ResNeXt, Wide ResNet, and MNASNet.

For our purposes, we will be using the ResNet for its superior “residual learning” capabilities which basically means it does a better job at being a CNN.

Moving onto building, training, and testing.

Let’s just load in the dataset from the previous part and also change some formatting. Notice that I used dict. types to store a lot of the values and that was done intentionally so that I could create a better training function in terms of readability.

If you have followed along from part 1, the only section of this code that you will need is from line 61–102

From here, it’s all fairly similar apart from some slight changes in the formatting. Notice that I created a few new dicts for the image_dataset call (combines both the train and val set), the individual dataloaders for both sets, and a dict for the sizes of the datasets. You can check the lengths of the dataloaders by calling the element in the dataset_sizes dict.

Let’s also preview the images using an updated imshow function:

The output should be the following (resize might alter the view of the image and shuffle could also change the images):

Batch Size = 10: DR Images from the trainloader

Now comes the training function that I implemented using PyTorch documentation. The only major change in this is that now the learning rate decays while creating different models and keeping only the best. It does this by calculating both the training and validation accuracy after each epoch and only updating the model if these 2 accuracies improve.

This code basically keeps 1 coherent accuracy tester and uses conditional statements to only run certain functions based on whether we are looking to update the model or test the model (train and val respectively)

During this process, we implement certain functions for a given procedure so that we can evaluate the performance of a model for its given training accuracy.

This also allows us to produce a far better model as it automatically tunes the hyper-parameters to optimize the learning process. We did the same thing in the last part expect I did it all manually. In this case, it will store only the best model.

Now that we’ve created a training sequence function, we can now establish our training model and actually start training it. In this case, we’ll use the ResNet model but for a given dataset, testing other pre-trained models is also effective.

You’re gonna have to leave this on for a while depending on your hardware and dataset size. For me, the total training time was about 4 hours and 25 minutes because I ran the model on the cpu. The GPU would work but just make sure to change the device var to “cuda”.

After training, the results are astronomical compared to the previous model.

With a training loss accuracy of 98.95% and a validation score of 80.16%, the model is a serious contender to the error rate of a typical ophthalmologist.

Even still, some things you should keep in mind is that data is typically different across several sources and in the real world, this validation accuracy will be fairly low. Getting high quality data but also low quality data is essential for training an effective model that is actually usable in the real world.

Although this is a pretty straightforward project, it does showcase the growth computer vision as a field has seen with innovative companies and passionate builders working together to solve some of the world’s most challenging issues. This opens so many doors for not just the underprivileged and victims of these diseases, but for also transforming these industries and fields with just a few lines of code. I can get behind that, and I hope that you can too.

Thanks for taking the time to read this and I hope you got something out of it. If you want to get more technical or simply reach out to me, you can find me on LinkedIn, Email, or GitHub. You can also subscribe to my newsletter here.

Forgot to introduce myself: I’m Dev

I’m an innovator, developer, and proud nerd looking to leave a mark on the world through technology and science.

I love connecting with new people so feel free to reach out to me if you have any questions or want to talk technical.

--

--

Dev Patel
The Startup

ML X SynthBio | Looking to learn, grow and build the ideas of the future into reality. | 18/o