#!/usr/bin/env python3
import os.path, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import argparse

import json
import numpy as np

import opensfm.dataset as dataset
import opensfm.io as io

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Convert output from Bundler to OpenSfM reconstructions')
    parser.add_argument('dataset', help='path to the dataset to be processed')
    parser.add_argument('--list', help='the list.txt file')
    parser.add_argument('--bundleout', help='the bundle.out file')
    args = parser.parse_args()

    if args.list:
        list_file = args.list
    else:
        list_file = 'list.txt'

    if not args.bundleout:
        print "No bundle.out given.  Importing only the list of images list.txt."

    if os.path.exists(list_file):
        print 'Converting output from Bundler to OpenSfM'
        data = dataset.DataSet(args.dataset)
        tracks_file = data._DataSet__tracks_graph_file()
        reconstruction_file = data._DataSet__reconstruction_file(None)
        import_bundler(args.dataset, args.bundleout, list_file, tracks_file, reconstruction_file)
    else:
        print 'Cannot find list.txt'


def import_bundler(
    data_path, bundle_file, list_file, track_file, reconstruction_file=None
):
    """
    Reconstruction and tracks graph from Bundler's output
    """

    # Init OpenSfM working folder.
    io.mkdir_p(data_path)

    # Copy image list.
    list_dir = os.path.dirname(list_file)
    with io.open_rt(list_file) as fin:
        lines = fin.read().splitlines()
    ordered_shots = []
    image_list = []
    for line in lines:
        image_path = os.path.join(list_dir, line.split()[0])
        rel_to_data = os.path.relpath(image_path, data_path)
        image_list.append(rel_to_data)
        ordered_shots.append(os.path.basename(image_path))
    with io.open_wt(os.path.join(data_path, "image_list.txt")) as fout:
        fout.write("\n".join(image_list) + "\n")

    # Check for bundle_file
    if not bundle_file or not os.path.isfile(bundle_file):
        return None

    with io.open_rt(bundle_file) as fin:
        lines = fin.readlines()
    offset = 1 if "#" in lines[0] else 0

    # header
    num_shot, num_point = map(int, lines[offset].split(" "))
    offset += 1

    # initialization
    reconstruction = types.Reconstruction()

    # cameras
    for i in range(num_shot):
        # Creating a model for each shot.
        shot_key = ordered_shots[i]
        focal, k1, k2 = map(float, lines[offset].rstrip("\n").split(" "))

        if focal > 0:
            im = imread(os.path.join(data_path, image_list[i]))
            height, width = im.shape[0:2]
            camera = pygeometry.Camera.create_perspective(
                focal / max(width, height), k1, k2
            )
            camera.id = "camera_" + str(i)
            camera.width = width
            camera.height = height
            reconstruction.add_camera(camera)

            # Shots
            rline = []
            for k in range(3):
                rline += lines[offset + 1 + k].rstrip("\n").split(" ")
            R = " ".join(rline)
            t = lines[offset + 4].rstrip("\n").split(" ")
            R = np.array(list(map(float, R.split()))).reshape(3, 3)
            t = np.array(list(map(float, t)))
            R[1], R[2] = -R[1], -R[2]  # Reverse y and z
            t[1], t[2] = -t[1], -t[2]
            pose = pygeometry.Pose()
            pose.set_rotation_matrix(R)
            pose.translation = t

            reconstruction.create_shot(shot_key, camera.id, pose)
        else:
            logger.warning("ignoring failed image {}".format(shot_key))
        offset += 5

    # tracks
    track_lines = []
    for i in range(num_point):
        coordinates = lines[offset].rstrip("\n").split(" ")
        color = lines[offset + 1].rstrip("\n").split(" ")
        point = reconstruction.create_point(i, list(map(float, coordinates)))
        point.color = list(map(int, color))

        view_line = lines[offset + 2].rstrip("\n").split(" ")

        num_view, view_list = int(view_line[0]), view_line[1:]

        for k in range(num_view):
            shot_key = ordered_shots[int(view_list[4 * k])]
            if shot_key in reconstruction.shots:
                camera = reconstruction.shots[shot_key].camera
                scale = max(camera.width, camera.height)
                v = "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".format(
                    shot_key,
                    i,
                    view_list[4 * k + 1],
                    float(view_list[4 * k + 2]) / scale,
                    -float(view_list[4 * k + 3]) / scale,
                    point.color[0],
                    point.color[1],
                    point.color[2],
                )
                track_lines.append(v)
        offset += 3

    # save track file
    with io.open_wt(track_file) as fout:
        fout.writelines("\n".join(track_lines))

    # save reconstruction
    if reconstruction_file is not None:
        with io.open_wt(reconstruction_file) as fout:
            obj = reconstructions_to_json([reconstruction])
            json_dump(obj, fout)
    return reconstruction
