# SAHI

<div style="text-align: center">
  <img src="https://learnopencv.com/wp-content/uploads/2023/06/sliced_inference.gif" alt="Image"  width=500 height=400>
</div>






Traditional object detection models often struggle with small objects due to their limited size and the contextual information available in an image. That’s where SAHI comes into play to charm with its great results. SAHI addresses this by employing techniques specifically focusing on augmenting the dataset to highlight these small instances. It enhances the training process by using methods like slicing images into smaller patches where small objects become more prominent and easier to detect.

To learn more about Sliced Aided Hyper Inference (SAHI), bookmark this for later.

Link: https://learnopencv.com/slicing-aided-hyper-inference/


## Import Dependencies

In [1]:
!pip install -qq -U sahi

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.6/112.6 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.4/88.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for fire (setup.py) ... [?25l[?25hdone


In [7]:
# import required functions, classes
import os
import cv2
import PIL
import random
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageOps, ImageStat
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import requests
import zipfile


from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, predict, get_prediction
from sahi.utils.file import download_from_url
from sahi.prediction import visualize_object_predictions
from sahi.utils.cv import read_image
from IPython.display import Image


import torch
import torchvision.transforms as T
from torchvision.transforms import v2 as Tv2
from torchvision import tv_tensors
from torchvision.transforms import functional as F
from torchvision.transforms.functional import to_pil_image
import torchvision.models.detection as detection
from torchvision import transforms
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

Let's set seed for reproducibility

In [8]:
def set_seeds():
    # fix random seeds
    SEED_VALUE = 42

    random.seed(SEED_VALUE)
    np.random.seed(SEED_VALUE)
    torch.manual_seed(SEED_VALUE)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(SEED_VALUE)
        torch.cuda.manual_seed_all(SEED_VALUE)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True
set_seeds()

In [9]:
!rm -rf /content/SeaDroneSee_test

## Download test subset

In [5]:
# Ensure the base directory exists
if not os.path.exists('SeaDroneSee_test'):
    os.mkdir('SeaDroneSee_test')

# Ensure the Model_ckpt directory exists inside the base directory
if not os.path.exists('SeaDroneSee_test/Model_ckpt'):
    os.mkdir('SeaDroneSee_test/Model_ckpt')

def download_file(url, save_name):
    if not os.path.exists(save_name):
        # Handling potential redirection in requests
        with requests.get(url, allow_redirects=True) as r:
            if r.status_code == 200:
                with open(save_name, 'wb') as f:
                    f.write(r.content)
            else:
                print("Failed to download the file, status code:", r.status_code)

def unzip(zip_file=None, target_dir=None):
    try:
        with zipfile.ZipFile(zip_file, 'r') as z:
            z.extractall(target_dir)
            print("Extracted all to:", target_dir)
    except zipfile.BadZipFile:
        print("Invalid file or error during extraction: Bad Zip File")
    except Exception as e:
        print("An error occurred:", e)

# Correct Dropbox link for test images (Ensure this is the direct download link or properly redirects)
download_url = 'https://www.dropbox.com/scl/fi/4qidpahgu9mogam33uxlz/SeaDroneSee_test.zip?rlkey=1gt6mebuppxg4ehzhicwqafav&st=rtuwbmuo&dl=1'
save_path = 'SeaDroneSee_test/SeaDroneSee_test.zip'
download_file(download_url, save_path)

# Correct Dropbox link for model checkpoint
model_ckpt_url = 'https://www.dropbox.com/scl/fi/xmftrum0a8rgjp82j6n65/model_ckpt.zip?rlkey=aywwl28rbcbiejggdps0durfu&st=dda61bld&dl=1'
model_save_path = 'SeaDroneSee_test/Model_ckpt.zip'
download_file(model_ckpt_url, model_save_path)

In [None]:
# Unzip test images to SeaDroneSee_test
unzip(zip_file=save_path, target_dir='SeaDroneSee_test')
# Unzip model checkpoint to Model_ckpt folder inside SeaDroneSee_test
unzip(zip_file=model_save_path, target_dir='SeaDroneSee_test/Model_ckpt')

In [None]:
pwd

In [None]:
cd SeaDroneSee_test

#### Class Mapping

Let's do a class mapping and assign a unique color for each label or class ID,

In [6]:
classes_to_idx = {
    0: 'ignored',
    1: 'swimmer',
    2: 'boat',
    3: 'jetski',
    4: 'life_saving_appliances',
    5: "buoy"
}

# Mapping category IDs to colors
category_colors = {
    0: 'black',   # ignored
    1: 'red',     # swimmer
    2: 'orange',   # boat
    3: 'blue',    # jetski
    4: 'purple',  # life saving appliances
    5: 'yellow'   # buoy
}

## Load Fine-tuned Faster RCNN Model Checkpoint

Let’s load the best model checkpoint and adjust the box_nms_thresh of 0.3 as a postprocessing step to avoid overlapping bounding boxes from same class instances. As we will use our fine-tuned FasterRCNN model, the pretrained=False argument is passed. The model’s state dictionary passes checks across all the layers, and for inference the model is set to eval mode.



In [10]:
checkpoint_path = "Model_ckpt/model_ckpt/Mobilenet_5e-4_best_model_checkpoint_epoch_27.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Function to load the trained model
def load_model(checkpoint_path, device):
    model = detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, num_classes=len(classes_to_idx),box_nms_thresh=0.3)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    return model
model = load_model(checkpoint_path, device)

In [10]:
detection_model = AutoDetectionModel.from_pretrained(
   model_type='torchvision',
   model=model, #Faster RCNN Model
   confidence_threshold=0.7,
   image_size=5436, #Image's longest dimension
   device="cpu", # or "cuda:0"
   load_at_init=True,
)

Using slice height and slice width we can control the dimension of the sliding window. As our model trained on a patch dimension of half the size of image dimensions we will choose the slice width and slice height accordingly.

In [16]:
img_path = 'test/7882.jpg'
img_filename_temp = img_path.split('/')[1]
img_filename = img_filename_temp.split('.')[0]

# print(img_filename)
img_pil = PIL.Image.open(img_path)
W,H = img_pil.size
# print(W)
s_h,s_w = H/2,W/2
s_h ,s_w = int(s_h),int(s_w)

7882



The get_sliced_prediction returns a list of detected object instances with their bbox, score and category id

In [17]:
result = get_sliced_prediction(
   img_path,
   detection_model,
   slice_height=s_h,
   slice_width=s_w,
   overlap_height_ratio=0.2,
   overlap_width_ratio=0.2,
)

Performing prediction on 9 slices.


Here we can see the class id is correct but corresponding label id is in accordance to COCO classes. So we will fix this by defining some custom functions that does class mapping and draw the bbox which matches the category id.


In [18]:
result.object_prediction_list

[ObjectPrediction<
     bbox: BoundingBox: <(1825.7972106933594, 876.7810821533203, 1899.0037841796875, 908.668701171875), w: 73.20657348632812, h: 31.887619018554688>,
     mask: None,
     score: PredictionScore: <value: 0.9916843175888062>,
     category: Category: <id: 1, name: person>>,
 ObjectPrediction<
     bbox: BoundingBox: <(1271.9078369140625, 1302.0615844726562, 1319.986328125, 1335.0255432128906), w: 48.0784912109375, h: 32.963958740234375>,
     mask: None,
     score: PredictionScore: <value: 0.9903111457824707>,
     category: Category: <id: 1, name: person>>,
 ObjectPrediction<
     bbox: BoundingBox: <(1842.816162109375, 1008.3194732666016, 1934.3472290039062, 1044.4426727294922), w: 91.53106689453125, h: 36.123199462890625>,
     mask: None,
     score: PredictionScore: <value: 0.9834246039390564>,
     category: Category: <id: 1, name: person>>,
 ObjectPrediction<
     bbox: BoundingBox: <(2409.4612426757812, 1173.730094909668, 2436.6535034179688, 1199.152221679687

In [19]:
img = cv2.imread(img_path,cv2.IMREAD_UNCHANGED)
img_converted = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
numpydata = np.asarray(img_converted)

This custom draw_bounding_boxes( ) utility takes in the image and objec_prediction_list from SAHI’s get_sliced_predictions and draw visually pleasing predictions.

In [20]:
def draw_bounding_boxes(image, object_prediction_list):
    draw = ImageDraw.Draw(image)
    font_size = int(min(image.size) * 0.008)  # Adjust font size based on image size
    font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
    font = ImageFont.truetype(font_path, font_size) if os.path.exists(font_path) else ImageFont.load_default()

    for prediction in object_prediction_list:
        bbox = prediction.bbox.to_xywh()
        category_id = prediction.category.id
        x, y, w, h = bbox
        x1, y1, x2, y2 = x, y, x + w, y + h
        color = category_colors.get(category_id, 'white')  # Default to white if category_id is unknown
        draw.rectangle([x1, y1, x2, y2], outline=color, width=6)
        # draw.text((x1, y1 - font_size), str(classes_to_idx[category_id]), fill=color, font=font)

    return image

In [21]:
# Draw bounding boxes
image_with_bboxes = draw_bounding_boxes(img_pil, result.object_prediction_list)

# Define the output path
output_directory = 'sahi_ouput_data'
output_path = os.path.join(output_directory, f'result_{img_filename}.png')

# Create the directory if it doesn't exist
os.makedirs(output_directory, exist_ok=True)


# Save the resulting image
output_path = f'sahi_ouput_data/result_{img_filename}.png'
image_with_bboxes.save(output_path)

# Display the image (optional, if running in an environment that supports it)
image_with_bboxes.show()

## **Without SAHI or Original Image is Forward Passed**

Let’s directly pass the original image by resizing it to train image size of (382,216) without SAHI or without Patch Creation to our fine-tuned Faster R-CNN model.


In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_transform():
    transforms = []
    transforms.append(Tv2.ToDtype(torch.float, scale=True))
    transforms.append(Tv2.ToPureTensor())
    return Tv2.Compose(transforms)


# Function to show image with bounding boxes and save it to disk
def show_image_with_boxes(img, targets, category_colors, classes_to_idx, output_path):
    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(img)
    boxes = targets['boxes'].cpu().numpy()
    labels = targets['labels'].cpu().numpy()
    scores = targets['scores'].cpu().numpy()
    for bbox, label, score in zip(boxes, labels, scores):
        if score >= 0.3:  # Only show boxes with confidence score >= 0.3
            w = bbox[2] - bbox[0]
            h = bbox[3] - bbox[1]
            color = category_colors.get(label, 'gray')  # Use gray for unmapped classes
            rect = patches.Rectangle((bbox[0], bbox[1]), w, h, linewidth=2, edgecolor=color, facecolor='none')
            ax.add_patch(rect)
            # ax.text(bbox[0], bbox[1], f'{classes_to_idx[label]}: {score:.2f}', color='white', fontsize=12, bbox=dict(facecolor=color, alpha=0.5))
    plt.axis('off')
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()

# Function to predict and visualize the output for a single image
def predict_and_visualize(image_path, model, device, category_colors, classes_to_idx, output_path):
    # Load and transform the image
    transforms = get_transform()

    img = PIL.Image.open(image_path).convert("RGB")
    img = img.resize((382, 216), PIL.Image.BILINEAR)  # Resize the image
    # print(img.size)
    img_tensor = F.to_tensor(img)
    img_trans = transforms(img_tensor).to(device)


    # Perform prediction
    with torch.no_grad():
        output = model(torch.unsqueeze(img_trans,dim=0))[0]

    # Visualize and save the output
    show_image_with_boxes(img, output, category_colors, classes_to_idx, output_path)

model = load_model(checkpoint_path, device)
# Path to the image you want to predict

img_path = 'test/7882.jpg'
img_filename_temp = img_path.split('/')[1]
img_filename = img_filename_temp.split('.')[0]


# Path to save the output image
output_path = f'{img_filename}_fwd_pass_full_size.jpg'

# Predict and visualize the output for the single image
predict_and_visualize(img_path, model, device, category_colors, classes_to_idx, output_path)




## Inference Comparison between SAHI and Original Image Forward Pass

## Comparison 1

test/7882.jpg

<div style="text-align: center">
  <img src="https://www.dropbox.com/scl/fi/11i9pi4rbdbi6jkovloce/7882_fwd_pass_full_size.jpg?rlkey=xfjemqbp8lnjd5bdtz5vlgctq&st=st332fpi&dl=1" alt="Image" >
</div>






<div style="text-align: center">
  <img src="https://www.dropbox.com/scl/fi/mty3rjdnygiiewvq0r47g/SAHI-7882.png?rlkey=wnvgj58q3xiaq2t44gpec2cvq&st=4ijscb9f&dl=1" alt="Image" >
</div>


## Comparison 2

test/2843.jpg

<div style="text-align: center">
  <img src="https://www.dropbox.com/scl/fi/r791wuqwu2qz8xbyl9k78/2843_fwd_pass_full_size.jpg?rlkey=gyo8u7tlvgcr5h0yzpwb5i7cf&st=b0gs3xki&dl=1" alt="Image" >
</div>


<div style="text-align: center">
  <img src="https://www.dropbox.com/scl/fi/3omc3dyu11tdsvdd2b8is/SAHI-2843.png?rlkey=213f1gazbylucazryjaoxoutw&st=5wrmzp7o&dl=1" alt="Image" >
</div>


## Comparison 3

test/1070.jpg

<div style="text-align: center">
  <img src="https://www.dropbox.com/scl/fi/avh3l9wuq0poumo9962ux/1070_fwd_pass_Full-Size.jpg?rlkey=mpxy3d05ljsdvhzbca9hjj1d6&st=n97vuzj4&dl=1" alt="Image" >
</div>


<div style="text-align: center">
  <img src="https://www.dropbox.com/scl/fi/4k1absw5i4ss41mxrez8o/1070-With-SAHI.png?rlkey=dycnkku1w8v3ks0k13ly19rbk&st=ckaiu1dh&dl=1" alt="Image" >
</div>


**The results are pretty impressive indeed!**