#!/usr/bin/env python3

import os.path, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import argparse
import os

import cv2
import matplotlib.pyplot as pl
import networkx as nx
import numpy as np

from opensfm import dataset
from opensfm import features
from opensfm import matching
from opensfm import tracking
from opensfm import io
import opensfm.reconstruction as reconstruct


class FigureWrapper:
    def __init__(self, figure, rows, cols):
        self.figure = figure
        self.rows = rows
        self.cols = cols


class ImagePair:
    def __init__(self, im1, im2, im1_array, im2_array):
        self.im1 = im1
        self.im2 = im2
        self.im1_array = im1_array
        self.im2_array = im2_array


def redirect_print():
    """ Redirects the sys.stdout to the null device.

    :return: The previous value of sys.stdout.
    """
    reset = sys.stdout
    f = open(os.devnull, 'w')
    sys.stdout = f

    return reset


def reset_print(f):
    """ Sets the sys.stdout device.

    :param f: The device.
    """
    sys.stdout = f


def show_images(plot, im1, im2):
    """ Shows the images in the supplied subplot. """
    h1, w1, c = im1.shape
    h2, w2, c = im2.shape
    image = np.zeros((max(h1, h2), w1+w2, 3), dtype=im1.dtype)
    image[0:h1, 0:w1, :] = im1
    image[0:h2, w1:(w1+w2), :] = im2
    plot.imshow(image)


def plot_points(plot, im1, p1, point_format1='ob', im2=None, p2=None, point_format2='ob'):
    """ Plots the points in the supplied subplot. """
    if p1.shape[0] != 0:
        h1, w1, c = im1.shape
        p1d = features.denormalized_image_coordinates(p1, w1, h1)
        plot.plot(p1d[:, 0], p1d[:, 1], point_format1)

    if im2 is None or p2 is None or p2.shape[0] == 0:
        return

    h2, w2, c = im2.shape
    p2d = features.denormalized_image_coordinates(p2, w2, h2)

    plot.plot(p2d[:, 0] + w1, p2d[:, 1], point_format2)


def plot_matches(plot, im1, im2, p1, p2, line_format='c', point_format='ob'):
    """ Plots the matches in the supplied subplot. """
    if p1.shape[0] == 0 or p1.shape != p2.shape:
        return

    h1, w1, c = im1.shape
    h2, w2, c = im2.shape
    p1d = features.denormalized_image_coordinates(p1, w1, h1)
    p2d = features.denormalized_image_coordinates(p2, w2, h2)
    for a, b in zip(p1d, p2d):
        plot.plot([a[0], b[0] + w1], [a[1], b[1]], line_format)

    plot_points(plot, im1, p1, point_format, im2, p2, point_format)


def plot_circle(subplot, diameter, format):
    """ Plots a circle.

    :param subplot: The subplot.
    :param diameter: The diameter of the circle.
    :param format: The circle format.
    """

    angles = np.linspace(0, 2 * np.pi, 300)
    x = diameter * np.sin(angles)
    y = diameter * np.cos(angles)

    subplot.plot(x, y, format)


def create_subplot(figure, rows, columns, index, title, x_lim, y_lim, font_size=12, axis=False, aspect=None, grid=False):
    """ Creates a subplot. """
    subplot = figure.add_subplot(rows, columns, index)
    pl.xlim(x_lim[0], x_lim[1])
    pl.ylim(y_lim[0], y_lim[1])
    subplot.text(0.5, 1.04,
                 title,
                 horizontalalignment='center',
                 fontsize=font_size,
                 transform=subplot.transAxes)

    if not axis:
        subplot.axis('off')

    if aspect is not None:
        subplot.set_aspect(aspect)

    if grid:
        pl.grid(visible=True, which='major', color='0.3')

    return subplot


def create_four_figure(title, single_column):
    """ Creates a figure for four subplots. """
    fig = pl.figure(figsize=(12, 21) if single_column else (24, 12))
    fig.suptitle(title, fontsize=14, fontweight='bold')
    rows = 4 if single_column else 2
    cols = 1 if single_column else 2

    return fig, rows, cols


def plot_points_sub(figure, rows, columns, index, title, im1, im2, p1, p2, point_format1, point_format2):
    """ Creates a subplot with the images and plots the points in each image. """
    subplot = create_subplot(figure, rows, columns, index, title,
                             (0, im1.shape[1] + im2.shape[1]), (np.max([im1.shape[0], im2.shape[0]]), 0))

    show_images(subplot, im1, im2)
    plot_points(subplot, im1, p1, point_format1, im2, p2, point_format2)

    return subplot


def plot_matches_sub(figure, rows, columns, index, title, im1, im2, p1, p2, line_format, point_format):
    """ Creates a subplot with the images and plots the points as matches. """
    subplot = create_subplot(figure, rows, columns, index, title,
                             (0, im1.shape[1] + im2.shape[1]), (np.max([im1.shape[0], im2.shape[0]]), 0))

    show_images(subplot, im1, im2)
    plot_matches(subplot, im1, im2, p1, p2, line_format, point_format)

    return subplot


