From the series: Deep Learning Webinars 2020
Abhijit Bhattacharjee, MathWorks
Get a more detailed look into designing, customizing, and training advanced neural networks. Learn about the extended deep learning framework in MATLAB®, which enables you to implement advanced network architectures such as generative adversarial networks (GANs), variational autoencoders (VAEs), or Siamese networks.
You will also learn how to:
I'm very happy to be here today to talk to you about some of the advanced capabilities in MATLAB for deep learning. So today's agenda is going to look like this.
Well, first, let's talk about an interesting example we're going to go through. We're going to talk about Generative Adversarial Networks, or GANs. That's one of the advanced types of neural networks. Here, you see a GAN that's training.
I won't explain all the details just now. We'll be covering all that again later. But the idea is that these are pretty representative of advanced neural network because we have seen that a lot of mainstream applications are coming about due to GANs. If any of you have used a face aging program or seen deep fakes on the internet, or anything that generates a realistic image that's fake, chances are it's being driven by a GAN, which is a deep neural network.
So we're going to talk about that, and the agenda is as follows. We're going to go into a little bit of background. But because this is an advanced session, that background will be very light. So not to worry, we won't spend too much time there.
Then most of the time I'm going to be spending on the second two bullet points here. I'm going to compare some of the framework features. And when I say I'm comparing framework features, I'm talking about the features between older releases of MATLAB and newer releases of MATLAB.
I'm also going to do a little bit of calling out to showing where MATLAB is similar to open source frameworks such as TensorFlow and PyTorch. So you'll see that deep learning in MATLAB is now quite similar to deep learning in those frameworks but with some additional benefits. And then, lastly, we'll finish up with some of the ways that we can help you in achieving your deep learning goals.
So now, let's go ahead and start with the background. And let me repeat one more time, as I said before, let's put everything into context very quickly. I've shown this slide in the previous session. I'm showing it again here just to make sure everyone's on the same page.
AI or a deep learning models don't exist in a vacuum. They exist in the context of a system, if you're an engineer, or you actually want to apply this in a real world. So what does that system design usually look like? We generally think of the first step as data preparation. In machine learning, without data, without good data, without prepared data, you're not going to get very far.
So what does data preparation usually mean? Preprocessed, cleaned data that's perhaps labeled by humans. So that's why we have human insight. Or perhaps, if it's expensive to collect, you might need to simulate the data using some sort of realistic simulator. And so either way, data is a very large part of the process, very time consuming, in many cases, to create the data set properly.
Today, we're not going to focus too much on this because what we're going to focus on is the second part, AI modeling. So after you have your data, of course, the fun, sexy part is creating the AI model. And in particular, last session, which was two days ago, we talked a lot about designing and tuning models. And today, we're going to talk much more about designing a particular type of model, a more advanced type of neural network.
But AI modeling or deep learning modeling involves a lot of other processes as well, like running models on GPUs or interoperating between open source and MATLAB. After you deliver an AI model, it should be tested in the context of a system. And so there's a lot of testing and validation tools within MathWorks that can help you with that.
And then, finally, chances are you're going to be-- your business is selling a product or service that contains AI, not that you're delivering AI models themselves to your customers. So the question of deployment comes up, and we have a lot of resources to help you deploy your AI deep learning models to embedded devices or the Enterprise or the cloud. So as I said before, we're going to be focusing on this second block here today, AI modeling.
And with that, I showed this slide earlier if you attended last times. But I'm going to show it again one more time to highlight something else, and that is in the deep learning evolution within MATLAB-- we've been in this space for now five or six years, and we've invested heavily, at a furious pace. And you can see the evolution of the different features we've added over time.
In 2019 in particular, we added-- let me show-- oops. We added the features that we will talk about today, which in particular are automatic differentiation and then these other algorithm improvements that enable advanced networks like GANs and others that we're going to talk about today.
And we've been recognized as a leader for 2020 by a third-party consulting and research firm, Gartner. And I'd like to highlight that they have recognized us even above Microsoft and Google in this space. So they believe we have the completeness of vision and ability to execute in this area. So we're very proud and humbled to be recognized in that way.
That does it for the background. Now let's talk about what are these framework features that enable us to achieve the goal of modeling advanced neural networks today. So as I was going through before, I had a chronological view of what we were doing. And so in 2019a and prior-- so this is about two releases ago. We release twice a year, a and b release.
So in 2019a and prior, we had the ability to address more limited types of neural networks. What were those neural networks? Let me show you some of those. We had Convolutional Neural Networks, or CNNs. We had recurrent neural networks, and one example is an LSTM. We had the ability to combine the two in a architecture known as C-LSTM. And then, of course, we also always had plain vanilla neural networks, which you could probably talk about as multi-layer perceptron models.
So these, while comprehensive, while there are a lot of deep new network applications that involve these types of neural networks, there is still a lot more that were not there in MATLAB. But that's all changed as of '19b because in '19a we only had a simpler framework, which let you address, broadly speaking, series, directed acyclic graph, or recurrent networks.
But what we did in '19b and '20a is we extended the framework to allow for the architecture of advanced or custom neural networks. And this is to say that we've extended the framework that it was there. So you still now have the option of both the simple framework and the extended framework moving forward. It's not that one is being deprecated or the other. What we've done is added a low level API to complement the high level APIs that was there.
So what does the extended neural network framework add to MATLAB? I think one of the key things that adds is the ability to do generalized unsupervised deep learning. So when we say unsupervised, we're talking about the case of machine learning that is working on unlabeled data, not labeled. And so the type of unsupervised learning that deep learning enables is called generative modeling, and we're going to talk about that as well.
We also added the ability to do multiple input, multiple output neural networks. There is a way to do more than this. But the MIMO networks are really important for the type of deep learning I just talked about, unsupervised learning.
We also added the ability to do functional programming. Now, if you're not familiar with that, let me give a brief explanation. So in the simpler API, one way to architect neural networks is to think of them as layers and then put them together in a set of layers. But there are sometimes deep learning operations that cannot be described in terms of a layer, and rather, you have to describe them mathematically in terms of a function. Some of the advanced neural networks require this ability.
So this is now the same type of programming, let's say, paradigm that you see in TensorFlow or PyTorch, whereas the layer-based approach is more like Keras. So now that we have functional programming, you can basically do anything that you could do in PyTorch or TensorFlow.
And then, finally, we've also added complete customization while training these neural networks, and this is going to be important for a number of different aspects. Before, there was only really partial customization. I won't go into too much detail about what things have been added here. But let's just say that you now have complete customization ability.
And the reason that these are all now possible is because of a couple of key advancements and technologies that we've added. So these enabling features include shared weights. A lot of these advanced neural networks are actually not one neural network, they're multiple neural networks that work together and therefore need to share a parameter space or weights. Automatic differentiation is also a key technology because pretty much every neural network is trained through gradient descent, and that means that you need to be able to take the derivative or take the gradient. And automatically doing so makes the job a lot easier. And then flexible training structures also allow us to do more.
So now I want to ask the question and do a little bit of review of what a typical deep learning training loop consists of. I think this is really important, even if you're familiar, just to put everyone on the same page. And if you're not, this will be illuminating for you.
The first step in a deep learning training loop-- we're talking after you've loaded your data, after you've preprocessed your data-- when you're actually training the model, the first thing that happens is you split the entire data set, which is called an epoch, into smaller chunks, which are called mini batches. I'm going to introduce a lot of terminology on this slide. So I'm going to go slowly, but then at the end I'll leave some time for you to screen cap it if you'd like to keep a reference.
After you've chunked up your data, you go through each mini batch and you preprocess further, if necessary. Then what happens is you pass the mini batch's input data forward through the network. This is also called inference.
Now, you might know the terminology "inference" in relation to using a trained neural network. That is, if you have an already trained neural network and you're using it to do something, that's called inference. But you also perform inference during the training phase of neural network. In fact, that's the third step here. So passing the input data forward is inference.
Now, the next step that happens in a loop is that once you pass that data forward and then get the inference answer, or the inference output, what you're doing is then comparing the result that you got out of neural network to your objective, which is usually computing a loss function. And the loss function is usually a comparison between the predicted answer out of the neural network and the real answer, which is the training data. So everything that you need in order to compare the result of the neural network to the real answer is encapsulated here in this loss function. So this idea of a loss function would become very important as we move forward.
After you compute the loss, the next thing that you do is you differentiate or compute the gradients of every operation or layer in the neural network with respect to the objective. So we're not going to go into details of the calculus here. There's no need for that. But suffice to say that there is a lot of calculus involved. And because of that, you need an efficient grading computation method, and that's something that happens at this stage in the neural network training.
After you compute the gradients, the next part is to update the weights and biases, or what we call parameters, of the model. So a model-- this is the learning step. Right here, number 6 here, is the actual learning of the model. When it updates the weights and biases in relation to getting a better objective or trying to decrease the loss function, when it does so, that is when it's doing learning.
And these steps 5 and 6 are what is known as back propagation. So this is a key term that's often thrown around when it comes to deep learning training, but we don't see it very much, especially if you're only exposed to a simpler API for neural networks. But this is what's called back propagation.
And then, finally, there's usually some visualization of the process. So typically, after you've performed this update, you will spit out a graph or a plot or a picture of some sort, or some metrics might come out on the command line to indicate what's happening at each step.
So I'm going to pause here for a few seconds if you'd like to screen cap this. Of course, we're going to send you the slides. But I'm going to refer to all this terminology throughout, so you may want to have a reference as we go forward.
As I do so, let me also explain one other aspect to this. The reason I'm going to such granular detail about the deep learning training is because this is what we need to understand in order to understand how a advanced neural network is trained. This is not a process that's inherent just to MATLAB. This is any deep learning framework that you are working with. This is the general process, whether using PyTorch or TensorFlow or Keras. This is what's happening.
Now, if you are familiar with the more simple types of frameworks, like MATLAB from before or Keras, the layer-based APIs usually have a single command that does all seven steps here. You don't have to do anything at all. It abstracts all of that away.
So when you simply use train network in MATLAB, which is a command, or model.fit in Keras, which is a command, it does all of this looping for you and you have limited control about what happens inside this loop when that happens. But if you want better, more advanced control, you're going to have to do all these steps yourself. And that's why I'm going through this process.
So I want to introduce a few key terms that come about in this extended framework. I don't need you to remember all of these by heart right now, but I'll just point out that we have a lot of new functions in MATLAB that have to do with the extended framework. They all start with the prefix dl.
And so you can see that there's some of the functions like dlarray, which is a new data container. There's dlgradient, which computes the gradients, dlupdate, which does the back propagation. So chances are, when you see a term that's preceded by dl, that's part of this extended framework that's been added to MATLAB. Like I said before, the extended framework doesn't replace anything. It only adds to what was there. We also still have a simple framework.
So we're going to show first an example using MNIST. I know you might be tired of this at this point, if you've already attended a previous webinars. But I'm choosing this example because it's very, very simple to understand. What we're going to do in this example is we're going to classify images of digits from 0 through 9, and we're going to do so using both the simple framework, and I'm going to do a side by side comparison reimplementing this using the advanced or extended framework so that you can see just how much more architecture is involved, but then what benefits you get out of that advanced architecture as well.
So let's jump into MATLAB now. And again, please feel free to ask questions throughout using the Q&A panel. All right. So let's see.
Let's just do a really quick run through. Again, we trained this model many times in the previous webinar that we had a couple of days ago. But I'll just quickly run through it so everybody's on the same page. We're going to load some data. I'm not going to go into detail about how that happens.
But now that we've loaded the data, we have in our memory here-- oops. Let me make sure my Zoom app is running. OK, here we go.
So we have here 60,000 training images, 10,000 test images. And just to look at what they look like here, I'm just going to visualize some of them. So here are some digits from the data set.
Now, here's the simple way of training this neural network. We define the architecture, which is defined as layers. We define a few hyperparameters. We don't need too many for this problem, so I'm just going to define those hyperparameters.
And then we're going to execute the training here using one command called train network. So let's go ahead and run that.
Now, in this example, as I mentioned before, all of those seven steps of deep learning training loop are encapsulated in the single command. That's all happening under the hood. So it's iterating over the whole data set. It's doing the forward propagation inference and the back propagation and everything. All of that's happening under the hood.
And as you can see, we're gaining accuracy and we're diminishing the loss. These are the objectives we're looking for in this training. I'm going to go ahead and stop this. It's pretty much well-trained at this point, very, very simple problem.
So not only that, we also had this really nice built-in visualization which allows you to see the loss and the accuracy as we move forward and a bunch of additional information that's here, like when we're using a GPU, which I am, and which iteration we're on, and so on and so forth. There's a lot of interesting and useful information here.
Unfortunately, when we go to advanced neural networks, we'll be responsible for creating those metrics. But that same crisis is now an opportunity. You get to visualize things exactly the way you want. You get to see and operate the neural network training the way you want in a way that works best for you.
So let's see how that will work in comparison when we do the advanced neural network. Oh, I missed one step. Lastly, you will use the neural network to test your data. And so, in the end, I've tested my neural network, and it has 93% on the test set. So that's a quick review of how the simple neural network framework works. A lot of works here.
Now what I'm going to do is open up the extended framework version of solving the same problem. In fact, there are two ways to do this in the extended framework that I'm going to point out now. So I'm going to run it twice, but to show you the two approaches.
So let me put this side by side here so we can go through side by side. I'm going to clear all this output. So as you see, on the left we have the simple approach, and on the right we have the DLNetwork approach, which is the first of the advanced methods. So in both cases, the loading of the data is exactly the same. There's nothing different here. But something does change when it comes to the creation of the network.
I don't know if you are looking at these two different pieces of code side by side. And again, my screen resolution is very high. Let me make this a little larger. So hopefully you can see the text on the screen now. Eagle-eyed viewers will notice that, on the left, we have this last extra layer called the classification layer that we do not have here on the right. The reason for that is that's called a loss function. That is the objective of the neural network training, and that loss function doesn't exist on this side.
The reason for that is we're going to put that in later. That's something you'll have to customize if you're going to use this approach. So we don't have the last function here.
The next step, I'm going to show the hyperparameters side by side. Notice on the left, with a simple framework, we don't specify all that many hyperparameters. We've taken the time at MathWorks to try to come up with a lot of good default values so that if you're trying to solve the problems a simpler way, you don't have to worry about too many things. So here, the only hyperparameters I've chosen are related to how many images are in each mini batch and then what the learning rate is, and really nothing else is being specified.
But on the right, you'll notice that I have a lot more hyperparameters being shown. I'm actually going through and specifying a lot more. And this is because we're going to be customizing the training and a lot more needs to be specified here. So this is something that's going to be characteristic of the more advanced training loops that we're going through.
The next thing is, how do we train the network? It's a single command using the simple framework. In the extended framework, it's a loop that you have to write yourself. In fact, it's a double loop. We're looping over all the data, and then we're looping over all the data in each chunk.
And all of those steps that I talked about are happening here. So you're explicitly doing things like evaluating the gradients. You're explicitly doing things like doing the updates of the weights and biases or back propagation. So all of this is happening on data that is contained in dlarrays, which are the new data container for this extended framework.
One of the key aspects is this model gradients function that I've listed here. This model gradients function is something you have to write. And when you write it, it gives you the flexibility to basically specify the loss how you like. You can write any law of custom loss function you want, you can write any kind of forward propagation or back propagation routine that you want. But it is a requirement. So I'll show you that.
This model gradients function, let's open it. It's down here. I'll zoom in again. This takes in the input of the neural network and the output. And we're doing the forward pass here. Remember, I said that we do inference as step 3 in the training.
And then we compute the loss, which is the comparison of the output of the model to the real answer. And then we do the back propagation step of taking the gradient, which is, again, computing the derivative with respect to the loss, as you see here. So if you're doing all these things, you can explicitly specify what you want.
We've used a built-in function for the loss here, which is called cross entropy. But I don't need to use that. Now I have the flexibility to write anything.
So the question really comes down to, what do you need for your neural network? Do you need this level of customization? Do you need to be able to do things that you can't do otherwise? Then yes, you'll use the extended framework. If you are happy with the built-in functionality and it's sufficient for your problem, then there's no need to go into all this detail and write your own framework.
So let me add one more comparison script, which is that now I'm going to open up the second advanced approach, which is sometimes necessary for some types of models. Let me put this side by side as well. All right.
So the second approach here is called the model function approach. So the dlnetwork approach is like an intermediate kind of hybrid step. The model function is the n-th degree of customization, gives you very low level control over the internals of the neural network. So let me explain how you would solve the same problem of classifying digits using the model function approach.
As before, the loading of the data is no different in the simple approach and this model function. No surprises there. But now notice that we have a creation of a neural network using layers on the left, and we don't even have the neural network on the right.
There's a reason for that, because the model function approach is to write a neural network in terms of a series of function operations or mathematical equations, instead of as a series of layers. There are some neural networks out there-- I can't think of any examples off the top of my head, but I'm sure it'll come to me-- but there are some neural networks out there where the operations within them cannot be described in terms of a series of layers. Instead, you have to go down into the low level and describe them mathematically. And that's why we're going to describe all of these operations mathematically to help you understand where this is coming from.
So where is this neural network? It's actually all the way down in a function. So let me show you that now. This function, which I've called Model, takes the input value, which is the images, and then runs them one by one through a composition of functions. So first, we're going to do a convolution. And then we take the output of that, and then plug it into another function, the relu. Then they take the output of that, plug it into another function called the maxpool, output of that, and so on and so forth.
So we're doing a composition of functions here rather than a layer-based API. But we are doing the exact same operations in both of these cases. But you see how, because we're writing these as functions, now I can go to very low level control on how I'm doing things. I can add any kind of mathematical expression here.
So with that said, , the rest of this is quite similar. We're going to define hyperparameters as we've done before. In this approach, you also have to define a lot of hyperparameters. There's one other aspect that's special to this approach, and that is I have to define the model parameters themselves the weights and biases. So that I'm doing here.
Now, if you've been following along, some of this might be maybe over the head of some of you guys. And that's OK. I just wanted to expose you to this for now. But some of you may have used, let's say, TensorFlow or PyTorch in the past and will find this very familiar. In those frameworks as well, a lot of the programming is functional programming. The simpler approach that we have on the left is comparable to something like Keras.
So now that you've seen we can do all of these things in MATLAB, I'm going to go through and show you in a real world example because you would never use these advanced functionalities for such a simple problem. Let's actually go to a problem that's more complicated.
All right. Now we're back to our presentation. So the next step that we're going to go over is modeling an actually advanced neural network.
So what are some of the things that you can model with the extended framework? So I've been talking about GANs for a while, and this is the example we're going to go through today because I think it has a lot of broad appeal. People really are captured by the idea that a neural network can generate something almost as realistic as a real object. Of course, this example of a GAN is kind of simple and it's not very realistic, but it conveys the same kinds of concepts that we're looking for here. So GANs are one example.
There's other examples of advanced neural networks that require the functionality we just talked about. So for example, variational autoencoders. These have become popular, especially recently, for applications like anomaly detection or noise reduction or noise filtering because what they're very good at doing is reconstructing data. So if you are training a VAE with what we would say is normal data, it can be used to analyze whether or not data is normal, and it can be even used to reproduce what should be the normal data. So that's why it's becoming very popular for auto-- sorry, for anomaly detections and noise reduction applications.
With that, another example is-- we've been seeing requests for this-- Siamese networks. Siamese networks are unique in that they're often used for one-shot learning. What is one-shot learning? Well, you know how neural networks take millions of samples to train something? You know, you need 10,000 pictures of a cat just to train a neural network what a cat looks like.
What if you could train it with one picture of a cat? That's what one-shot learning is. And these Siamese are the networks are so-called because they're like two different networks that are conjoined together. And because of this joining of two different networks, we require the advanced programming functionality.
Another example that I like to point out is called the attention network or attention mechanism. These are more advanced, cutting edge neural networks that can do very sophisticated things. For example, this one here is showing a picture, and the output of a neural network is a sentence describing the picture. So this is what we call image captioning. Imagine a neural network that can be trained to describe this picture as a dog sitting on some grass. That is the type of model we're talking about here with attention mechanisms.
And there's many more that are not listed here. I just highlighted a few key application areas or types of neural networks. There's many more that can be done using the same functionality.
So let's talk more details about GANs. As I said, GAN stands for Generative Adversarial Networks. So let's talk about each word.
Generative means that it generates data. That's the point of a GAN.
Adversarial means that it's trained adversarially. In fact, what's happening is a GAN is not one single network, it's two networks that are fighting against each other. And so we're going to make a comparison in this case.
A generator can be thought of as an art forger. Let's say that you have an artist whose job it is to make realistic fakes of paintings. Then on the other flip side, a different network called the discriminator is like the detective who's trying to catch the art thief or the art forger. And the detective wants to determine if, whenever he or she sees the image, is it real or is it fake.
And so what we see is the job of the generator is to create as realistic of a fake image as possible, while at the same time, the discriminator's job is to figure out if the image is real, meaning it comes from the training sets, or fake, meaning it came from the generator. And if you architect the problem properly, what happens is, initially, both of these networks are very bad at their job. Generator makes a really bad image, discriminator really doesn't know what's going on. And as they move forward in the training, the generator gets better and better at fake images, and discriminator gets better and better at understanding whether it's real or fake.
Now, the generator network does have an input. You can say this is like the inspiration for the network. Generally, this input is something like random noise or a structured pattern of some sort. So you may have seen applications like neural network style transfer, where you take a photograph, and then you can convert it to look like it's a van Gogh painting or Picasso painting or something like that. That photograph that you took is an input to a generator, and then the generator kind of dreams up what it would look like if it were a Picasso. And so those are also used usually done through GANs.
So this is the general concept. And I'm sure, if you're interested in GANs, you can find much more well thought out literature on this. I'm trying to be a little brief here.
But to be a little more clear about this, there are many, many types of GANs. At last count, I remember seeing there are at least 700 papers on GANs. So there are a huge number of variations.
They can operate not just on images, but videos or sound or other types of media as well. The input to a GAN can be random or can be structured. There are many types of networks you can use as a generator. There are many types of networks you can use as a discriminator.
So I'm going to go through into-- in order to understand the example we're going to go through now, I'm going to go into a slight bit of mathematical detail. So we're going to visualize the math a little bit.
Let's talk about one of the networks, which is called the generator. Its job, as we said, is to take in some input, which could be random noise or it could be some structure, and then turn that into a fake image, which will be as realistic as possible. And it learns to do so through adversarial training.
Now, we're going to call the input z, which is-- we call that a latent vector, but think of that as the noise input. It goes through the network, and so you have the output as G of z, which is the fake image.
Next thing let's talk about is the discriminator network. The discriminator network is its own neural network that is trained to classify images. This, in fact, is an image classifier. It's simple enough if you understand image classification. This takes in an image and then outputs a label, and the label in this case is real or fake. So its job is to understand if it's real or not.
At the same time, the discriminator is also going to be fed with data from the real training set. So the real photographs of these objects, we'll call them x. Those real photographs are also going in as inputs to the discriminator network.
Now, what's different about how we train this is once the discriminator processes the information, it either processes the real information, which is x, or the fake information, G of z. We're going to help the discriminator out and tell it whether if it got it right or not, if it's fake or if it's real. Generally, you don't do this in an image classification network the way we do it here, but we're going to tell it. We're going to help it.
So the way we're going to do that is as follows. We have to write a loss function that's really smart. We have to write such a loss function that helps the generator get as good as possible, which means to minimize the generator's output, and at the same time to get the discriminator as good as possible, which means to minimize its objective.
So the problem is they have compete. They're in competition. So you really have to write this loss function in a smart way.
So the output of a GAN, during training at least, is a value between 0 and 1, where 1 indicates more real, and toward 0, it's indicating fake. So there are actually two components to the loss function that we're going to see. Let's not worry too much about the exact math here. But if you are looking at this, this is a function that we are going to minimize. And D of G of z, which is telling the discriminator that the fake data is actually real, is going to be approaching 1.
So what that means is we're trying to get the generator as good as possible. But at the same time, there's another component to this loss function where the discriminator really wants to make sure that it understands if the fake image is fake and a real image is real. And so you'll notice that the comparison here on the second side is basically 1 minus the comparison here. So at the same time you're minimizing this, you're maximizing this.
So what that means is rather than the output of this network driving towards either 0 or asymptoting to 0 or asymptoting to 1, what will end up happening is the trained neural network will be at equilibrium at some sort of saddle point in the middle. I'm just showing a representation here. So a properly trained GAN is going to equalize somewhere in the 0.5 range, and that's what happens with the loss function because they're competing against each other.
So with that, I'm going to actually go through an example of this. And we're going to use a particular variant of a GAN called a C-GAN. As I mentioned, there's, like, 700 papers. There are all different types of GANs out there. I'm choosing a specific one that I think is a little more relevant and a little more real.
So C-GAN stands for Conditional Generative Adversarial Network. It is a type of GAN that can take advantage of labels during the training process. So a regular GAN, if you wanted it to generate images of flowers it'll just take in any noise data and generate any flower that is in your data set.
But what if you want to specify that I don't want just any flower, I want a specific flower, like a sunflower. Give me a fake image of a sunflower versus some other-- like a lily or something. Now I'm using a label to control the output of the GAN. And if I'm using a label to control the output, then this is what we call conditional. So if we want to generate a particular type, I get this output.
So this is an example we're run through now with you. Let's jump into MATLAB. Back to MATLAB, let's actually open up to the example I'm talking about. I'm going to close all the other ones. All right, here we go.
So this is an example that's in our documentation, and so I'm going to share these examples with you. But I want to go through this just to illustrate what's happening here. So the first diagram that you see here relates to a general GAN. And as you see, it's the same diagram we've seen before. Noise goes in the generator, and then the real and fake images go to discriminator.
The difference in the conditional GAN is we're adding a second signal for both the generator and the discriminator called the labels, so that way we can conditionally generate images based on a particular label. So in this code, I won't go to too much detail, but I'll highlight some of the areas that I think are very important.
We're going to load some training data, and this is basically a data set of flower images. Let me show you the data. I'm going to open up an Explorer window, and so let's zoom in.
So you see here, we have five different categories of flowers. And under each folder, we have images of that type of flower. And they're pretty numerous and varied. So generally, when you want to train a GAN, you want robust examples. You want it to be really good, so give it a lot of good examples. So that's our data set.
And now, the first thing we're going to do after importing the data is we're going to define the first network. We're going to find the generator network. And what I want to point out about this generator network is the key operations are happening here in these transposed convolution blocks. So this operation allows an input of noise to turn into an image.
Other GANs will have other blocks. But for now, I just want to highlight these key blocks here. I mean, you need all the other stuff as well. But I think it's fair to say that the key stuff is happening here.
So I'm going to go ahead and run this section because it will help us to define-- we will define the generator. And the way that we're running the generator, or defining the generator, is we're using a layer-based approach. So because we're using the layer-based approach, we're going to convert this all into a deal network at the end, which I highlighted for you before.
And let me actually open the same network up in the Deep Network Designer because it's illustrative to see what it looks like in its architecture. It's not all that complicated, but it's also not very straightforward. So let's open this up. Now, usually the Deep Network Designer takes a bit to open the first time. Let's see what happens here.
All right. So we're going to import this network that I just defined, and here's the generator. OK, let's zoom in here.
So at the bird's eye view, we see that this is mostly a series network. There are, of course, two inputs to this neural network. There's the random noise input, and then there's the labels because it's a conditional GAN.
I also want to point out that there are some special icons here. These indicate that these are custom layers. These layers were written just for this problem. They don't exist in the layer library at the left.
So you can use the Deep Network Designer to handle custom layers. You just won't be able do very much with it in this interface. You'll have to write the code. But anyway, we have ourselves a generator network.
Now, the next thing we're going to do is work on the second network, which is the discriminator. And notice that the key blocks here are convolutions, not transposed convolutions. In fact, this is a very typical convolutional neural network. As I said, this is an image classifier. So typical image classifiers take an image input, and then take and then provide a label as the output. So when we open this network up in the Deep Network Designer, it should look very familiar to you.
So once again, we are describing this in terms of layers. And then we convert that all into a DLNetwork at the end. So let's open that up. So let's open the discriminator network now, and you'll see that this is pretty much the same as the generator in this case, except with the two inputs are the images, whether they're real or fake, and the labels. And now we have actual convolution layers instead of transposed convolution.
So it might seem like in this GAN that, well, what's so advanced about this? We were able to use Deep Network Designer or a layer-based approach to create these neural networks. Why do we need the advanced training structures that we're talking about? It seems pretty simple to me.
Well, this is where it's going to get a little hairy. First, we have to define the loss function. And as I mentioned on that slide, the loss function is no simple matter. We have two networks that are competing against each other. So the way that the loss function is described is not a simple-- a single layer. It's a sophisticated operation.
And so let me show you that operation. If we go down to the function, the loss function is described down here. So you can see that we have the generator loss, we have the discriminator loss, and then we combine them.
And if we're doing that, those all get-- this function, the loss function appears here in the modeling function. Where we're doing the forward pass-- so let me go through and highlight those areas. We're doing the forward pass, like we talked about. We're doing the forward pass through both networks.
We compute the loss function, and then we take the gradients of both of these networks. This operation, this combining of the neural networks, is the part that cannot be done using the simple framework, which is why we're using the extended framework here.
So with that said, once you have written those functions-- let's go all the way back, right here. Once you've written those functions and properly architected the loss and the gradients and all that stuff, you can go forward and specify your hyperparameters, and then you can train the network. So let's go do that.
So I'm going to go and train the network. And as you see, this is a training loop. Now, in a trading loop, you're responsible for creating your own visualization. And so although that may be not as easy as before, it also affords you the opportunity to produce a plot that shows you exactly what you want to see. So in this example, we're going to produce a plot that shows us both some generated images as we go and the loss functions of both of these neural networks.
So this is a giant training loop. I won't go into details here, but it's very much the same concept as what we saw before. I also will not actually train it for you because this GAN took me about 2 and 1/2 hours to train. Instead, I'm going to show you a video. And in fact, we showed you this video earlier.
So as you see, as I train this model-- this is now sped up about 400 times-- we are producing the output from the GAN with each iteration, and we're also computing the loss function or the loss score of these models as we go forward. And as you can see, although this is not a very advanced GAN, so it's not super realistic, the images are clearly much more so realistic than they were before.
And one thing I'll point out about the loss is notice that it's not asymptoting down to 0 for the discriminator. The discriminator's the orange one. It's asymptoting down to roughly 0.5, 0.6, because it's competing against the GAN.
And then last step about this is, how do we use the trained GAN? The trained GAN does not require the discriminator or the loss function. The trained GAN only requires the generator.
So as you'll see here, I can choose a class. Let's say I choose the class called tulips. I can choose the class, and then the real thing that we're doing here is we're using the prediction method of the generator network. We're adding in some random noise. We call that z, and we're also making sure to choose the right class, T. And then we get a generated output of tulips.
If I run this section again, it'll be randomized because we're adding some other random noise. If I choose a different class, like sunflowers, now we're going to see sunflowers. So this is what a trained GAN looks like. This is how it's used.
All right. With that said, let me conclude with a few final thoughts. And then for those of you who are able to stay past the allotted time, we're going to have a little bit more Q&A. But let's finish up with a few final thoughts.
We've covered a lot in this topic today. I want to point out that we have a lot of documentation and examples. So today, one of the examples was drawn from our documentation, but there's so many examples of various different networks. And today, we were focused on GANs, but other networks there as well.
And to point out exactly we can find these, if I open the MATLAB Help, every toolbox has its own page. So if you got to Deep Learning Tool Box and then click on Examples, there are examples in the Customization area. So this is where you'll see all sorts of advanced neural networks and examples of how you can use them. And we're always coming out with more and more with every release. So this is something I want to point out where you can find more information.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .Select web site
You can also select a web site from the following list:
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.