# -*- coding: utf-8 -*-
import time
import json
import math
import asyncio
import aiohttp
import requests
import numpy as np
from collections import deque

from engine import detect_signal
from indicators import calc_mf_stats
from logger import append_log
from mempool import fetch_mempool, normalize_mempool


# ================================================================
# GLOBAL STATE
# ================================================================

PRICE_WINDOW = deque(maxlen=120)   # 2 minutes of price data
PRICE_HISTORY = deque(maxlen=600)  # 10 minutes - (timestamp, price) pairs

HAS_POSITION = False
BUY_PRICE = None
ENTRY_TS = None

TOTAL_TRADES = 0
WINS = 0
LOSSES = 0
TOTAL_PNL = 0.0
BEST_TRADE = 0.0
WORST_TRADE = 0.0


# ================================================================
# HELPERS
# ================================================================

BASE_URLS = [
    "https://api1.binance.com",
    "https://api2.binance.com",
    "https://api3.binance.com",
]

def get_aggtrades(symbol="BTCUSDT", limit=500):
    url = f"https://data-api.binance.vision/api/v3/aggTrades?symbol={symbol}&limit={limit}"
    try:
        r = requests.get(url, timeout=3)
        r.raise_for_status()
        data = r.json()
    except Exception as e:
        print(f"Error fetching aggTrades: {e}")
        return []

    trades = []
    for t in data:
        price = float(t["p"])
        qty = float(t["q"])
        # Binance Vision API uses "m"
        is_buyer_maker = bool(t["m"])

        trades.append({
            "p": price,
            "q": qty,
            "isBuyerMaker": is_buyer_maker,
        })

    return trades


def build_mf_from_trades(trades):
    """
    Convert aggTrades into MF stats expected by calc_mf_stats().
    isBuyerMaker=True ? aggressive SELL
    isBuyerMaker=False ? aggressive BUY
    """
    buy = 0.0
    sell = 0.0

    for t in trades:
        if t["isBuyerMaker"]:
            sell += t["q"]
        else:
            buy += t["q"]

    return {
        "1": {"buy": buy, "sell": sell},
        "5": {"buy": buy, "sell": sell},
        "15": {"buy": buy, "sell": sell},
    }


def calc_cvd_from_trades(trades):
    """
    Calculate CVD (Cumulative Volume Delta) from aggTrades
    CVD = sum(buy volumes) - sum(sell volumes) in USD
    
    Returns: Normalized CVD (typically between -10 to +10)
    """
    buy_volume_usd = 0.0
    sell_volume_usd = 0.0
    
    for t in trades:
        price = t["p"]
        qty = t["q"]
        volume_usd = price * qty
        
        if t["isBuyerMaker"]:
            # Seller aggressive (taker sell)
            sell_volume_usd += volume_usd
        else:
            # Buyer aggressive (taker buy)
            buy_volume_usd += volume_usd
    
    # CVD in millions of dollars
    cvd_raw = (buy_volume_usd - sell_volume_usd) / 1_000_000
    
    # Normalize to reasonable range (-10 to +10)
    cvd = max(-10.0, min(10.0, cvd_raw))
    
    return float(cvd)


def calc_whale_from_trades(trades, threshold_usd=50_000):
    """
    Identify WHALE trades (large transactions)
    
    Args:
        trades: List of aggTrades
        threshold_usd: Minimum threshold for whale trade (default $50k)
    
    Returns: Total whale volume in millions of dollars
    """
    whale_volume_usd = 0.0
    
    for t in trades:
        price = t["p"]
        qty = t["q"]
        volume_usd = price * qty
        
        # If trade is larger than threshold - it's a whale trade
        if volume_usd >= threshold_usd:
            whale_volume_usd += volume_usd
    
    # Return in millions
    whale = whale_volume_usd / 1_000_000
    
    return float(whale)


