In [2]:
import os, cv2, glob
import numpy as np
import pandas as pd
import random, tqdm
# import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album
import segmentation_models_pytorch as smp
import torchmetrics
from tqdm import tqdm
import cv2
import numpy as np
import torch
import albumentations as album
import segmentation_models_pytorch as smp
from torchinfo import summary


In [3]:
def get_validation_augmentation():   
    # Resize to (height=720, width=1280)
    resize = album.Resize(height=480, width=640, always_apply=True),

    # Add sufficient padding to ensure image is divisible by 32
    # test_transform = [
    #     album.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=0),
    # ]
    return album.Compose(resize)

def get_preprocessing(preprocessing_fn=None): 
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')



# process_fn = get_preprocessing(preprocessing_fn)

In [3]:
def visualize_segmentation(images, masks, outputs):
    """
    Visualizes input image, ground truth mask, and predicted mask.
    """
    # Convert tensors to CPU and detach
    images = images.cpu().detach()
    masks = masks.cpu().detach()
    outputs = torch.sigmoid(outputs).cpu().detach()  # Apply sigmoid to get the output mask
    
    # Apply argmax to reduce the 2-channel output to 1 channel (binary mask)
    outputs = torch.argmax(outputs, dim=1)  # This will convert the (batch_size, 2, H, W) into (H, W)
    masks = torch.argmax(masks, dim=1)      # Convert the (batch_size, 2, H, W) ground truth mask similarly
    
    # Loop through the batch and visualize
    for i in range(min(images.shape[0], 1)):  # Visualize one sample per batch
        fig, axs = plt.subplots(1, 3, figsize=(15, 10))
        
        # Plot the input image
        axs[0].imshow(images[i].permute(1, 2, 0))  # Convert to HWC format
        axs[0].set_title("Input Image")
        axs[0].axis('off')
        
        # Plot the ground truth mask
        axs[1].imshow(masks[i], cmap='gray')  # Ground truth mask (now single-channel)
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis('off')
        
        # Plot the predicted mask
        axs[2].imshow(outputs[i], cmap='gray')  # Predicted mask (now single-channel)
        axs[2].set_title("Predicted Mask")
        axs[2].axis('off')
        
        plt.show()

In [4]:

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing function used during training
def get_preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

# Model and preprocessing initialization
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid'  # Can also use softmax2d for multi-class



In [5]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class Conv2dReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(Conv2dReLU, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.block(x)

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.conv1 = Conv2dReLU(in_channels, out_channels)
        self.conv2 = Conv2dReLU(out_channels, out_channels)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    
    def forward(self, x, skip):
        # Upsample
        x = self.up(x)
        
        # Resize skip connection to match the size of x
        if x.size() != skip.size():
            skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=True)
        
        # Concatenate the skip connection (encoder output)
        x = torch.cat([x, skip], dim=1)
        
        # Apply convolutions
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class UNetResNet50(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super(UNetResNet50, self).__init__()

        # Load ResNet50 pre-trained on ImageNet
#         self.encoder = models.resnet50(pretrained=pretrained)
        self.encoder = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        

        # Encoder layers from ResNet-50 (for skip connections)
        self.encoder_layers = [
            nn.Sequential(self.encoder.conv1, self.encoder.bn1, self.encoder.relu, self.encoder.maxpool),  # (64, H/4, W/4)
            self.encoder.layer1,  # (256, H/4, W/4)
            self.encoder.layer2,  # (512, H/8, W/8)
            self.encoder.layer3,  # (1024, H/16, W/16)
            self.encoder.layer4   # (2048, H/32, W/32)
        ]

        # Decoder (Upsampling blocks)
        self.decoder4 = DecoderBlock(2048 + 1024, 512)  # Block for layer4 + layer3
        self.decoder3 = DecoderBlock(512 + 512, 256)    # Block for layer3 + layer2
        self.decoder2 = DecoderBlock(256 + 256, 128)    # Block for layer2 + layer1
        self.decoder1 = DecoderBlock(128 + 64, 64)      # Block for layer1 + conv1

        # Final segmentation head
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(64, num_classes, kernel_size=3, padding=1),
            # nn.Sigmoid()  # Use sigmoid for binary segmentation
        )

    def forward(self, x):
        # Save original input size for final upsampling
        original_size = x.shape[2:]  # (H, W)

        # Encoder forward pass
        x0 = self.encoder_layers[0](x)  # Initial convolution block (conv1)
        x1 = self.encoder_layers[1](x0)  # Skip connection 1 (layer1)
        x2 = self.encoder_layers[2](x1)  # Skip connection 2 (layer2)
        x3 = self.encoder_layers[3](x2)  # Skip connection 3 (layer3)
        x4 = self.encoder_layers[4](x3)  # Skip connection 4 (layer4)

        # Decoder forward pass
        x = self.decoder4(x4, x3)  # Decoder for layer4 + skip3
        x = self.decoder3(x, x2)   # Decoder for layer3 + skip2
        x = self.decoder2(x, x1)   # Decoder for layer2 + skip1
        x = self.decoder1(x, x0)    # Decoder for layer1 + initial conv1 output

        # Upsample the final output to match the input size dynamically
        x = F.interpolate(x, size=original_size, mode='bilinear', align_corners=True)

        # Final segmentation output
        x = self.segmentation_head(x)

        return x

