import json
import base64
from typing import Optional, Dict

# Import the Redis client that all workers will share
from .redis_client import redis_client

class ResponseManager:
    """
    Manages response state in a central Redis store, keyed by student_id.
    This solution is safe for multiple workers.
    """

    def __init__(self):
        """Initializes by connecting to the shared Redis client."""
        if redis_client is None:
            raise ConnectionError("ResponseManager requires a valid Redis connection. Check your REDIS_HOST/PORT environment variables.")
        self.redis = redis_client
        self.ttl_seconds = 600  # Responses will expire after 10 minutes

    def _get_key(self, student_id: str) -> str:
        """Creates a consistent key for Redis to avoid conflicts."""
        return f"student_response:{student_id}"

    def store_response(self, student_id: str, text: str, audio_filepath: Optional[str] = None, audio_bytes: Optional[bytes] = None) -> None:
        """Stores a response for a specific student_id in Redis."""
        key = self._get_key(student_id)
        
        # Encode binary audio data into a string (Base64) to store it in JSON
        encoded_audio = base64.b64encode(audio_bytes).decode('utf-8') if audio_bytes else None
        
        payload = {
            "text": text,
            "audio_filepath": audio_filepath,
            "audio_bytes_b64": encoded_audio
        }
        
        # Convert the dictionary to a JSON string and store it in Redis with an expiration time
        self.redis.setex(key, self.ttl_seconds, json.dumps(payload))

    def get_response(self, student_id: str) -> Dict:
        """
        Gets the response for a student without deleting it.
        This allows the client to safely retry the request if it fails.
        The key will be cleaned up automatically by Redis when its TTL expires.
        """
        key = self._get_key(student_id)
        
        # 1. Use a simple, non-destructive GET command. No pipeline needed.
        json_value = self.redis.get(key)

        if not json_value:
            return {"text": None, "audio_filepath": None, "audio_bytes": None}

        # 2. Decode the payload as before.
        payload = json.loads(json_value)
        
        if payload.get("audio_bytes_b64"):
            payload["audio_bytes"] = base64.b64decode(payload["audio_bytes_b64"])
        else:
            payload["audio_bytes"] = None
        
        del payload["audio_bytes_b64"]
        
        return payload

    def clear_response(self, student_id: str) -> None:
        """
        Clears a response for a specific student from Redis.
        This is still important to call at the *beginning* of a new /chat request
        to ensure old data is invalidated immediately.
        """
        key = self._get_key(student_id)
        self.redis.delete(key)

    def is_response_fresh(self, student_id: str) -> bool:
        """
        Checks if a response exists in Redis for the given student.
        This is much simpler and more reliable now.
        """
        key = self._get_key(student_id)
        # redis.exists() returns the number of keys that exist (0 or 1 in this case).
        return self.redis.exists(key) > 0