from django.db import connection


# ============================================================
# INSERT FULLTEXT
# ============================================================
def insert_fulltext(attachment_id, raw_text, cleaned_text, category_id=None, tags=None):
    """
    Sesuai struktur tabel:
      pk
      pk_attachment
      raw_text
      cleaned_text
      created_at
    """
    sql = """
        INSERT INTO attachments_text (pk_attachment, raw_text, cleaned_text)
        VALUES (%s, %s, %s)
        RETURNING pk;
    """

    with connection.cursor() as cur:
        cur.execute(sql, [attachment_id, raw_text, cleaned_text])
        row = cur.fetchone()
        return row[0]
    

# ============================================================
# INSERT CHUNKS
# ============================================================
def insert_chunks(attachment_id, chunks):
    """
    Struktur tabel attachments_chunks:
      pk
      pk_attachment
      chunk_index
      chunk_content
      token_count
    """

    sql = """
        INSERT INTO attachments_chunks
            (pk_attachment, chunk_index, chunk_content, token_count)
        VALUES (%s, %s, %s, %s)
        RETURNING pk;
    """

    ids = []
    with connection.cursor() as cur:
        for index, chunk in enumerate(chunks):
            token_count = len(chunk.split())
            cur.execute(sql, [attachment_id, index, chunk, token_count])
            ids.append(cur.fetchone()[0])

    return ids


# ============================================================
# INSERT EMBEDDINGS
# ============================================================
def insert_embeddings(chunk_ids, embeddings):
    """
    Struktur tabel attachments_embeddings:
      pk
      pk_chunk
      embedding vector(1024)
      created_at
    """

    sql = """
        INSERT INTO attachments_embeddings (pk_chunk, embedding)
        VALUES (%s, %s);
    """

    with connection.cursor() as cur:
        for cid, emb in zip(chunk_ids, embeddings):
            emb_list = emb.tolist()
            cur.execute(sql, [cid, emb_list])