model = UNetResNet50(num_classes=2, pretrained=True)
model.to(DEVICE)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

# Load the model checkpoint (adjust the path)
model.load_state_dict(torch.load('models/base_line_e50_v2/model_epoch_24.pth'))
model.eval()

# Preprocessing and augmentations for validation
def get_validation_augmentation():   
    resize = album.Resize(height=480, width=640, always_apply=True)
    return album.Compose([resize])

# Function to perform inference on a single image
def infer_single_image(image):
    # Read the image
    # image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Apply augmentations
    augmented = get_validation_augmentation()(image=image)
    image = augmented['image']
    
    # Apply preprocessing
    preprocessed = get_preprocessing(preprocessing_fn)(image=image)
    image_tensor = torch.tensor(preprocessed['image']).unsqueeze(0).to(DEVICE)  # Add batch dimension
    
    # Perform inference
    with torch.no_grad():
        output = model(image_tensor)
        output = torch.sigmoid(output).cpu().numpy()  # For binary segmentation
        output = np.argmax(output, axis=1)  # Convert to 1-channel output if needed
    
    return output, image



Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /home/somusan/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth" to /home/somusan/.cache/torch/hub/checkpoints/dino_resnet50_pretrain.pth
100%|██████████| 90.0M/90.0M [00:32<00:00, 2.88MB/s]


In [6]:
summary(model, input_size=(4, 3, 480, 640))


Layer (type:depth-idx)                        Output Shape              Param #
UNetResNet50                                  [4, 2, 480, 640]          --
├─ResNet: 1-1                                 --                        --
│    └─Conv2d: 2-1                            [4, 64, 240, 320]         9,408
│    └─BatchNorm2d: 2-2                       [4, 64, 240, 320]         128
│    └─ReLU: 2-3                              [4, 64, 240, 320]         --
│    └─MaxPool2d: 2-4                         [4, 64, 120, 160]         --
│    └─Sequential: 2-5                        [4, 256, 120, 160]        --
│    │    └─Bottleneck: 3-1                   [4, 256, 120, 160]        75,008
│    │    └─Bottleneck: 3-2                   [4, 256, 120, 160]        70,400
│    │    └─Bottleneck: 3-3                   [4, 256, 120, 160]        70,400
│    └─Sequential: 2-6                        [4, 512, 60, 80]          --
│    │    └─Bottleneck: 3-4                   [4, 512, 60, 80]          379,392

In [7]:
import numpy as np
import cv2

