update to handle the test impatient user test case

parent 5b5af3f3
......@@ -72,16 +72,6 @@ class AgentService:
def is_available(self) -> bool:
return self.openai_service.is_available()
def text_to_speech(self, text: str, language: str) -> bytes:
if not self.tts_service or not self.tts_service.is_available():
raise HTTPException(status_code=503, detail="TTS service is not available")
# Step 1: apply tashkeel before sending to TTS
text = self.tashkeel_agent.apply_tashkeel(text)
print(f"Tashkeel applied: {text}")
# Step 2: send to TTS
return self.tts_service.generate_speech(text, language)
def generate_response(self, user_message: str, student_id: str, subject: str = "Science",
......@@ -95,195 +85,6 @@ class AgentService:
print(f"response: {response}")
return response
def search_similar(self, query_embedding: List[float], student_id: str,
subject: str = "chemistry", top_k: int = 3):
"""Search similar content with student-specific filtering"""
if not self.pgvector:
raise HTTPException(status_code=400, detail="PGVector service not enabled")
try:
student_info = self.db_service.get_student_info(student_id)
if not student_info:
raise HTTPException(status_code=404, detail=f"Student with ID {student_id} not found")
return self.pgvector.search_with_curriculum_context(
query_embedding=query_embedding,
grade=student_info['grade'],
subject=subject,
is_arabic=student_info['is_arabic'],
limit=top_k
)
except Exception as e:
logger.error(f"Error in search_similar: {e}")
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
def get_available_subjects(self, student_id: str) -> List[str]:
"""Get available subjects for the student"""
if not self.pgvector:
return []
try:
student_info = self.db_service.get_student_info(student_id)
if not student_info:
return []
return self.pgvector.get_subjects_by_grade_and_language(
student_info['grade'], student_info['is_arabic']
)
except Exception as e:
logger.error(f"Error getting available subjects: {e}")
return []
def get_curriculum_overview(self, student_id: str, subject: str = "Science") -> Dict:
"""Get curriculum overview for a specific student"""
if not self.pgvector:
return {"error": "PGVector service not available"}
try:
student_info = self.db_service.get_student_info(student_id)
if not student_info:
return {"error": "Student not found"}
curriculum = self.pgvector.get_curriculum_structure(
student_info['grade'], student_info['is_arabic'], subject
)
if not curriculum:
return {"error": f"No curriculum found for Grade {student_info['grade']}"}
return {
"student_info": {
"name": student_info['student_name'],
"grade": student_info['grade'],
"study_language": student_info['study_language'].value,
"nationality": student_info['nationality']
},
"curriculum": curriculum,
"available_units": self.pgvector.get_units_for_grade(
student_info['grade'], student_info['is_arabic'], subject
),
"source": "JSON-based curriculum structure"
}
except Exception as e:
logger.error(f"Error getting curriculum overview: {e}")
return {"error": str(e)}
# Conversation management methods
def export_conversation(self, student_id: str) -> List[Dict[str, str]]:
"""Export conversation history for a student"""
return self.response_generator.get_conversation_history(student_id)
def clear_conversation(self, student_id: str) -> Dict[str, str]:
"""Clear conversation history for a student"""
try:
self.db_service.clear_history(student_id)
return {"status": "success", "message": f"Conversation cleared for student {student_id}"}
except Exception as e:
logger.error(f"Error clearing conversation: {e}")
return {"status": "error", "message": f"Failed to clear conversation: {str(e)}"}
def get_agent_stats(self, student_id: str) -> Dict:
"""Get conversation statistics for a student"""
try:
history = self.response_generator.get_conversation_history(student_id)
user_messages = [msg for msg in history if msg['role'] == 'user']
assistant_messages = [msg for msg in history if msg['role'] == 'assistant']
system_messages = [msg for msg in history if msg['role'] == 'system']
student_info = self.db_service.get_student_info(student_id)
language_info = {}
curriculum_info = {}
if student_info:
language_info = {
"study_language": student_info['study_language'].value,
"nationality": student_info['nationality'],
"grade": student_info['grade']
}
if self.pgvector:
curriculum = self.pgvector.get_curriculum_structure(
student_info['grade'], student_info['is_arabic']
)
curriculum_info = {
"curriculum_available": curriculum is not None,
"curriculum_source": "JSON file" if curriculum else "None",
"available_subjects": self.pgvector.get_subjects_by_grade_and_language(
student_info['grade'], student_info['is_arabic']
),
"available_units": len(curriculum.get('units', [])) if curriculum else 0
}
return {
"student_id": student_id,
"total_messages": len(history),
"user_messages": len(user_messages),
"assistant_messages": len(assistant_messages),
"system_messages": len(system_messages),
"conversation_active": len(history) > 0,
**language_info,
**curriculum_info
}
except Exception as e:
logger.error(f"Error getting agent stats: {e}")
return {"student_id": student_id, "error": str(e)}
def get_curriculum_structure_info(self, student_id: str, subject: str = "Science") -> Dict:
"""Get detailed curriculum structure information"""
if not self.pgvector:
return {"error": "PGVector service not available"}
try:
student_info = self.db_service.get_student_info(student_id)
if not student_info:
return {"error": "Student not found"}
curriculum = self.pgvector.get_curriculum_structure(
student_info['grade'], student_info['is_arabic'], subject
)
if not curriculum:
return {"error": "No curriculum structure found"}
# Extract detailed structure info
units_info = []
if 'units' in curriculum:
for unit in curriculum['units']:
unit_info = {
"number": unit.get('number'),
"name": unit.get('name'),
"description": unit.get('description', ''),
"concepts_count": len(unit.get('concepts', [])),
"concepts": []
}
for concept in unit.get('concepts', []):
concept_info = {
"number": concept.get('number'),
"name": concept.get('name'),
"description": concept.get('description', ''),
"lessons_count": len(concept.get('lessons', []))
}
unit_info["concepts"].append(concept_info)
units_info.append(unit_info)
return {
"student_info": {
"grade": student_info['grade'],
"language": "Arabic" if student_info['is_arabic'] else "English",
"nationality": student_info['nationality']
},
"curriculum_title": curriculum.get('title', ''),
"total_units": len(units_info),
"units": units_info,
"source": "JSON-based curriculum file"
}
except Exception as e:
logger.error(f"Error getting curriculum structure info: {e}")
return {"error": str(e)}
def close(self):
"""Close database connection pools"""
......
......@@ -31,15 +31,13 @@ class ChatService:
MessageType.TEXT: TextMessageHandler()
}
def process_message(self, student_id: str, file: Optional[UploadFile] = None, text: Optional[str] = None, game_context: Optional[str] = None):
"""Process message and generate text and audio response."""
self.response_manager.clear_response(student_id) # Clear any existing response
try:
if file and file.filename:
result = self.handlers[MessageType.AUDIO].handle(file=file)
user_message = result.get("transcription", "")
# Assuming handle method reads the file content
file_content = file.file.read()
user_message = self.handlers[MessageType.AUDIO].openai_service.transcribe_audio(file_content, file.filename)
elif text:
user_message = text
else:
......@@ -51,19 +49,20 @@ class ChatService:
final_message_for_agent = f"game context: {game_context}\nuser query: {user_message}"
agent_response_text = self.agent_service.generate_response(
user_message=final_message_for_agent, # <-- USE THE NEW VARIABLE HERE
user_message=final_message_for_agent,
student_id=student_id,
)
audio_data = self._generate_and_upload_audio(agent_response_text, student_id)
self.response_manager.store_response(
student_id = student_id,
student_id=student_id,
text=agent_response_text,
audio_filepath=audio_data.get("filepath"),
audio_bytes=audio_data.get("bytes")
)
print(f"Generated response for student {student_id}: {agent_response_text[:100]}...")
return {
......@@ -80,22 +79,13 @@ class ChatService:
def _generate_and_upload_audio(self, text: str, student_id: str) -> dict:
"""
Segments mixed-language text and generates TTS audio using the pluggable
AgentService, then uploads the final audio to MinIO.
Segments mixed-language text and generates TTS audio, then uploads to MinIO.
"""
try:
# 1. Segment the text into language-tagged parts
segments = self.segmentation_service.segment_text(text)
# 2. Generate a single, stitched audio file from the sequence
# This call will be routed correctly by the tts_manager
audio_bytes = self.agent_service.tts_service.generate_speech_from_sequence(segments)
# 3. Determine filename and upload (same as before)
provider = os.getenv("TTS_PROVIDER", "openai").lower()
file_extension = "wav"
content_type = "audio/wav"
timestamp = int(time.time())
filename = f"agent_response_{timestamp}_{student_id}.{file_extension}"
minio_file_path = f"audio/{filename}"
......@@ -111,7 +101,7 @@ class ChatService:
file_path=minio_file_path,
expires=3600 # 1 hour
)
print(f"Successfully generated and uploaded stitched TTS audio: {filename}")
print(f"Successfully generated and uploaded TTS audio: {filename}")
return {"bytes": audio_bytes, "filepath": full_url}
except Exception as e:
......
# services/response_manager.py
import json
import base64
from typing import Optional, Dict
......@@ -7,26 +8,28 @@ from .redis_client import redis_client
class ResponseManager:
"""
Manages response state in a central Redis store, keyed by student_id.
This solution is safe for multiple workers.
Manages response state in a central Redis store using a FIFO Queue (List)
for each student_id. This is the definitive, race-condition-safe solution
for a stateless client.
"""
def __init__(self):
"""Initializes by connecting to the shared Redis client."""
if redis_client is None:
raise ConnectionError("ResponseManager requires a valid Redis connection. Check your REDIS_HOST/PORT environment variables.")
raise ConnectionError("ResponseManager requires a valid Redis connection.")
self.redis = redis_client
self.ttl_seconds = 600 # Responses will expire after 10 minutes
self.ttl_seconds = 600 # A key will expire 10 mins after the LAST item is added
def _get_key(self, student_id: str) -> str:
"""Creates a consistent key for Redis to avoid conflicts."""
return f"student_response:{student_id}"
"""Creates a consistent key for the student's queue."""
return f"student_queue:{student_id}"
def store_response(self, student_id: str, text: str, audio_filepath: Optional[str] = None, audio_bytes: Optional[bytes] = None) -> None:
"""Stores a response for a specific student_id in Redis."""
"""
Adds a new response to the END of the queue for a specific student.
"""
key = self._get_key(student_id)
# Encode binary audio data into a string (Base64) to store it in JSON
encoded_audio = base64.b64encode(audio_bytes).decode('utf-8') if audio_bytes else None
payload = {
......@@ -35,24 +38,24 @@ class ResponseManager:
"audio_bytes_b64": encoded_audio
}
# Convert the dictionary to a JSON string and store it in Redis with an expiration time
self.redis.setex(key, self.ttl_seconds, json.dumps(payload))
# RPUSH adds the new item to the right (end) of the list.
self.redis.rpush(key, json.dumps(payload))
# Reset the expiration time for the whole queue each time a new item is added.
self.redis.expire(key, self.ttl_seconds)
def get_response(self, student_id: str) -> Dict:
"""
Gets the response for a student without deleting it.
This allows the client to safely retry the request if it fails.
The key will be cleaned up automatically by Redis when its TTL expires.
Atomically retrieves and removes the OLDEST response from the front of the queue.
"""
key = self._get_key(student_id)
# 1. Use a simple, non-destructive GET command. No pipeline needed.
json_value = self.redis.get(key)
# LPOP atomically retrieves and removes the item from the left (start) of the list.
json_value = self.redis.lpop(key)
if not json_value:
return {"text": None, "audio_filepath": None, "audio_bytes": None}
# 2. Decode the payload as before.
# Decode the payload.
payload = json.loads(json_value)
if payload.get("audio_bytes_b64"):
......@@ -66,18 +69,15 @@ class ResponseManager:
def clear_response(self, student_id: str) -> None:
"""
Clears a response for a specific student from Redis.
This is still important to call at the *beginning* of a new /chat request
to ensure old data is invalidated immediately.
Completely deletes the entire queue for a student. Useful for '/clear' endpoints.
"""
key = self._get_key(student_id)
self.redis.delete(key)
def is_response_fresh(self, student_id: str) -> bool:
"""
Checks if a response exists in Redis for the given student.
This is much simpler and more reliable now.
Checks if there are any items in the student's queue.
"""
key = self._get_key(student_id)
# redis.exists() returns the number of keys that exist (0 or 1 in this case).
return self.redis.exists(key) > 0
\ No newline at end of file
# LLEN gets the length of the list. If it's > 0, there's a response ready.
return self.redis.llen(key) > 0
\ No newline at end of file
......@@ -18,7 +18,6 @@ class ResponseService:
self.response_manager = response_manager
self.audio_service = audio_service # Keep for now if used elsewhere
# --- REWRITTEN and IMPROVED ---
def get_agent_response(self, student_id: str):
"""
Gets the agent response from the manager and streams the raw audio bytes
......@@ -28,7 +27,6 @@ class ResponseService:
raise HTTPException(status_code=404, detail="Agent response not ready or expired.")
response_data = self.response_manager.get_response(student_id)
self.response_manager.clear_response(student_id) # Clear after getting it
text_response = response_data.get("text")
audio_bytes = response_data.get("audio_bytes")
......@@ -41,6 +39,10 @@ class ResponseService:
"text": text_response
}
if not response_data or not response_data.get("text"):
raise HTTPException(status_code=404, detail=f"Response for student {student_id} was already claimed or expired.")
# Determine content type based on filename extension
filename = response_data.get("audio_filename", "")
media_type = "audio/wav" if filename.endswith(".wav") else "audio/mpeg"
......
import requests
import threading
import time
import base64
import statistics
# --- Configuration ---
BASE_URL = "https://voice-agent.caprover.al-arcade.com"
# A challenging but reasonable concurrency level for a single container.
CONCURRENCY_LEVEL = 8
# Use 8 unique students to test different contexts.
STUDENTS_AND_QUESTIONS = {
"student_001": "What is photosynthesis and why is it important for life on Earth?",
"student_002": "Explain Newton's first law of motion with a simple example.",
"student_003": "How do polar bears adapt to survive in the freezing Arctic environment?",
"student_004": "What is the difference between a star and a planet?",
"student_005": "Can you explain the main stages of the water cycle?",
"student_006": "What is static electricity and how can I see it at home?",
"student_007": "Why do some animals, like birds, migrate every year?",
"student_008": "What is DNA and what does it do in our bodies?"
}
# Thread-safe dictionary to store results
final_results = {}
lock = threading.Lock()
def run_test_for_student(student_id, question):
"""
Performs the full POST -> GET cycle and records detailed timing information.
Includes a 60-second timeout to prevent the script from hanging.
"""
start_time = time.time()
try:
# --- Step 1: POST the chat message ---
chat_url = f"{BASE_URL}/chat"
chat_payload = {'student_id': student_id, 'text': question}
post_start = time.time()
# CRITICAL: Add a timeout to prevent hanging
chat_response = requests.post(chat_url, data=chat_payload, timeout=60)
post_end = time.time()
chat_response.raise_for_status() # Will raise an exception for 4xx or 5xx errors
# --- Step 2: GET the audio response ---
get_url = f"{BASE_URL}/get-audio-response?student_id={student_id}"
get_start = time.time()
# CRITICAL: Add a timeout here as well
audio_response = requests.get(get_url, timeout=60)
get_end = time.time()
audio_response.raise_for_status()
end_time = time.time()
encoded_text = audio_response.headers.get('X-Response-Text', '')
decoded_text = base64.b64decode(encoded_text).decode('utf-8') if encoded_text else 'NO RESPONSE TEXT'
with lock:
final_results[student_id] = {
"request": question,
"response": decoded_text,
"duration": end_time - start_time,
"post_duration": post_end - post_start,
"get_duration": get_end - get_start,
"status": "SUCCESS"
}
except Exception as e:
end_time = time.time()
with lock:
final_results[student_id] = {
"request": question,
"error": str(e),
"duration": end_time - start_time,
"status": "FAILED"
}
if __name__ == "__main__":
script_start_time = time.time()
print("="*70)
print("🔥 REASONABLE CONCURRENCY TEST 🔥")
print("="*70)
print(f"Concurrency Level: {CONCURRENCY_LEVEL} simultaneous requests")
print(f"Target URL: {BASE_URL}")
print(f"Started at: {time.strftime('%H:%M:%S')}")
print("="*70 + "\n")
threads = []
# Create and start a thread for each student
for sid, q in STUDENTS_AND_QUESTIONS.items():
thread = threading.Thread(target=run_test_for_student, args=(sid, q))
threads.append(thread)
thread.start()
# Stagger the start of each request slightly to simulate a more realistic load
time.sleep(0.2)
# Wait for all threads to complete
for thread in threads:
thread.join()
total_time = time.time() - script_start_time
# --- Analysis and Reporting ---
print("\n" + "="*70)
print("📊 DETAILED RESULTS & ANALYSIS 📊")
print("="*70 + "\n")
sorted_results = sorted(final_results.items())
for student_id, result in sorted_results:
status_icon = "✅" if result['status'] == "SUCCESS" else "❌"
print(f"{status_icon} Student ID: {student_id}")
print(f" ▶️ Request: '{result.get('request', 'N/A')}'")
if result['status'] == "SUCCESS":
print(f" ◀️ Response: '{result.get('response', 'N/A')[:80]}...'")
print(f" ⏱️ Timing: Total={result['duration']:.2f}s (POST={result['post_duration']:.2f}s, GET={result['get_duration']:.2f}s)")
else:
print(f" ❌ ERROR: {result.get('error')}")
print(f" ⏱️ Failed after {result['duration']:.2f}s")
print("-" * 70)
# --- Summary ---
success_count = len([res for res in final_results.values() if res['status'] == 'SUCCESS'])
failed_count = len(final_results) - success_count
print("\n" + "="*70)
print("📈 SUMMARY STATISTICS 📈")
print("="*70)
print(f"✅ Successful Requests: {success_count} / {CONCURRENCY_LEVEL}")
print(f"❌ Failed Requests: {failed_count} / {CONCURRENCY_LEVEL}")
print(f"⏱️ Total Test Runtime: {total_time:.2f}s")
if success_count > 0:
durations = [res['duration'] for res in final_results.values() if res['status'] == 'SUCCESS']
print(f"📊 Average Success Time: {statistics.mean(durations):.2f}s")
print(f"🚀 Fastest Success Time: {min(durations):.2f}s")
print(f"🐢 Slowest Success Time: {max(durations):.2f}s")
print("="*70)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment