#!/usr/bin/env python3

import argparse
import os
import matplotlib.pyplot as pl
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from textwrap import wrap
from matplotlib import cm

from opensfm import dataset
from opensfm import io


def small_colorbar(ax, mappable=None):
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    pl.colorbar(cax=cax, mappable=mappable)


def depth_colormap(d, cmap=None, invalid_val=0, invalid_color=(.5, .5, .5)):
    """
    Colormaps and sets 0 (invalid) values to zero_color
    """
    sm = cm.ScalarMappable(cmap=cm.get_cmap(cmap))
    sm.set_array(d)
    rgb = sm.to_rgba(d)[:, :, :3]
    rgb[d == invalid_val] = invalid_color
    return rgb, sm


def plot_depthmap(im, title, depth, clean, plane, score, nghbr):
    ax = pl.subplot(2, 3, 1)
    ax_title = ax.set_title(title)
    ax_title.set_y(1.05)
    pl.imshow(im)
    ax = pl.subplot(2, 3, 2)
    ax_title = ax.set_title("Depth")
    ax_title.set_y(1.08)
    rgb, sm = depth_colormap(depth)
    pl.imshow(rgb)
    small_colorbar(ax, mappable=sm)
    ax = pl.subplot(2, 3, 3)
    ax_title = ax.set_title("Clean depth")
    ax_title.set_y(1.08)
    rgb, sm = depth_colormap(clean)
    pl.imshow(rgb)
    small_colorbar(ax, mappable=sm)
    ax = pl.subplot(2, 3, 4)
    ax_title = ax.set_title("Score")
    ax_title.set_y(1.08)
    pl.imshow(score)
    small_colorbar(ax)
    ax = pl.subplot(2, 3, 5)
    ax_title = ax.set_title("Neighbor")
    ax_title.set_y(1.08)
    pl.imshow(nghbr)
    small_colorbar(ax)
    ax = pl.subplot(2, 3, 6)
    ax_title = ax.set_title("Plane normal")
    ax_title.set_y(1.02)
    pl.imshow(color_plane_normals(plane))


def color_plane_normals(plane):
    norm = np.linalg.norm(plane, axis=2)
    normal = plane / norm[..., np.newaxis]

    n = 100
    coords = np.linspace(-1, 1, n)
    x, y = np.meshgrid(coords, coords, sparse=True)
    r = x**2 + y**2
    z = - np.sqrt(np.clip(1 - r, 0, 1))
    normal[:n, :n, 0] = np.where(r < 1, x, normal[:n, :n, 0])
    normal[:n, :n, 1] = np.where(r < 1, y, normal[:n, :n, 1])
    normal[:n, :n, 2] = np.where(r < 1, z, normal[:n, :n, 2])

    return ((1 - normal) * 128).astype(np.uint8)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Compute reconstruction')
    parser.add_argument('dataset',
                        help='path to the dataset to be processed')
    parser.add_argument('--image',
                        help='name of the image to show')
    parser.add_argument('--subfolder',
                        help='undistorted subfolder',
                        default='undistorted')
    parser.add_argument('--save-figs',
                        help='save figures instead of showing them',
                        action='store_true')
    args = parser.parse_args()

    data = dataset.DataSet(args.dataset)
    udata_path = os.path.join(data.data_path, args.subfolder)
    udata = dataset.UndistortedDataSet(data, udata_path, io_handler=data.io_handler)

    if args.image:
        images = [args.image]
    else:
        reconstructions = udata.load_undistorted_reconstruction()
        images = []
        for reconstruction in reconstructions:
            images.extend(reconstruction.shots.keys())

    for image in images:
        if not udata.raw_depthmap_exists(image):
            continue

        depth, plane, score, nghbr, nghbrs = udata.load_raw_depthmap(image)
        clean_depth, clean_plane, clean_score = udata.load_clean_depthmap(image)

        print("Plotting depthmap for {0}".format(image))
        pl.close("all")
        pl.figure(figsize=(30, 16), dpi=90, facecolor='w', edgecolor='k')
        title = "Shot: " + image + "\n" + "\n".join(wrap("Neighbors: " + ', '.join(nghbrs), 80))
        plot_depthmap(udata.load_undistorted_image(image), title, depth, clean_depth, plane, score, nghbr)
        pl.tight_layout()

        if args.save_figs:
            p = udata.data_path + '/plot_depthmaps'
            io.mkdir_p(p)
            pl.savefig(p + '/' + image + '.png')
            pl.close()
        else:
            pl.show()
