# (c) cavaliba.com - data - tasks.py

import time

from celery import shared_task

import app_home.cache as cache
from app_data.data import Instance
from app_data.group import Group
from app_data.loader import load_broker
from app_data.models import DataInstance, DataTask
from app_data.pipeline import Pipeline
from app_data.task_manager import create_datatask, finish_task, is_aborted, update_progress
from app_data.user import User
from app_home.log import DEBUG, ERROR, log

# from app_data.task_manager import abort_task
# from app_data.task_manager import datatask_cleanup


# ------------------------------------------------------------------
# Example long task — template for real tasks
# ------------------------------------------------------------------


def submit_example_long(params=None, owner_type="auto", owner_id="system"):
    """Submit the example long task. Returns handle UUID or None if singleton active."""
    dt = create_datatask(
        name="Example Long Task",
        params=params,
        singleton="example_long",
        owner_type=owner_type,
        owner_id=owner_id,
    )
    if dt is None:
        return None
    task_example_long.delay(str(dt.handle), params or {})
    return str(dt.handle)


@shared_task(ignore_result=True)
def task_example_long(handle, params):
    """Example long-running task. Iterates 10 steps, supports abort."""
    try:
        dt = DataTask.objects.get(handle=handle)
    except DataTask.DoesNotExist:
        return

    if dt.state == "ABORTED":
        return

    dt.state = "RUNNING"
    dt.save(update_fields=["state"])

    total = params.get("steps", 10) if params else 10
    duration = params.get("duration", 0.1) if params else 0.1

    try:
        for i in range(total):
            if is_aborted(handle):
                return
            time.sleep(duration)
            update_progress(
                handle,
                percent=int((i + 1) * 100 / total),
                count=i + 1,
                total=total,
                message=f"step {i + 1}/{total}",
            )

        finish_task(handle, state="DONE", output={"steps_done": total})

    except Exception as exc:
        finish_task(handle, state="FAILED", output={"error": str(exc)})
        log(
            ERROR,
            aaa=None,
            app="data",
            view="task_example_long",
            action="run",
            status="FAILED",
            data=str(exc),
        )


# ------------------------------------------------------------------
# Pipeline task
# ------------------------------------------------------------------
# singleton by pipeline_name


def submit_pipeline(
    pipeline_name,
    schema_names,
    dryrun=False,
    aaa=None,
    owner_type="api",
    owner_id="system",
    sync=False,
):
    p = Pipeline.from_name(pipeline_name)
    if p is None:
        return None, "pipeline not found"

    total_instances = 0
    for schemaname in schema_names:
        if schemaname == "user":
            total_instances += User.count_all()
        else:
            total_instances += DataInstance.objects.filter(classname=schemaname).count()

    params = {
        "pipeline_name": pipeline_name,
        "schema_names": schema_names,
        "dryrun": dryrun,
        "aaa": aaa,
        "total_instances": total_instances,
    }
    dt = create_datatask(
        name=f"Pipeline {pipeline_name}",
        params=params,
        singleton=pipeline_name,
        owner_type=owner_type,
        owner_id=owner_id,
    )
    if dt is None:
        return None, "datatask creation failed (probably single task allowed)"

    if sync:
        task_pipeline(str(dt.handle), params)
    else:
        task_pipeline.delay(str(dt.handle), params)
    return str(dt.handle), None


@shared_task(ignore_result=True)
def task_pipeline(handle, params):
    try:
        dt = DataTask.objects.get(handle=handle)
    except DataTask.DoesNotExist:
        return

    if dt.state == "ABORTED":
        return

    dt.state = "RUNNING"
    dt.save(update_fields=["state"])

    pipeline_name = params.get("pipeline_name")
    schema_names = params.get("schema_names", [])
    dryrun = params.get("dryrun", False)
    aaa = params.get("aaa")
    total_instances = params.get("total_instances", 0)

    # Load and validate pipeline
    try:
        p = Pipeline.from_name(pipeline_name)
    except Exception as exc:
        finish_task(handle, state="FAILED", output={"error": f"pipeline load error: {exc}"})
        log(
            ERROR,
            aaa=None,
            app="data",
            view="task_pipeline",
            action="load",
            status="FAILED",
            data=str(exc),
        )
        return

    if p is None:
        finish_task(handle, state="FAILED", output={"error": "pipeline not found"})
        return

    if p.run_permission and aaa:
        perms = aaa.get("perms", [])
        if isinstance(perms, list) and p.run_permission not in perms:
            finish_task(
                handle, state="FAILED", output={"error": f"permission denied ({p.run_permission})"}
            )
            return

    BATCH_SIZE = 20
    MAX_ERRORS = 50
    total_ok = 0
    total_discarded = 0
    total_errors = []
    results = []
    global_count = 0

    for schemaname in schema_names:
        count_ok = 0
        count_discarded = 0
        errors = []

        # v4.0 - user is a Schema
        for instance in Instance.iterate_classname(classname=schemaname):
            datadict = instance.get_dict_for_export()

            if datadict is None:
                continue

            try:
                result = p.apply(datadict)
            except Exception as exc:
                if len(errors) < MAX_ERRORS:
                    errors.append(f"apply error: {exc}")
                global_count += 1
                continue

            if result is not None:
                count_discarded += 1
            else:
                if not dryrun:
                    try:
                        # err = load_instance(datadict=datadict, aaa=aaa)
                        reply = load_broker(datalist=[datadict], aaa=aaa)
                        reply_errors = reply.get("errors", [])
                        if reply_errors:
                            if len(errors) < MAX_ERRORS:
                                errors.extend(reply_errors)
                        else:
                            count_ok += 1
                    except Exception as exc:
                        if len(errors) < MAX_ERRORS:
                            errors.append(f"save error: {exc}")
                else:
                    count_ok += 1

            global_count += 1
            if global_count % BATCH_SIZE == 0:
                if is_aborted(handle):
                    finish_task(
                        handle,
                        state="ABORTED",
                        output={
                            "pipeline": pipeline_name,
                            "total_ok": total_ok + count_ok,
                            "total_discarded": total_discarded + count_discarded,
                        },
                    )
                    return
                update_progress(
                    handle,
                    percent=int(global_count * 100 / total_instances) if total_instances else 0,
                    count=global_count,
                    total=total_instances,
                    message=f"{schemaname}: {count_ok} ok, {count_discarded} discarded so far",
                )

        total_ok += count_ok
        total_discarded += count_discarded
        total_errors += errors
        results.append(
            {
                "schema": schemaname,
                "count_ok": count_ok,
                "count_discarded": count_discarded,
                "errors": errors,
            }
        )

    # update_progress(
    #     handle,
    #     percent=100,
    #     count=total_ok,
    #     total=total_ok,
    #     message=f"DONE",
    # )

    finish_task(
        handle,
        state="DONE",
        output={
            "pipeline": pipeline_name,
            "dryrun": dryrun,
            "total_ok": total_ok,
            "total_discarded": total_discarded,
            "total_errors": len(total_errors),
            "results": results,
        },
    )


# ------------------------------------------------------------------
# IAM - autogroup update
# ------------------------------------------------------------------
@shared_task(ignore_result=True)
def task_autogroup_update():
    cache.init()
    count = Group.autogroup_update()
    log(
        DEBUG,
        aaa=None,
        app="iam",
        view="cron",
        action="autogroup_update",
        status="OK",
        data=f"{count} groups updated",
    )