def display_figure(figure, save_figs, data=None, file_name='',
                   left=0.01, bottom=0.01, right=0.99, top=0.93, wspace=0.04, hspace=0.14):
    """ Displays or saves a figure. """
    figure.subplots_adjust(left=left,
                           bottom=bottom,
                           right=right,
                           top=top if save_figs else top - 0.02,
                           wspace=wspace,
                           hspace=hspace)

    if save_figs:
        p = os.path.join(args.dataset, 'plot_inliers')
        io.mkdir_p(p)
        figure.savefig(os.path.join(p, file_name), dpi=100)
        pl.close()
    else:
        pl.show()


def reproject_tracks(im, tracks, reconstruction):
    """ Reprojects the 3D points for a list of tracks for an image based on the reconstruction.

    :param im: The name of the image.
    :param tracks: Array of tracks that exist in the reconstruction.
    :param: reconstruction: A Reconstruction.
    :return: An array of reprojected track points where each row contains x and y values.
    """
    p = []
    for track in tracks:
        shot = reconstruction.shots[im]
        point = reconstruction.points[str(track)]
        p.append(shot.project(point.coordinates))

    return np.array(p) if len(p) else np.empty((0, 2), int)


def reprojection_errors(reprojections, observations, scale):
    """ Calculates reprojection errors with respect to observations.

    :param reprojections: The reprojections.
    :param observations: The observations.
    :param scale: The scale for the errors.
    :return: The scaled errors, mean, standard deviation, minimum error norm and maximum error norm.
    """

    if not len(reprojections):
        return np.empty((0, 2), np.float), np.zeros((2,), np.float), np.zeros((2.), np.float), 0, 0, np.nan

    errors = scale * np.subtract(reprojections, observations)
    mean = np.mean(errors, axis=0)
    std = np.std(errors, axis=0)

    errors_norm = np.linalg.norm(errors, axis=1)
    min_norm = np.min(errors_norm)
    max_norm = np.max(errors_norm)

    observations_norm = scale * np.linalg.norm(observations, axis=1)

    if len(observations_norm) > 1:
        corr = np.corrcoef(observations_norm, errors_norm)[0, 1]
    else:
        corr = np.nan

    return errors, mean, std, min_norm, max_norm, corr


def triangulate_tracks(tracks, reconstruction, tracks_manager, min_ray_angle):
    """ Triangulates a list of tracks.

    :param tracks: The array of tracks.
    :param reconstruction: The reconstruction.
    :param: tracks_manager, The tracks manager.
    :param: min_ray_angle: The minimum ray angle difference for a triangulation to be considered valid.
    :return: An array of booleans determining if each track was successfully triangulated or not.
    """
    succeeded = []
    triangulator = reconstruct.TrackTriangulator(reconstruction, reconstruct.TrackHandlerTrackManager(tracks_manager, reconstruction))

    for track in tracks:
        # Triangulate with 1 as reprojection threshold to avoid excluding tracks because of error.
        triangulator.triangulate(str(track), 1, min_ray_angle, iterations=10)
        succeeded.append(True) if str(track) in reconstruction.points else succeeded.append(False)

    return np.array(succeeded) if len(succeeded) else np.empty((0,), bool)


def find_reconstruction(images, data):
    """ Finds a reconstruction in which all images exist.

    :param images: List of image names.
    :param data: The data set.
    :return: The reconstruction that contains all images. None if no reconstruction was found.
    """
    reconstructions = data.load_reconstruction()
    reconstruction = None

    for r in reconstructions:
        found = True
        for im in images:
            if im not in r.shots:
                found = False
                break

        if found:
            reconstruction = r
            break

    return reconstruction


def reconstruction_tracks(tracks, reconstruction):
    """ Gets the set of tracks in a list of tracks that exist in a reconstruction as well as the ones that does not
        belong to the reconstruction.

    :param tracks: The list of tracks.
    :param reconstruction: The reconstruction.
    :return: An array with the rows containing the track arrays that existed in the reconstruction. Another array
             with the rows containing the track arrays that did not belong to the reconstruction
    """
    r_tracks = []
    n_tracks = []

    for track in tracks:
        if str(track[0]) in reconstruction.points:
            r_tracks.append(track)
        else:
            n_tracks.append(track)

    r_tracks = np.array(r_tracks) if len(r_tracks) else np.empty((0, tracks.shape[1]), int)
    n_tracks = np.array(n_tracks) if len(n_tracks) else np.empty((0, tracks.shape[1]), int)

    return r_tracks, n_tracks


def load_points(im, data):
    """ Loads the feature points for an image

    :param im: The name of the image.
    :param data: The data set.
    :return: An array with rows containing the x and y coordinates for the features in the image.
    """
    features_data = data.load_features(im)
    assert features_data
    return np.array(features_data.points[:, :2], np.float64)


def load_common_tracks(im1, im2, tracks_manager):
    """ Loads the common track ids and corresponding feature ids for two images in a graph.

    :param im1: The name of the first image.
    :param im2: The name of the second image.
    :param tracks_manager: The track manager.
    :return: An array with rows containing a track id and the corresponding feature id for the first and second image.
    """
    t1 = tracks_manager.get_shot_observations(im1)
    t2 = tracks_manager.get_shot_observations(im2)
    tc, p1, p2 = tracking.common_tracks(tracks_manager, im1, im2)

    common_tracks = []
    for t in tc:
        common_tracks.append(np.array([int(t), t1[t].id, t2[t].id]))

    return np.array(common_tracks) if len(common_tracks) else np.empty((0, 3), int)


