from flask import Blueprint, request, jsonify, stream_with_context, Response, send_file, render_template, session, redirect, url_for
from flask_cors import CORS, cross_origin
from utils.extractor_functions import cleanup_directories, merge_pages, cleanup_temp_files
from models.mcq import MCQ as final_mcq
from utils.pipeline_functions import *
import logging
from pathlib import Path
import json
import os
import time
import uuid
from threading import Thread
from datetime import datetime, timedelta
from dotenv import load_dotenv
import redis
from celery import chain
from celery_app import celery
from mongoengine import connect
from bson import ObjectId
from urllib.parse import unquote
from mongoengine import get_db
from flask import current_app

load_dotenv()

extractor_bp = Blueprint('extractor', __name__)
CORS(extractor_bp)

# Initialize Redis connection


# Redis client (same local instance)
redis_client = redis.Redis(
    host=os.getenv('REDIS_HOST', 'redis'),
    port=os.getenv('REDIS_PORT', 6379),
    db=os.getenv('REDIS_DB', 0)
)
# Celery configuration

# System temp dir (cross-platform)
TMP_DIR = Path("public/tmp")

# Add these constants for directory paths
CROPS_OUT_DIR = "public/images/crops_out"
DEBUG_OUTPUTS_DIR = "public/debug_outputs"
POPPLER_PATH= os.getenv("POPPLER_PATH")


logger = logging.getLogger(__name__)

@celery.task(name='mcq_extraction.extract_text_from_pdf')
def extract_text_from_pdf_task(pdf_path: str, task_id: str, dpi: int = 200, max_workers: int = 8):
    """
    Single-step pipeline:
      1) Upload PDF to Mistral & run OCR
      2) Save OCR markdown + images
      3) Extract structured page-level MCQs via OpenAI
      4) Validate each question via OpenAI (concurrent)
      5) Merge and save final MCQ JSON
    Keeps Redis progress updates and writes intermediate + final file paths under TMP_DIR/<task_id>/
    """
    try:
        # sanity
        if not isinstance(pdf_path, (str, Path)):
            raise TypeError(f"Invalid pdf_path type: {type(pdf_path)}. Expected str or Path.")
        tmp_task_dir = Path(TMP_DIR) / task_id
        tmp_task_dir.mkdir(parents=True, exist_ok=True)

        mistral = get_mistral_client()
        openai_client = get_openai_client()

        prev_state = get_progress(task_id) or {}

        def set_progress(page: int, total: int, status: str = "processing", extra: dict | None = None):
            percent = round((page / total) * 100, 2) if total else 0
            progress_data = {
                "page": page,
                "total": total,
                "percent": percent,
                "status": status,
                "subject": prev_state.get("subject", ""),
                "filename": prev_state.get("filename", ""),
                "exam": prev_state.get("exam", ""),
                "sub_exam": prev_state.get("sub_exam", ""),
                "topic": prev_state.get("topic", ""),
                "chapter": prev_state.get("chapter", "")
            }
            if extra:
                progress_data.update(extra)
            try:
                redis_client.setex(f"task:{task_id}", 24 * 3600, json.dumps(progress_data, ensure_ascii=False))
            except Exception as e:
                logger.debug("Failed to write progress to redis: %s", e)

        # ---------- 1) Upload to Mistral ----------
        set_progress(0, 1, status="starting", extra={"stage": "uploading"})
        uploaded = upload_pdf_for_ocr(Path(pdf_path), mistral)

        # ---------- 2) Run OCR ----------
        set_progress(0, 1, status="processing", extra={"stage": "ocr"})
        ocr_resp = run_ocr_on_uploaded_file(uploaded, mistral)

        # ---------- 3) Save OCR markdown + images ----------
        ocr_md_path = tmp_task_dir / f"{task_id}_ocr.md"
        save_ocr_markdown(ocr_resp, ocr_md_path)

        images_dir = tmp_task_dir / "images"
        images = extract_images_from_ocr(ocr_resp, images_dir)

        # ---------- 4) Split pages and extract via OpenAI ----------
        md_text = ocr_md_path.read_text(encoding="utf-8")
        pages = split_pages_from_markdown(md_text)
        total_pages = len(pages) or 1

        # Structured extraction (page-level).
        structured_pages_for_intermediate_save = []
        all_questions_for_validation = [] # NEW: Flat list to aggregate all questions

        for idx, ptext in enumerate(pages, start=1):
            set_progress(idx - 1, total_pages, status="processing", extra={"stage": "extracting", "substage": f"page_{idx}"})
            
            # Use the low-level extraction function directly
            page_result = call_extraction_llm(openai_client, ptext, SYSTEM_PROMPT_EXTRACT) 
            
            # Save intermediate page data for optional debugging/reloading
            structured_pages_for_intermediate_save.append({"page": idx, "content": page_result})
            
            # Extract questions and flatten
            questions = page_result.get("Questions") or page_result.get("questions") or page_result.get("mcqs") or []
            
            # Ensure it's a list before processing
            if isinstance(questions, list):
                for q in questions:
                    # ADD THE PAGE NUMBER FIELD
                    q["Page number"] = idx 
                    all_questions_for_validation.append(q)

            set_progress(idx, total_pages, status="processing", extra={"stage": "extracting", "substage": f"page_{idx}"})

        # Save structured intermediate JSON (using the intermediate page structure)
        structured_path = tmp_task_dir / f"{task_id}_ocr_structured.json"
        structured_path.write_text(json.dumps(structured_pages_for_intermediate_save, ensure_ascii=False, indent=2), encoding="utf-8")

        # Optional: remove original PDF to save space (keep this behavior as before)
        try:
            p = Path(pdf_path)
            if p.exists():
                p.unlink()
        except Exception:
            logger.debug("Could not remove uploaded pdf")

        # ---------- 5) Validate questions (concurrent) ----------
        set_progress(0, total_pages, status="processing", extra={"stage": "validating"})

        # Validate the aggregated flat list
        if all_questions_for_validation:
            logger.info("Validating %d total questions across all pages.", len(all_questions_for_validation))
            
            final_validated_mcqs = process_questions_concurrently(
                all_questions_for_validation, openai_client, VALIDATION_PROMPT, max_workers=max_workers
            )
        else:
            final_validated_mcqs = []
            
        # progress update with final validated count
        set_progress(total_pages, total_pages, status="processing", extra={"stage": "validating", "substage": "final", "extracted_count": len(final_validated_mcqs)})


        # ---------- 6) Save final flattened MCQ JSON ----------
        # Replaced call to merge_pages() with saving the flattened list directly.
        out_path = tmp_task_dir / f"{task_id}_mcqs.json"
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with out_path.open("w", encoding="utf-8") as f:
            json.dump(final_validated_mcqs, f, ensure_ascii=False, indent=2)

        # cleanup directories (your existing function will remove temp images if referenced)
        try:
            # Pass the final flattened list to cleanup_directories
            cleanup_directories(final_validated_mcqs) 
        except Exception:
            logger.debug("cleanup_directories failed or not applicable")

        # ---------- Final Redis update ----------
        final_progress = {
            "page": total_pages,
            "total": total_pages,
            "percent": 100,
            "status": "done",
            "subject": prev_state.get("subject", ""),
            "filename": prev_state.get("filename", ""),
            "exam": prev_state.get("exam", ""),
            "sub_exam": prev_state.get("sub_exam", ""),
            "topic": prev_state.get("topic", ""),
            "chapter": prev_state.get("chapter", ""),
            "output": str(out_path),
            "results": final_validated_mcqs # Store the flattened list in Redis
        }
        redis_client.setex(f"task:{task_id}", 24 * 3600, json.dumps(final_progress, ensure_ascii=False))

        return task_id

    except Exception as e:
        logger.exception("Error in full extraction task: %s", e)
        try:
            redis_client.setex(f"task:{task_id}", 24 * 3600, json.dumps({"status": "error", "error": str(e)}))
        except Exception:
            pass
        raise



