import fitz  # PyMuPDF
import docx
import pytesseract
from PIL import Image, ImageEnhance, ImageFilter
from pdf2image import convert_from_bytes
import os
import platform
import re
from collections import Counter
from Sastrawi.Stemmer.StemmerFactory import StemmerFactory
from webapp.models.koleksi import Categories
from .ml_model_loader import predict_label
import json

# ======== Setup Tesseract Path (Windows) ========
if platform.system() == "Windows":
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    TESSERACT_PATH = os.path.join(BASE_DIR, 'tesseract', 'tesseract.exe')
    pytesseract.pytesseract.tesseract_cmd = TESSERACT_PATH
    POPPLER_BIN = os.path.join(BASE_DIR, 'tools', 'poppler', 'Library', 'bin')
    os.environ["PATH"] += os.pathsep + POPPLER_BIN

# ======== Inisialisasi Stemmer Sastrawi ========
factory = StemmerFactory()
stemmer = factory.create_stemmer()

# ============ 🔍 OCR Image Preprocessing ============
def preprocess_image_for_ocr(image):
    image = image.convert("L")  # grayscale
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance(2)  # tingkatkan kontras
    image = image.filter(ImageFilter.MedianFilter())
    return image

# ============ 📄 Ekstrak Teks dari File ============
def extract_text_from_file(file, max_pages=4, dpi=200, threshold_chars=50):
    """
    Ekstrak teks dari PDF/DOCX/IMG
    - PDF → cek dulu teks langsung, kalau < threshold_chars → fallback OCR
    - OCR dibatasi max_pages halaman pertama dengan dpi lebih rendah (200)
    """

    name = file.name.lower()

    if name.endswith(".pdf"):
        text = ""

        # --- 1. Ekstrak teks langsung dari PDF ---
        try:
            file.seek(0)
            doc = fitz.open(stream=file.read(), filetype="pdf")
            for page in doc[:max_pages]:
                text += page.get_text("text") + "\n"
            doc.close()
        except Exception as e:
            print("[ERROR] Gagal baca teks PDF:", e)

        # --- 2. Kalau teks terlalu sedikit, fallback OCR ---
        if len(text.strip()) < threshold_chars:
            print("[INFO] PDF teks minim, fallback OCR...")
            text = ""
            try:
                file.seek(0)
                images = convert_from_bytes(file.read(), dpi=dpi, first_page=1, last_page=max_pages)
                for img in images:
                    processed = preprocess_image_for_ocr(img)
                    text += pytesseract.image_to_string(processed, lang="ind") + "\n"
            except Exception as e:
                print("[ERROR] PDF OCR gagal:", e)
                return "Dokumen tidak terbaca"

        return text.strip()

    elif name.endswith(".docx"):
        return extract_text_from_docx(file)

    elif name.endswith((".jpg", ".jpeg", ".png")):
        try:
            image = Image.open(file)
            processed = preprocess_image_for_ocr(image)
            return pytesseract.image_to_string(processed, lang="ind")
        except Exception as e:
            print("[ERROR] Gagal OCR gambar:", e)
            return ""

    return ""

def extract_text_from_docx(file):
    doc = docx.Document(file)
    return "\n".join([para.text for para in doc.paragraphs])

# ============ 🧹 Normalisasi + Stemming ============
def clean_text(text):
    """Untuk TAGGING: lowercase + regex + stemming"""
    text = text.lower()
    text = re.sub(r"[^a-z\s]", " ", text)
    return stemmer.stem(text)

def normalize_for_model(text):
    """Untuk IndoBERT: lowercase + regex saja (NO stemming)"""
    text = text.lower()
    text = re.sub(r"[^a-z\s]", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

def clean_ocr_tokens(tokens):
    """Rapikan list token hasil OCR -> kalimat string"""
    result = []
    for w in tokens:
        w = w.lower()
        if len(w) <= 2:  # buang kata terlalu pendek
            continue
        if not re.match("^[a-z]+$", w):  # buang kata aneh (angka/simbol)
            continue
        result.append(w)
    return " ".join(result)

# ============ 🧠 Prediksi Kategori ============
LABEL_MAP_PATH = os.path.join(
    os.path.dirname(os.path.abspath(__file__)),
    "indo_model_final",
    "label_mapping.json"
)
with open(LABEL_MAP_PATH, "r") as f:
    label_mapping = json.load(f)

label_mapping = {int(k): v for k, v in label_mapping.items()}
reverse_label_map = {v: k for k, v in label_mapping.items()}

def get_categories_from_db():
    return {cat.pk: cat.text for cat in Categories.objects.all()}

def predict_category(text):
    normalized = normalize_for_model(text)
    predicted_label = predict_label(normalized)
    predicted_kategori = label_mapping.get(predicted_label, "Dokumen Lainnya")

    kategori_obj = Categories.objects.filter(text__icontains=predicted_kategori).first()
    if kategori_obj:
        print("[DEBUG] Kategori terdeteksi:", predicted_kategori)
        return kategori_obj.pk
    else:
        print("[DEBUG] Tidak ditemukan, fallback ke default.")
        return None

# ============ 🏷️ Tag Suggestion (Unsupervised) ============
def suggest_tags(text, top_n=6):
    STOPWORDS = {
        "yang", "dalam", "akan", "dengan", "untuk", "pada", "dan", "atau",
        "saya", "bapak", "ibu", "tanggal", "hari", "tahun", "hal", "perihal"
    }

    stemmed = clean_text(text)
    words = stemmed.split()
    words = clean_ocr_tokens(words.split() if isinstance(words, str) else words)

    print("[DEBUG] Hasil OCR + stemmed cleaned:", words)

    filtered = [w for w in words.split() if len(w) > 3 and w not in STOPWORDS]
    freq = Counter(filtered).most_common(top_n)

    tags = [w for w, _ in freq]
    print("[DEBUG] Suggested tags (unsupervised):", tags)
    return tags