def load_tracks(im, tracks_manager):
    """ Loads the track ids and corresponding feature ids for an image in a tracks_manager.

    :param im: The name of the image.
    :param tracks_manager: The track manager.
    :return: An array with rows containing a track id and the corresponding feature id.
    """
    tracks = []
    for t, obs in tracks_manager.get_shot_observations(im).items():
        tracks.append(np.array([int(t), obs.id]))

    return np.array(tracks) if len(tracks) else np.empty((0, 2), int)


def matches(a1, a2, map_function):
    """ Retrieves the row arrays in a1 that has an equal row array in a2 and the row arrays in a1 that
        does not have an equal row array in a2.

    :param a1: Numpy array of arrays.
    :param a2: Numpy array of arrays.
    :param map_function: Function mapping the array a1 to another array.
    :return: An array with the mapped rows in a1 that correspond to any row in a2 as well as an array with
             the mapped rows in a1 that do not correspond with an array in a2.
    """
    m = []
    nm = []

    for arr1 in a1:
        found = False
        for arr2 in a2:
            if np.array_equal(map_function(arr1), arr2):
                found = True
                break

        if not found:
            nm.append(arr1)
        else:
            m.append(arr1)

    m = np.array(m) if len(m) else np.empty((0, a1.shape[1]), int)
    nm = np.array(nm) if len(nm) else np.empty((0, a1.shape[1]), int)

    return m, nm


def thresholds(im1_array, im2_array, key, default, data):
    """ Retrieves the threshold value for a config key as well as the corresponding valude in pixels
        for two image arrays.

    :param im1_array: First image array.
    :param im2_array: Second image array.
    :param key: Threshold config key.
    :param default: Default threshold value.
    :param data: Dataset.
    :return: The threshold config value and the value in pixels for each image.
    """
    threshold = data.config.get(key, default)
    pixel_threshold1 = threshold * np.max(im1_array.shape[:2])
    pixel_threshold2 = threshold * np.max(im2_array.shape[:2])

    return threshold, pixel_threshold1, pixel_threshold2


def create_matches_figure(im1, im2, data, save_figs=False, single_column=False):
    """ Plots the features, symmetric matches as well as the robust matching inliers and outliers.

    :param im1: Name of the first image.
    :param im2: Name of the second image.
    :param data: Data set.
    :param save_figs: Boolean specifying if figures should be saved to file.
    :param single_column: Boolean specifying if the subplots should be ordered in a single column.
    """

    fig, rows, cols = \
        create_four_figure('Matches ({0}): {1} - {2}'.format(data.feature_type().upper(), im1, im2), single_column)

    im1_array = data.load_image(im1)
    im2_array = data.load_image(im2)

    # Calculate symmetric matches.
    features_data1 = data.load_features(im1)
    features_data2 = data.load_features(im2)
    assert features_data1
    assert features_data2
    p1 = features_data1.points
    p2 = features_data2.points

    symmetric_matches = matching.match_brute_force_symmetric(p1, p2, data.config)
    symmetric_matches = np.array(symmetric_matches)

    if symmetric_matches.shape[0] < 8:
        print('Not enough matches for eight point algorithm: ' + str(symmetric_matches.shape[0]))
        return

    # Plot features
    features_title = 'Features (loaded): {0} - {1}, {2} - {3}'.format(im1, p1.shape[0], im2, p2.shape[0])
    plot_points_sub(fig, rows, cols, 1, features_title, im1_array, im2_array, p1, p2, 'ob', 'om')

    # Plot symmetric matches.
    s_matches1 = p1[symmetric_matches[:, 0]]
    s_matches2 = p2[symmetric_matches[:, 1]]

    plot_matches_sub(fig, rows, cols, 2,
                     'Symmetric matches (calculated): {0}'.format(symmetric_matches.shape[0]),
                     im1_array, im2_array,
                     s_matches1, s_matches2,
                     'c', 'ob')

    # Calculate robust matches and plot inliers.
    cameras = data.load_camera_models()
    camera1 = cameras[data.load_exif(im1)['camera']]
    camera2 = cameras[data.load_exif(im2)['camera']]
    robust_matches = matching.robust_match(p1, p2, camera1, camera2, symmetric_matches, data.config)

    r_matches1 = p1[robust_matches[:, 0]]
    r_matches2 = p2[robust_matches[:, 1]]

    threshold, pixels1, pixels2 = thresholds(im1_array, im2_array, 'robust_matching_threshold', 0.006, data)
    plot_matches_sub(
        fig, rows, cols, 3,
        'Robust matching inliers (RANSAC 7-point algorithm, calculated): {0}. Threshold: {1:.2g} ({2:.1f} - {3:.1f} pixels)'
        .format(robust_matches.shape[0], threshold, pixels1, pixels2),
        im1_array, im2_array,
        r_matches1, r_matches2,
        'g', 'oy')

    # Plot robust matching outliers.
    outliers = matches(symmetric_matches, robust_matches, lambda a: a)[1]
    outliers1 = p1[outliers[:, 0]]
    outliers2 = p2[outliers[:, 1]]

    plot_matches_sub(fig, rows, cols, 4,
                     'Robust matching outliers (calculated): {0}. Threshold: {1:.2g} ({2:.1f} - {3:.1f} pixels)'
                     .format(outliers.shape[0], threshold, pixels1, pixels2),
                     im1_array, im2_array,
                     outliers1, outliers2,
                     'r', 'om')

    display_figure(fig, save_figs, data, '{0}_{1}_{2}_matches.jpg'.format(im1, im2, data.feature_type()))