# Helper function to get progress from Redis
def get_progress(task_id):
    """Get progress data from Redis"""
    data = redis_client.get(f"task:{task_id}")
    if data:
        return json.loads(data)
    return None

# --- API route: SSE progress stream ---
@extractor_bp.route('/progress/<task_id>')
@cross_origin()
def progress(task_id):
    if not task_id or task_id == 'undefined':
        return jsonify({"error": "Invalid task ID"}), 400

    def event_stream():
        last_percent = -1
        while True:
            state = get_progress(task_id)
            if not state:
                yield f"data: {json.dumps({'status':'not_found'})}\n\n"
                break
            if state.get("status") == "done" or state.get("status") == "error":
                yield f"data: {json.dumps(state)}\n\n"
                break
            if state.get("percent") != last_percent:
                yield f"data: {json.dumps(state)}\n\n"
                last_percent = state.get("percent")
            time.sleep(1)

    return Response(
        stream_with_context(event_stream()),
        mimetype="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no"
        }
    )

# Similarly update the save-to-db route to use get_progress()
@celery.task(name='mcq_extraction.save_to_database')
def save_to_database(task_id):
    """
    Celery task: load results from Redis (get_progress), save into DB.
    Returns a plain dict (not a Flask Response).
    """
    connect(host=os.getenv('MONGO_URI'), alias='default')
    logger.info("Starting save_to_database_task for task_id=%s", task_id)
    state = get_progress(task_id)
    if not state:
        return {"error": "No results found", "saved_count": 0}

    # state may be dict with 'results' list
    mcq_data = state.get("results") if isinstance(state, dict) else state
    if not isinstance(mcq_data, list):
        return {"error": "Invalid data format", "saved_count": 0}

    saved_count = 0
    skipped_count = 0
    errors = []

    for index, mcq_item in enumerate(mcq_data):
        # Validate fields
        if not mcq_item or not isinstance(mcq_item, dict):
            errors.append(f"Item {index}: Not an object")
            skipped_count += 1
            continue

        # New format: Question, Options
        if 'Question' not in mcq_item or 'Options' not in mcq_item:
            errors.append(f"Item {index}: Missing required fields")
            skipped_count += 1
            continue

        question = (mcq_item.get('Question') or "").strip()
        options = mcq_item.get('Options') or {}
        option_a = (options.get('A') or options.get('a') or "").strip()
        option_b = (options.get('B') or options.get('b') or "").strip()
        option_c = (options.get('C') or options.get('c') or "").strip()
        option_d = (options.get('D') or options.get('d') or "").strip()

        if not option_a or not option_b:
            errors.append(f"Item {index}: Insufficient options")
            skipped_count += 1
            continue

        figure_refs = mcq_item.get('FigureRefs') or mcq_item.get('Figures') or []
        figure_ref = json.dumps(figure_refs) if isinstance(figure_refs, list) and figure_refs else None

        # optional fields from new format
        qnum = mcq_item.get('QuestionNumber') or mcq_item.get('QNumber') or None
        page_no = mcq_item.get('Page number') or mcq_item.get('page') or mcq_item.get('Page') or None

        # Build DB doc (adjust fields per your final_mcq model)
        mcq_doc = final_mcq(
            task_id=task_id,
            filename=state.get('filename', '') if isinstance(state, dict) else '',
            subject=state.get('subject', 'general') if isinstance(state, dict) else 'general',
            exam=state.get('exam', '') if isinstance(state, dict) else '',
            sub_exam=state.get('sub_exam', '') if isinstance(state, dict) else '',
            topic=state.get('topic', '') if isinstance(state, dict) else '',
            chapter=state.get('chapter', '') if isinstance(state, dict) else '',
            question_number=qnum,
            page_number=page_no,
            question=question,
            option_a=option_a,
            option_b=option_b,
            option_c=option_c,
            option_d=option_d,
            figure_ref=figure_ref,
            correct_answer="",   # new format has no answer
            explanation="",
            created_at=datetime.utcnow()
        )
        mcq_doc.save()
        saved_count += 1


    # Optional cleanup after saving
    try:
        cleanup_temp_files(task_id)
    except Exception:
        pass

    return {
        "message": f"Saved {saved_count} questions to database",
        "saved_count": saved_count,
        "skipped_count": skipped_count,
        "errors": errors or None
    }


