# mesh_mensaje_clase.py

import paho.mqtt.client as mqtt
from meshtastic import mqtt_pb2, mesh_pb2, portnums_pb2, telemetry_pb2
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import base64
from datetime import datetime, timezone
import ssl
import time
import json
import os

# ------------------------------------------------------------
# 🔧 Cargar configuración desde config.json
# ------------------------------------------------------------
_ruta_config = os.path.join(os.path.dirname(__file__), 'mensa-config.json')
with open(_ruta_config, 'r', encoding='utf-8') as f:
    _config_datos = json.load(f)

CLAVES_CANAL = _config_datos['CLAVES_CANAL']
NOMBRE_MODELOS = _config_datos['NOMBRE_MODELOS']
NOMBRE_ROLES = _config_datos['NOMBRE_ROLES']


# ------------------------------------------------------------
# 📦 Proceso_Crudo: genera diccionario de mensaje crudo
# ------------------------------------------------------------
class Proceso_Crudo:
    def __init__(self, callback_crudo=None):
        self.callback_crudo = callback_crudo

    def procesar(self, paquete, desde_nodo_str, id_recep_str, canal_nombre, canal_indice):
        """Devuelve un diccionario con los metadatos crudos del mensaje."""
        try:
            # Determinar si está encriptado
            encriptado = paquete.HasField("encrypted")
            if encriptado:
                portnum = -1
                portnum_name = "ENCRYPTED_UNKNOWN"
                payload_size = len(paquete.encrypted)
                payload_hex = paquete.encrypted.hex()
            else:
                portnum = getattr(paquete.decoded, 'portnum', -1)
                try:
                    portnum_name = portnums_pb2.PortNum.Name(portnum)
                except ValueError:
                    portnum_name = f"UNKNOWN_{portnum}"
                payload_size = len(paquete.decoded.payload) if paquete.HasField("decoded") else 0
                payload_hex = paquete.decoded.payload.hex() if paquete.HasField("decoded") else ""

            crudo = {
                "portnum": portnum,
                "portnum_name": portnum_name,
                "id_nodo": desde_nodo_str,
                "id_recep": id_recep_str,
                "canal_nombre": canal_nombre,
                "canal_indice": canal_indice,
                "encriptado": 1 if encriptado else 0,
                "payload_size": payload_size,
                "payload_hex": payload_hex,
                "fecha_recepcion": datetime.now(timezone.utc).isoformat()
            }

            if self.callback_crudo:
                self.callback_crudo(crudo)

            return crudo

        except Exception:
            return None


# ------------------------------------------------------------
# 📍 Proceso_Posicion
# ------------------------------------------------------------
class Proceso_Posicion:
    def __init__(self, callback_procesado=None):
        self.callback_procesado = callback_procesado

    def _safe_float(self, val, default=0.0):
        if val is None:
            return float(default)
        try:
            return float(val)
        except (ValueError, TypeError):
            return float(default)

    def proceso(self, paquete, desde_nodo_str, canal_nombre):
        try:
            posicion = mesh_pb2.Position()
            posicion.ParseFromString(paquete.decoded.payload)

            lat = getattr(posicion, 'latitude_i', 0) / 1e7
            lon = getattr(posicion, 'longitude_i', 0) / 1e7
            alt = getattr(posicion, 'altitude', 0)
            time_rx = getattr(posicion, 'time', 0)
            satelites = getattr(posicion, 'sats_in_view', 0)
            velocidad = getattr(posicion, 'ground_speed', 0)

            if time_rx > 0:
                fecha_gps = datetime.fromtimestamp(time_rx, tz=timezone.utc).isoformat()
            else:
                fecha_gps = None

            datos = {
                "id_nodo": desde_nodo_str,
                "canal_nombre": canal_nombre,
                "latitud": self._safe_float(lat),
                "longitud": self._safe_float(lon),
                "altitud": self._safe_float(alt, 0),
                "fecha_gps": fecha_gps,
                "satelites": self._safe_float(satelites, 0),
                "velocidad": self._safe_float(velocidad, 0)
            }

            if self.callback_procesado:
                self.callback_procesado(datos)

            return datos

        except Exception:
            return None