def plot_common_tracks(fw, ip, index, p1, p2, tracks, robust_tracks, linked_tracks, robust_matches):
    """ Plot common tracks corresponding to robust matches and tracks not corresponding to robust matches
        (linked tracks) in different colors.

    :param fw: Figure wrapper.
    :param ip: Image pair.
    :param index: Subplot index.
    :param p1: Feature points for the first image.
    :param p2: Feature points for the second image.
    :param tracks: Common tracks.
    :param robust_tracks: Tracks corresponding to robust matches.
    :param linked_tracks: Common tracks linked by other image correspondences.
    :param robust_matches: Array of feature ids for robust matches.
    """

    robust_points1 = p1[robust_tracks[:, 1]]
    robust_points2 = p2[robust_tracks[:, 2]]

    linked_points1 = p1[linked_tracks[:, 1]]
    linked_points2 = p2[linked_tracks[:, 2]]

    title = 'Common tracks (loaded): {0}. Robust match tracks: {1} / {2}. Linked tracks: {3}'.format(
        tracks.shape[0],
        robust_tracks.shape[0],
        robust_matches.shape[0],
        linked_tracks.shape[0])

    subplot = plot_matches_sub(fw.figure, fw.rows, fw.cols, index, title, ip.im1_array, ip.im2_array,
                               robust_points1, robust_points2, 'c', 'ob')
    plot_matches(subplot, ip.im1_array, ip.im2_array, linked_points1, linked_points2, 'g', 'oy')


def plot_opencv_find_homography(fw, ip, index, tracks, track_points1, track_points2, data):
    """ Plot inliers and outliers of the tracks from OpenCV find homography.

    :param fw: Figure wrapper.
    :param ip: Image pair.
    :param index: Subplot index.
    :param tracks: Common tracks.
    :param track_points1: Feature points corresponding to common tracks for image 1.
    :param track_points2: Feature points corresponding to common tracks for image 2.
    :param data: Data set.
    :return:
    """

    threshold, pixels1, pixels2 = thresholds(ip.im1_array, ip.im2_array, 'homography_threshold', 0.004, data)
    H, inliers = cv2.findHomography(track_points1, track_points2, cv2.RANSAC, threshold)

    inliers = np.array(np.squeeze(inliers), bool)

    inliers1 = track_points1[inliers, :]
    inliers2 = track_points2[inliers, :]

    outliers1 = track_points1[~inliers, :]
    outliers2 = track_points2[~inliers, :]

    title = \
        'OpenCV find homography inliers (calculated): {0}. Outliers: {1}. Outlier ratio: {2:.3f}. Threshold: {3:.2g} '\
        .format(inliers.sum(), (~inliers).sum(), float((~inliers).sum()) / tracks.shape[0], threshold) + \
        '({0:.1f} - {1:.1f} pixels)'.format(pixels1, pixels2)

    subplot = plot_matches_sub(fw.figure, fw.rows, fw.cols, index, title, ip.im1_array, ip.im2_array,
                               outliers1, outliers2, 'r', 'om')
    plot_matches(subplot, ip.im1_array, ip.im2_array, inliers1, inliers2, 'c', 'ob')


def plot_two_view_reconstruction(fw, ip, index, track_points1, track_points2, data):
    """ Plot inliers and outliers of the tracks from two view reconstruction.

    :param fw: Figure wrapper.
    :param ip: Image pair.
    :param index: Subplot index.
    :param track_points1: Feature points corresponding to common tracks for image 1.
    :param track_points2: Feature points corresponding to common tracks for image 2.
    :param data: Data set.
    :return:
    """
    cameras = data.load_camera_models()
    camera1 = cameras[data.load_exif(ip.im1)['camera']]
    camera2 = cameras[data.load_exif(ip.im2)['camera']]

    threshold, pixels1, pixels2 = thresholds(ip.im1_array, ip.im2_array, 'five_point_algo_threshold', 0.006, data)
    iterations = data.config['five_point_refine_rec_iterations']
    R, t, inliers, _ = reconstruct.two_view_reconstruction_general(track_points1, track_points2, camera1, camera2, threshold, iterations)

    inliers1 = track_points1[inliers, :]
    inliers2 = track_points2[inliers, :]

    outliers = np.ones(track_points1.shape[0], dtype=bool)
    outliers[inliers] = False
    outliers1 = track_points1[outliers, :]
    outliers2 = track_points2[outliers, :]

    title = 'CSfM two view reconstruction inliers (calculated): {0}. Outliers: {1}. Threshold: {2:.2g} '.format(
        len(inliers), outliers.sum(), threshold) + \
        '({0:.1f} - {1:.1f} pixels)'.format(pixels1, pixels2)

    subplot = plot_matches_sub(fw.figure, fw.rows, fw.cols, index, title, ip.im1_array, ip.im2_array,
                               inliers1, inliers2, 'c', 'ob')
    plot_matches(subplot, ip.im1_array, ip.im2_array, outliers1, outliers2, 'r', 'om')