@extractor_bp.route('/process', methods=['POST'])
def process_pdf():
    """
    Combined endpoint for uploading a PDF and starting the full OCR + MCQ extraction pipeline.
    Steps:
      1. Validate and save PDF.
      2. Store task metadata in Redis.
      3. Start Celery task chain (OCR + MCQ extraction).
    """
    # --- Debug: log request details ---
    print(f"DEBUG: request.files keys = {list(request.files.keys())}")
    print(f"DEBUG: request.form keys = {list(request.form.keys())}")
    print(f"DEBUG: Content-Type = {request.content_type}")
    
    # --- Validate file ---
    if 'file' not in request.files:
        return jsonify({"error": "No file part", "debug_keys": list(request.files.keys())}), 400
    file = request.files['file']
    print("RECEIVED FILE:", file.filename)
    if file.filename == '' or not file.filename.lower().endswith('.pdf'):
        return jsonify({"error": "Invalid file type, only .pdf allowed"}), 400

    
    subjects = request.form.getlist('subject')  
    subjects = [s for s in subjects if s and s.strip()]


    exam = request.form.get('exam', '').strip().lower()
    sub_exam = request.form.get('sub_exam', '').strip().lower()
    topic = request.form.get('topic', '').strip().lower()
    chapter = request.form.get('chapter', '').strip().lower()

    # --- Prepare paths ---
    task_id = str(uuid.uuid4())
    temp_dir = Path(TMP_DIR) / task_id
    temp_dir.mkdir(parents=True, exist_ok=True)

    temp_pdf_path = os.path.join(temp_dir, file.filename)
    file.save(temp_pdf_path)

    # --- Initialize Redis progress ---
    progress_data = {
        "percent": 0,
        "status": "starting",
        "subject": subjects,
        "filename": file.filename.lower(),
        "exam": exam,
        "sub_exam": sub_exam,
        "topic": topic,
        "chapter": chapter
    }
    redis_client.setex(f"task:{task_id}", 3600, json.dumps(progress_data))

    # --- Define intermediate + final output paths ---
    ocr_json_path = os.path.join(temp_dir, f"{Path(file.filename).stem}_ocr.json")
    mcq_json_path = os.path.join(temp_dir, f"{Path(file.filename).stem}_mcqs.json")

    # --- Start Celery task ---
    process_task= extract_text_from_pdf_task.s(temp_pdf_path, task_id).set(queue='video_serial')
    save_db_task = save_to_database.s().set(queue='video_serial')
    task_chain = chain(process_task, save_db_task)
    task_chain.apply_async()  

    print(f"[INFO] Started combined OCR + MCQ pipeline for {file.filename}, task_id={task_id}")
    return jsonify({
        "message": "Processing started",
        "task_id": task_id,
        "ocr_json": ocr_json_path,
        "mcq_json": mcq_json_path
    }), 202

