#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
gps_track_speed.py (CSV + GPX 지원)

GUI 호환 제공 함수:
- load_track(in_path, return_full=False|True)   # CSV/GPX 자동 인식
- load_csv(in_csv, return_full=False|True)
- load_gpx(in_gpx, return_full=False|True)
- save_folium_map(lat, lon, spd_kmh, out_html, ...)
- maybe_save_png(lat, lon, spd_kmh, out_png, ...)
- save_gpx(lat, lon, spd_kmh, out_gpx, ...)

기능:
- 트랙(선) 속도색 표시 (0=파랑, 고속=빨강, vmax=검정)
- 포인트 클릭 시 속도 팝업
- 포인트 옆에 속도 숫자 라벨 '항상 표시' (흰색 반투명 박스)
"""

from __future__ import annotations

import os
import re
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd


# -----------------------------
# CSV loader
# -----------------------------
def load_csv(in_csv: str, return_full: bool = False):
    df = pd.read_csv(in_csv)

    cols_lower = {c.lower().strip(): c for c in df.columns}

    def pick(*cands: str) -> Optional[str]:
        for c in cands:
            if c in df.columns:
                return c
        for c in cands:
            k = c.lower().strip()
            if k in cols_lower:
                return cols_lower[k]
        return None

    col_lat = pick("Latitude", "Lat")
    col_lon = pick("Longitude", "Lon", "Lng", "Long")
    col_spd = pick("Speed(km/h)", "Speed", "Speed_kmh", "Speed(kmh)", "Speed km/h")

    if not col_lat or not col_lon or not col_spd:
        raise ValueError(
            "CSV 헤더에서 위도/경도/속도 컬럼을 찾지 못했습니다.\n"
            f"현재 컬럼: {list(df.columns)}\n"
            "필수: Latitude, Longitude, Speed(km/h) (또는 Lat/Lon/Speed 등)"
        )

    col_ele = pick("Altitude", "Alt", "Ele", "Elevation")
    col_time = pick("Time", "Timestamp", "Datetime", "Date_Time", "DateTime")
    col_rtk = pick("RTK", "Fix", "Status")

    lat = df[col_lat].astype(float).to_numpy()
    lon = df[col_lon].astype(float).to_numpy()
    spd = df[col_spd].astype(float).to_numpy()

    ok = np.isfinite(lat) & np.isfinite(lon) & np.isfinite(spd)
    lat, lon, spd = lat[ok], lon[ok], spd[ok]

    def opt(colname: Optional[str]):
        if not colname:
            return None
        arr = df[colname].to_numpy()[ok]
        return arr

    ele = opt(col_ele)
    if ele is not None:
        try:
            ele = ele.astype(float)
        except Exception:
            ele = None

    t = opt(col_time)
    if t is not None:
        t = t.astype(str)

    rtk = opt(col_rtk)
    if rtk is not None:
        try:
            rtk = rtk.astype(int)
        except Exception:
            rtk = None

    if return_full:
        return (
            lat.tolist(),
            lon.tolist(),
            spd.tolist(),
            (ele.tolist() if ele is not None else None),
            (t.tolist() if t is not None else None),
            (rtk.tolist() if rtk is not None else None),
        )

    return lat.tolist(), lon.tolist(), spd.tolist()


# -----------------------------
# GPX loader
# -----------------------------
def load_gpx(in_gpx: str, return_full: bool = False):
    """
    GPX에서 트랙포인트를 읽습니다.
    - lat/lon: trkpt 속성
    - ele: <ele> 있으면 사용
    - time: <time> 있으면 사용
    - speed: <extensions><speed_kmh> 있으면 사용, 없으면 0
    """
    import xml.etree.ElementTree as ET

    def strip_ns(tag: str) -> str:
        return tag.split("}")[-1] if "}" in tag else tag

    tree = ET.parse(in_gpx)
    root = tree.getroot()

    lat_list: List[float] = []
    lon_list: List[float] = []
    spd_list: List[float] = []
    ele_list: List[float] = []
    time_list: List[str] = []

    # 모든 trkpt 수집 (namespace 상관없이)
    for trkpt in root.iter():
        if strip_ns(trkpt.tag) != "trkpt":
            continue

        lat_s = trkpt.attrib.get("lat", None)
        lon_s = trkpt.attrib.get("lon", None)
        if lat_s is None or lon_s is None:
            continue

        try:
            lat = float(lat_s)
            lon = float(lon_s)
        except Exception:
            continue

        ele = None
        t = None
        spd = None

        # children 탐색
        for ch in list(trkpt):
            name = strip_ns(ch.tag)
            if name == "ele":
                try:
                    ele = float((ch.text or "").strip())
                except Exception:
                    ele = None
            elif name == "time":
                t = (ch.text or "").strip()
            elif name == "extensions":
                # extensions 안에서 speed_kmh 찾기
                for ex in ch.iter():
                    if strip_ns(ex.tag) == "speed_kmh":
                        try:
                            spd = float((ex.text or "").strip())
                        except Exception:
                            spd = None

        if spd is None:
            spd = 0.0  # GPX에 속도값이 없으면 0으로

        lat_list.append(lat)
        lon_list.append(lon)
        spd_list.append(spd)
        if ele is not None:
            ele_list.append(ele)
        if t is not None:
            time_list.append(t)

    if len(lat_list) < 2:
        raise ValueError("GPX에서 트랙포인트(trkpt)를 충분히 찾지 못했습니다. (최소 2개 필요)")

    if return_full:
        return (
            lat_list,
            lon_list,
            spd_list,
            (ele_list if len(ele_list) == len(lat_list) else None),
            (time_list if len(time_list) == len(lat_list) else None),
            None,  # rtk 정보는 GPX에 없을 수 있어 None
        )

    return lat_list, lon_list, spd_list


# -----------------------------
# CSV/GPX 자동 로더
# -----------------------------
def load_track(in_path: str, return_full: bool = False):
    ext = os.path.splitext(in_path)[1].lower()
    if ext == ".gpx":
        return load_gpx(in_path, return_full=return_full)
    # 기본은 CSV
    return load_csv(in_path, return_full=return_full)


# -----------------------------
# Speed -> Color
# -----------------------------
def speed_to_hex(speed_kmh: float, vmax: float = 300.0, threshold: float = 0.90) -> str:
    if vmax <= 0:
        vmax = 300.0
    s = max(0.0, min(float(speed_kmh), float(vmax)))
    cut = float(vmax) * float(threshold)
    if cut <= 0:
        cut = float(vmax)

    if s <= cut:
        t = 0.0 if cut == 0 else s / cut
        r = int(round(255 * t))
        b = int(round(255 * (1.0 - t)))
        return f"#{r:02X}00{b:02X}"

    t2 = 0.0 if (vmax - cut) == 0 else (s - cut) / (vmax - cut)
    r = int(round(255 * (1.0 - t2)))
    return f"#{r:02X}0000"


# -----------------------------
# Map creation (Folium)
# -----------------------------
def save_folium_map(
    lat: List[float],
    lon: List[float],
    spd_kmh: List[float],
    out_html: str,
    vmax: float = 300.0,
    threshold: float = 0.90,
    weight: int = 5,
    zoom_start: int = 18,
    tiles: str = "OpenStreetMap",
    point_step: int = 1,    # 점(원) 표시 간격
    label_step: int = 5,    # 숫자 라벨 표시 간격 (포인트 많으면 5/10/20 추천)
):
    import folium
    from branca.element import MacroElement, Template

    n = min(len(lat), len(lon), len(spd_kmh))
    if n < 2:
        raise ValueError("포인트가 너무 적습니다(최소 2개 필요).")

    center = [float(np.mean(lat[:n])), float(np.mean(lon[:n]))]
    m = folium.Map(location=center, zoom_start=zoom_start, tiles=tiles, control_scale=True)

    # 1) colored segments
    for i in range(n - 1):
        s = (float(spd_kmh[i]) + float(spd_kmh[i + 1])) / 2.0
        color = speed_to_hex(s, vmax=vmax, threshold=threshold)
        p1 = [float(lat[i]), float(lon[i])]
        p2 = [float(lat[i + 1]), float(lon[i + 1])]
        folium.PolyLine([p1, p2], color=color, weight=int(weight), opacity=0.9).add_to(m)

    # 2) clickable point markers (optional)
    step_pt = max(1, int(point_step))
    for i in range(0, n, step_pt):
        s = float(spd_kmh[i])
        color = speed_to_hex(s, vmax=vmax, threshold=threshold)
        popup_html = f"<b>Index</b>: {i}<br><b>Speed</b>: {s:.2f} km/h"
        folium.CircleMarker(
            location=[float(lat[i]), float(lon[i])],
            radius=4,
            color=color,
            fill=True,
            fill_color=color,
            fill_opacity=0.9,
            weight=1,
            tooltip=f"{s:.1f} km/h",
            popup=folium.Popup(popup_html, max_width=260),
        ).add_to(m)

    # 3) speed text labels (always visible) with white semi-transparent box
    step_lb = max(1, int(label_step))
    for i in range(0, n, step_lb):
        s = float(spd_kmh[i])
        color = speed_to_hex(s, vmax=vmax, threshold=threshold)

        html = f"""
        <div style="
            display:inline-block;
            padding:2px 6px;
            border-radius:6px;
            background:rgba(255,255,255,0.75);
            border:1px solid rgba(0,0,0,0.25);
            box-shadow: 0 1px 2px rgba(0,0,0,0.15);
            font-size:12px;
            font-weight:700;
            color:{color};
            white-space:nowrap;
            transform: translate(10px, -10px);
        ">{s:.1f}</div>
        """
        folium.Marker(
            location=[float(lat[i]), float(lon[i])],
            icon=folium.DivIcon(html=html),
        ).add_to(m)

    # start/end
    folium.Marker([float(lat[0]), float(lon[0])], tooltip="START").add_to(m)
    folium.Marker([float(lat[n - 1]), float(lon[n - 1])], tooltip="END").add_to(m)

    # legend + title
    legend_html = f"""
    {{% macro html(this, kwargs) %}}
    <div style="
      position: fixed; bottom: 30px; left: 30px; width: 300px; z-index:9999;
      background: rgba(255,255,255,0.92); border: 2px solid #333; border-radius: 8px;
      padding: 10px; font-size: 12px;">
      <div style="font-weight:900; margin-bottom:4px;">MYGPS.CO.KR  Track 분석기</div>
      <div style="font-weight:700; margin-bottom:6px;">Speed (km/h)</div>
      <div style="height:12px; border-radius:6px;
        background: linear-gradient(to right, #0000FF, #FF0000, #000000);"></div>
      <div style="display:flex; justify-content:space-between; margin-top:6px;">
        <span>0</span><span>{int(vmax)}</span>
      </div>
      <div style="margin-top:6px; color:#444;">
        점 옆 숫자 표시 (label_step={step_lb})
      </div>
    </div>
    {{% endmacro %}}
    """
    macro = MacroElement()
    macro._template = Template(legend_html)
    m.get_root().add_child(macro)

    os.makedirs(os.path.dirname(os.path.abspath(out_html)) or ".", exist_ok=True)
    m.save(out_html)


# -----------------------------
# PNG
# -----------------------------
def maybe_save_png(lat, lon, spd_kmh, out_png: str, vmax: float = 300.0, threshold: float = 0.90):
    if not out_png:
        return
    import matplotlib.pyplot as plt

    n = min(len(lat), len(lon), len(spd_kmh))
    if n < 2:
        return

    os.makedirs(os.path.dirname(os.path.abspath(out_png)) or ".", exist_ok=True)

    for i in range(n - 1):
        s = (float(spd_kmh[i]) + float(spd_kmh[i + 1])) / 2.0
        c = speed_to_hex(s, vmax=vmax, threshold=threshold)
        plt.plot([lon[i], lon[i + 1]], [lat[i], lat[i + 1]], color=c, linewidth=2)

    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.title("MYGPS.CO.KR  Track 분석기")
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()


# -----------------------------
# GPX save
# -----------------------------
def save_gpx(
    lat: List[float],
    lon: List[float],
    spd_kmh: List[float],
    out_gpx: str,
    ele: Optional[List[float]] = None,
    time_str: Optional[List[str]] = None,
    rtk: Optional[List[int]] = None,
    track_name: str = "GPS Track",
):
    import xml.etree.ElementTree as ET

    def looks_full_datetime(s: str) -> bool:
        s = (s or "").strip()
        return bool(re.match(r"^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}", s))

    n = min(len(lat), len(lon), len(spd_kmh))
    gpx = ET.Element(
        "gpx",
        {"version": "1.1", "creator": "MYGPS.CO.KR Track 분석기", "xmlns": "http://www.topografix.com/GPX/1/1"},
    )
    trk = ET.SubElement(gpx, "trk")
    ET.SubElement(trk, "name").text = track_name
    seg = ET.SubElement(trk, "trkseg")

    for i in range(n):
        pt = ET.SubElement(seg, "trkpt", {"lat": f"{float(lat[i]):.8f}", "lon": f"{float(lon[i]):.8f}"})

        if ele is not None and i < len(ele) and ele[i] is not None:
            try:
                ET.SubElement(pt, "ele").text = f"{float(ele[i]):.3f}"
            except Exception:
                pass

        if time_str is not None and i < len(time_str) and time_str[i] is not None:
            ts = str(time_str[i]).strip()
            if looks_full_datetime(ts):
                ts = ts.replace(" ", "T")
                ET.SubElement(pt, "time").text = ts

        ext = ET.SubElement(pt, "extensions")
        ET.SubElement(ext, "speed_kmh").text = f"{float(spd_kmh[i]):.3f}"
        if rtk is not None and i < len(rtk) and rtk[i] is not None:
            try:
                ET.SubElement(ext, "rtk").text = str(int(rtk[i]))
            except Exception:
                pass

    try:
        ET.indent(gpx, space="  ", level=0)  # py3.9+
    except Exception:
        pass

    os.makedirs(os.path.dirname(os.path.abspath(out_gpx)) or ".", exist_ok=True)
    ET.ElementTree(gpx).write(out_gpx, encoding="utf-8", xml_declaration=True)
