{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"colab": {
"name": "recitation6-fall21.ipynb",
"provenance": []
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ZQLWbF3L2AtP"
},
"source": [
"# Recitation - 6\n",
"___\n",
"### Image Dataset\n",
"Custom Dataset & Dataloader\\\n",
"Torchvision ImageFolder Dataset\n",
"### Model\n",
"Residual Block\\\n",
"CNN Models with Residual Block\n",
"### Cosine Similarity\n",
"### Losses\n",
"Center Loss\\\n",
"Triplet Loss\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PFpMWpeb2RAp",
"outputId": "8ff482c9-9e03-48cd-99aa-882ac1268ec6"
},
"source": [
"!unzip mnist"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Archive: mnist.zip\n",
"replace mnist/testing/0/1487.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: n\n",
"replace mnist/testing/0/1768.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: no\n",
"replace mnist/testing/0/192.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: "
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Y6xf6dFV3Oh5",
"outputId": "85ea0ed8-af18-47e5-abb5-4aa7f2f9838c"
},
"source": [
"!nvidia-smi"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Fri Oct 8 13:33:46 2021 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 470.74 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 73C P0 74W / 149W | 1621MiB / 11441MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1C38DN_I2AtU"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Tz5cjqS82AtW"
},
"source": [
"import os\n",
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torchvision \n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import Dataset, DataLoader"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "thiTppkB2AtZ"
},
"source": [
"## Custom DataSet with DataLoader\n",
"___\n",
"We have used a subset of MNIST"
]
},
{
"cell_type": "code",
"metadata": {
"id": "eNeiL80e2Atb"
},
"source": [
"class ImageDataset(Dataset):\n",
" def __init__(self, file_list, target_list):\n",
" self.file_list = file_list\n",
" self.target_list = target_list\n",
" self.n_class = len(list(set(target_list)))\n",
"\n",
" def __len__(self):\n",
" return len(self.file_list)\n",
"\n",
" def __getitem__(self, index):\n",
" img = Image.open(self.file_list[index])\n",
" img = torchvision.transforms.ToTensor()(img)\n",
" label = self.target_list[index]\n",
" return img, label"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "iM78P7G02Atd"
},
"source": [
"#### Parse the given directory to accumulate all the images"
]
},
{
"cell_type": "code",
"metadata": {
"id": "z6rsYF012Atd"
},
"source": [
"def parse_data(datadir):\n",
" img_list = []\n",
" ID_list = []\n",
" for root, directories, filenames in os.walk(datadir): #root: median/1\n",
" for filename in filenames:\n",
" if filename.endswith('.png'):\n",
" filei = os.path.join(root, filename)\n",
" img_list.append(filei)\n",
" ID_list.append(root.split('/')[-1])\n",
"\n",
" # construct a dictionary, where key and value correspond to ID and target\n",
" uniqueID_list = list(set(ID_list))\n",
" class_n = len(uniqueID_list)\n",
" target_dict = dict(zip(uniqueID_list, range(class_n)))\n",
" label_list = [target_dict[ID_key] for ID_key in ID_list]\n",
"\n",
" print('{}\\t\\t{}\\n{}\\t\\t{}'.format('#Images', '#Labels', len(img_list), len(set(label_list))))\n",
" return img_list, label_list, class_n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"scrolled": false,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "V9T0aVXO2Atf",
"outputId": "26d524b4-6f51-4e8b-d081-53711deb56de"
},
"source": [
"img_list, label_list, class_n = parse_data('mnist/training')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"#Images\t\t#Labels\n",
"5000\t\t10\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "viB0vfgN2Ath",
"outputId": "88cd00c2-41b8-41e6-e401-d38129ce5de9"
},
"source": [
"print(img_list[1888])"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"mnist/training/7/11854.png\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9HAzCVW12Ati"
},
"source": [
"trainset = ImageDataset(img_list, label_list)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Kj50Kxp62Atj"
},
"source": [
"train_data_item, train_data_label = trainset.__getitem__(0)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iYNQELVY2Atj",
"outputId": "5dc7b0eb-9a90-4b51-ab08-0e51afca5d66"
},
"source": [
"print('data item shape: {}\\t data item label: {}'.format(train_data_item.shape, train_data_label))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"data item shape: torch.Size([1, 28, 28])\t data item label: 0\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DmBu6GGc2Atk"
},
"source": [
"dataloader = DataLoader(trainset, batch_size=10, shuffle=True, num_workers=1, drop_last=False)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "vzCAeqcS2Atm"
},
"source": [
"## Torchvision DataSet and DataLoader"
]
},
{
"cell_type": "code",
"metadata": {
"id": "z73SLKYB2Atn"
},
"source": [
"imageFolder_dataset = torchvision.datasets.ImageFolder(root='mnist/training/', transform=torchvision.transforms.ToTensor())"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "MmdS6KKi2Ato"
},
"source": [
"imageFolder_dataloader = DataLoader(imageFolder_dataset, batch_size=10, shuffle=True, num_workers=1)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SoCZgWwn2Atq",
"outputId": "bd3a6bf0-3700-46de-9ac3-be92887add07"
},
"source": [
"print(imageFolder_dataset.__len__(), len(imageFolder_dataset.classes))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"5000 10\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gRAuxUTf2Ats"
},
"source": [
"## Residual Block\n",
"\n",
"Resnet: https://arxiv.org/pdf/1512.03385.pdf\n",
"\n",
"Here is a basic usage of shortcut in Resnet"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Dc8jnM_W2Att"
},
"source": [
"# This is the simplest possible residual block, with only one CNN layer.\n",
"# Looking at the paper, you can extend this block to have more layers, bottleneck, grouped convs (from shufflenet), etc.\n",
"# Or even look at more recent papers like resnext, regnet, resnest, senet, etc.\n",
"class SimpleResidualBlock(nn.Module):\n",
" def __init__(self, channel_size, stride=1):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(channel_size, channel_size, kernel_size=3, stride=stride, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(channel_size)\n",
" if stride == 1:\n",
" self.shortcut = nn.Identity()\n",
" else:\n",
" self.shortcut = nn.Conv2d(channel_size, channel_size, kernel_size=1, stride=stride)\n",
" self.relu = nn.ReLU()\n",
" \n",
" def forward(self, x):\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" \n",
" shortcut = self.shortcut(x)\n",
" \n",
" out = self.relu(out + shortcut)\n",
" \n",
" return out"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZaOOXdrr2Atw"
},
"source": [
"# This has hard-coded hidden feature sizes.\n",
"# You can extend this to take in a list of hidden sizes as argument if you want.\n",
"class ClassificationNetwork(nn.Module):\n",
" def __init__(self, in_features, num_classes):\n",
" super().__init__()\n",
" \n",
" self.layers = nn.Sequential(\n",
" nn.Conv2d(in_features, 64, kernel_size=3, stride=1, padding=1, bias=False),\n",
" nn.BatchNorm2d(64),\n",
" nn.ReLU(),\n",
" SimpleResidualBlock(64),\n",
" SimpleResidualBlock(64),\n",
" SimpleResidualBlock(64),\n",
" SimpleResidualBlock(64),\n",
" nn.AdaptiveAvgPool2d((1, 1)), # For each channel, collapses (averages) the entire feature map (height & width) to 1x1\n",
" nn.Flatten(), # the above ends up with batch_size x 64 x 1 x 1, flatten to batch_size x 64\n",
" )\n",
" self.linear = nn.Linear(64, num_classes)\n",
" \n",
" def forward(self, x, return_embedding=False):\n",
" embedding = self.layers(x) \n",
" \n",
" if return_embedding:\n",
" return embedding\n",
" else:\n",
" return self.linear(embedding) "
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Df1Yukn42Atx",
"outputId": "afaa8566-6043-4001-cfb6-66bf130368fc"
},
"source": [
"train_dataset = torchvision.datasets.ImageFolder(root='mnist/training/', \n",
" transform=torchvision.transforms.ToTensor())\n",
"train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, \n",
" shuffle=True, num_workers=8)\n",
"\n",
"dev_dataset = torchvision.datasets.ImageFolder(root='mnist/testing/', \n",
" transform=torchvision.transforms.ToTensor())\n",
"dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=128, \n",
" shuffle=False, num_workers=8)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" cpuset_checked))\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NvBcX3BY2Atx"
},
"source": [
"numEpochs = 10\n",
"in_features = 3 # RGB channels\n",
"\n",
"learningRate = 5e-2\n",
"weightDecay = 5e-5\n",
"\n",
"num_classes = len(train_dataset.classes)\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"network = ClassificationNetwork(in_features, num_classes)\n",
"network = network.to(device)\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.SGD(network.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=0.9)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9ajxMuLT2Atx",
"outputId": "b2016788-e94f-4b69-e233-be3e06faf983"
},
"source": [
"# Train!\n",
"for epoch in range(numEpochs):\n",
" \n",
" # Train\n",
" network.train()\n",
" avg_loss = 0.0\n",
" for batch_num, (x, y) in enumerate(train_dataloader):\n",
" optimizer.zero_grad()\n",
" \n",
" x, y = x.to(device), y.to(device)\n",
"\n",
" outputs = network(x)\n",
"\n",
" loss = criterion(outputs, y.long())\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" avg_loss += loss.item()\n",
"\n",
" if batch_num % 10 == 9:\n",
" print('Epoch: {}\\tBatch: {}\\tAvg-Loss: {:.4f}'.format(epoch, batch_num+1, avg_loss/50))\n",
" avg_loss = 0.0\n",
" \n",
" # Validate\n",
" network.eval()\n",
" num_correct = 0\n",
" for batch_num, (x, y) in enumerate(dev_dataloader):\n",
" x, y = x.to(device), y.to(device)\n",
" outputs = network(x)\n",
" num_correct += (torch.argmax(outputs, axis=1) == y).sum().item()\n",
" \n",
" print('Epoch: {}, Validation Accuracy: {:.2f}'.format(epoch, num_correct / len(dev_dataset)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" cpuset_checked))\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch: 0\tBatch: 10\tAvg-Loss: 0.4459\n",
"Epoch: 0\tBatch: 20\tAvg-Loss: 0.3750\n",
"Epoch: 0\tBatch: 30\tAvg-Loss: 0.3318\n",
"Epoch: 0\tBatch: 40\tAvg-Loss: 0.3124\n",
"Epoch: 0, Validation Accuracy: 0.15\n",
"Epoch: 1\tBatch: 10\tAvg-Loss: 0.2855\n",
"Epoch: 1\tBatch: 20\tAvg-Loss: 0.2545\n",
"Epoch: 1\tBatch: 30\tAvg-Loss: 0.2309\n",
"Epoch: 1\tBatch: 40\tAvg-Loss: 0.2029\n",
"Epoch: 1, Validation Accuracy: 0.16\n",
"Epoch: 2\tBatch: 10\tAvg-Loss: 0.2189\n",
"Epoch: 2\tBatch: 20\tAvg-Loss: 0.1736\n",
"Epoch: 2\tBatch: 30\tAvg-Loss: 0.1480\n",
"Epoch: 2\tBatch: 40\tAvg-Loss: 0.1320\n",
"Epoch: 2, Validation Accuracy: 0.33\n",
"Epoch: 3\tBatch: 10\tAvg-Loss: 0.1201\n",
"Epoch: 3\tBatch: 20\tAvg-Loss: 0.0961\n",
"Epoch: 3\tBatch: 30\tAvg-Loss: 0.0913\n",
"Epoch: 3\tBatch: 40\tAvg-Loss: 0.0877\n",
"Epoch: 3, Validation Accuracy: 0.58\n",
"Epoch: 4\tBatch: 10\tAvg-Loss: 0.0983\n",
"Epoch: 4\tBatch: 20\tAvg-Loss: 0.0801\n",
"Epoch: 4\tBatch: 30\tAvg-Loss: 0.0660\n",
"Epoch: 4\tBatch: 40\tAvg-Loss: 0.0541\n",
"Epoch: 4, Validation Accuracy: 0.73\n",
"Epoch: 5\tBatch: 10\tAvg-Loss: 0.0688\n",
"Epoch: 5\tBatch: 20\tAvg-Loss: 0.0593\n",
"Epoch: 5\tBatch: 30\tAvg-Loss: 0.0444\n",
"Epoch: 5\tBatch: 40\tAvg-Loss: 0.0480\n",
"Epoch: 5, Validation Accuracy: 0.80\n",
"Epoch: 6\tBatch: 10\tAvg-Loss: 0.0545\n",
"Epoch: 6\tBatch: 20\tAvg-Loss: 0.0464\n",
"Epoch: 6\tBatch: 30\tAvg-Loss: 0.0400\n",
"Epoch: 6\tBatch: 40\tAvg-Loss: 0.0344\n",
"Epoch: 6, Validation Accuracy: 0.91\n",
"Epoch: 7\tBatch: 10\tAvg-Loss: 0.0403\n",
"Epoch: 7\tBatch: 20\tAvg-Loss: 0.0345\n",
"Epoch: 7\tBatch: 30\tAvg-Loss: 0.0325\n",
"Epoch: 7\tBatch: 40\tAvg-Loss: 0.0428\n",
"Epoch: 7, Validation Accuracy: 0.73\n",
"Epoch: 8\tBatch: 10\tAvg-Loss: 0.0538\n",
"Epoch: 8\tBatch: 20\tAvg-Loss: 0.0420\n",
"Epoch: 8\tBatch: 30\tAvg-Loss: 0.0353\n",
"Epoch: 8\tBatch: 40\tAvg-Loss: 0.0426\n",
"Epoch: 8, Validation Accuracy: 0.89\n",
"Epoch: 9\tBatch: 10\tAvg-Loss: 0.0439\n",
"Epoch: 9\tBatch: 20\tAvg-Loss: 0.0312\n",
"Epoch: 9\tBatch: 30\tAvg-Loss: 0.0278\n",
"Epoch: 9\tBatch: 40\tAvg-Loss: 0.0316\n",
"Epoch: 9, Validation Accuracy: 0.87\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uflhIjqcBdUA"
},
"source": [
"## Computing Cosine Similarity between Feature Embeddings\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xGgTtsra2Aty"
},
"source": [
"# Let's try cosine similarity\n",
"\n",
"compute_sim = nn.CosineSimilarity(dim=0)\n",
"\n",
"img_a = dev_dataset[0][0] # this is class 0\n",
"img_b = dev_dataset[1][0] # this is also class 0\n",
"img_c = dev_dataset[51][0] # this is class 1\n",
"img_d = dev_dataset[451][0] # this is class 9"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "TfuSGH552Atz"
},
"source": [
"network.eval()\n",
"feats_a = network(img_a.cuda().unsqueeze(0), return_embedding=True).squeeze(0)\n",
"feats_b = network(img_b.cuda().unsqueeze(0), return_embedding=True).squeeze(0)\n",
"feats_c = network(img_c.cuda().unsqueeze(0), return_embedding=True).squeeze(0)\n",
"feats_d = network(img_d.cuda().unsqueeze(0), return_embedding=True).squeeze(0)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dCuC5HiV2Atz",
"outputId": "01670e37-0e58-4f9f-9646-a728ac968f8e"
},
"source": [
"print(\"CS between two images of class 0: {:.4f}\".format(compute_sim(feats_a, feats_b)))\n",
"print(\"CS between an image of class 0 and image of class 1: {:.4f}\".format(compute_sim(feats_a, feats_c)))\n",
"print(\"CS between an image of class 0 and image of class 9: {:.4f}\".format(compute_sim(feats_a, feats_d)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"CS between two images of class 0: 0.9911\n",
"CS between an image of class 0 and image of class 1: 0.8435\n",
"CS between an image of class 0 and image of class 9: 0.8955\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zdbtihjR4bW6"
},
"source": [
"## Center Loss\n",
"___\n",
"The following piece of code for Center Loss has been pulled and modified based on the code from the GitHub Repo: https://github.com/KaiyangZhou/pytorch-center-loss\n",
" \n",
"Reference:\n",
"Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016."
]
},
{
"cell_type": "code",
"metadata": {
"id": "VUumis_-2Atz"
},
"source": [
"class CenterLoss(nn.Module):\n",
" \"\"\"\n",
" Args:\n",
" num_classes (int): number of classes.\n",
" feat_dim (int): feature dimension.\n",
" \"\"\"\n",
" def __init__(self, num_classes, feat_dim, device=torch.device('cpu')):\n",
" super(CenterLoss, self).__init__()\n",
" self.num_classes = num_classes\n",
" self.feat_dim = feat_dim\n",
" self.device = device\n",
" \n",
" self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).to(self.device))\n",
"\n",
" def forward(self, x, labels):\n",
" \"\"\"\n",
" Args:\n",
" x: feature matrix with shape (batch_size, feat_dim).\n",
" labels: ground truth labels with shape (batch_size).\n",
" \"\"\"\n",
" batch_size = x.size(0)\n",
" distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \\\n",
" torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()\n",
" distmat.addmm_(1, -2, x, self.centers.t())\n",
"\n",
" classes = torch.arange(self.num_classes).long().to(self.device)\n",
" labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)\n",
" mask = labels.eq(classes.expand(batch_size, self.num_classes))\n",
"\n",
" dist = []\n",
" for i in range(batch_size):\n",
" value = distmat[i][mask[i]]\n",
" value = value.clamp(min=1e-12, max=1e+12) # for numerical stability\n",
" dist.append(value)\n",
" dist = torch.cat(dist)\n",
" loss = dist.mean()\n",
"\n",
" return loss"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JGYKIgEW6vng"
},
"source": [
"class Network(nn.Module):\n",
" def __init__(self, num_feats, hidden_sizes, num_classes, feat_dim=10):\n",
" super(Network, self).__init__()\n",
" \n",
" self.hidden_sizes = [num_feats] + hidden_sizes + [num_classes]\n",
" \n",
" self.layers = []\n",
" for idx, channel_size in enumerate(hidden_sizes):\n",
" self.layers.append(nn.Conv2d(in_channels=self.hidden_sizes[idx], \n",
" out_channels=self.hidden_sizes[idx+1], \n",
" kernel_size=3, stride=2, bias=False))\n",
" self.layers.append(nn.ReLU(inplace=True))\n",
" self.layers.append(SimpleResidualBlock(channel_size = channel_size))\n",
" \n",
" self.layers = nn.Sequential(*self.layers)\n",
" self.linear_label = nn.Linear(self.hidden_sizes[-2], self.hidden_sizes[-1], bias=False)\n",
" \n",
" # For creating the embedding to be passed into the Center Loss criterion\n",
" self.linear_closs = nn.Linear(self.hidden_sizes[-2], feat_dim, bias=False)\n",
" self.relu_closs = nn.ReLU(inplace=True)\n",
" \n",
" def forward(self, x, evalMode=False):\n",
" output = x\n",
" output = self.layers(output)\n",
" \n",
" output = F.avg_pool2d(output, [output.size(2), output.size(3)], stride=1)\n",
" output = output.reshape(output.shape[0], output.shape[1])\n",
" \n",
" label_output = self.linear_label(output)\n",
" label_output = label_output/torch.norm(self.linear_label.weight, dim=1)\n",
" \n",
" # Create the feature embedding for the Center Loss\n",
" closs_output = self.linear_closs(output)\n",
" closs_output = self.relu_closs(closs_output)\n",
"\n",
" return closs_output, label_output\n",
"\n",
"def init_weights(m):\n",
" if type(m) == nn.Conv2d or type(m) == nn.Linear:\n",
" torch.nn.init.xavier_normal_(m.weight.data)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wHOXCqVN3fzt"
},
"source": [
"def train_closs(model, data_loader, test_loader, task='Classification'):\n",
" model.train()\n",
"\n",
" for epoch in range(numEpochs):\n",
" avg_loss = 0.0\n",
" for batch_num, (feats, labels) in enumerate(data_loader):\n",
" feats, labels = feats.to(device), labels.to(device)\n",
" \n",
" optimizer_label.zero_grad()\n",
" optimizer_closs.zero_grad()\n",
" \n",
" feature, outputs = model(feats)\n",
"\n",
" l_loss = criterion_label(outputs, labels.long())\n",
" c_loss = criterion_closs(feature, labels.long())\n",
" loss = l_loss + closs_weight * c_loss\n",
" \n",
" loss.backward()\n",
" \n",
" optimizer_label.step()\n",
" # by doing so, weight_cent would not impact on the learning of centers\n",
" for param in criterion_closs.parameters():\n",
" param.grad.data *= (1. / closs_weight)\n",
" optimizer_closs.step()\n",
" \n",
" avg_loss += loss.item()\n",
"\n",
" if batch_num % 50 == 49:\n",
" print('Epoch: {}\\tBatch: {}\\tAvg-Loss: {:.4f}'.format(epoch+1, batch_num+1, avg_loss/50))\n",
" avg_loss = 0.0 \n",
" \n",
" torch.cuda.empty_cache()\n",
" del feats\n",
" del labels\n",
" del loss\n",
" \n",
" if task == 'Classification':\n",
" val_loss, val_acc = test_classify_closs(model, test_loader)\n",
" train_loss, train_acc = test_classify_closs(model, data_loader)\n",
" print('Train Loss: {:.4f}\\tTrain Accuracy: {:.4f}\\tVal Loss: {:.4f}\\tVal Accuracy: {:.4f}'.\n",
" format(train_loss, train_acc, val_loss, val_acc))\n",
" else:\n",
" test_verify(model, test_loader)\n",
"\n",
"\n",
"def test_classify_closs(model, test_loader):\n",
" model.eval()\n",
" test_loss = []\n",
" accuracy = 0\n",
" total = 0\n",
"\n",
" for batch_num, (feats, labels) in enumerate(test_loader):\n",
" feats, labels = feats.to(device), labels.to(device)\n",
" feature, outputs = model(feats)\n",
" \n",
" _, pred_labels = torch.max(F.softmax(outputs, dim=1), 1)\n",
" pred_labels = pred_labels.view(-1)\n",
" \n",
" l_loss = criterion_label(outputs, labels.long())\n",
" c_loss = criterion_closs(feature, labels.long())\n",
" loss = l_loss + closs_weight * c_loss\n",
" \n",
" accuracy += torch.sum(torch.eq(pred_labels, labels)).item()\n",
" total += len(labels)\n",
" test_loss.extend([loss.item()]*feats.size()[0])\n",
" del feats\n",
" del labels\n",
"\n",
" model.train()\n",
" return np.mean(test_loss), accuracy/total"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BvqVxNry4iCw"
},
"source": [
"numEpochs = 10\n",
"num_feats = 3\n",
"closs_weight = 1\n",
"lr_cent = 0.5\n",
"feat_dim = 10\n",
"\n",
"weightDecay = 5e-5\n",
"\n",
"hidden_sizes = [32, 64]\n",
"num_classes = len(train_dataset.classes)\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"network = Network(num_feats, hidden_sizes, num_classes, feat_dim)\n",
"network.apply(init_weights)\n",
"\n",
"criterion_label = nn.CrossEntropyLoss()\n",
"criterion_closs = CenterLoss(num_classes, feat_dim, device)\n",
"optimizer_label = torch.optim.SGD(network.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=0.9)\n",
"optimizer_closs = torch.optim.SGD(criterion_closs.parameters(), lr=lr_cent)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_p4mDFQq4jkM",
"outputId": "5619d6e7-cafe-4fa3-c05c-088de25c55af"
},
"source": [
"network.train()\n",
"network.to(device)\n",
"train_closs(network, train_dataloader, dev_dataloader)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" cpuset_checked))\n",
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:23: UserWarning: This overload of addmm_ is deprecated:\n",
"\taddmm_(Number beta, Number alpha, Tensor mat1, Tensor mat2)\n",
"Consider using one of the following signatures instead:\n",
"\taddmm_(Tensor mat1, Tensor mat2, *, Number beta, Number alpha) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:1025.)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train Loss: 1.9780\tTrain Accuracy: 0.2150\tVal Loss: 1.9809\tVal Accuracy: 0.2080\n",
"Train Loss: 1.2386\tTrain Accuracy: 0.6250\tVal Loss: 1.2105\tVal Accuracy: 0.6480\n",
"Train Loss: 0.9858\tTrain Accuracy: 0.6466\tVal Loss: 0.9397\tVal Accuracy: 0.6620\n",
"Train Loss: 0.9313\tTrain Accuracy: 0.6600\tVal Loss: 0.9176\tVal Accuracy: 0.6620\n",
"Train Loss: 0.5583\tTrain Accuracy: 0.8600\tVal Loss: 0.5493\tVal Accuracy: 0.8600\n",
"Train Loss: 4.5172\tTrain Accuracy: 0.1360\tVal Loss: 4.4614\tVal Accuracy: 0.1320\n",
"Train Loss: 0.4353\tTrain Accuracy: 0.8812\tVal Loss: 0.4230\tVal Accuracy: 0.8980\n",
"Train Loss: 0.4435\tTrain Accuracy: 0.8686\tVal Loss: 0.3987\tVal Accuracy: 0.9060\n",
"Train Loss: 0.3900\tTrain Accuracy: 0.8876\tVal Loss: 0.3889\tVal Accuracy: 0.8920\n",
"Train Loss: 0.5629\tTrain Accuracy: 0.8208\tVal Loss: 0.5867\tVal Accuracy: 0.8140\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6NX8X1wj8RHo"
},
"source": [
"## Triplet Loss\n",
"___\n",
"You can make a dataloader that returns a tuple of three images. Two being from the same class and one from a different class. You can then use triplet loss to seperate out the different class pair distance and decrease same class pair distance.\n",
"\n",
"More on this link: https://github.com/adambielski/siamese-triplet/blob/master/losses.py"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Nvl_ik678PHO",
"outputId": "da48651e-dbf0-4cff-c130-68b02bba0e2f"
},
"source": [
"triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)\n",
"face_img1, label_img1 = trainset.__getitem__(0)\n",
"face_img2, label_img2 = trainset.__getitem__(1)\n",
"face_img3, label_img3 = trainset.__getitem__(-1)\n",
"\n",
"print(label_img1, label_img2, label_img3)\n",
"## face_img1 and face_img2 are from the same class and face_img3 is from a different class.\n",
"loss = triplet_loss(face_img1, face_img2, face_img3)\n",
"print (\"Loss={:0.2f}\".format(loss))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 0 1\n",
"Loss=0.85\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "j54x6nGy8Ui5"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}