def calc_price_change_5m(current_price):
    """
    Calculate percentage price change over 5 minutes
    
    Args:
        current_price: Current price
    
    Returns: Percentage price change (e.g., 0.15 = 0.15%)
    """
    if not PRICE_HISTORY:
        return 0.0
    
    current_time = time.time()
    target_time = current_time - 300  # 5 minutes back
    
    # Find the closest price to 5 minutes ago
    closest_price = None
    min_time_diff = float('inf')
    
    for ts, price in PRICE_HISTORY:
        time_diff = abs(ts - target_time)
        if time_diff < min_time_diff:
            min_time_diff = time_diff
            closest_price = price
    
    if closest_price is None or closest_price == 0:
        return 0.0
    
    # Calculate percentage change
    price_change_pct = ((current_price - closest_price) / closest_price) * 100.0
    
    return float(price_change_pct)


def get_orderbook(symbol="BTCUSDT", limit=50):
    url = f"https://data-api.binance.vision/api/v3/depth?symbol={symbol}&limit={limit}"
    try:
        r = requests.get(url, timeout=3)
        r.raise_for_status()
        ob = r.json()
    except Exception as e:
        print(f"[OB ERROR] {e}")
        return 1.0, 1.0

    try:
        # TAKE ONLY FIRST 5 LEVELS — BEST ACCURACY
        bids = sum(float(x[1]) for x in ob["bids"][:5])
        asks = sum(float(x[1]) for x in ob["asks"][:5])

        best_bid = float(ob["bids"][0][0])
        best_ask = float(ob["asks"][0][0])

        depth_ratio = bids / asks if asks > 0 else 1.0
        best_ratio = best_bid / best_ask if best_ask > 0 else 1.0
    except:
        depth_ratio = 1.0
        best_ratio = 1.0

    return depth_ratio, best_ratio


def compute_smooth(prices: list):
    if len(prices) < 5:
        return 0.0
    x = np.arange(len(prices))
    y = np.array(prices, dtype=float)
    try:
        m, b = np.polyfit(x, y, 1)
        smooth = math.tanh(m / (y.mean() * 0.0001 + 1e-8))
    except:
        smooth = 0.0
    return float(smooth)


def print_trade_summary():
    global TOTAL_TRADES, WINS, LOSSES, TOTAL_PNL, BEST_TRADE, WORST_TRADE

    if TOTAL_TRADES == 0:
        return

    win_rate = WINS / TOTAL_TRADES * 100
    avg_pnl = TOTAL_PNL / TOTAL_TRADES

    print("\n" + "=" * 60)
    print("TRADING SUMMARY")
    print("=" * 60)
    print(f"Total Trades: {TOTAL_TRADES}")
    print(f"Wins: {WINS} | Losses: {LOSSES}")
    print(f"Win Rate: {win_rate:.1f}%")
    print(f"Total PnL: {TOTAL_PNL:+.2f}%")
    print(f"Avg PnL: {avg_pnl:+.3f}%")
    print(f"Best Trade: {BEST_TRADE:+.2f}%")
    print(f"Worst Trade: {WORST_TRADE:+.2f}%")
    print("=" * 60 + "\n")


# ================================================================
# ONE CYCLE
# ================================================================