def plot_bootstrapped_reconstruction(fw, ip, index, p1, p2, robust_tracks, linked_tracks, tracks_manager, data):
    """ Plot successfully reconstructed and failed 3D points by bootstrapping a reconstruction.

    :param fw: Figure wrapper.
    :param ip: Image pair.
    :param index: Subplot index.
    :param p1: Feature points for the first image.
    :param p2: Feature points for the second image.
    :param robust_tracks: Tracks corresponding to robust matches.
    :param linked_tracks: Common tracks linked by other image correspondences.
    :param tracks_manager: The tracks manager.
    :param data: Data set.
    """

    print_reset = redirect_print()
    _, pm1, pm2 = tracking.common_tracks(tracks_manager, ip.im1, ip.im2)
    reconstruction, _ = reconstruct.bootstrap_reconstruction(data, tracks_manager, ip.im1, ip.im2, pm1, pm2)
    reset_print(print_reset)

    threshold, pixels1, pixels2 = thresholds(ip.im1_array, ip.im2_array, 'triangulation_threshold', 0.004, data)

    if reconstruction:
        robust_rec_tracks, failed_robust = reconstruction_tracks(robust_tracks, reconstruction)
        non_robust_rec_tracks, failed_non_robust = reconstruction_tracks(linked_tracks, reconstruction)
        failed_rec_tracks = np.vstack((failed_robust, failed_non_robust))

        robust_rec_points1 = p1[robust_rec_tracks[:, 1]]
        robust_rec_points2 = p2[robust_rec_tracks[:, 2]]

        non_robust_rec_points1 = p1[non_robust_rec_tracks[:, 1]]
        non_robust_rec_points2 = p2[non_robust_rec_tracks[:, 2]]

        failed_rec_points1 = p1[failed_rec_tracks[:, 1]]
        failed_rec_points2 = p2[failed_rec_tracks[:, 2]]

        rec_title = \
            'Bootstrapped 3D points: {0}. Robust {1}. Linked: {2}. Failed: {3}. '.format(
                len(reconstruction.points),
                robust_rec_tracks.shape[0],
                non_robust_rec_tracks.shape[0],
                failed_rec_tracks.shape[0]) + \
            'Triangulation threshold: {0:.2g} ({1:.1f} - {2:.1f} pixels). Min ray angle: {3}.'.format(
                threshold, pixels1, pixels2, data.config['triangulation_min_ray_angle'])

        rec_plot = plot_matches_sub(fw.figure, fw.rows, fw.cols, index, rec_title, ip.im1_array, ip.im2_array,
                                    robust_rec_points1, robust_rec_points2, 'c', 'ob')
        plot_matches(rec_plot, ip.im1_array, ip.im2_array, non_robust_rec_points1, non_robust_rec_points2, 'g', 'oy')
        plot_matches(rec_plot, ip.im1_array, ip.im2_array, failed_rec_points1, failed_rec_points2, 'r', 'om')
    else:
        failed_title = \
            'Bootstrap failed. Less than {0} points. Triangulation threshold: {1:.2g} ({2:.1f} - {3:.1f} pixels).'\
            .format(data.config['five_point_algo_min_inliers'], threshold, pixels1, pixels2) +\
            'Min ray angle: {0}.'.format(data.config['triangulation_min_ray_angle'])
        create_subplot(fw.figure, fw.rows, fw.cols, index, failed_title, (0, 100), (0, 100))


def create_tracks_figure(im1, im2, data, save_figs=False, single_column=False):
    """ Plots the common tracks as well as the find homography, two view reconstruction and bootstrapped
        reconstruction inliers and outliers.

    :param im1: Name of the first image.
    :param im2: Name of the second image.
    :param data: Data set.
    :param save_figs: Boolean specifying if figures should be saved to file.
    :param single_column: Boolean specifying if the subplots should be ordered in a single column.
    """

    fig, rows, cols = \
        create_four_figure('Tracks ({0}): {1} - {2}'.format(data.feature_type().upper(), im1, im2), single_column)

    fw = FigureWrapper(fig, rows, cols)
    ip = ImagePair(im1, im2, data.load_image(im1), data.load_image(im2))

    p1 = load_points(im1, data)
    p2 = load_points(im2, data)

    # Retrieve tracks and robust matches from file
    robust_matches = data.find_matches(im1, im2)
    robust_matches = np.array(robust_matches) if len(robust_matches) else np.empty((0, 2), int)

    tracks_manager = data.load_tracks_manager()
    tracks = load_common_tracks(im1, im2, tracks_manager)
    track_points1 = p1[tracks[:, 1]]
    track_points2 = p2[tracks[:, 2]]

    if tracks.shape[0] < 5:
        print('Not enough tracks for five point algorithm: ' + str(tracks.shape[0]))
        return

    robust_tracks, linked_tracks = matches(tracks, robust_matches, lambda a: a[1:])

    plot_common_tracks(fw, ip, 1, p1, p2, tracks, robust_tracks, linked_tracks, robust_matches)
    plot_opencv_find_homography(fw, ip, 2, tracks, track_points1, track_points2, data)
    plot_two_view_reconstruction(fw, ip, 3, track_points1, track_points2, data)
    plot_bootstrapped_reconstruction(fw, ip, 4, p1, p2, robust_tracks, linked_tracks, tracks_manager, data)

    display_figure(fig, save_figs, data, '{0}_{1}_{2}_tracks.jpg'.format(im1, im2, data.feature_type()))


