Image Matching with SIFT and RANSAC: A Practical Example

Image Matching with SIFT and RANSAC

This document demonstrates image matching using SIFT (Scale-Invariant Feature Transform) and RANSAC (RANdom SAmple Consensus) algorithms. The code snippets below illustrate the process.

Import Libraries

import cv2
import os
import matplotlib.pyplot as plt
import numpy as np

Define Query Images

query_images = [
    'query//hw7_poster_1.jpg',
    'query/hw7_poster_2.jpg',
    'query/hw7_poster_3.jpg'
]

Function to Show Combined Images

def show_combined(image_1, image_2):
    h1, w1 = image_1.shape[0], image_1.shape[1]
    h2, w2 = image_2.shape[0], image_2.shape[1]
    height, width = np.max([h1, h2]), w1 + w2

    ret = np.ones((height, width, 3), dtype=np.uint8)
    ret[:h1,:w1]=image_1
    ret[:h2,w1:]=image_2
    fig = plt.figure(figsize=(20, 15))
    plt.imshow(ret)
    plt.xticks([])
    plt.yticks([])
    plt.show()

Image Matching Process

folder_name = 'Database'
sift = cv2.xfeatures2d.SIFT_create()
bf = cv2.BFMatcher()

for image in query_images:
    original_query_image = cv2.imread(image)
    query_image = cv2.cvtColor(original_query_image, cv2.COLOR_BGR2RGB)

    kps1, descs1 = sift.detectAndCompute(query_image, None)

    best_image_url = None
    max_matches = 0
    kp = None
    dsc = None

    for t in os.listdir(folder_name):
        database_image_path = os.path.join(folder_name, t)
        database_image = cv2.imread(database_image_path)
        database_image = cv2.cvtColor(database_image, cv2.COLOR_BGR2RGB)
        kps2, descs2 = sift.detectAndCompute(database_image, None)
        matches = bf.knnMatch(descs1, descs2, k=2)

        # Apply ratio test
        good = []
        for m,n in matches:
            if m.distance < 0.75 * n.distance:
                good.append([m])

        src_pts = np.float32([ kps1[l[0].queryIdx].pt for l in good ]).reshape(-1,1,2)
        dst_pts = np.float32([ kps2[l[0].trainIdx].pt for l in good ]).reshape(-1,1,2)

        h, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC)

        if max_matches < np.sum(mask):
            max_matches = np.sum(mask)
            best_image_url = database_image_path
            kp = kps2
            dsc = descs2
            database_image_final = database_image
            best_matches_final = good
            src = src_pts
            dst = dst_pts
            mask_fin = mask

    SIFT_query = np.copy(query_image)
    SIFT_target = np.copy(database_image_final)

    cv2.drawKeypoints(query_image, kps1, SIFT_query, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    cv2.drawKeypoints(database_image_final, kp, SIFT_target, flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)

    print("================ BEST MATCH =====================")
    show_combined(query_image, database_image_final)

    print("================ SIFT keypoints =====================")
    show_combined(SIFT_query, SIFT_target)
    print("================ SIFT descriptors =====================")
    res = cv2.drawMatchesKnn(query_image,kps1,database_image_final,kp,best_matches_final,None,flags=2)
    plt.figure(figsize=(20, 15))
    plt.imshow(res)
    plt.xticks([])
    plt.yticks([])
    plt.show()

    matches_mask = np.array(mask_fin.ravel().tolist())
    inline_matches = list(np.array(best_matches_final)[matches_mask==1])

    res_ransak = cv2.drawMatchesKnn(query_image,kps1,database_image_final,kp,inline_matches,None,flags=2)

    print("================ RANSAC inliers =====================")
    plt.figure(figsize=(20, 15))
    plt.imshow(res_ransak)
    plt.xticks([])
    plt.yticks([])
    plt.show()