# Other routes (download, results, save-to-db, etc.) remain mostly the same
# but need to be updated to use get_progress() instead of progress_store

@extractor_bp.route('/download/<task_id>', methods=['GET'])
def download(task_id):
    state = get_progress(task_id)
    if not state or state.get("status") != "done":
        return jsonify({"error": "Task not found or not completed"}), 404

    output_file = state.get("output")
    if not output_file or not os.path.exists(output_file):
        return jsonify({"error": "Output file not found"}), 404

    return send_file(output_file, as_attachment=True, download_name=f"extracted_mcqs_{task_id}.json")

@extractor_bp.route('/results/<task_id>', methods=['GET'])
@cross_origin()
def get_results(task_id):
    try:
        state = get_progress(task_id)
        if not state:
            return jsonify({"error": "Task not found"}), 404
        
        if state.get("status") != "done":
            return jsonify({"error": "Task not completed yet", "status": state.get("status")}), 202
        
        # Return the results directly from Redis if available
        if "results" in state:
            return jsonify(state["results"])
        
        # Fallback: read from file if results not in Redis
        output_file = state.get("output")
        if not output_file:
            return jsonify({"error": "Output file path not specified"}), 404
        
        output_path = Path(output_file)
        if not output_path.exists():
            return jsonify({"error": "Output file not found"}), 404
        
        try:
            with open(output_path, 'r', encoding='utf-8') as f:
                results = json.load(f)
            
            # Update Redis with results
            state["results"] = results
            redis_client.setex(f"task:{task_id}", 3600, json.dumps(state))
            
            return jsonify(results)
            
        except json.JSONDecodeError as e:
            logging.error(f"Failed to parse output file {output_file}: {e}")
            return jsonify({"error": "Invalid JSON in output file"}), 500
        except Exception as e:
            logging.error(f"Error reading output file {output_file}: {e}")
            return jsonify({"error": "Failed to read output file"}), 500
            
    except Exception as e:
        logging.error(f"Error in get_results for task {task_id}: {e}")
        return jsonify({"error": "Internal server error"}), 500

# ---------- CHECK IF FILE ALREADY EXISTS ----------
@extractor_bp.route('/check_file_exists', methods=['POST'])
@cross_origin()
def check_file_exists():
    """
    Expect JSON: { "filename": "some-file.pdf" }
    Response:
      { "exists": bool, "id": "<mcq_doc_id>" | None,
        "user_name": "<uploader>" | None,
        "created_at": "<iso-8601>" | None }
    """
    try:
        data = request.get_json(silent=True) or {}
        filename = (data.get("filename") or "").strip()
        if not filename:
            return jsonify({"exists": False, "id": None, "user_name": None, "created_at": None})

        # Normalize for case-insensitive comparison; adjust if you store normalized names on save
        # Use __iexact so MongoEngine will perform case-insensitive match
        # Ensure DB connection (if your app doesn't connect elsewhere)
        try:
            # If you use a connect(...) earlier in app startup, this is no-op
            connect(host=os.getenv('MONGO_URI'), alias='default')
        except Exception:
            # not fatal; continue (app may already be connected)
            pass

        # Query for first matching MCQ document
        mcq_doc = final_mcq.objects(filename__iexact=filename).only('id', 'user_name', 'created_at').first()

        if not mcq_doc:
            return jsonify({"exists": False, "id": None, "user_name": None, "created_at": None})

        # Build response with small metadata so frontend can show context
        created_iso = None
        try:
            if getattr(mcq_doc, "created_at", None):
                created_iso = mcq_doc.created_at.isoformat()
        except Exception:
            created_iso = None

        return jsonify({
            "exists": True,
            "id": str(mcq_doc.id) if getattr(mcq_doc, "id", None) else None,
            "user_name": getattr(mcq_doc, "user_name", None),
            "created_at": created_iso
        })

    except Exception as e:
        logger.exception("Error in check_file_exists: %s", e)
        # Return non-fatal response so frontend can still allow upload
        return jsonify({"exists": False, "id": None, "user_name": None, "created_at": None}), 500
# ---------------------------------------------------


@extractor_bp.route('/extractor', methods=['GET'])
def extraction():
    if 'user' not in session:
        return redirect(url_for('user.user_login'))
    return render_template('user/mcq-extractor.html')