def plot_reconstructed_tracks(fw, ip, index, p1, p2, tracks, rec_tracks, non_rec_tracks):
    """ Plots reconstructed tracks and rejected tracks as matches in different colors.

    :param fw: Figure wrapper.
    :param ip: Image pair.
    :param index: Subplot index.
    :param p1: Feature points for the first image.
    :param p2: Feature points for the second image.
    :param tracks: Common tracks.
    :param rec_tracks: Common tracks included in the reconstruction.
    :param non_rec_tracks: Common tracks not included in the reconstruction.
    """

    rec_track_points1 = p1[rec_tracks[:, 1]]
    rec_track_points2 = p2[rec_tracks[:, 2]]
    non_rec_track_points1 = p1[non_rec_tracks[:, 1]]
    non_rec_track_points2 = p2[non_rec_tracks[:, 2]]

    title = \
        'Common tracks: {0}. Reconstructed points: {1}. Triangulation and reprojection error removals: {2}. All loaded.'\
        .format(tracks.shape[0], rec_tracks.shape[0], tracks.shape[0] - rec_tracks.shape[0])

    subplot = create_subplot(fw.figure, fw.rows, fw.cols, index, title,
                             (0, ip.im1_array.shape[1] + ip.im2_array.shape[1]),
                             (np.max([ip.im1_array.shape[0], ip.im2_array.shape[0]]), 0),
                             15)

    show_images(subplot, ip.im1_array, ip.im2_array)
    plot_matches(subplot, ip.im1_array, ip.im2_array, rec_track_points1, rec_track_points2, 'c', 'ob')
    plot_matches(subplot, ip.im1_array, ip.im2_array, non_rec_track_points1, non_rec_track_points2, 'r', 'om')


def plot_reprojected_tracks(fw, ip, index, p1, p2, tracks, rec_tracks, non_rec_tracks, reconstruction, tracks_manager, data):
    """ Reprojects tracks included in and excluded from reconstruction and plots reprojections on top of observations.

    :param fw: Figure wrapper.
    :param ip: Image pair.
    :param index: Sub plot index.
    :param p1: Feature points for the first image.
    :param p2: Feature points for the second image.
    :param tracks: Common tracks.
    :param rec_tracks: Common tracks included in the reconstruction.
    :param non_rec_tracks: Common tracks not included in the reconstruction.
    :param reconstruction: Reconstruction.
    :param tracks_manager: The tracks manager.
    :param data: Data set
    """

    rec_track_points1 = p1[rec_tracks[:, 1]]
    rec_track_points2 = p2[rec_tracks[:, 2]]

    rp1 = reproject_tracks(ip.im1, rec_tracks[:, 0], reconstruction)
    rp2 = reproject_tracks(ip.im2, rec_tracks[:, 0], reconstruction)

    min_ray_angle = data.config['triangulation_min_ray_angle']
    succeeded = triangulate_tracks(non_rec_tracks[:, 0], reconstruction, tracks_manager, min_ray_angle)
    reprojected1 = reproject_tracks(ip.im1, non_rec_tracks[succeeded, 0], reconstruction)
    reprojected2 = reproject_tracks(ip.im2, non_rec_tracks[succeeded, 0], reconstruction)

    triangulated_points1 = p1[non_rec_tracks[succeeded, 1]]
    triangulated_points2 = p2[non_rec_tracks[succeeded, 2]]
    excluded_points1 = p1[non_rec_tracks[~succeeded, 1]]
    excluded_points2 = p2[non_rec_tracks[~succeeded, 2]]

    title = 'Common tracks (loaded): {0}. Reprojected points (calculated): {1}.\n'\
        .format(tracks.shape[0], rec_tracks.shape[0]) + \
            'Rejected tracks: {0}. Re-triangulated: {1}. Failed: {2} (behind a camera or min ray angle: {3} deg).'\
        .format(non_rec_tracks.shape[0], succeeded.sum(), (~succeeded).sum(), min_ray_angle)

    subplot = create_subplot(fw.figure, fw.rows, fw.cols, index, title,
                             (0, ip.im1_array.shape[1] + ip.im2_array.shape[1]),
                             (np.max([ip.im1_array.shape[0], ip.im2_array.shape[0]]), 0),
                             font_size=15)
    show_images(subplot, ip.im1_array, ip.im2_array)
    plot_points(subplot, ip.im1_array, rec_track_points1, 'ob', ip.im2_array, rec_track_points2, 'ob')
    plot_points(subplot, ip.im1_array, rp1, '+w', ip.im2_array, rp2, '+w')
    plot_points(subplot, ip.im1_array, triangulated_points1, 'om', ip.im2_array, triangulated_points2, 'om')
    plot_points(subplot, ip.im1_array, reprojected1, '+w', ip.im2_array, reprojected2, '+w')
    plot_points(subplot, ip.im1_array, excluded_points1, 'or', ip.im2_array, excluded_points2, 'or')


