import os
import psycopg2
from psycopg2.pool import ThreadedConnectionPool
import logging
import time
import threading
from contextlib import contextmanager

logger = logging.getLogger(__name__)


class ConnectionPool:
    """connection pool with health monitoring and automatic recovery"""
    
    def __init__(self, minconn=2, maxconn=20, **db_params):
        self.db_params = db_params
        self.minconn = minconn
        self.maxconn = maxconn
        self.pool = None
        self._pool_lock = threading.RLock()
        self._last_health_check = 0
        self._health_check_interval = 300  # 5 minutes
        self._connection_timeout = 30
        self._idle_timeout = 7200  # 2 hours
        
        self._initialize_pool()
        
    def _initialize_pool(self):
        """Initialize the connection pool with proper parameters"""
        try:
            # Add connection parameters to handle idle connections
            pool_params = {
                **self.db_params,
                'connect_timeout': self._connection_timeout,
                # These parameters help with connection management
                'keepalives_idle': 600,      # Start keepalives after 10 min idle
                'keepalives_interval': 30,   # Send keepalive every 30 seconds  
                'keepalives_count': 3,       # Close connection after 3 failed keepalives
            }
            
            self.pool = ThreadedConnectionPool(
                minconn=self.minconn,
                maxconn=self.maxconn,
                **pool_params
            )
            logger.info(f"Connection pool initialized with {self.minconn}-{self.maxconn} connections")
            
        except Exception as e:
            logger.error(f"Failed to initialize connection pool: {e}")
            raise
    
    def _recreate_pool(self):
        """Recreate the connection pool in case of catastrophic failure"""
        with self._pool_lock:
            if self.pool:
                try:
                    self.pool.closeall()
                except:
                    pass
            
            logger.warning("Recreating connection pool...")
            self._initialize_pool()
    
    def _validate_connection(self, conn):
        """Validate a connection with comprehensive checks"""
        try:
            # Check if connection is alive
            with conn.cursor() as cur:
                cur.execute("SELECT 1")
                cur.fetchone()
                
            # Check connection status
            if conn.closed != 0:
                return False
                
            # Check for any pending transactions
            if conn.info.transaction_status != psycopg2.extensions.TRANSACTION_STATUS_IDLE:
                try:
                    conn.rollback()
                except:
                    return False
                    
            return True
            
        except (psycopg2.OperationalError, psycopg2.InterfaceError, 
                psycopg2.DatabaseError) as e:
            logger.debug(f"Connection validation failed: {e}")
            return False
        except Exception as e:
            logger.warning(f"Unexpected error during connection validation: {e}")
            return False
    
    def _health_check(self):
        """Perform periodic health check on the pool"""
        current_time = time.time()
        if current_time - self._last_health_check < self._health_check_interval:
            return
            
        try:
            with self._pool_lock:
                if self.pool:
                    # Try to get a connection to test pool health
                    test_conn = self.pool.getconn()
                    if test_conn and self._validate_connection(test_conn):
                        self.pool.putconn(test_conn)
                        self._last_health_check = current_time
                        return
                    else:
                        # Connection is bad, try to close it
                        if test_conn:
                            try:
                                test_conn.close()
                            except:
                                pass
                
                # Pool seems unhealthy, recreate it
                logger.warning("Pool health check failed, recreating pool")
                self._recreate_pool()
                self._last_health_check = current_time
                
        except Exception as e:
            logger.error(f"Health check failed: {e}")
            try:
                self._recreate_pool()
            except Exception as recreate_error:
                logger.error(f"Failed to recreate pool during health check: {recreate_error}")
    
    @contextmanager
    def get_connection(self, max_retries=3):
        """Get a validated connection with automatic retry and recovery"""
        self._health_check()
        
        conn = None
        for attempt in range(max_retries):
            try:
                with self._pool_lock:
                    if not self.pool:
                        self._initialize_pool()
                    conn = self.pool.getconn()
                
                if conn and self._validate_connection(conn):
                    try:
                        yield conn
                        return
                    finally:
                        if conn:
                            try:
                                # Ensure connection is in a clean state
                                if conn.info.transaction_status != psycopg2.extensions.TRANSACTION_STATUS_IDLE:
                                    conn.rollback()
                                self.pool.putconn(conn)
                            except Exception as e:
                                logger.warning(f"Error returning connection to pool: {e}")
                                try:
                                    conn.close()
                                except:
                                    pass
                else:
                    # Bad connection, close it
                    if conn:
                        try:
                            conn.close()
                        except:
                            pass
                    conn = None
                    
            except Exception as e:
                logger.warning(f"Connection attempt {attempt + 1} failed: {e}")
                if conn:
                    try:
                        conn.close()
                    except:
                        pass
                    conn = None
                
                if attempt == max_retries - 1:
                    # Last attempt, try to recreate pool
                    try:
                        self._recreate_pool()
                    except Exception as recreate_error:
                        logger.error(f"Failed to recreate pool: {recreate_error}")
                    raise ConnectionError(f"Failed to get valid connection after {max_retries} attempts")
                
                # Wait before retry with exponential backoff
                time.sleep(min(2 ** attempt, 10))
    
    def get_valid_conn(self):
        """Legacy method for backward compatibility - get a validated connection"""
        max_retries = 3
        for attempt in range(max_retries):
            try:
                with self._pool_lock:
                    if not self.pool:
                        self._initialize_pool()
                    conn = self.pool.getconn()
                
                if conn and self._validate_connection(conn):
                    return conn
                else:
                    # Bad connection, close it
                    if conn:
                        try:
                            conn.close()
                        except:
                            pass
                    conn = None
                    
            except Exception as e:
                logger.warning(f"Connection attempt {attempt + 1} failed: {e}")
                if conn:
                    try:
                        conn.close()
                    except:
                        pass
                    conn = None
                
                if attempt == max_retries - 1:
                    # Last attempt, try to recreate pool
                    try:
                        self._recreate_pool()
                    except Exception as recreate_error:
                        logger.error(f"Failed to recreate pool: {recreate_error}")
                    raise ConnectionError(f"Failed to get valid connection after {max_retries} attempts")
                
                # Wait before retry with exponential backoff
                time.sleep(min(2 ** attempt, 10))
    
    def put_conn(self, conn):
        """Return connection to pool - legacy method for backward compatibility"""
        try:
            if conn:
                # Ensure connection is in a clean state
                if conn.info.transaction_status != psycopg2.extensions.TRANSACTION_STATUS_IDLE:
                    conn.rollback()
                self.pool.putconn(conn)
        except Exception as e:
            logger.warning(f"Error returning connection to pool: {e}")
            try:
                conn.close()
            except:
                pass
    
    def close_all(self):
        """Close all connections in the pool"""
        with self._pool_lock:
            if self.pool:
                try:
                    self.pool.closeall()
                finally:
                    self.pool = None