@extractor_bp.route('/mcq/view-list', methods=['GET'])
def mcq_view_list():
    """
    Fetch and return MCQs with optional filters; resolve related names (exam, sub_exam, subject, chapter, topic)
    by looking up their documents and returning the human readable title/name.
    """

    # --- Build filters from query params (same as current) ---
    filters = {}
    exam_q = request.args.get('exam')
    sub_exam_q = request.args.get('sub_exam')
    subject_q = request.args.get('subject')
    chapter_q = request.args.get('chapter')
    topic_q = request.args.get('topic')
    file_name = request.args.get('file')
    uploaded_date = request.args.get('uploaded_date')

    if exam_q:
        filters['exam'] = exam_q
    if sub_exam_q:
        filters['sub_exam'] = sub_exam_q
    if subject_q:
        filters['subject'] = subject_q
    if chapter_q:
        filters['chapter'] = chapter_q
    if topic_q:
        filters['topic'] = topic_q
    if file_name:
        filters['filename__icontains'] = file_name
    if uploaded_date:
        try:
            date_obj = datetime.strptime(uploaded_date, "%Y-%m-%d")
            next_day = date_obj + timedelta(days=1)
            filters['created_at__gte'] = date_obj
            filters['created_at__lt'] = next_day
        except ValueError:
            pass

    # --- Fetch MCQs ---
    mcqs = final_mcq.objects(**filters).order_by('-created_at')

    # If no mcqs, return empty quickly
    if not mcqs:
        return jsonify({"count": 0, "mcqs": []})

    # --- Decide which collections to query for lookups ---
    # Update these names if your DB uses other collection names
    COLLECTION_MAP = {
        'exam': ['exam_list','exams', 'exam', 'exam_master'],
        'sub_exam': ['sub_exams_list','sub_exams', 'subexam', 'sub_exam_master'],
        'subject': ['video_courses_subjects','subjects', 'subject', 'subject_master'],
        'chapter': ['chapters', 'chapter', 'chapter_master'],
        'topic': ['topics', 'topic', 'topic_master']
    }

    # Helper: gather all referenced ids/names from MCQ docs so we can batch query
    exam_ids = set()
    subexam_ids = set()
    subject_ids = set()
    chapter_ids = set()
    topic_ids = set()

    for m in mcqs:
        # values can be ObjectId-like string or simple string name
        if getattr(m, 'exam', None):
            exam_ids.add(str(getattr(m, 'exam')))
        if getattr(m, 'sub_exam', None):
            subexam_ids.add(str(getattr(m, 'sub_exam')))
        # subject may be list
        raw_subs = getattr(m, 'subject', None) or []
        if isinstance(raw_subs, (list, tuple)):
            for s in raw_subs:
                if s:
                    subject_ids.add(str(s))
        else:
            if raw_subs:
                subject_ids.add(str(raw_subs))
        if getattr(m, 'chapter', None):
            chapter_ids.add(str(getattr(m, 'chapter')))
        if getattr(m, 'topic', None):
            topic_ids.add(str(getattr(m, 'topic')))

    # Connect to raw pymongo DB (mongoengine helper)
    try:
        db = get_db()
    except Exception:
        db = None

    def _batch_resolve(ids_set, candidate_collections):
        """
        ids_set: set of string values (could be ObjectId hex or a plain name)
        candidate_collections: list of collection names to try
        returns: dict mapping the original string -> resolved human name
        """
        if not ids_set:
            return {}

        # partition ids_set into objid-like (24 hex chars) and non-objid
        objid_list = []
        plain_list = []
        for v in ids_set:
            if not v:
                continue
            try:
                if len(v) == 24:
                    _ = ObjectId(v)
                    objid_list.append(ObjectId(v))
                else:
                    plain_list.append(v)
            except Exception:
                plain_list.append(v)

        resolved = {}
        # try db queries
        if db:
            for coll_name in candidate_collections:
                try:
                    coll = db[coll_name]
                except Exception:
                    continue

                # query by ObjectId _id for those that look like ObjectId
                if objid_list:
                    try:
                        docs = coll.find({"_id": {"$in": objid_list}})
                        for d in docs:
                            key = str(d.get('_id'))
                            resolved[key] = (
                                d.get('exam_title') 
                                or d.get('sub_exam_title') 
                                or d.get('subject_name') 
                                or d.get('name') 
                                or d.get('title') 
                                or d.get('label') 
                                or key
                            )
                    except Exception:
                        pass

                # query by string id fields or explicit 'id' field
                if plain_list:
                    try:
                        docs = coll.find(
                            {
                                "$or": [
                                    {"id": {"$in": plain_list}}, 
                                    {"name": {"$in": plain_list}}, 
                                    {"title": {"$in": plain_list}}
                                    ]
                            }
                        )
                        for d in docs:
                            # decide the key by preferring string 'id' or stored name
                            if d.get('id'):
                                key = str(d.get('id'))
                            else:
                                # if this doc's name/title was matched, map that name -> display value
                                key = d.get('name') or d.get('title') or str(d.get('_id'))
                            resolved[key] = (
                                d.get('name') or d.get('title') or d.get('label') or key
                            )
                    except Exception:
                        pass

        # any ids not resolved should map to the original string (fallback)
        final_map = {}
        for v in ids_set:
            final_map[v] = resolved.get(v) or resolved.get(str(v)) or v

        return final_map

    # Build maps by querying candidate collections in order
    exam_map = _batch_resolve(exam_ids, COLLECTION_MAP['exam'])
    subexam_map = _batch_resolve(subexam_ids, COLLECTION_MAP['sub_exam'])
    subject_map = _batch_resolve(subject_ids, COLLECTION_MAP['subject'])
    chapter_map = _batch_resolve(chapter_ids, COLLECTION_MAP['chapter'])
    topic_map = _batch_resolve(topic_ids, COLLECTION_MAP['topic'])

    # --- Build final response array, replacing ids with resolved names ---
    mcq_list = []
    for m in mcqs:
        raw_exam = getattr(m, 'exam', '') or ''
        raw_subexam = getattr(m, 'sub_exam', '') or ''
        raw_chapter = getattr(m, 'chapter', '') or ''
        raw_topic = getattr(m, 'topic', '') or ''
        raw_subjects = getattr(m, 'subject', []) or []

        # resolve single values
        exam_name = exam_map.get(str(raw_exam), str(raw_exam)) if raw_exam else ''
        subexam_name = subexam_map.get(str(raw_subexam), str(raw_subexam)) if raw_subexam else ''
        chapter_name = chapter_map.get(str(raw_chapter), str(raw_chapter)) if raw_chapter else ''
        topic_name = topic_map.get(str(raw_topic), str(raw_topic)) if raw_topic else ''
        

        # resolve subject list
        resolved_subjects = []
        if isinstance(raw_subjects, (list, tuple)):
            for s in raw_subjects:
                if s is None or s == '':
                    continue
                resolved_subjects.append(subject_map.get(str(s), str(s)))
        else:
            if raw_subjects:
                resolved_subjects.append(subject_map.get(str(raw_subjects), str(raw_subjects)))
        mcq_list.append({
            "id": str(m.id),
            "task_id": getattr(m, "task_id", "") or "",
            "filename": getattr(m, "filename", "") or "",
            "exam": exam_name,
            "sub_exam": subexam_name,
            "subject": resolved_subjects,      # front-end can join if needed
            "chapter": chapter_name,
            "topic": topic_name,
            "question": getattr(m, "question", "") or "",
            "options": {
                "A": getattr(m, "option_a", None),
                "B": getattr(m, "option_b", None),
                "C": getattr(m, "option_c", None),
                "D": getattr(m, "option_d", None),
            },
            "correct_answer": getattr(m, "correct_answer", None),
            "page_no": getattr(m, "page_number", None),
            "explanation": getattr(m, "explanation", None),
            "created_at": m.created_at.strftime("%Y-%m-%d %H:%M") if m.created_at else "",
        })

    return jsonify({
        "count": len(mcq_list),
        "mcqs": mcq_list
    })

