Writing a model can be a daunting task! Here is an example to help walk you through it:
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import innocuous.Endpoint as magic
from innocuous.MagicObj import MagicObj
import logging
from colorlog import ColoredFormatter
Set up your logging and streams (Below you can find descriptions and examples of each method)
LOG_LEVEL = logging.DEBUG
datefmt = '%Y-%m-%d %H:%M:%S'
LOGFORMAT = "%(log_color)s[%(asctime)s][%(levelname)-8s]%(reset)s %(log_color)s%(message)s%(reset)s"
logging.root.setLevel(LOG_LEVEL)
formatter = ColoredFormatter(LOGFORMAT, datefmt)
stream = logging.StreamHandler()
stream.setLevel(LOG_LEVEL)
stream.setFormatter(formatter)
log = logging.getLogger('PodApp')
log.setLevel(LOG_LEVEL)
log.addHandler(stream)
Create a model class and set your layers and parameters
class Model(nn.Module):
def __init__(self, img_channel=1, out_channels=10):
super(Model, self).__init__()
self.cnn1 = nn.Conv2d(in_channels=img_channel, out_channels=16, kernel_size=5, stride=1, padding=0)
self.relu1 = nn.ReLU()
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0)
self.relu2 = nn.ReLU()
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(32 * 4 * 4, out_channels)
def forward(self, x):
out = self.cnn1(x)
out = self.relu1(out)
out = self.maxpool1(out)
out = self.cnn2(out)
out = self.relu2(out)
out = self.maxpool2(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
return out
Train your model with your dataset and create a Validation/Test set
def main(lr=0.001, epochs=2, batch_size=256):
# from main import log, myHandler
# log.addHandler(myHandler)
log.info('list cwd2: {}'.format(os.getcwd()))
log.info('*******************************************************************************')
log.info('testing logging abilities')
mj = MagicObj()
fileHelper = magic.FileHelper()
dataset_path = mj.get_path('/Users/noam/Downloads/mnist')
train_path = os.path.join(dataset_path, 'train')
val_path = os.path.join(dataset_path, 'val')
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor()
])
train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)
val_dataset = torchvision.datasets.ImageFolder(val_path, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
model = Model()
pretrained_state = fileHelper.get("data://checkpoint/models/checkpoint.pt")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
prefix = 'classifier.'
loaded_dict = torch.load(pretrained_state, map_location=device)
adapted_dict = model.state_dict()
adapted_dict.update(loaded_dict)
model.load_state_dict(adapted_dict)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
Test and log your model’s accuracy (Continuation of previous function)
Save checkpoint
for epoch in range(epochs):
eval_loss = 0.0
eval_acc = 0.0
with tqdm(train_loader, unit="batch") as tepoch:
for data in tepoch:
images, labels = data
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
checkpoint = model.state_dict()
_, pred = outputs.max(1)
num_correct = (pred == labels).sum().item()
acc = num_correct/images.shape[0]
tepoch.set_postfix(loss=loss.item(), accuracy=acc)
model.eval()
with tqdm(val_loader, unit="batch") as tepoch:
for data in tepoch:
images, labels = data
with torch.no_grad():
outputs = model(images)
loss = criterion(outputs, labels)
eval_loss += loss.item()
_, pred = outputs.max(1)
num_correct = (pred == labels).sum().item()
acc = num_correct/images.shape[0]
eval_acc += acc
tepoch.set_postfix(loss=loss.item(), accuracy=acc)
mj.torch_save(checkpoint=checkpoint, path='/Users/noam/Downloads', epoch=epoch)
mj.log(accuracy=eval_acc/len(val_loader), loss=eval_loss/len(val_loader))
stream = logging.StreamHandler()
This method returns a new instance of the StreamHandler class
logging.root.setLevel(LOG_LEVEL)
stream = logging.StreamHandler()
stream.setLevel(LOG_LEVEL)
This method sets the threshold for this logger to the passed parameter (LOG_LEVEL). Logging messages which are less severe than LOG_LEVEL will be ignored; logging messages which have severity LOG_LEVEL or higher will be emitted by whichever handler or handlers service this logger, unless a handler’s level has been set to a higher severity level than LOG_LEVEL.
log.info('testing logging abilities')
This method logs a message with level “INFO” on the root logger.
log = logging.getLogger('PodApp')
This method returns a logger with the parameter name (PodApp). If no name is provided, it returns the root logger of the hierarchy.
pretrained_state = fileHelper.get("data://checkpoint/models/checkpoint.pt")
This method fetches and returns a file from the File Manager in the Web UI.