def draw_segmentation_map(labels, palette):
    """
    :param labels: Label array from the model. Should be of shape 
        <height x width>. No channel information required.
    :param palette: List containing color information.
        e.g. [[0, 255, 0], [255, 255, 0]] 
    """
    # create Numpy arrays containing zeros
    # to be used to fill them with respective red, green, and blue pixels
    red_map = np.zeros_like(labels).astype(np.uint8)
    green_map = np.zeros_like(labels).astype(np.uint8)
    blue_map = np.zeros_like(labels).astype(np.uint8)

    for label_num in range(0, len(palette)):
        index = labels == label_num
        red_map[index] = palette[label_num][0]
        green_map[index] = palette[label_num][1]
        blue_map[index] = palette[label_num][2]
        
    segmentation_map = np.stack([red_map, green_map, blue_map], axis=2)
    return segmentation_map

def image_overlay(image, segmented_image):
    """
    :param image: Image in RGB format.
    :param segmented_image: Segmentation map in RGB format. 
    """
    alpha = 0.8  # transparency for the original image
    beta = 1.0   # transparency for the segmentation map
    gamma = 0    # scalar added to each sum
    segmented_image = np.uint8(segmented_image)
    image = np.array(image)
    # image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    
    # Overlay the segmentation map on the original image
    overlay = cv2.addWeighted(image, alpha, segmented_image, beta, gamma)
    
    return overlay

# Updated LABEL_COLORS_LIST for 2 classes: background and road
LABEL_COLORS_LIST = [
    [0, 0, 0],      # background (black)
    [255, 0, 0]     # road (red)
]

# Assuming 'mask' is of size (1, 480, 640) as the output of your model
# Step 1: Remove the channel dimension (1) from the mask
def process_inference_mask(mask):
    if mask.ndim == 3 and mask.shape[0] == 1:
        mask = np.squeeze(mask, axis=0)  # Shape becomes (480, 640)
    return mask



In [17]:
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

# Ensure the VideoCapture is opened correctly
cap = cv2.VideoCapture("/home/opencvuniv/Work/somusan/blogpost/road_seg/infer_vod/blr_indian_road_dashcam.mp4")

# Prepare to save the output video using the 'H264' codec
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Use 'H264' codec if available
out = cv2.VideoWriter('/home/opencvuniv/Work/somusan/blogpost/road_seg/infer_vod/indian_road_v11.mp4', fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))

frame_cnt = 0

try:
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("End of video stream.")
            break
        
        # Ensure frame has the correct size
        original_height, original_width = frame.shape[:2]

        # Run inference on the frame
        result, img = infer_single_image(frame)

        # Ensure the processed mask matches the frame size
        processed_mask = process_inference_mask(result)
        segmentation_map = draw_segmentation_map(processed_mask, LABEL_COLORS_LIST)
        
        # Ensure the output image size matches the input frame size
        overlay_image = image_overlay(img, segmentation_map)

        # Resize to match the original frame size, if needed
        if overlay_image.shape[:2] != (original_height, original_width):
            overlay_image = cv2.resize(overlay_image, (original_width, original_height))
        
        overlay_image = cv2.cvtColor(overlay_image, cv2.COLOR_BGR2RGB)
        # Write the frame to the output video
        out.write(overlay_image)
       
        frame_cnt += 1
        print(f"Processed frame {frame_cnt}")

finally:
    # Release resources
    cap.release()
    out.release()
    print("Video processing completed.")


Processed frame 1
Processed frame 2
Processed frame 3
Processed frame 4
Processed frame 5
Processed frame 6
Processed frame 7
Processed frame 8
Processed frame 9
Processed frame 10
Processed frame 11
Processed frame 12
Processed frame 13
Processed frame 14
Processed frame 15
Processed frame 16
Processed frame 17
Processed frame 18
Processed frame 19
Processed frame 20
Processed frame 21
Processed frame 22
Processed frame 23
Processed frame 24
Processed frame 25
Processed frame 26
Processed frame 27
Processed frame 28
Processed frame 29
Processed frame 30
Processed frame 31
Processed frame 32
Processed frame 33
Processed frame 34
Processed frame 35
Processed frame 36
Processed frame 37
Processed frame 38
Processed frame 39
Processed frame 40
Processed frame 41
Processed frame 42
Processed frame 43
Processed frame 44
Processed frame 45
Processed frame 46
Processed frame 47
Processed frame 48
Processed frame 49
Processed frame 50
Processed frame 51
Processed frame 52
Processed frame 53
Pr