Job - nnMetrics.py
This page contains an in-depth description on the nnMetrics.py file
Function definitions
Sample file
import numpy as np
def metricSupportFn(outputs, labels):
labels = np.array([t.numpy() for t in labels])
outputs = outputs.cpu().numpy()
outputs_rounded = np.array(np.matrix.round(outputs))
vals = []
for i in (outputs_rounded == labels):
vals.append(i.sum()/len(i))
total = len(vals)
correct = np.array(vals).sum()
return total, correct
def metricSupportFn2(outputs, labels):
from sklearn.metrics import confusion_matrix
classes = len(labels[0])
tp, fp, tn, fn = 0, 0, 0, 0
for i in range(classes):
y_pred = np.where(outputs[:,i] > 0.5, 1, 0)
tn_tmp, fp_tmp, fn_tmp, tp_tmp = confusion_matrix(labels[:,i], y_pred, labels=[0, 1]).ravel()
tp += tp_tmp
fp += fp_tmp
tn += tn_tmp
fn += fn_tmp
return (tp, fp, tn, fn)
def optimizerFn(model, lr):
import torch.optim as optim
return optim.SGD(model.parameters(), lr=lr, momentum=0.9)
def criterionFn():
import torch
return torch.nn.CrossEntropyLoss()
def transformFn():
from torchvision import transforms
t = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor()
])
return t
Last updated