@extractor_bp.route('/mcq-view-demo', methods=['GET'])
def mcq_view_demo():
     """Render a demo page to view MCQs with dropdown filters"""
     return render_template('user/view-mcqs.html')

@extractor_bp.route('/mcq-view', methods=['GET'])
def view_mcqs():
    """Admin-style MCQ viewer with filters"""
    return render_template('user/view-mcqs.html')

@extractor_bp.route('/mcq-edit/<mcq_id>', methods=['GET'])
def mcq_edit_page(mcq_id):
    mcq = final_mcq.objects(id=mcq_id).first()
    if not mcq:
        return "MCQ not found", 404

    return render_template('user/mcq-edit.html', mcq=mcq)

@extractor_bp.route('/mcq/update/<mcq_id>', methods=['PUT'])
def update_mcq(mcq_id):
    """Update MCQ details"""
    from flask import request, jsonify
    from bson import ObjectId

    data = request.get_json() or {}

    mcq = final_mcq.objects(id=ObjectId(mcq_id)).first()
    if not mcq:
        return jsonify({"error": "MCQ not found"}), 404

    mcq.question       = data.get("question", mcq.question)
    mcq.option_a       = data.get("option_a", mcq.option_a)
    mcq.option_b       = data.get("option_b", mcq.option_b)
    mcq.option_c       = data.get("option_c", mcq.option_c)
    mcq.option_d       = data.get("option_d", mcq.option_d)
    mcq.correct_answer = data.get("correct_answer", mcq.correct_answer)
    mcq.explanation    = data.get("explanation", mcq.explanation)
    mcq.page_number    = data.get("page_no", mcq.page_number)

    # 🔹 optional: if frontend sends a subject, update it
    subj_val = data.get("subject", None)
    if subj_val is not None:
      # if frontend sends a single subject string
      if isinstance(subj_val, str):
        mcq.subject = [subj_val]
      # if frontend sends a list already
      elif isinstance(subj_val, (list, tuple)):
        mcq.subject = list(subj_val)

    # 🔹 normalize existing subject so it's always a list
    if isinstance(mcq.subject, str):
        mcq.subject = [mcq.subject]
    elif mcq.subject is None:
        mcq.subject = []

    mcq.save()

    return jsonify({"message": "MCQ updated successfully"}), 200

