#!/usr/bin/env python3

import argparse
from itertools import combinations

import matplotlib.pyplot as pl
import networkx as nx
import numpy as np

from opensfm import dataset
from opensfm import features
from opensfm import io
from opensfm import tracking


def plot_matches(im1, im2, p1, p2):
    h1, w1, c = im1.shape
    h2, w2, c = im2.shape
    image = np.zeros((max(h1, h2), w1 + w2, 3), dtype=im1.dtype)
    image[0:h1, 0:w1, :] = im1
    image[0:h2, w1:(w1 + w2), :] = im2

    p1 = features.denormalized_image_coordinates(p1, w1, h1)
    p2 = features.denormalized_image_coordinates(p2, w2, h2)
    pl.imshow(image)
    for a, b in zip(p1, p2):
        pl.plot([a[0], b[0] + w1], [a[1], b[1]], 'c')

    pl.plot(p1[:, 0], p1[:, 1], 'ob')
    pl.plot(p2[:, 0] + w1, p2[:, 1], 'ob')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Plot tracks')
    parser.add_argument('dataset',
                        help='path to the dataset to be processed')
    parser.add_argument('--image',
                        help='show tracks for a specific')
    parser.add_argument('--images',
                        help='show tracks between a subset of images (separated by commas)')
    parser.add_argument('--graph',
                        help='display image graph',
                        action='store_true')
    parser.add_argument('--save_figs',
                        help='save figures instead of showing them',
                        action='store_true')
    args = parser.parse_args()

    data = dataset.DataSet(args.dataset)
    tracks_manager = data.load_tracks_manager()
    images = tracks_manager.get_shot_ids()
    image_graph = tracking.as_weighted_graph(tracks_manager)

    if args.graph:
        # Plot graph
        weights = [i[2]['weight'] for i in image_graph.edges(data=True)]

        nx.draw_graphviz(
            image_graph, edge_color=weights,
            edge_cmap=pl.get_cmap('Blues'), edge_vmin=0, edge_vmax=200)
        pl.axis('off')
        pl.show()
    else:
        # Plot matches between images
        if args.image:
            pairs = [(args.image, o) for o in images if o != args.image]
        elif args.images:
            subset = args.images.split(',')
            pairs = combinations(subset, 2)
        else:
            pairs = combinations(images, 2)

        i = 0
        for im1, im2 in pairs:
            t, p1, p2 = tracking.common_tracks(tracks_manager, im1, im2)
            if len(t) >= 10:
                pl.figure(figsize=(20, 10))
                pl.title('Images: ' + im1 + ' - ' + im2 +
                         ', matches: ' + str(len(t)))
                plot_matches(data.load_image(im1),
                             data.load_image(im2), p1, p2)
                i += 1
                if args.save_figs:
                    p = args.dataset + '/plot_tracks'
                    io.mkdir_p(p)
                    pl.savefig(p + '/' + im1 + '_' + im2 + '.jpg',
                               dpi=100)
                    pl.close()
                else:
                    if i >= 10:
                        i = 0
                        pl.show()

        if not args.save_figs and i > 0:
            pl.show()
