#!/usr/bin/env python3

import argparse
import matplotlib.pyplot as pl
import numpy as np

from opensfm import io
from opensfm.dataset import DataSet
from opensfm.large.metadataset import MetaDataSet


def set_axes_limits(samples):
    min = np.amin(positions, 0)
    max = np.amax(positions, 0)
    dist_x = max[1] - min[1]
    dist_y = max[0] - min[0]

    pl.xlim(min[1] - dist_x / 20, max[1] + dist_x / 20)
    pl.ylim(min[0] - dist_y / 20, max[0] + dist_y / 20)


def set_gps_labels(fig):
    fig.suptitle('GPS positions', fontsize=18)
    pl.xlabel('Longitude', fontsize=16)
    pl.ylabel('Latitude', fontsize=16)


def plot_clusters(ax, positions, labels, centers=None):
    if centers is not None:
        ax.plot(centers[:, 1], centers[:, 0], 'xk')

    s = np.max([2, 30 * (1 - positions.shape[0] / 100000.)])
    ax.scatter(positions[:, 1], positions[:, 0], c=labels, lw=0, s=s)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Plot gps location of a splitted dataset')
    parser.add_argument('dataset',
                        help='path to the dataset')
    parser.add_argument('-m', '--mode',
                        help='plot images clusters in different colors',
                        default='positions',
                        choices=['clusters', 'centers', 'neighbors', 'positions', 'hist', 'submodels'])
    parser.add_argument('-s', '--save-figs',
                        help='save figures instead of showing them',
                        action='store_true')
    args = parser.parse_args()

    meta_data = MetaDataSet(args.dataset)

    fig = pl.figure(figsize=(20, 15))

    if args.mode == 'clusters' or args.mode == 'centers':
        _, positions, labels, centers = meta_data.load_clusters()
        if args.mode == 'clusters':
            centers = None

        ax = fig.add_subplot(111)
        ax.set_aspect(1)
        plot_clusters(ax, positions, labels, centers)
        set_axes_limits(positions)
        set_gps_labels(fig)
    elif args.mode == 'neighbors':
        positionMap = {}
        for image, lat, lon in meta_data.images_with_gps():
            positionMap[image] = [lat, lon]

        colors = []
        markers = []
        positions = []
        color_types = ['r', 'g', 'm', 'c', 'y', 'b']
        marker_types = ['o', '+', 'x', '*']
        for index, cluster in enumerate(meta_data.load_clusters_with_neighbors()):
            for image in cluster:
                colors.append(color_types[index % 6])
                markers.append(marker_types[index % 4])
                positions.append(positionMap[image])

        positions = np.array(positions)

        ax = fig.add_subplot(111)
        ax.set_aspect(1)
        for p, c, m in zip(positions, colors, markers):
            ax.plot(p[1], p[0], '{}{}'.format(c, m))

        set_axes_limits(positions)
        set_gps_labels(fig)
    elif args.mode == 'positions':
        positions = []
        for image, lat, lon in meta_data.images_with_gps():
            positions.append([lat, lon])

        positions = np.array(positions)

        ax = fig.add_subplot(111)
        ax.set_aspect(1)
        ax.plot(positions[:, 1], positions[:, 0], 'or')
        set_axes_limits(positions)
        set_gps_labels(fig)
    elif args.mode == 'hist':
        labels = meta_data.load_clusters()[2]
        cl_count = np.bincount(labels)
        nb_count = []
        for cluster in meta_data.load_clusters_with_neighbors():
            nb_count.append(len(cluster))

        range_max = np.max(np.append(cl_count, nb_count)) + 1
        num_bins = np.min([range_max, 50])

        fig.suptitle('Cluster sizes', fontsize=18)
        ax = fig.add_subplot(211)
        ax.hist(cl_count, bins=num_bins, range=(0, range_max))
        ax.set_title("K-means clusters", fontsize=18)
        pl.xlabel('Cluster size', fontsize=16)
        pl.ylabel('Number of clusters', fontsize=16)

        ax = fig.add_subplot(212)
        ax.hist(nb_count, bins=num_bins, range=(0, range_max), color='r')
        ax.set_title("Clusters with added neighbors", fontsize=18)
        pl.xlabel('Cluster size', fontsize=16)
        pl.ylabel('Number of clusters', fontsize=16)
    elif args.mode == 'submodels':
        positionMap = {}
        for im, lat, lon in meta_data.images_with_gps():
            positionMap[im] = [lat, lon]

        ax = fig.add_subplot(111)
        ax.set_aspect(1)

        colors = []
        markers = []
        positions = []
        color_types = ['r', 'g', 'm', 'c', 'y', 'b']
        marker_types = ['o', '+', 'x', '*']
        for index, submodel_path in enumerate(meta_data.get_submodel_paths()):
            data = DataSet(submodel_path)
            cluster_positions = []
            for im in data.images():
                colors.append(color_types[index % 6])
                markers.append(marker_types[index % 4])
                positions.append(positionMap[im])
                cluster_positions.append(positionMap[im])

            cluster_center = np.mean(np.array(cluster_positions), axis=0)

            ax.plot(cluster_center[1], cluster_center[0], 'xk')
            ax.annotate(submodel_path, (cluster_center[1], cluster_center[0]))

        positions = np.array(positions)

        for p, c, m in zip(positions, colors, markers):
            ax.plot(p[1], p[0], '{}{}'.format(c, m))

        set_axes_limits(positions)
        set_gps_labels(fig)

    if args.save_figs:
        plot_path = '{}/plot_gps'.format(args.dataset)
        io.mkdir_p(plot_path)
        pl.savefig('{}/{}.png'.format(plot_path, args.mode), dpi=100)
        pl.close()
    else:
        pl.show()
