#!/usr/bin/python

# Version: 0.1
# Author: Ferdinand Mütsch <muetsch@kit.edu>

# Script to convert nuPlan maps to CommonRoad format.
# Currently only includes lanes and lane connectors layers.
#
# Limitations / open to dos:
# - Parse traffic lights / regulatory elements
# - Parse walkways
# - Parse crosswalks
# - Can't convert directly to Lanelet2 without CommonRoad as intermediate step

# Setup (Python 3.9):
# pip install git+https://github.com/motional/nuplan-devkit.git  # install nuplan devkit
# pip install -r https://raw.githubusercontent.com/motional/nuplan-devkit/master/requirements.txt  # install devkit dependencies
# pip install commonroad-all commonroad-scenario-designer lanelet2

# Usage
# Step 1: nuPlan -> CommonRoad
# python nuplan_to_commonroad_map.py --map_root ~/data/nuplan/v1.1/dataset/maps
#
# Step 2: CommonRoad -> Lanelet2
# crdesigner --input-file map.xml --output-file map.osm crlanelet2 --proj "epsg:32619"

import argparse
import logging
from typing import List, cast, Tuple, Optional

import numpy as np
from commonroad.common.common_lanelet import LaneletType
from commonroad.common.file_writer import CommonRoadFileWriter
from commonroad.common.writer.file_writer_interface import OverwriteExistingFile
from commonroad.planning.planning_problem import PlanningProblemSet
from commonroad.scenario.lanelet import LaneletNetwork, Lanelet
from commonroad.scenario.scenario import Scenario
from lanelet2.core import Lanelet as L2Lanelet, LineString3d, getId, Point3d
from nuplan.common.maps.abstract_map_objects import LaneGraphEdgeMapObject
from nuplan.common.maps.maps_datatypes import SemanticMapLayer
from nuplan.common.maps.nuplan_map.lane import NuPlanLane
from nuplan.common.maps.nuplan_map.lane_connector import NuPlanLaneConnector
from nuplan.database.maps_db.gpkg_mapsdb import GPKGMapsDB
from nuplan.database.maps_db.map_api import NuPlanMapWrapper

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('NuPlan2CommonRoadMapConverter')


class NuPlan2CommonRoadMapConverter:
    def __init__(self, map_root: str, map_name: str = 'us-ma-boston', map_version: str = 'nuplan-maps-v1.0', target_crs_epsg: Optional[int] = None):
        self.nuplan_map: NuPlanMapWrapper = NuPlanMapWrapper(
            GPKGMapsDB(map_version, map_root),
            map_name
        )
        self.nuplan_map.initialize_all_layers()

        if target_crs_epsg:
            for layer in self.nuplan_map._vector_map.values():
                layer.to_crs(epsg=target_crs_epsg, inplace=True)

    def get(self) -> Scenario:
        scenario: Scenario = Scenario(dt=0.05, author='Default Author', affiliation='None', source='None', tags=set())
        lanelets: LaneletNetwork = scenario.lanelet_network

        all_lanes: List[NuPlanLane] = cast(List[NuPlanLane], [self.nuplan_map.get_map_object(fid, SemanticMapLayer.LANE) for fid in self.nuplan_map._get_vector_map_layer(SemanticMapLayer.LANE)['fid'].tolist()])
        all_connectors: List[NuPlanLaneConnector] = cast(List[NuPlanLaneConnector], [self.nuplan_map.get_map_object(fid, SemanticMapLayer.LANE_CONNECTOR) for fid in self.nuplan_map._get_vector_map_layer(SemanticMapLayer.LANE_CONNECTOR)['fid'].tolist()])

        # parse lanes and lane connectors
        for l in [*all_lanes, *all_connectors]:
            n_vertices: int = max(len(l.left_boundary.linestring.coords), len(l.right_boundary.linestring.coords))

            left: List[Tuple[float, float]] = self.sample_2d(list(zip(*l.left_boundary.linestring.xy)), n_vertices)
            right: List[Tuple[float, float]] = self.sample_2d(list(zip(*l.right_boundary.linestring.xy)), n_vertices)
            center: List[Tuple[float, float]] = self.get_center_line(left, right)

            lanelets.add_lanelet(Lanelet(
                lanelet_id=int(l.id),
                lanelet_type={LaneletType.URBAN},
                left_vertices=np.array(left),
                right_vertices=np.array(right),
                center_vertices=np.array(center),
            ))

        # another pass to add relations
        for l in [*all_lanes, *all_connectors]:
            l = cast(LaneGraphEdgeMapObject, l)
            lanelet: Lanelet = lanelets.find_lanelet_by_id(int(l.id))
            if lanelet is None:
                continue

            if (left := l.adjacent_edges[0]) is not None:
                lanelet.adj_left = int(left.id)
                lanelet.adj_left_same_direction = l.get_roadblock_id() == left.get_roadblock_id()
            if (right := l.adjacent_edges[1]) is not None:
                lanelet.adj_right = int(right.id)
                lanelet.adj_right_same_direction = l.get_roadblock_id() == right.get_roadblock_id()
            for successor in l.outgoing_edges:
                lanelet.add_successor(int(successor.id))
            for predecessor in l.incoming_edges:
                lanelet.add_predecessor(int(predecessor.id))

        return scenario

    @staticmethod
    def to_file(scenario: Scenario, out_file: str):
        CommonRoadFileWriter(scenario, PlanningProblemSet(None)).write_to_file(out_file, OverwriteExistingFile.ALWAYS)

    @classmethod
    def get_center_line(cls, left: List[Tuple[float, float]], right: List[Tuple[float, float]]) -> List[Tuple[float, float]]:
        # hacky and awkward way of getting a centerline between two boundaries
        l: L2Lanelet = L2Lanelet(
            getId(),
            LineString3d(getId(), [Point3d(getId(), x, y, 0) for x, y in left]),
            LineString3d(getId(), [Point3d(getId(), x, y, 0) for x, y in right]),
        )
        return cls.sample_2d(list(map(lambda p: (p.x, p.y), l.centerline)), max(len(left), len(right)))

    @staticmethod
    def sample_2d(coords: List[Tuple[float, float]], n: int) -> List[Tuple[float, float]]:
        x, y = zip(*coords)
        x = np.array(x)
        y = np.array(y)

        idxs = np.arange(len(x))
        idxs_new = np.linspace(0, len(x) - 1, n)

        x_new = np.interp(idxs_new, idxs, x)
        y_new = np.interp(idxs_new, idxs, y)
        coords_new = list(zip(x_new, y_new))

        assert len(coords_new) == n
        return coords_new


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--map_root', type=str, required=True, help='Absolut path to your nuPlan maps/ directory (containing the .json file)')
    parser.add_argument('--map', type=str, required=False, default='us-ma-boston', help='nuPlan map to convert')
    parser.add_argument('--map_version', type=str, required=False, default='nuplan-maps-v1.0', help='nuPlan map version to use')
    parser.add_argument('--crs', type=str, required=False, default=None, help='Optionally convert coordinates to the CRS defined by this EPSG code')
    parser.add_argument('--out', type=str, required=False, default='map.xml', help='Target file path')
    args = parser.parse_args()

    logger.info(f'Loading map ...')
    converter = NuPlan2CommonRoadMapConverter(
        args.map_root, args.map, args.map_version, args.crs
    )

    logger.info(f'Converting ...')
    NuPlan2CommonRoadMapConverter.to_file(converter.get(), args.out)