Aadit Deshpande

(he/him)

DP

Contact: f20190077 [at] pilani.bits-pilani.ac.in
GitHub: aadit3003
Twitter: @Aadit🏳️‍🌈
CV (PDF)

Last Updated: Feb 2023

Theme by orderedlist

Macaws Plebs 🦜🧐- A PyTorch Mnemonic

14 Feb 2023

I’ve been working on a project involving Split learning, i.e., privacy-preserving decentralized neural network training. Essentially, in situations where we don’t want to share data (sensitive content like patient records or medical images), we instead just share the activations from multiple client networks to a single server network which then passes back the gradients during backpropagation. I’ll have a more detailed blog post on this in the future when I’m finished with the project, so keep an eye out for that.

As a PyTorch beginner, the lack of uniformity in different tutorials regarding the order of operations left me really confused. However, I found an excellent starting point - Deep learning with PyTorch: A 60 min Blitz (Although, I spent more than 60 minutes to understand this fully). Inspired by this article, I decided to create an acronym/mnemonic [1] to streamline the mise en place for training a neural network in PyTorch.

The Acronym - MCOZ PLBS

Letter Step Description
M Model Instance of the Neural network module class.
C Criterion Pick a loss function (BCE, MSE, etc.).
O Optimizer Pick an optimizer (Adam, SGD, etc.).
Z Zero Reset the gradient for each new input.
P Predict Calculate the model’s prediction on the input.
L Loss Calculate the loss (actual vs. predicted label).
B Backward Backpropagation.
S Step Parameter update via the optimizer.

I pronounce it as 'Macaws Plebs' (picture an elitist condescending macaw 🦜🧐).

Macaws in Action

Here's how I use 'MCOZ PLBS' to train a typical ResNet classifier. We start by defining some key objects - the model (M), a loss function or criterion (C) and an optimizer (O). (There are a few more steps such as subclassing 'nn.Module' or changing the final FC layer, but I've omitted them for brevity).


model = models.resnet34() # M
model.to(device)
criterion = nn.CrossEntropyLoss() # C
optimizer = optim.Adam(model.parameters()) # O

Now, for the main event- the training loop for 50 epochs. We start by resetting the gradient (Z). Then the forward pass (P), loss calculation (L), backward pass (B), and finally, the parameter update (S).


for epoch in range(50):
  running_loss = 0.0

  for i, data in enumerate(trainloader, 0):
    inputs, labels = data

    optimizer.zero_grad() # Z

    outputs = model(inputs) # P
    loss = criterion(outputs, labels) # L
    loss.backward() # B
    optimizer.step() # S

    running_loss += loss.item()

    if i % 100 == 99:
      print(f"[{epoch}, {i}] loss: {running_loss/100}")
      running_loss = 0.0

print("DONE!!")

A Split Macaw

'MCOZ PLBS' makes it really easy to implement all kinds of variations of the vanilla training loop. For example, here is a simple Split Learning classification training loop with a client network and a server network. The client side has the data and does a feedforward calculation on the input data upto a 'cutoff layer', after which the server takes over without ever seeing the data. Essentially, we have two of everything - two models (M), a loss function (C), and two optimizers (O). (The code is inspired by this workshop by the Camera Culture group, MIT.)


client_model = Client({"cut_layer":3}).to(device) # M
server_model = Server({"cut_layer":3}).to(device) # M
criterion = nn.CrossEntropyLoss() # C
client_optimizer = optim.Adam(model.parameters()) # O
server_optimizer = optim.Adam(model.parameters()) # O

Now, in the training loop we have a division of responsibility - The client handles the forward pass, while the server (data agnostic) performs backpropagation. All this can be simply achieved by doing the steps of MCOZ PLBS twice. We can also scale this to multiple clients and train the network without any of the clients needing to share the data with the server.


for epoch in range(50):
  running_loss = 0.0

  for i, data in enumerate(trainloader, 0):
    inputs, labels = data

    client_optimizer.zero_grad() # Z
    server_optimizer.zero_grad() # Z

    activations = client_model(inputs) # P
    server_inputs = activations.detach().clone()
    server_inputs = Variable(server_inputs, requires_grad = True)
    outputs = server_model(server_inputs) # P

    loss = criterion(outputs, labels) # L

    loss.backward() # B
    server_optimizer.step() # S
    activations.backward(server_inputs.grad) # B
    client_optimizer.step() # S

    running_loss += loss.item()

    if i % 100 == 99:
      print(f"[{epoch}, {i}] loss: {running_loss/100}")
      running_loss = 0.0

print("DONE!!")

And that’s it! You now have all the ingredients to train your own neural network in PyTorch I’d love to hear from my readers if you all have fun acronyms or mnemonics for CS-ey stuff.

[1] The spelling of the word ‘mnemonic’ has been permanently stamped into my brain since age 11 when I misspelled the word in a National Spelling Bee as “P-N-E-M-O-N-I-C” 💀. In my defense, pneumonia, and pneumonic sound very similar.

Back to blog