async def run_cycle(session, symbol="BTCUSDT"):
    global HAS_POSITION, BUY_PRICE, ENTRY_TS
    global TOTAL_TRADES, WINS, LOSSES, TOTAL_PNL, BEST_TRADE, WORST_TRADE

    t0 = time.time()

    # --- TRADES ---
    trades = get_aggtrades(symbol)
    if not trades:
        print("No trades")
        return

    price = trades[-1]["p"]
    PRICE_WINDOW.append(price)
    
    # Save price with timestamp to history
    PRICE_HISTORY.append((time.time(), price))
    
    smooth = compute_smooth(list(PRICE_WINDOW))

    # --- ORDERBOOK ---
    depth_ratio, best_level_ratio = get_orderbook(symbol)

    # --- MF FIXED ---
    mf_raw = build_mf_from_trades(trades)
    mf = calc_mf_stats(mf_raw)

    buy_ratio = mf.get("buy_ratio", 0.5)
    ratio_1m = buy_ratio
    ratio_5m = buy_ratio
    ratio_15m = buy_ratio
    
    # --- CVD, WHALE, PRICE_CHANGE (NOW CALCULATED!) ---
    cvd = calc_cvd_from_trades(trades)
    whale = calc_whale_from_trades(trades, threshold_usd=50_000)
    price_change_5m = calc_price_change_5m(price)

    # --- MEMPOOL ---
    try:
        mem = await fetch_mempool(session)
        onchain = normalize_mempool(mem)
    except:
        onchain = [0, 0, 0]

    # --- POSITION PNL ---
    if HAS_POSITION and BUY_PRICE is not None:
        pnl_pct_live = (price - BUY_PRICE) / BUY_PRICE * 100.0
    else:
        pnl_pct_live = 0.0

    # --- POSITION AGE ---
    if HAS_POSITION and ENTRY_TS:
        position_age = time.time() - ENTRY_TS
    else:
        position_age = 0.0

    # --- STATE PACK (FULL VERSION FOR ENGINE) ---
    state = {
        "price": price,
        "smooth": smooth,
        "buy_ratio": buy_ratio,
        "cvd": cvd,
        "whale": whale,
        "depth_ratio": depth_ratio,
        "best_level_ratio": best_level_ratio,
        "ratio_1m": ratio_1m,
        "ratio_5m": ratio_5m,
        "ratio_15m": ratio_15m,
        "price_change_pct_5m": price_change_5m,
        "onchain": onchain,
        "has_position": HAS_POSITION,
        "buy_price": BUY_PRICE,
        "position_age": position_age,
        "pnl_pct": pnl_pct_live,
    }

    # --- ENGINE DECISION ---
    decision = detect_signal(state)
    action = decision.get("action")
    executed = False
    exec_info = {}

    # ===== BUY =====
    if action == "BUY":
        if not HAS_POSITION:
            HAS_POSITION = True
            BUY_PRICE = price
            ENTRY_TS = time.time()
            executed = True
            exec_info = {"exec_side": "OPEN_LONG", "exec_price": price}
            print(f"\n?? OPEN LONG at ${price:,.2f}")

    # ===== SELL =====
    elif action == "SELL":
        if HAS_POSITION:
            pnl_pct = (price - BUY_PRICE) / BUY_PRICE * 100
            executed = True
            exec_info = {
                "exec_side": "CLOSE_LONG",
                "exec_price": price,
                "pnl_pct_exec": pnl_pct,
                "hold_time_sec": time.time() - ENTRY_TS,
            }

            TOTAL_TRADES += 1
            TOTAL_PNL += pnl_pct
            BEST_TRADE = max(BEST_TRADE, pnl_pct)
            WORST_TRADE = min(WORST_TRADE, pnl_pct)

            if pnl_pct > 0:
                WINS += 1
            else:
                LOSSES += 1

            print(f"\n?? CLOSED at ${price:,.2f} | PnL: {pnl_pct:+.2f}%")
            print_trade_summary()

            HAS_POSITION = False
            BUY_PRICE = None
            ENTRY_TS = None

    decision["executed"] = executed
    decision["exec_info"] = exec_info

    append_log({"state": state, "decision": decision, "ts": time.time()})

    # Enhanced console output
    print(">>>", {
        "price": round(price, 2),
        "action": action,
        "buy_ratio": round(buy_ratio, 3),
        "cvd": round(cvd, 2),
        "whale": round(whale, 2),
        "price_change_5m": round(price_change_5m, 3),
        "executed": executed
    })

    print(f"Cycle time: {time.time() - t0:.4f}s")


# ================================================================
# MAIN LOOP
# ================================================================

async def main():
    print("\n" + "??"*30)
    print("BOT STARTED - WITH CVD/WHALE/PRICE_CHANGE")
    print("??"*30 + "\n")
    print("??  Warming up... (wait 5-10 minutes for full functionality)")
    print("=" * 60 + "\n")
    
    async with aiohttp.ClientSession() as session:
        while True:
            await run_cycle(session, "BTCUSDT")
            await asyncio.sleep(1)


if __name__ == "__main__":
    asyncio.run(main())