/*
Copyright © 2025 cavaliba.com
*/
package cmd

import (
	"encoding/json"
	"fmt"
	"net/http"
	"os"
	"time"

	"github.com/spf13/cobra"
	"github.com/spf13/viper"
)

var pipelineSchema string
var pipelineDryrun bool
var pipelineNoWait bool
var pipelineInterval int

var terminalStates = map[string]bool{
	"DONE":    true,
	"FAILED":  true,
	"ABORTED": true,
}

type pipelineSubmitResponse struct {
	Pipeline string `json:"pipeline"`
	Dryrun   bool   `json:"dryrun"`
	Handle   string `json:"handle"`
	TaskURL  string `json:"task_url"`
}

type taskProgressResponse struct {
	Handle   string                 `json:"handle"`
	State    string                 `json:"state"`
	Name     string                 `json:"name"`
	Progress map[string]interface{} `json:"progress"`
	Output   map[string]interface{} `json:"output"`
}

func pollTask(handle string) {
	endpoint := fmt.Sprintf("tasks/%s/", handle)
	interval := time.Duration(pipelineInterval) * time.Second

	fmt.Printf("polling task %s every %ds ...\n", handle, pipelineInterval)

	lastPercent := -1

	for {
		target := APITarget{
			url:            viper.GetString("url") + endpoint,
			ssl_skipverify: viper.GetBool("ssl_skipverify"),
		}

		result, err := CallAPI(target)
		if err != nil {
			fmt.Printf("  poll error: %v\n", err)
			time.Sleep(interval)
			continue
		}

		var t taskProgressResponse
		if jsonErr := json.Unmarshal([]byte(result.body), &t); jsonErr != nil {
			fmt.Printf("  parse error: %v\n", jsonErr)
			time.Sleep(interval)
			continue
		}

		percent := 0
		message := ""
		count := 0
		total := 0
		if t.Progress != nil {
			if v, ok := t.Progress["percent"]; ok {
				if f, ok := v.(float64); ok {
					percent = int(f)
				}
			}
			if v, ok := t.Progress["message"]; ok {
				if s, ok := v.(string); ok {
					message = s
				}
			}
			if v, ok := t.Progress["count"]; ok {
				if f, ok := v.(float64); ok {
					count = int(f)
				}
			}
			if v, ok := t.Progress["total"]; ok {
				if f, ok := v.(float64); ok {
					total = int(f)
				}
			}
		}

		if terminalStates[t.State] || percent != lastPercent {
			if message != "" {
				fmt.Printf("  [%s] %3d%%  %s\n", t.State, percent, message)
			} else if total > 0 {
				fmt.Printf("  [%s] %3d%%  %d/%d\n", t.State, percent, count, total)
			} else {
				fmt.Printf("  [%s] %3d%%\n", t.State, percent)
			}
			lastPercent = percent
		}

		if terminalStates[t.State] {
			if t.Output != nil {
				outputBytes, _ := json.MarshalIndent(t.Output, "  ", "  ")
				PrintOutput(fmt.Sprintf("  output: %s", string(outputBytes)))
			}
			return
		}

		time.Sleep(interval)
	}
}

var pipelineCmd = &cobra.Command{
	Use:     "pipeline",
	Short:   "list or run a pipeline",
	Long:    `Call cavaliba pipelines/ API. Without --key: list pipelines. With --key <name> and --schema <s1,s2>: submit a pipeline run and poll for completion. Use --no-wait to skip polling. Use --interval <sec> to set poll interval (default 2s).`,
	Aliases: []string{"pipelines"},
	Run: func(cmd *cobra.Command, args []string) {

		// No --key: list pipelines
		if !cmd.Flags().Changed("key") {
			target := APITarget{
				url:            viper.GetString("url") + "pipelines/",
				ssl_skipverify: viper.GetBool("ssl_skipverify"),
			}
			err := AppendGlobalOptions(&target)
			if err != nil {
				fmt.Println(err)
				os.Exit(1)
			}
			PrintVerboseTarget(target)
			result, err := CallAPI(target)
			if err != nil {
				PrintError(result, err)
				os.Exit(0)
			}
			PrintVerboseResult(result)
			PrintOutput(result.body)
			return
		}

		if pipelineSchema == "" {
			fmt.Println("error: --schema <schema1,schema2> is required to run a pipeline")
			os.Exit(1)
		}

		endpoint := fmt.Sprintf("pipelines/%s/", key)
		queryParams := "?schema=" + pipelineSchema
		if pipelineDryrun {
			queryParams += "&dryrun=true"
		}

		target := APITarget{
			url:            viper.GetString("url") + endpoint + queryParams,
			method:         http.MethodPost,
			ssl_skipverify: viper.GetBool("ssl_skipverify"),
		}

		PrintVerboseTarget(target)

		result, err := CallAPI(target)

		// 202 Accepted is the expected success response for pipeline submit
		if err != nil && result.http_code != 202 {
			PrintError(result, err)
			os.Exit(0)
		}

		PrintVerboseResult(result)
		fmt.Println(result.body)

		if pipelineNoWait || result.http_code != 202 {
			return
		}

		var submitted pipelineSubmitResponse
		if jsonErr := json.Unmarshal([]byte(result.body), &submitted); jsonErr != nil || submitted.Handle == "" {
			fmt.Println("could not parse task handle from response, skipping poll")
			return
		}

		pollTask(submitted.Handle)
	},
}

func init() {
	pipelineCmd.Flags().StringVar(&pipelineSchema, "schema", "", "Comma-separated schema names (required to run)")
	pipelineCmd.Flags().BoolVar(&pipelineDryrun, "dryrun", false, "Dry-run mode (no changes applied)")
	pipelineCmd.Flags().BoolVar(&pipelineNoWait, "no-wait", false, "Submit pipeline and return immediately without polling")
	pipelineCmd.Flags().IntVar(&pipelineInterval, "interval", 2, "Poll interval in seconds")
	rootCmd.AddCommand(pipelineCmd)
}