# ------------------------------------------------------------
# 📊 Proceso_Telemetria
# ------------------------------------------------------------
class Proceso_Telemetria:
    def __init__(self, callback_procesado=None):
        self.callback_procesado = callback_procesado

    def _safe_float(self, val, default=0.0):
        if val is None:
            return float(default)
        try:
            return float(val)
        except (ValueError, TypeError):
            return float(default)

    def proceso(self, paquete, desde_nodo_str, canal_nombre):
        try:
            telemetria = telemetry_pb2.Telemetry()
            telemetria.ParseFromString(paquete.decoded.payload)

            bateria = getattr(telemetria.device_metrics, 'battery_level', None)
            voltaje = getattr(telemetria.device_metrics, 'voltage', None)
            canal_util = getattr(telemetria.device_metrics, 'channel_utilization', None)
            aire_util = getattr(telemetria.device_metrics, 'air_util_tx', None)

            temperatura = None
            humedad = None
            presion = None
            if telemetria.HasField('environment_metrics'):
                env = telemetria.environment_metrics
                temperatura = getattr(env, 'temperature', None)
                humedad = getattr(env, 'relative_humidity', None)
                presion = getattr(env, 'barometric_pressure', None)

            datos = {
                "id_nodo": desde_nodo_str,
                "canal_nombre": canal_nombre,
                "bateria": self._safe_float(bateria, None),
                "voltaje": self._safe_float(voltaje, None),
                "canal_util": self._safe_float(canal_util, None),
                "aire_util": self._safe_float(aire_util, None),
                "temperatura": self._safe_float(temperatura, None),
                "humedad": self._safe_float(humedad, None),
                "presion": self._safe_float(presion, None)
            }

            if self.callback_procesado:
                self.callback_procesado(datos)

            return datos

        except Exception:
            return None


# ------------------------------------------------------------
# 👤 Proceso_Usuario
# ------------------------------------------------------------
class Proceso_Usuario:
    def __init__(self, callback_procesado=None):
        self.callback_procesado = callback_procesado

    def proceso(self, paquete, desde_nodo_str, canal_nombre):
        try:
            usuario = mesh_pb2.User()
            usuario.ParseFromString(paquete.decoded.payload)

            nombre_largo = getattr(usuario, 'long_name', '').strip() or "Sin nombre"
            nombre_corto = getattr(usuario, 'short_name', '').strip() or desde_nodo_str[-4:].upper()
            mac_bytes = getattr(usuario, 'macaddr', b'')
            mac = ':'.join(f"{b:02x}" for b in mac_bytes) if mac_bytes else "00:00:00:00:00:00"
            hw_model = getattr(usuario, 'hw_model', 0)
            rol = getattr(usuario, 'role', 0)

            modelo = NOMBRE_MODELOS.get(str(hw_model), f"Modelo_{hw_model}")
            rol_str = NOMBRE_ROLES.get(str(rol), "Desconocido")

            datos = {
                "id_nodo": desde_nodo_str,
                "canal_nombre": canal_nombre,
                "nombre_largo": nombre_largo,
                "nombre_corto": nombre_corto,
                "mac": mac,
                "modelo": modelo,
                "rol": rol_str,
                "firmware": None  # normalmente no viene en NODEINFO
            }

            if self.callback_procesado:
                self.callback_procesado(datos)

            return datos

        except Exception:
            return None


# ------------------------------------------------------------
# 💬 Proceso_Texto
# ------------------------------------------------------------
class Proceso_Texto:
    def __init__(self, callback_procesado=None):
        self.callback_procesado = callback_procesado

    def proceso(self, paquete, desde_nodo_str, id_recep_str, canal_nombre):
        try:
            texto_bytes = paquete.decoded.payload
            try:
                texto = texto_bytes.decode('utf-8').strip()
            except UnicodeDecodeError:
                texto = texto_bytes.decode('utf-8', errors='replace').strip()

            datos = {
                "texto": texto,
                "id_nodo": desde_nodo_str,
                "id_recep": id_recep_str,
                "canal_nombre": canal_nombre
            }

            if self.callback_procesado:
                self.callback_procesado(datos)

            return datos

        except Exception:
            return None


