Job - nn.py
This page contains an in-depth description on the nn.py file
Function/Class definitions
Sample file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.mobilenet import mobilenet_v2
class MobileNetCNN(nn.Module):
def __init__(self):
super().__init__()
mobile = mobilenet_v2(pretrained=False)
self.mobilenet_layer = nn.Sequential(*list(mobile.children())[:-1])
self.fc1 = nn.Linear(1280, 512)
self.fc2 = nn.Linear(512, 14)
self.dropout = nn.Dropout(0.5)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.mobilenet_layer(x)
x = F.adaptive_avg_pool2d(x, 1).reshape(-1, 1280)
x = self.dropout(x)
x = self.fc1(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
def getModel():
return MobileNetCNN()
Last updated