import os
import django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'webapp.settings')
django.setup()

from webapp.models.koleksi import Categories
import time
import re
import fitz
import pytesseract
import platform
import numpy as np
import cv2
import shutil
import hashlib
from pdf2image import convert_from_bytes, convert_from_path
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
from ftfy import fix_text
from unidecode import unidecode
from collections import Counter
from multiprocessing import Pool
from functools import partial
from Sastrawi.Stemmer.StemmerFactory import StemmerFactory
from Sastrawi.StopWordRemover.StopWordRemoverFactory import StopWordRemoverFactory
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# ======== Setup Tesseract Path (Windows) ========
if platform.system() == "Windows":
    BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    pytesseract.pytesseract.tesseract_cmd = os.path.join(BASE_DIR, 'tesseract', 'tesseract.exe')
    POPPLER_BIN = os.path.join(BASE_DIR, 'tools', 'poppler', 'Library', 'bin')
    os.environ["PATH"] += os.pathsep + POPPLER_BIN

# ======== NLP Tools ========
stemmer = StemmerFactory().create_stemmer()
stop_factory = StopWordRemoverFactory()
stopwords = set(stop_factory.get_stop_words())

# ======================================================
# 🧠 LOAD MACHINE LEARNING MODEL (Trained NLP)
# ======================================================
import joblib

MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "trained_nlp_model")

try:
    model_path = os.path.join(MODEL_DIR, "model.pkl")
    vectorizer_path = os.path.join(MODEL_DIR, "vectorizer.pkl")
    label_encoder_path = os.path.join(MODEL_DIR, "label_encoder.pkl")

    ml_model = joblib.load(model_path)
    vectorizer = joblib.load(vectorizer_path)
    label_encoder = joblib.load(label_encoder_path)

    print("[INFO] Model NLP berhasil dimuat ✅")
except Exception as e:
    ml_model = None
    vectorizer = None
    label_encoder = None
    print(f"[WARN] Model NLP belum dimuat: {e}")


# --- CACHE HALAMAN UNTUK HINDARI OCR BERULANG ---
TURBO_PAGE_CACHE = {}

def hash_image_turbo(img: Image.Image):
    """Hash cepat halaman untuk caching."""
    return hashlib.md5(img.tobytes()).hexdigest()

# --- PREPROCESSING TURBO ---
def preprocess_image_turbo(img):
    img = img.convert("L")  # grayscale
    img = ImageOps.autocontrast(img)  # perbaiki kontras
    img = img.filter(ImageFilter.SHARPEN)  # tajamkan
    return img

# --- PROCESS PAGE (dengan preprocessing) ---
def process_page_turbo(img, lang='ind+eng'):
    h = hash_image_turbo(img)
    if h in TURBO_PAGE_CACHE:
        return TURBO_PAGE_CACHE[h]

    img = preprocess_image_turbo(img)
    text = pytesseract.image_to_string(img, lang=lang)
    text = re.sub(r'\s+', ' ', text).strip()

    TURBO_PAGE_CACHE[h] = text
    return text

# --- HALAMAN FILTER (density) ---
def turbo_text_density(img):
    img_small = img.resize((200, 280)).convert("L")
    pixels = img_small.histogram()
    white_pixels = sum(pixels[200:])
    return white_pixels / len(img_small.getdata())


# ======================================================
# 🧼 CLEANING & NORMALIZATION
# ======================================================
def cleaning_ocr_text(text: str) -> str:
    text = fix_text(text)
    text = unidecode(text)
    text = re.sub(r'[_\-–—=]+', ' ', text)
    text = re.sub(r'([a-z])\s*-\s*([a-z])', r'\1\2', text)
    text = re.sub(r'([a-z])\1{2,}', r'\1', text)
    text = re.sub(r'(ee|ae|oe|ie){2,}', ' ', text)
    text = re.sub(r'[^\w\s.,:()/\-]', ' ', text)
    text = re.sub(r'(\d)\s*[\.,]\s*(\d)', r'\1.\2', text)
    text = re.sub(r'(\d)\s*:\s*(\d)', r'\1:\2', text)
    text = re.sub(r'\s*r\s*p\s*', ' rp', text)
    text = re.sub(r'\s+[a-zA-Z0-9]\s+', ' ', text)
    text = re.sub(r'(\d):(\d)', r'\1\2', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'\b[a-zA-Z]{25,}\b', '', text)
    text = re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', text)
    text = re.sub(r'([a-z]{2,})([A-Z]{2,})', r'\1 \2', text)
    text = re.sub(r'\.{2,}', '. ', text)
    text = re.sub(r'[^a-zA-Z0-9.,;:()\-\n\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'\s*\.\s*\.\s*', '. ', text)
    return text.strip()

def normalize_ocr_text(text: str) -> str:
    tokens = text.lower().split()
    tokens = [t for t in tokens if t not in stopwords]
    tokens = [stemmer.stem(t) for t in tokens]
    return ' '.join(tokens)

# ======================================================
# 🔍 OCR SUPPORT FUNCTIONS
# ======================================================
def fast_text_density(pil_img):
    img = np.array(pil_img.convert('L'))
    _, binary = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY_INV)
    text_pixels = np.count_nonzero(binary)
    total_pixels = binary.size
    return text_pixels / total_pixels