# ------------------------------------------------------------
# 📡 Ruta_Mensajes: orquestador principal
# ------------------------------------------------------------
class Ruta_Mensajes:
    def __init__(self, topico_base, callback_crudo=None, callback_posicion=None,
                 callback_telemetria=None, callback_usuario=None, callback_texto=None):
        self.topico_base = topico_base
        self.proceso_crudo = Proceso_Crudo(callback_crudo)
        self.proceso_posicion = Proceso_Posicion(callback_posicion)
        self.proceso_telemetria = Proceso_Telemetria(callback_telemetria)
        self.proceso_usuario = Proceso_Usuario(callback_usuario)
        self.proceso_texto = Proceso_Texto(callback_texto)

        self.cliente = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2)
        self.cliente.on_connect = self.nueva_conexion
        self.cliente.on_message = self.nuevo_mensaje

    def _desencripta_mensaje(self, paquete, desde_nodo_str, canal_nombre):
        clave_b64 = CLAVES_CANAL.get(canal_nombre)
        if not clave_b64:
            return False
        try:
            relleno = (4 - (len(clave_b64) % 4)) % 4
            clave_ajustada = clave_b64 + ('=' * relleno)
            clave_normalizada = clave_ajustada.replace('-', '+').replace('_', '/')
            clave_bytes = base64.b64decode(clave_normalizada)

            vector_inicial = paquete.id.to_bytes(8, "little") + getattr(paquete, "from").to_bytes(8, "little")
            cifrador = Cipher(algorithms.AES(clave_bytes), modes.CTR(vector_inicial), backend=default_backend())
            descifrador = cifrador.decryptor()
            descifrado = descifrador.update(paquete.encrypted) + descifrador.finalize()

            datos = mesh_pb2.Data()
            datos.ParseFromString(descifrado)
            paquete.decoded.CopyFrom(datos)
            return True
        except Exception:
            return False

    def nueva_conexion(self, cliente, userdata, flags, reason_code, properties):
        if reason_code == 0:
            topico = f"{self.topico_base}#"
            cliente.subscribe(topico)
            print(f"✅ Conectado y suscrito a: {topico}")
        else:
            print(f"❌ Error al conectar: {reason_code}")
            
            
    def nuevo_mensaje(self, cliente, userdata, msg):
        try:
            envoltura = mqtt_pb2.ServiceEnvelope()
            envoltura.ParseFromString(msg.payload)
            paquete = envoltura.packet

            desde_nodo_int = getattr(paquete, 'from', 0)
            desde_nodo_str = f"!{desde_nodo_int:08x}"
            id_recep_int = getattr(paquete, 'to', 0xffffffff)
            id_recep_str = f"!{id_recep_int:08x}"
            canal_nombre = getattr(envoltura, 'channel_id', 'unknown')
            canal_indice = getattr(paquete, 'channel', 0)

            # ✅ Registrar mensaje crudo (siempre)
            self.proceso_crudo.procesar(paquete, desde_nodo_str, id_recep_str, canal_nombre, canal_indice)

            # Desencriptar si es necesario
            if paquete.HasField("encrypted"):
                if not self._desencripta_mensaje(paquete, desde_nodo_str, canal_nombre):
                    return

            if not paquete.HasField("decoded"):
                return

            tipo_mensaje = paquete.decoded.portnum

            if tipo_mensaje == portnums_pb2.POSITION_APP:
                self.proceso_posicion.proceso(paquete, desde_nodo_str, canal_nombre)
            elif tipo_mensaje == portnums_pb2.TELEMETRY_APP:
                self.proceso_telemetria.proceso(paquete, desde_nodo_str, canal_nombre)
            elif tipo_mensaje == portnums_pb2.NODEINFO_APP:
                self.proceso_usuario.proceso(paquete, desde_nodo_str, canal_nombre)
            elif tipo_mensaje == portnums_pb2.TEXT_MESSAGE_APP:
                self.proceso_texto.proceso(paquete, desde_nodo_str, id_recep_str, canal_nombre)

        except Exception:
            pass  # silencioso, como en original

    def conectar(self, config_mqtt):
        self.cliente.username_pw_set(config_mqtt['usuario'], config_mqtt['clave'])
        if config_mqtt['puerto'] == 8883:
            self.cliente.tls_set(ca_certs=None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLS)
            self.cliente.tls_insecure_set(True)
        self.cliente.connect(config_mqtt['broker'], config_mqtt['puerto'], 60)
        self.cliente.loop_forever()