def create_reconstruction_matches_figure(im1, im2, data, save_figs=False):
    """ Plots the track match inliers and outliers in the loaded reconstruction for the two images as
        well as the reprojections.

    :param im1: Name of the first image.
    :param im2: Name of the second image.
    :param data: Data set.
    :param save_figs: Boolean specifying if figures should be saved to file.
    """

    fig = pl.figure(figsize=(15, 15))
    fig.suptitle('Reprojected reconstructed 3D points ({0}): {1} - {2}'.format(data.feature_type().upper(), im1, im2),
                 fontsize=18, fontweight='bold')

    fw = FigureWrapper(fig, 2, 1)
    ip = ImagePair(im1, im2, data.load_image(im1), data.load_image(im2))

    p1 = load_points(im1, data)
    p2 = load_points(im2, data)

    tracks_manager = data.load_tracks_manager()
    tracks = load_common_tracks(im1, im2, tracks_manager)
    if tracks.shape[0] < 1:
        print('No tracks to plot.')
        return

    reconstruction = find_reconstruction([im1, im2], data)
    if reconstruction is None:
        print('{0} and {1} does not exist in a common reconstruction.'.format(im1, im2))

    rec_tracks, non_rec_tracks = reconstruction_tracks(tracks, reconstruction)

    plot_reconstructed_tracks(fw, ip, 1, p1, p2, tracks, rec_tracks, non_rec_tracks)
    plot_reprojected_tracks(fw, ip, 2, p1, p2, tracks, rec_tracks, non_rec_tracks, reconstruction, tracks_manager, data)

    display_figure(fig, save_figs, data, '{0}_{1}_{2}_reconstruction.jpg'.format(im1, im2, data.feature_type()))


def create_complete_reconstruction_figure(im, data, save_figs=False, single_column=False):
    """ Plots the inliers, outliers and their respective reprojections.

    :param im: Image name.
    :param data: Data set.
    :param save_figs: Boolean specifying if figures should be saved to file.
    :param single_column: Boolean specifying if results should be plotted in a single figure.
    """

    fig = pl.figure(figsize=(9, 15) if single_column else (18, 9))
    fig.suptitle('Reprojected reconstructed 3D points ({0}): {1}'.format(data.feature_type().upper(), im),
                 fontsize=14, fontweight='bold')
    fw = FigureWrapper(fig, 2 if single_column else 1, 1 if single_column else 2)

    reconstruction = find_reconstruction([im], data)
    if reconstruction is None:
        print('{0} does not exist in a reconstruction.'.format(im))
        return

    im_array = data.load_image(im)
    p = load_points(im, data)

    tracks_manager = data.load_tracks_manager()
    tracks = load_tracks(im, tracks_manager)
    if tracks.shape[0] < 1:
        print('No tracks for exist for {0}.'.format(im))
        return

    rec_tracks, non_rec_tracks = reconstruction_tracks(tracks, reconstruction)
    rec_track_points = p[rec_tracks[:, 1]]

    rp = reproject_tracks(im, rec_tracks[:, 0], reconstruction)

    rec_title = 'Tracks: {0}. Included tracks (loaded): {1} with reprojected points (calculated).'\
        .format(tracks.shape[0], rec_tracks.shape[0])

    rec_plot = create_subplot(fw.figure, fw.rows, fw.cols, 1, rec_title, (0, im_array.shape[1]), (im_array.shape[0], 0))
    rec_plot.imshow(im_array)
    plot_points(rec_plot, im_array, rec_track_points, 'ob')
    plot_points(rec_plot, im_array, rp, '+w')

    min_ray_angle = data.config['triangulation_min_ray_angle']
    succeeded = triangulate_tracks(non_rec_tracks[:, 0], reconstruction, tracks_manager, min_ray_angle)
    reprojected = reproject_tracks(im, non_rec_tracks[succeeded, 0], reconstruction)

    triangulated_points = p[non_rec_tracks[succeeded, 1]]
    excluded_points = p[non_rec_tracks[~succeeded, 1]]

    non_rec_title = \
        'Rejected tracks: {0}. Re-triangulated: {1}. Failed: {2} (behind a camera or min ray angle: {3} deg).'\
        .format(non_rec_tracks.shape[0], succeeded.sum(), (~succeeded).sum(), min_ray_angle)

    non_rec_plot = create_subplot(fw.figure, fw.rows, fw.cols, 2, non_rec_title,
                                  (0, im_array.shape[1]), (im_array.shape[0], 0))
    non_rec_plot.imshow(im_array)
    plot_points(non_rec_plot, im_array, triangulated_points, 'om')
    plot_points(non_rec_plot, im_array, excluded_points, 'or')
    plot_points(non_rec_plot, im_array, reprojected, '+w')

    display_figure(fig, save_figs, data, '{0}_{1}_complete_rec.jpg'.format(im, data.feature_type()))