def process_page(page, lang='ind+eng'):
    text = pytesseract.image_to_string(page, lang=lang)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def ocr_pdf_fast(file, dpi=150, lang='ind+eng', max_pages=5, max_chars=30000):

    # baca dari stream → convert halaman ke PIL langsung
    file.seek(0)
    pdf_bytes = file.read()
    pages = convert_from_bytes(pdf_bytes, dpi=dpi, fmt='jpeg', last_page=max_pages)

    # filter halaman dengan density rendah
    kept_pages = []
    for i, img in enumerate(pages):
        if i == 0:
            kept_pages.append(img)
            continue

        ratio = turbo_text_density(img)
        if ratio <= 0.09:
            kept_pages.append(img)

    # OCR paralel langsung dari PIL Image (tanpa save file)
    texts = []
    total_chars = 0

    # Worker adaptif → lebih cepat
    workers = max(1, min(os.cpu_count() // 2, 3))

    with Pool(processes=workers) as pool:

        jobs = [pool.apply_async(process_page_turbo, (img, lang)) for img in kept_pages]

        for job in jobs:
            text = job.get()
            
            # batas karakter
            total_chars += len(text)
            if total_chars >= max_chars:
                break

            texts.append(text)

    return "\n".join(texts)


def ocr_pdf_cached(file, cache_dir=None, dpi=150, lang='ind+eng', max_chars=30000):
    if cache_dir is None:
        cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ocr_cache")
    os.makedirs(cache_dir, exist_ok=True)

    temp_pdf_path = os.path.join(cache_dir, "temp_ocr_input.pdf")
    file.seek(0)
    with open(temp_pdf_path, "wb") as f:
        f.write(file.read())

    pdf_name = os.path.splitext(os.path.basename(temp_pdf_path))[0]
    cache_subdir = os.path.join(cache_dir, pdf_name)
    os.makedirs(cache_subdir, exist_ok=True)

    pages = convert_from_path(temp_pdf_path, dpi=dpi, fmt='jpeg', last_page=5, thread_count=4)
    for i, page in enumerate(pages):
        image_path = os.path.join(cache_subdir, f"page_{i+1:03d}.jpg")
        page.save(image_path, 'JPEG')

    cached_images = [os.path.join(cache_subdir, f) for f in sorted(os.listdir(cache_subdir)) if f.endswith('.jpg')]
    kept_pages = [cached_images[0]]
    cached_images = cached_images[1:]
    for cache_image in cached_images:
        img = Image.open(cache_image)
        ratio = fast_text_density(img)
        if ratio > 0.09:
            os.remove(cache_image)

    cached_images = [os.path.join(cache_subdir, f) for f in sorted(os.listdir(cache_subdir)) if f.endswith('.jpg')]
    cached_images = cached_images[:5]

    texts = []
    total_chars = 0
    process_func = partial(process_page, lang=lang)
    with Pool(processes=min(2, os.cpu_count() // 2)) as pool:
        for text in pool.imap(process_func, cached_images):
            if total_chars >= max_chars:
                print("⚠️ Limit 50.000 karakter tercapai, hentikan OCR.")
                break
            total_chars += len(text)
            texts.append(text)

    shutil.rmtree(cache_subdir, ignore_errors=True)
    os.remove(temp_pdf_path)

    return "\n".join(texts)

def get_categories_from_db():
    return {cat.pk: cat.text for cat in Categories.objects.all()}
def quick_extract_text(path):
    """
    Extract cepat sesuai logika utama:
    - Gambar → OCR halaman 1
    - PDF scanned → OCR halaman 1 saja
    - PDF normal → extract text halaman 1 saja
    """

    # Pastikan path absolut menuju FILE_DIR
    if not os.path.exists(path):
        return ""

    ext = os.path.splitext(path)[1].lower()

    # ---------- IMAGE ----------
    if ext in [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]:
        try:
            img = Image.open(path)
            return pytesseract.image_to_string(img, lang="ind") or ""
        except:
            return ""

    # ---------- PDF ----------
    if ext == ".pdf":
        try:
            doc = fitz.open(path)
            page = doc.load_page(0)

            # coba extract text normal
            text = page.get_text().strip()
            if text:
                return text

            # kalau kosong → scanned OCR 1 halaman saja
            pix = page.get_pixmap()
            img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
            return pytesseract.image_to_string(img, lang="ind") or ""
        except:
            return ""

    # ---------- FILE LAIN ----------
    return ""



def extract_text_from_file(file, max_chars=30000, log=True):
    start_time = time.time()
    name = file.name.lower()
    hasil_text = ""

    if name.endswith('.pdf'):
        try:
            file.seek(0)
            doc = fitz.open(stream=file.read(), filetype="pdf")
            text = "\n".join([page.get_text() for page in doc])
            if len(text.strip()) > 50:
                hasil_text = cleaning_ocr_text(text)
            else:
                raise ValueError("Teks hasil fitz terlalu sedikit, gunakan OCR fallback.")
            if len(hasil_text) > max_chars:
                hasil_text = hasil_text[:max_chars]
        except Exception as e:
            if log: print(f"[WARN] PDF fitz gagal: {e}, fallback ke OCR (cached)...")
            file.seek(0)
            hasil_text = ocr_pdf_fast(file, max_chars=max_chars)

    elif name.endswith('.docx'):
        import docx
        doc = docx.Document(file)
        hasil_text = "\n".join([p.text for p in doc.paragraphs])
        hasil_text = cleaning_ocr_text(hasil_text)

    elif name.endswith(('.jpg', '.jpeg', '.png')):
        image = Image.open(file)
        text = pytesseract.image_to_string(image, lang='ind+eng')
        hasil_text = cleaning_ocr_text(text)

    else:
        if log: print(f"[WARN] Format file tidak didukung: {name}")
        return ""

    elapsed = time.time() - start_time
    if log:
        print(f"[INFO] Ekstraksi selesai dalam {elapsed:.2f} detik, jumlah karakter: {len(hasil_text)}")
        print(f"[DEBUG] Preview teks:\n{hasil_text[:500]}...")  # hanya 500 karakter awal
    return hasil_text.strip()


# ======================================================
# 🧠 AI-BASED CATEGORY PREDICTION (dengan timing)
# ======================================================
def predict_category(text, log=True):
    start_time = time.time()
    global ml_model, vectorizer, label_encoder

    try:
        if ml_model and vectorizer and label_encoder:
            cleaned = cleaning_ocr_text(text)
            vec = vectorizer.transform([cleaned])
            pred = ml_model.predict(vec)
            label = label_encoder.inverse_transform(pred)[0]

            cat = Categories.objects.filter(text__icontains=label).first()
            if cat:
                result = cat.pk
            else:
                fallback = Categories.objects.filter(text__icontains='dokumen lainnya').first()
                result = fallback.pk if fallback else None
        else:
            if log: print("[WARN] Model belum dimuat, fallback ke cosine similarity.")
            categories = get_categories_from_db()
            stemmed = normalize_ocr_text(text)
            corpus = list(categories.values()) + [stemmed]
            tfidf = TfidfVectorizer().fit_transform(corpus)
            sims = cosine_similarity(tfidf[-1], tfidf[:-1])
            scores = sims.flatten()
            best_idx = scores.argmax()
            result = list(categories.keys())[best_idx]

        elapsed = time.time() - start_time
        if log: print(f"[INFO] Prediksi kategori selesai dalam {elapsed:.2f} detik, kategori ID: {result}")
        return result

    except Exception as e:
        if log: print(f"[ERROR] Gagal prediksi kategori: {e}")
        return None
    

# ======================================================
# 🔎 LIGHTWEIGHT ENTITY EXTRACTION (FAST)
# ======================================================
def extract_entities_from_ocr(text: str):
    """
    Ekstrak entitas penting dari teks hasil OCR secara cepat.
    Mengambil: instansi, tahun, daerah, dan kategori dasar.
    """
    from datetime import datetime

    text_low = unidecode(text.lower())

    # Instansi (dinas / badan / sekretariat)
    instansi_match = re.search(r'(dinas|badan|sekretariat)[a-z\s]{3,40}', text_low)
    instansi = instansi_match.group(0).strip() if instansi_match else None

    # Tahun (4 digit, mulai dari 20xx)
    tahun_match = re.search(r'\b(20\d{2})\b', text_low)
    tahun = tahun_match.group(1) if tahun_match else None

    # Daerah (provinsi/kabupaten/kota + nama)
    daerah_match = re.search(r'(provinsi|kabupaten|kota)\s+[a-z\s]{3,40}', text_low)
    daerah = daerah_match.group(0).strip() if daerah_match else None

    # Kategori awal (deteksi cepat)
    kategori_match = re.search(
        r'(peraturan daerah|peraturan bupati|nota dinas|kontrak|spm|sp2d|surat keputusan|produk hukum|pengadaan)',
        text_low
    )
    kategori = kategori_match.group(0).strip() if kategori_match else None

    return {
        'instansi': instansi,
        'tahun': tahun,
        'daerah': daerah,
        'kategori': kategori,
    }


def suggest_tags(text: str, top_n: int = 4):
    """
    Analisis teks dokumen dan hasil entitas untuk menghasilkan tag otomatis.
    ⚡ Optimized: sangat cepat dan hasil maksimal 4 tag unik.
    """
    try:
        # --- STEP 1: Ekstraksi entitas (jika tersedia)
        entities = extract_entities_from_ocr(text)
    except Exception:
        entities = {}

    kategori = entities.get('kategori')
    instansi = entities.get('instansi')
    tahun = entities.get('tahun')
    daerah = entities.get('daerah')

    # --- STEP 2: Deteksi kategori berbasis aturan ringan
    kategori_map = {
        "produk hukum": ["peraturan daerah", "peraturan bupati", "produk hukum"],
        "peraturan daerah": ["peraturan daerah", "perda"],
        "peraturan bupati": ["peraturan bupati", "perbup"],
        "surat edaran": ["surat edaran"],
        "nota dinas": ["nota dinas"],
        "surat keputusan bupati": ["surat keputusan bupati", "sk bupati"],
        "surat keputusan skpd": ["surat keputusan kepala dinas", "skpd"],
        "kontrak": ["kontrak", "spk", "perjanjian kerja"],
        "pengadaan barang": ["pengadaan barang", "barang dan jasa"],
        "pengadaan kontruksi": ["kontruksi", "pembangunan"],
        "rumah dan kontrakan": ["kontrakan", "rumah dinas", "sewa rumah"],
        "jalan dan jembatan": ["jalan", "jembatan", "aspal"],
        "sarana dan prasarana lain": ["sarana", "prasarana"],
        "pengairan dan irigasi": ["pengairan", "irigasi"],
        "surat berharga": ["surat berharga", "obligasi"],
        "spp, spm, sp2d": ["spp", "spm", "sp2d"],
        "surat tagihan, nota, kwitansi": ["nota", "kwitansi", "tagihan"],
        "disposisi": ["disposisi"],
        "dokumen lainnya": ["dokumen lain", "lainnya"],
        "jasa konsultasi": ["konsultan", "konsultasi"],
        "jasa lainnya": ["jasa lainnya"]
    }

    text_lower = text.lower()
    kategori_detected = None
    for k, words in kategori_map.items():
        if any(w in text_lower for w in words):
            kategori_detected = k
            break

    if not kategori_detected and kategori:
        kategori_detected = kategori

    # --- STEP 3: Cepat ambil kata penting
    # (tanpa normalisasi berat atau lemmatization)
    text_basic = re.sub(r'[^a-zA-Z\s]', ' ', text_lower)
    tokens = [t for t in text_basic.split() if len(t) > 4]
    # pakai Counter langsung ke top_n kecil (hemat waktu)
    keywords = [w for w, _ in Counter(tokens).most_common(top_n)]

    # --- STEP 4: Gabungkan hasil, hapus duplikat, batasi ke 4
    hasil_tag = []
    for item in [kategori_detected, instansi, daerah, tahun]:
        if item and str(item).strip():
            hasil_tag.append(str(item).strip())
    hasil_tag.extend(keywords)

    # unik + kapitalisasi + limit 4
    hasil_tag = list(dict.fromkeys([t.title() for t in hasil_tag if len(t) > 2]))[:4]

    print(f"[AUTO TAG] Hasil tag: {hasil_tag}")
    return hasil_tag
