Contact: f20190077 [at] pilani.bits-pilani.ac.in
GitHub: aadit3003
Twitter: @Aadit🏳️🌈
CV (PDF)
Last Updated: Feb 2023
Theme by orderedlist
14 Feb 2023
I’ve been working on a project involving Split learning, i.e., privacy-preserving decentralized neural network training. Essentially, in situations where we don’t want to share data (sensitive content like patient records or medical images), we instead just share the activations from multiple client networks to a single server network which then passes back the gradients during backpropagation. I’ll have a more detailed blog post on this in the future when I’m finished with the project, so keep an eye out for that.
As a PyTorch beginner, the lack of uniformity in different tutorials regarding the order of operations left me really confused. However, I found an excellent starting point - Deep learning with PyTorch: A 60 min Blitz (Although, I spent more than 60 minutes to understand this fully). Inspired by this article, I decided to create an acronym/mnemonic [1] to streamline the mise en place for training a neural network in PyTorch.
Letter | Step | Description |
---|---|---|
M | Model | Instance of the Neural network module class. |
C | Criterion | Pick a loss function (BCE, MSE, etc.). |
O | Optimizer | Pick an optimizer (Adam, SGD, etc.). |
Z | Zero | Reset the gradient for each new input. |
P | Predict | Calculate the model’s prediction on the input. |
L | Loss | Calculate the loss (actual vs. predicted label). |
B | Backward | Backpropagation. |
S | Step | Parameter update via the optimizer. |
I pronounce it as 'Macaws Plebs' (picture an elitist condescending macaw 🦜🧐).
Here's how I use 'MCOZ PLBS' to train a typical ResNet classifier. We start by defining some key objects - the model (M), a loss function or criterion (C) and an optimizer (O). (There are a few more steps such as subclassing 'nn.Module' or changing the final FC layer, but I've omitted them for brevity).
model = models.resnet34() # M
model.to(device)
criterion = nn.CrossEntropyLoss() # C
optimizer = optim.Adam(model.parameters()) # O
Now, for the main event- the training loop for 50 epochs. We start by resetting the gradient (Z). Then the forward pass (P), loss calculation (L), backward pass (B), and finally, the parameter update (S).
for epoch in range(50):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad() # Z
outputs = model(inputs) # P
loss = criterion(outputs, labels) # L
loss.backward() # B
optimizer.step() # S
running_loss += loss.item()
if i % 100 == 99:
print(f"[{epoch}, {i}] loss: {running_loss/100}")
running_loss = 0.0
print("DONE!!")
'MCOZ PLBS' makes it really easy to implement all kinds of variations of the vanilla training loop. For example, here is a simple Split Learning classification training loop with a client network and a server network. The client side has the data and does a feedforward calculation on the input data upto a 'cutoff layer', after which the server takes over without ever seeing the data. Essentially, we have two of everything - two models (M), a loss function (C), and two optimizers (O). (The code is inspired by this workshop by the Camera Culture group, MIT.)
client_model = Client({"cut_layer":3}).to(device) # M
server_model = Server({"cut_layer":3}).to(device) # M
criterion = nn.CrossEntropyLoss() # C
client_optimizer = optim.Adam(model.parameters()) # O
server_optimizer = optim.Adam(model.parameters()) # O
Now, in the training loop we have a division of responsibility - The client handles the forward pass, while the server (data agnostic) performs backpropagation. All this can be simply achieved by doing the steps of MCOZ PLBS twice. We can also scale this to multiple clients and train the network without any of the clients needing to share the data with the server.
for epoch in range(50):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
client_optimizer.zero_grad() # Z
server_optimizer.zero_grad() # Z
activations = client_model(inputs) # P
server_inputs = activations.detach().clone()
server_inputs = Variable(server_inputs, requires_grad = True)
outputs = server_model(server_inputs) # P
loss = criterion(outputs, labels) # L
loss.backward() # B
server_optimizer.step() # S
activations.backward(server_inputs.grad) # B
client_optimizer.step() # S
running_loss += loss.item()
if i % 100 == 99:
print(f"[{epoch}, {i}] loss: {running_loss/100}")
running_loss = 0.0
print("DONE!!")
And that’s it! You now have all the ingredients to train your own neural network in PyTorch I’d love to hear from my readers if you all have fun acronyms or mnemonics for CS-ey stuff.
[1] The spelling of the word ‘mnemonic’ has been permanently stamped into my brain since age 11 when I misspelled the word in a National Spelling Bee as “P-N-E-M-O-N-I-C” 💀. In my defense, pneumonia, and pneumonic sound very similar.