# (c) cavaliba.com - data - tasks.py

import time

from app_data.data import Instance, load_instance
from app_data.models import DataInstance, DataTask
from app_data.pipeline import Pipeline
from app_data.task_manager import create_datatask, finish_task, is_aborted, update_progress
from app_home.log import ERROR, log
from app_user.models import SireneUser
from app_user.user import load_user, user_listdict_format
from celery import shared_task

# from app_data.task_manager import abort_task
# from app_data.task_manager import datatask_cleanup



# ------------------------------------------------------------------
# Example long task — template for real tasks
# ------------------------------------------------------------------

def submit_example_long(params=None, owner_type="auto", owner_id="system"):
    """Submit the example long task. Returns handle UUID or None if singleton active."""
    dt = create_datatask(
        name="Example Long Task",
        params=params,
        singleton="example_long",
        owner_type=owner_type,
        owner_id=owner_id,
    )
    if dt is None:
        return None
    task_example_long.delay(str(dt.handle), params or {})
    return str(dt.handle)


@shared_task(ignore_result=True)
def task_example_long(handle, params):
    """Example long-running task. Iterates 10 steps, supports abort."""
    try:
        dt = DataTask.objects.get(handle=handle)
    except DataTask.DoesNotExist:
        return

    if dt.state == "ABORTED":
        return

    dt.state = "RUNNING"
    dt.save(update_fields=["state"])

    total = params.get("steps", 10) if params else 10
    duration = params.get("duration", 0.1) if params else 0.1

    try:
        for i in range(total):
            if is_aborted(handle):
                return
            time.sleep(duration)
            update_progress(handle, percent=int((i + 1) * 100 / total),
                            count=i + 1, total=total, message=f"step {i+1}/{total}")


        finish_task(handle, state="DONE", output={"steps_done": total})

    except Exception as exc:
        finish_task(handle, state="FAILED", output={"error": str(exc)})
        log(ERROR, aaa=None, app="data", view="task_example_long", action="run",
            status="FAILED", data=str(exc))


# ------------------------------------------------------------------
# Pipeline task
# ------------------------------------------------------------------
# singleton by pipeline_name

def submit_pipeline(pipeline_name, schema_names, dryrun=False, aaa=None, owner_type="api", owner_id="system", sync=False):
    p = Pipeline.from_name(pipeline_name)
    if p is None:
        return None, "pipeline not found"

    total_instances = 0
    for schemaname in schema_names:
        if schemaname == "_user":
            total_instances += SireneUser.objects.count()
        else:
            total_instances += DataInstance.objects.filter(classname=schemaname).count()

    params = {
        "pipeline_name": pipeline_name,
        "schema_names": schema_names,
        "dryrun": dryrun,
        "aaa": aaa,
        "total_instances": total_instances,
    }
    dt = create_datatask(
        name=f"Pipeline {pipeline_name}",
        params=params,
        singleton=pipeline_name,
        owner_type=owner_type,
        owner_id=owner_id,
    )
    if dt is None:
        return None, "datatask creation failed (probably single task allowed)"

    if sync:
        task_pipeline(str(dt.handle), params)
    else:
        task_pipeline.delay(str(dt.handle), params)
    return str(dt.handle), None


@shared_task(ignore_result=True)
def task_pipeline(handle, params):
    try:
        dt = DataTask.objects.get(handle=handle)
    except DataTask.DoesNotExist:
        return

    if dt.state == "ABORTED":
        return

    dt.state = "RUNNING"
    dt.save(update_fields=["state"])

    pipeline_name = params.get("pipeline_name")
    schema_names = params.get("schema_names", [])
    dryrun = params.get("dryrun", False)
    aaa = params.get("aaa")
    total_instances = params.get("total_instances", 0)

    # Load and validate pipeline
    try:
        p = Pipeline.from_name(pipeline_name)
    except Exception as exc:
        finish_task(handle, state="FAILED", output={"error": f"pipeline load error: {exc}"})
        log(ERROR, aaa=None, app="data", view="task_pipeline", action="load", status="FAILED", data=str(exc))
        return

    if p is None:
        finish_task(handle, state="FAILED", output={"error": "pipeline not found"})
        return

    if p.run_permission and aaa:
        perms = aaa.get("perms", [])
        if isinstance(perms, list) and p.run_permission not in perms:
            finish_task(handle, state="FAILED", output={"error": f"permission denied ({p.run_permission})"})
            return

    BATCH_SIZE = 20
    MAX_ERRORS = 50
    total_ok = 0
    total_discarded = 0
    total_errors = []
    results = []
    global_count = 0

    for schemaname in schema_names:

        count_ok = 0
        count_discarded = 0
        errors = []

        if schemaname == "_user":
            items = (
                (user_listdict_format([user]) or [None])[0]
                for user in SireneUser.objects.all()
            )
        else:
            items = (
                instance.get_dict_for_export()
                for instance in Instance.iterate_classname(classname=schemaname)
            )

        for datadict in items:
            if datadict is None:
                continue

            try:
                result = p.apply(datadict)
            except Exception as exc:
                if len(errors) < MAX_ERRORS:
                    errors.append(f"apply error: {exc}")
                global_count += 1
                continue

            if result is not None:
                count_discarded += 1
            else:
                if not dryrun:
                    try:
                        if schemaname == "_user":
                            err = load_user(datadict=datadict, aaa=aaa)
                        else:
                            err = load_instance(datadict=datadict, aaa=aaa)
                        if err:
                            if len(errors) < MAX_ERRORS:
                                errors.append(err)
                        else:
                            count_ok += 1
                    except Exception as exc:
                        if len(errors) < MAX_ERRORS:
                            errors.append(f"save error: {exc}")
                else:
                    count_ok += 1

            global_count += 1
            if global_count % BATCH_SIZE == 0:
                if is_aborted(handle):
                    finish_task(handle, state="ABORTED", output={
                        "pipeline": pipeline_name,
                        "total_ok": total_ok + count_ok,
                        "total_discarded": total_discarded + count_discarded,
                    })
                    return
                update_progress(
                    handle,
                    percent=int(global_count * 100 / total_instances) if total_instances else 0,
                    count=global_count,
                    total=total_instances,
                    message=f"{schemaname}: {count_ok} ok, {count_discarded} discarded so far",
                )

        total_ok += count_ok
        total_discarded += count_discarded
        total_errors += errors
        results.append({
            "schema": schemaname,
            "count_ok": count_ok,
            "count_discarded": count_discarded,
            "errors": errors,
        })

    # update_progress(
    #     handle,
    #     percent=100,
    #     count=total_ok,
    #     total=total_ok,
    #     message=f"DONE",
    # )

    finish_task(handle, state="DONE", output={
        "pipeline": pipeline_name,
        "dryrun": dryrun,
        "total_ok": total_ok,
        "total_discarded": total_discarded,
        "total_errors": len(total_errors),
        "results": results,
    })

