
A Clever Technique to Improve Neural Network Training
September-13-2024
I don’t usually get excited about coding neural networks defining layers, writing the forward pass, and so on.
In fact, for most machine learning engineers, this task can be quite monotonous.
For me, the real challenge and enjoyment come from optimizing and maximizing every bit of GPU memory.

Today, I want to share a discovery I stumbled upon a couple of years ago while optimizing the training process.
What I'm about to reveal might seem quite obvious once you hear it, but I think many people overlook its subtlety.
Let’s dive in!
What was I doing?
I was working on an image classification task let’s use MNIST as an example.
Normalizing pixel values is a standard practice to stabilize and improve model training.

Here’s what my code looked like:
- First, I loaded and transformed the dataset, then defined the model, and so on.
train_data = MNIST(..., transform=transform)
model = MyModel(...).cuda()
optimizer = ...
criteria = ...
- Next, I had the usual training loop where I moved the data to the GPU before each training iteration, as shown below:
for epoch in range(epochs):
for inputs, labels in train_data:
inputs, labels = inputs.to(device=device), labels.to(device=device)
# regular training loop below
...
Here’s what the profiler output looked like:

- Most of the time and resources were spent on the kernel (the actual training code).
- However, a notable amount of time was used for transferring data from CPU to GPU (indicated by Memcpy).
At first glance, it seems like there’s little we can do to optimize this since it involves data transfer, which we can’t avoid.

But here’s the trick.
Remember that the original dataset consisted of 8-bit integer pixel values, which were normalized to 32-bit floats:

Next, we transferred these 32-bit floating-point tensors to the GPU, which meant that normalizing the data resulted in transferring more data.
That’s when I had my "Aha!" moment!
I realized that moving the normalization step to after the data transfer would solve this issue, as we’d be transferring 8-bit integers instead of 32-bit floats.
train_data = MNIST( ... )
for epoch in range(epochs):
for inputs, labels in trainç_data:
inputs, labels = input.to(device=device), labels.to(device=device)
inputs = inputs.float() / 255.0
# regular training loop below
...
As a result, I observed a significant reduction in the Memcpy time, which makes perfect intuitive sense.

I had trained several models before, but I never realized how such a subtle optimization could make a difference.
Of course, this technique isn’t applicable to all neural network tasks, such as NLP, where 32-bit float embeddings are commonly used.
Nevertheless, whenever I’ve had the chance to use this trick, I’ve seen noticeable improvements.
Isn’t that a clever trick?