@extractor_bp.route('/mcq/delete/<mcq_id>', methods=['POST', 'DELETE', 'GET'])
def delete_mcq(mcq_id):
    from bson import ObjectId
    try:
        final_mcq.objects(id=ObjectId(mcq_id)).delete()

        return jsonify({"success": True, "message": "Deleted successfully"})
    except Exception as e:
        current_app.logger.exception("mcq delete error: %s", e)
        return jsonify({"success": False, "message": str(e)}), 500
# -------------------------------
# 🆕 New Routes for "View Files" Screen
# -------------------------------

@extractor_bp.route('/mcq-files', methods=['GET'])
def view_mcq_files():
    """Render the first screen - list of MCQ files"""
    return render_template('user/view-files.html')


@extractor_bp.route('/mcq/files', methods=['GET'])
def mcq_file_list():
    """
    Fetch distinct file names based on filters (for the first screen).
    Returns file_name + example_id (first MCQ doc id) + uploader + created_at.
    """
    filters = {}

    # --- Collect filters from query params ---
    exam = request.args.get('exam')
    sub_exam = request.args.get('sub_exam')
    subject = request.args.get('subject')
    chapter = request.args.get('chapter')
    topic = request.args.get('topic')
    user = request.args.get('user')
    status = request.args.get('status')
    file = request.args.get('file')
    date = request.args.get('date')

    # --- Apply only if provided ---
    if exam:
        filters['exam'] = exam
    if sub_exam:
        filters['sub_exam'] = sub_exam
    if subject:
        filters['subject'] = subject
    if chapter:
        filters['chapter'] = chapter
    if topic:
        filters['topic'] = topic
    if user:
        filters['user_name__icontains'] = user
    if status:
        filters['status__icontains'] = status
    if file:
        filters['filename__icontains'] = file
    if date:
        try:
            date_obj = datetime.strptime(date, "%Y-%m-%d")
            filters['created_at__gte'] = date_obj
        except ValueError:
            pass

    # --- Query MCQ collection for distinct file names ---
    files = final_mcq.objects(**filters).distinct('filename')

    file_list = []
    for f in files:
        # fetch a single representative MCQ doc for this filename to get id/user/date
        first_entry = final_mcq.objects(filename=f).only('id', 'user_name', 'created_at').first()

        example_id = None
        created_iso = None
        user_name = "-"
        try:
            if first_entry and getattr(first_entry, 'id', None):
                example_id = str(first_entry.id)
            if first_entry and getattr(first_entry, 'created_at', None):
                # isoformat produces: YYYY-MM-DDTHH:MM:SS[.mmmmmm]
                created_iso = first_entry.created_at.isoformat()
            user_name = getattr(first_entry, 'user_name', '-') if first_entry else "-"
        except Exception:
            example_id = example_id or None
            created_iso = created_iso or None

        file_list.append({
            "file_name": f,
            "example_id": example_id,
            "user_name": user_name,
            "created_at": created_iso
        })

    return jsonify({"files": file_list})

# -------------------------------
# 🆕 Download Route (for View Files & View MCQs)
# -------------------------------

@extractor_bp.route('/download/<file_id>', methods=['GET'])
def download_file(file_id):
    """
    Download the uploaded original file referenced by an MCQ doc id.
    Checks both 'file_path' and legacy 'path' attributes and falls back to public/uploads/<filename>.
    """
    try:
        mcq_entry = final_mcq.objects(id=file_id).first()
        if not mcq_entry:
            return jsonify({"error": "File not found"}), 404

        # Accept both file_path (used in save_to_database) and legacy path
        file_path = getattr(mcq_entry, "file_path", None) or getattr(mcq_entry, "path", None)
        file_name = getattr(mcq_entry, "filename", None) or getattr(mcq_entry, "file_name", None)

        # Fallback to public/uploads/<filename>
        if not file_path and file_name:
            file_path = os.path.join("public", "uploads", file_name)

        # Final existence check
        if not file_path or not os.path.exists(file_path):
            logger.info("download_file: file not available for id=%s resolved_path=%s", file_id, file_path)
            return jsonify({"error": "File not available on server"}), 404

        return send_file(file_path, as_attachment=True, download_name=os.path.basename(file_path))

    except Exception as e:
        logger.exception("download_file error: %s", e)
        return jsonify({"error": str(e)}), 500
    