def create_reprojection_error_figure(im, data, save_figs=False, single_column=False):
    """ Plots the reprojection error distribution.

    :param im: Image name.
    :param data: Data set.
    :param save_figs: Boolean specifying if figures should be saved to file.
    :param single_column: Boolean specifying if results should be plotted in a single figure.
    """

    fig = pl.figure(figsize=(9, 18) if single_column else (18, 9))
    fig.suptitle('Reprojection error distribution in pixels ({0}): {1}'.format(data.feature_type().upper(), im),
                 fontsize=14, fontweight='bold')
    fw = FigureWrapper(fig, 2 if single_column else 1, 1 if single_column else 2)

    reconstruction = find_reconstruction([im], data)
    if reconstruction is None:
        print('{0} does not exist in a reconstruction.'.format(im))
        return

    im_array = data.load_image(im)
    p = load_points(im, data)

    tracks_manager = data.load_tracks_manager()
    tracks = load_tracks(im, tracks_manager)
    if tracks.shape[0] < 1:
        print('No tracks for exist for {0}.'.format(im))
        return

    rec_tracks, non_rec_tracks = reconstruction_tracks(tracks, reconstruction)
    rec_points = p[rec_tracks[:, 1]]

    reproj_rec = reproject_tracks(im, rec_tracks[:, 0], reconstruction)

    min_ray_angle = data.config['triangulation_min_ray_angle']
    succeeded = triangulate_tracks(non_rec_tracks[:, 0], reconstruction, tracks_manager, min_ray_angle)
    reproj_non = reproject_tracks(im, non_rec_tracks[succeeded, 0], reconstruction)

    triang_non = p[non_rec_tracks[succeeded, 1]]

    scale = np.max(im_array.shape[:2])
    rec_errors, rec_mean, rec_std, rec_min, rec_max, rec_corr = reprojection_errors(reproj_rec, rec_points, scale)
    non_errors, non_mean, non_std, non_min, non_max, non_corr = reprojection_errors(reproj_non, triang_non, scale)

    triang_thld = data.config['triangulation_threshold']
    bundle_thld = data.config['bundle_outlier_fixed_threshold']
    axis_max = np.max(np.abs(np.vstack((rec_errors, non_errors))))
    axis_max = 1.05 * np.max([axis_max, scale * triang_thld, scale * bundle_thld])

    title = 'Reconstructed tracks: {0}. '.format(rec_tracks.shape[0])
    error_title = "Mean: ({0:.1f}, {1:.1f}). Std: ({2:.1f}, {3:.1f}). Max norm: {4:.1f}. Min norm: {5:.1f}."
    thld_title = 'Thresholds: Triangulation: {0:.1f} ({1:.2g}). Bundle outlier: {2:.1f} ({3:.2g}).'\
        .format(scale * triang_thld, triang_thld, scale * bundle_thld, bundle_thld)
    corr_title = ' Distance to principal point - Error norm correlation: {0:.3f}'

    rec_sub = create_subplot(
        fw.figure, fw.rows, fw.cols, 1,
        title + error_title.format(rec_mean[0], rec_mean[1], rec_std[0], rec_std[1], rec_max, rec_min) + '\n' +
        thld_title + '\n' + corr_title.format(rec_corr) if rec_corr is not np.nan else '',
        (-axis_max, axis_max), (-axis_max, axis_max),
        aspect='equal', axis=True, grid=True)

    plot_circle(rec_sub, scale * triang_thld, 'r')
    plot_circle(rec_sub, scale * bundle_thld, 'm')

    zero = np.array([0])
    rec_sub.plot(zero, zero, '+k')
    rec_sub.plot(rec_errors[:, 0], rec_errors[:, 1], 'ob')

    title = 'Failed tracks: {0} / {1}. '.format(np.sum(succeeded), non_rec_tracks.shape[0])

    non_rec_sub = create_subplot(
        fw.figure, fw.rows, fw.cols, 2,
        title + error_title.format(non_mean[0], non_mean[1], non_std[0], non_std[1], non_max, non_min) +
        '\n' + (corr_title.format(non_corr) if non_corr is not np.nan else ''),
        (-axis_max, axis_max), (-axis_max, axis_max),
        aspect='equal', axis=True, grid=True)

    plot_circle(non_rec_sub, scale * triang_thld, 'r')
    plot_circle(non_rec_sub, scale * bundle_thld, 'm')
    non_rec_sub.plot(zero, zero, '+k')
    non_rec_sub.plot(non_errors[:, 0], non_errors[:, 1], 'om')

    display_figure(fig, save_figs, data, '{0}_{1}_error_dist.jpg'.format(im, data.feature_type()),
                   bottom=0.05, top=0.90 if single_column else 0.85, hspace=0.22)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Plot inlier and outlier matches between images')
    parser.add_argument('dataset',
                        help='path to the data set to be processed')
    parser.add_argument('image1',
                        help='name of the first image to show')
    parser.add_argument('image2',
                        help='name of the second image to show')
    parser.add_argument('--save_figs',
                        help='save figures instead of showing them',
                        action='store_true')
    parser.add_argument('--single_col',
                        help='show figures in one column',
                        action='store_true')
    parser.add_argument('--plot_rec',
                        help='plots the reconstruction tracks for each image',
                        action='store_true')

    args = parser.parse_args()
    ds = dataset.DataSet(args.dataset)
    image1 = args.image1
    image2 = args.image2
    save = args.save_figs
    single = args.single_col

    print('Plotting matches for {0} - {1}...'.format(image1, image2))
    create_matches_figure(image1, image2, ds, save, single)

    print('Plotting tracks for {0} - {1}...'.format(image1, image2))
    create_tracks_figure(image1, image2, ds, save, single)

    print('Plotting reconstruction matches for {0} - {1}...'.format(image1, image2))
    create_reconstruction_matches_figure(image1, image2, ds, save)

    if args.plot_rec:
        print('Plotting complete reconstruction tracks for {0}...'.format(image1))
        create_complete_reconstruction_figure(image1, ds, save, single)

        print('Plotting reprojection error for {0}...'.format(image1))
        create_reprojection_error_figure(image1, ds, save, single)

        print('Plotting complete reconstruction tracks for {0}...'.format(image2))
        create_complete_reconstruction_figure(image2, ds, save, single)

        print('Plotting reprojection error for {0}...'.format(image2))
        create_reprojection_error_figure(image2, ds, save, single)
