#!/usr/bin/env python3

import argparse
import os.path

import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Wedge

from opensfm import dataset
from opensfm import features
from opensfm import io


def plot_features(image, points):
    h, w, d = image.shape
    pixels = features.denormalized_image_coordinates(points, w, h)
    sizes = points[:, 2] * max(w, h)
    angles = points[:, 3]

    ax = plt.axes()
    ax.imshow(image)

    patches = []
    for p, s, a in zip(pixels, sizes, angles):
        patches.append(Wedge(p, s, a + 1, a - 1))

    collection = PatchCollection(patches, alpha=0.4)
    ax.add_collection(collection)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Plot detected features')
    parser.add_argument('dataset',
                        help='path to the dataset to be processed')
    parser.add_argument('--image',
                        help='name of the image to show')
    parser.add_argument('--save_figs',
                        help='save figures instead of showing them',
                        action='store_true')
    args = parser.parse_args()

    data = dataset.DataSet(args.dataset)

    images = [args.image] if args.image else data.images()
    for image in images:
        features_data = data.load_features(image)
        if not features_data:
            continue
        points = features_data.points
        print("plotting {0} points".format(len(points)))
        plt.figure()
        plt.title('Image: ' + image + ', features: ' + str(len(points)))
        fig = plot_features(data.load_image(image), points)

        if args.save_figs:
            p = os.path.join(args.dataset, 'plot_features')
            io.mkdir_p(p)
            plt.savefig(os.path.join(p, image + '.jpg'))
            plt.close()
        else:
            plt.show()
