diff --git a/cmd/cmd_utils.go b/cmd/cmd_utils.go index 2ae20cee7..51f08296b 100644 --- a/cmd/cmd_utils.go +++ b/cmd/cmd_utils.go @@ -50,6 +50,95 @@ func getAtmosCommandMaxDepth() int { return AtmosCommandDefaultDepth } +func detectCycle(commands []schema.Command) bool { + // Create a map of valid commands for quick lookup + validCommands := make(map[string]bool) + for _, cmd := range commands { + validCommands[cmd.Name] = true + } + // Build a command graph + graph := make(map[string][]string) + for _, cmd := range commands { + for _, step := range cmd.Steps { + cmdName := parseCommandName(step) + if cmdName != "" && !validCommands[cmdName] { + return true + } + graph[cmd.Name] = append(graph[cmd.Name], cmdName) + } + } + + // To track visited nodes and detect cycles + visited := make(map[string]bool) + recStack := make(map[string]bool) + + // Track the maximum recursion depth + maxDepth := getAtmosCommandMaxDepth() + + // Run DFS for each command to detect cycles and compute depth + for cmd := range graph { + if detectCycleUtil(cmd, graph, visited, recStack, 0, &maxDepth) { + return true // Cycle detected + } + } + + // Print or return the max depth if needed + fmt.Println("Maximum Recursion Depth:", maxDepth) + return false // No cycle detected +} + +func detectCycleUtil(command string, graph map[string][]string, visited, recStack map[string]bool, depth int, maxDepth *int) bool { + // Update the maximum recursion depth + if depth > *maxDepth { + *maxDepth = depth + } + + // If the current command is in the recursion stack, there's a cycle + if recStack[command] { + return true + } + + // If already visited, no need to explore again + if visited[command] { + return false + } + + // Mark as visited and add to recursion stack + visited[command] = true + recStack[command] = true + + // Recurse for all dependencies + for _, dep := range graph[command] { + if detectCycleUtil(dep, graph, visited, recStack, depth+1, maxDepth) { + return true + } + } + + // Remove from recursion stack before backtracking + recStack[command] = false + return false +} + +// Helper function to parse command name from the step +func parseCommandName(step string) string { + // Split the step into parts + parts := strings.Split(step, " ") + + // Check if the command starts with "atmos" and has additional parts + if len(parts) > 1 && parts[0] == "atmos" { + // Extract the actual command name, handling flags and arguments + cmdParts := []string{} + for _, part := range parts[1:] { + if strings.HasPrefix(part, "-") { + break + } + cmdParts = append(cmdParts, part) + } + return strings.Join(cmdParts, " ") + } + return "" +} + // processCustomCommands processes and executes custom commands func processCustomCommands( cliConfig schema.CliConfiguration, @@ -64,9 +153,9 @@ func processCustomCommands( existingTopLevelCommands = getTopLevelCommands() } - // Track the execution count for each command - visitCount := make(map[string]int) - maxIterations := getAtmosCommandMaxDepth() // Default to 10 or from environment + if detectCycle(commands) { + return fmt.Errorf("cycle detected in custom CLI commands - this could lead to infinite recursion") + } for _, commandCfg := range commands { // Clone the 'commandCfg' struct into a local variable because of the automatic closure in the `Run` function of the Cobra command. @@ -88,14 +177,6 @@ func processCustomCommands( preCustomCommand(cmd, args, parentCommand, commandConfig) }, Run: func(cmd *cobra.Command, args []string) { - // Increment the visit count for the command - visitCount[commandConfig.Name]++ - if visitCount[commandConfig.Name] > maxIterations { - u.LogWarning(cliConfig, fmt.Sprintf("Command '%s' reached max iteration limit (%d). Skipping further execution.\n", commandConfig.Name, maxIterations)) - return - } - - // Execute the command if under the limit executeCustomCommand(cliConfig, cmd, args, parentCommand, commandConfig) }, }