@extractor_bp.route('/download-excel/<identifier>', methods=['GET'])
def download_excel(identifier):
    """
    Generate an Excel file for all MCQs that belong to the given filename
    or for the file referenced by an example MCQ id.
    """
    try:
        identifier = unquote(identifier or "").strip()

        # Helper: looks like ObjectId
        def _looks_like_objectid(s: str) -> bool:
            if not s or len(s) != 24:
                return False
            try:
                int(s, 16)
                return True
            except Exception:
                return False

        mcqs = []
        filename_to_use = None

        # If identifier is an object id, resolve to filename first
        if _looks_like_objectid(identifier):
            try:
                example = final_mcq.objects(id=ObjectId(identifier)).first()
                if example:
                    filename_to_use = getattr(example, "filename", None)
            except Exception:
                filename_to_use = None

        # If we resolved filename, fetch all mcqs by that filename (case-insensitive)
        if filename_to_use:
            mcqs = list(final_mcq.objects(filename__iexact=filename_to_use))
            if not mcqs:
                mcqs = list(final_mcq.objects(filename__icontains=filename_to_use))
        else:
            # treat identifier as a filename directly
            mcqs = list(final_mcq.objects(filename__iexact=identifier))
            if not mcqs:
                mcqs = list(final_mcq.objects(filename__icontains=identifier))

        if not mcqs:
            return jsonify({"error": "No MCQs found for this file"}), 404

        # ---------- Excel generation ----------
        from openpyxl import Workbook
        from openpyxl.styles import Alignment, Font
        from io import BytesIO
        import re

        def clean_latex(s: str) -> str:
            if not s:
                return ""
            out = s.replace("$", "")
            out = re.sub(r'\\left|\\right', '', out)
            out = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', out)
            out = re.sub(r'\\text\{([^}]*)\}', r'\1', out)
            out = re.sub(r'\\[a-zA-Z]+', '', out)
            out = out.replace('{', '').replace('}', '')
            out = re.sub(r'\s+', ' ', out).strip()
            return out

        wb = Workbook()
        ws = wb.active
        ws.title = "MCQs"

        headers = [
            "Sr No", "Original Q.No", "Question",
            "Option A", "Option B", "Option C", "Option D",
            "Correct Answer", "Page No", "Explanation",
            "Exam", "Sub Exam", "Subject", "Chapter", "Topic", "Created At"
        ]
        ws.append(headers)
        header_font = Font(bold=True)
        for col_idx, _ in enumerate(headers, start=1):
            cell = ws.cell(row=1, column=col_idx)
            cell.font = header_font
            cell.alignment = Alignment(wrap_text=True, horizontal="center", vertical="center")

        for idx, m in enumerate(mcqs, start=1):
            subjects = m.subject if isinstance(m.subject, list) else ([m.subject] if m.subject else [])
            subject_str = ", ".join([str(s) for s in subjects]) if subjects else ""

            q_text = clean_latex(getattr(m, "question", "") or "")
            a = clean_latex(getattr(m, "option_a", "") or "")
            b = clean_latex(getattr(m, "option_b", "") or "")
            c = clean_latex(getattr(m, "option_c", "") or "")
            d = clean_latex(getattr(m, "option_d", "") or "")
            expl = clean_latex(getattr(m, "explanation", "") or "")

            created = ""
            try:
                if getattr(m, "created_at", None):
                    created = m.created_at.strftime("%Y-%m-%d %H:%M")
            except Exception:
                created = str(getattr(m, "created_at", ""))

            row = [
                idx,
                getattr(m, "original_qno", "") or getattr(m, "question_number", "") or "",
                q_text,
                a, b, c, d,
                getattr(m, "correct_answer", "") or "",
                getattr(m, "page_number", "") or getattr(m, "page_no", "") or "",
                expl,
                getattr(m, "exam", "") or "",
                getattr(m, "sub_exam", "") or "",
                subject_str,
                getattr(m, "chapter", "") or "",
                getattr(m, "topic", "") or "",
                created
            ]
            ws.append(row)

        widths = {
            1: 8, 2: 14, 3: 60, 4: 24, 5: 24, 6: 24, 7: 24,
            8: 14, 9: 10, 10: 40, 11: 18, 12: 18, 13: 24, 14: 18, 15: 18, 16: 20
        }
        for col, w in widths.items():
            col_letter = ws.cell(row=1, column=col).column_letter
            ws.column_dimensions[col_letter].width = w

        for row in ws.iter_rows(min_row=2, max_row=ws.max_row, min_col=1, max_col=len(headers)):
            for cell in row:
                cell.alignment = Alignment(wrap_text=True, vertical="top")

        buf = BytesIO()
        wb.save(buf)
        buf.seek(0)

        dl_basename = (filename_to_use or identifier).replace(".pdf", "")
        dl_name = f"{dl_basename}.xlsx"

        return send_file(
            buf,
            as_attachment=True,
            download_name=dl_name,
            mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
        )

    except Exception as e:
        logger.exception("download_excel error: %s", e)
        return jsonify({"error": str(e)}), 500
