diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 05a849a7..10eb33a1 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -62,10 +62,26 @@ jobs: with: go-version: ${{ env.GO_VERSION }} - - name: Install go module dependencies + - name: Install Go tools run: | + # Install shfmt go install mvdan.cc/sh/v3/cmd/shfmt@latest + # Install goimports + go install golang.org/x/tools/cmd/goimports@latest + + # Install gocyclo + go install github.com/fzipp/gocyclo/cmd/gocyclo@latest + + # Install golangci-lint + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b "$(go env GOPATH)/bin" + + # Install gocritic + go install github.com/go-critic/go-critic/cmd/gocritic@latest + + # Add Go bin directory to PATH + echo "$(go env GOPATH)/bin" >> "$GITHUB_PATH" + - name: Setup go-task run: | sh -c "$(curl --location https://taskfile.dev/install.sh)" -- -d -b /usr/local/bin v${{ env.TASK_VERSION }} diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 00000000..4208578c --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,53 @@ +--- +name: Tests +on: + merge_group: + pull_request: + branches: + - main + types: + - opened + - synchronize + - reopened + push: + branches: + - main + workflow_dispatch: + +concurrency: + cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.ref }} + +env: + GO_VERSION: "1.26.1" + +permissions: + actions: read + checks: write + contents: read + pull-requests: write + +jobs: + tests: + name: Run Go tests and determine code coverage + runs-on: ubuntu-latest + steps: + - name: Checkout git repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Setup Go + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 + with: + go-version: ${{ env.GO_VERSION }} + check-latest: true + + - name: Generate the coverage output + run: | + bash .hooks/run-go-tests.sh coverage + + - name: Upload coverage artifact + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 + with: + name: coverage-report + path: cli/coverage-all.out + retention-days: 14 diff --git a/.gitignore b/.gitignore index 5bbbc153..8cbbbd14 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ __pycache__/ .task/ # Build artifacts +dreadgoad ansible/roles/adcs_templates/files/ADCSTemplate.zip ansible/roles/vulns_adcs_templates/files/ADCSTemplate.zip diff --git a/.hooks/go-no-replacement.sh b/.hooks/go-no-replacement.sh new file mode 100755 index 00000000..1050ac11 --- /dev/null +++ b/.hooks/go-no-replacement.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +REPO_ROOT=$(git rev-parse --show-toplevel 2> /dev/null) +GO_MOD="${REPO_ROOT}/cli/go.mod" + +if grep -q "^replace " "${GO_MOD}" 2> /dev/null; then + echo "ERROR: Don't commit a replacement in go.mod!" + exit 1 +fi diff --git a/.hooks/go-vet.sh b/.hooks/go-vet.sh new file mode 100755 index 00000000..743d7348 --- /dev/null +++ b/.hooks/go-vet.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +REPO_ROOT=$(git rev-parse --show-toplevel 2> /dev/null) +cd "${REPO_ROOT}/cli" + +pkgs=$(go list ./...) + +for pkg in $pkgs; do + dir="$(basename "$pkg")/" + if [[ "${dir}" != .*/ ]]; then + go vet "${pkg}" + fi +done diff --git a/.hooks/run-go-tests.sh b/.hooks/run-go-tests.sh new file mode 100755 index 00000000..182958c4 --- /dev/null +++ b/.hooks/run-go-tests.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +set -euo pipefail + +TESTS_TO_RUN=${1:-} +RETURN_CODE=0 +GITHUB_ACTIONS=${GITHUB_ACTIONS:-} +PROJECT_NAME=$(basename "$(git rev-parse --show-toplevel 2> /dev/null)" || echo "dreadgoad") + +TIMESTAMP=$(date +"%Y%m%d%H%M%S") +LOGFILE="/tmp/${PROJECT_NAME}-unit-test-results-$TIMESTAMP.log" + +# Go code lives in the cli/ subdirectory +REPO_ROOT=$(git rev-parse --show-toplevel 2> /dev/null) +GO_DIR="${REPO_ROOT}/cli" + +if [[ -z "${TESTS_TO_RUN}" ]]; then + echo "No tests input" | tee -a "$LOGFILE" + echo "Example - Run all shorter collection of tests: bash .hooks/run-go-tests.sh short" | tee -a "$LOGFILE" + echo "Example - Run all tests: bash .hooks/run-go-tests.sh all" | tee -a "$LOGFILE" + echo "Example - Run coverage for a specific version: bash .hooks/run-go-tests.sh coverage" | tee -a "$LOGFILE" + echo "Example - Run tests for modified files: bash .hooks/run-go-tests.sh modified" | tee -a "$LOGFILE" + exit 1 +fi + +run_tests() { + local coverage_file=${1:-} + pushd "${GO_DIR}" > /dev/null || exit + echo "Logging output to ${LOGFILE}" | tee -a "$LOGFILE" + echo "Running tests..." | tee -a "$LOGFILE" + + # Check if go.mod and go.sum exist + if [[ -f "go.mod" && -f "go.sum" ]]; then + MOD_TMP=$(mktemp) + SUM_TMP=$(mktemp) + cp go.mod "$MOD_TMP" + cp go.sum "$SUM_TMP" + go mod tidy + if ! cmp -s go.mod "$MOD_TMP" || ! cmp -s go.sum "$SUM_TMP"; then + echo "Running 'go mod tidy' to clean up module dependencies..." | tee -a "$LOGFILE" + go mod tidy 2>&1 | tee -a "$LOGFILE" + fi + rm "$MOD_TMP" "$SUM_TMP" + fi + + if [[ "${TESTS_TO_RUN}" == 'coverage' ]]; then + go test -v -race -failfast -coverprofile="${coverage_file}" ./... 2>&1 | tee -a "$LOGFILE" + elif [[ "${TESTS_TO_RUN}" == 'all' ]]; then + go test -v -race -failfast ./... 2>&1 | tee -a "$LOGFILE" + elif [[ "${TESTS_TO_RUN}" == 'short' ]] && [[ "${GITHUB_ACTIONS}" != "true" ]]; then + go test -v -short -failfast -race ./... 2>&1 | tee -a "$LOGFILE" + elif [[ "${TESTS_TO_RUN}" == 'modified' ]]; then + local modified_files=() + while IFS= read -r file; do + [[ -n "$file" ]] && modified_files+=("$file") + done < <(git diff --name-only --cached | grep '^cli/.*\.go$' | sed 's|^cli/||' || true) + + if [[ ${#modified_files[@]} -eq 0 ]]; then + echo "No modified Go files found to test" | tee -a "$LOGFILE" + popd > /dev/null + return 0 + fi + + local pkg_dirs=() + for file in "${modified_files[@]}"; do + local pkg_dir + pkg_dir=$(dirname "$file") + pkg_dirs+=("$pkg_dir") + done + + # Remove duplicate package directories + IFS=$'\n' read -r -a pkg_dirs <<< "$(printf '%s\n' "${pkg_dirs[@]}" | sort -u)" + unset IFS + + for dir in "${pkg_dirs[@]}"; do + if [[ -n "$dir" ]]; then + local pkg_list + pkg_list=$(go list "./$dir/..." 2>&1) + if [[ -n "$pkg_list" ]] && [[ ! "$pkg_list" =~ "matched no packages" ]]; then + go test -v -short -race -failfast "./$dir/..." 2>&1 | tee -a "$LOGFILE" + else + echo "Skipping $dir (no packages match current platform/build constraints)" | tee -a "$LOGFILE" + fi + fi + done + else + if [[ "${GITHUB_ACTIONS}" != 'true' ]]; then + go test -v -failfast -race "./.../${TESTS_TO_RUN}" 2>&1 | tee -a "$LOGFILE" + fi + fi + + RETURN_CODE=$? + popd > /dev/null +} + +if [[ "${TESTS_TO_RUN}" == 'coverage' ]]; then + run_tests 'coverage-all.out' +else + run_tests +fi + +if [[ "${RETURN_CODE}" -ne 0 ]]; then + echo "unit tests failed" | tee -a "$LOGFILE" + exit 1 +fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3bbc758d..08b1f11e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -89,3 +89,75 @@ repos: always_run: false files: ^ansible/(roles/|plugins/|playbooks/).* additional_dependencies: [] + + - id: go-fmt + name: Run gofmt + entry: bash -c 'cd cli && gofmt -l -w .' + language: system + files: '\.go$' + pass_filenames: false + + - id: go-imports + name: Run goimports + entry: bash -c 'cd cli && goimports -l -w .' + language: system + files: '\.go$' + pass_filenames: false + + - id: go-cyclo + name: Run gocyclo + entry: bash -c 'cd cli && gocyclo -over 15 .' + language: system + files: '\.go$' + pass_filenames: false + + - id: golangci-lint + name: Run golangci-lint + entry: bash -c 'cd cli && golangci-lint run --timeout=5m --enable=unused ./...' + language: system + files: '\.go$' + pass_filenames: false + + - id: go-critic + name: Run go-critic + entry: bash -c 'cd cli && gocritic check ./...' + language: system + files: '\.go$' + pass_filenames: false + + - id: go-build + name: Run go build + entry: bash -c 'cd cli && go build ./...' + language: system + files: '\.go$' + pass_filenames: false + + - id: go-mod-tidy + name: Run go mod tidy + entry: bash -c 'cd cli && go mod tidy' + language: system + files: '(go\.mod|go\.sum)$' + pass_filenames: false + + - id: go-no-replacement + name: Avoid committing a go module replacement + entry: .hooks/go-no-replacement.sh + language: script + files: go.mod + + - id: go-unit-tests + name: Go unit tests + language: script + entry: .hooks/run-go-tests.sh modified + files: '\.go$' + pass_filenames: true + + - id: go-vet + name: Run go vet + language: script + entry: .hooks/go-vet.sh + files: '\.go$' + always_run: true + pass_filenames: true + require_serial: true + log_file: /tmp/go-vet.log diff --git a/cli/cmd/config_cmd.go b/cli/cmd/config_cmd.go new file mode 100644 index 00000000..87364f71 --- /dev/null +++ b/cli/cmd/config_cmd.go @@ -0,0 +1,106 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/dreadnode/dreadgoad/internal/config" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var configCmd = &cobra.Command{ + Use: "config", + Short: "Manage CLI configuration", +} + +var configShowCmd = &cobra.Command{ + Use: "show", + Short: "Display current effective configuration", + RunE: func(cmd *cobra.Command, args []string) error { + cfg := config.Get() + fmt.Printf("Environment: %s\n", cfg.Env) + fmt.Printf("Region: %s\n", valueOrDefault(cfg.Region, "(from inventory)")) + fmt.Printf("Debug: %v\n", cfg.Debug) + fmt.Printf("Max Retries: %d\n", cfg.MaxRetries) + fmt.Printf("Retry Delay: %ds\n", cfg.RetryDelay) + fmt.Printf("Idle Timeout: %ds\n", cfg.IdleTimeout) + fmt.Printf("Log Dir: %s\n", cfg.LogDir) + fmt.Printf("Project Root: %s\n", cfg.ProjectRoot) + fmt.Printf("Inventory: %s\n", cfg.InventoryPath()) + fmt.Printf("Ansible Config: %s\n", cfg.AnsibleCfgPath()) + fmt.Printf("Playbooks: %s\n", strings.Join(cfg.Playbooks, ", ")) + + if cfgFile := viper.ConfigFileUsed(); cfgFile != "" { + fmt.Printf("\nConfig file: %s\n", cfgFile) + } else { + fmt.Println("\nNo config file found (using defaults)") + } + return nil + }, +} + +var configInitCmd = &cobra.Command{ + Use: "init", + Short: "Create default configuration file", + RunE: func(cmd *cobra.Command, args []string) error { + home, _ := os.UserHomeDir() + dir := filepath.Join(home, ".config", "dreadgoad") + _ = os.MkdirAll(dir, 0o755) + cfgPath := filepath.Join(dir, "dreadgoad.yaml") + + if _, err := os.Stat(cfgPath); err == nil { + return fmt.Errorf("config file already exists: %s", cfgPath) + } + + content := `# DreadGOAD CLI Configuration +env: dev +# region: us-west-2 # Override AWS region (default: from inventory) +debug: false +max_retries: 3 +retry_delay: 30 +idle_timeout: 1200 +# log_dir: ~/.ansible/logs/goad +# project_root: /path/to/DreadGOAD # Auto-detected if omitted +` + if err := os.WriteFile(cfgPath, []byte(content), 0o644); err != nil { + return err + } + fmt.Printf("Created config: %s\n", cfgPath) + return nil + }, +} + +var configSetCmd = &cobra.Command{ + Use: "set ", + Short: "Set a configuration value", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + viper.Set(args[0], args[1]) + cfgFile := viper.ConfigFileUsed() + if cfgFile == "" { + return fmt.Errorf("no config file found. Run: dreadgoad config init") + } + if err := viper.WriteConfig(); err != nil { + return err + } + fmt.Printf("Set %s = %s in %s\n", args[0], args[1], cfgFile) + return nil + }, +} + +func init() { + rootCmd.AddCommand(configCmd) + configCmd.AddCommand(configShowCmd) + configCmd.AddCommand(configInitCmd) + configCmd.AddCommand(configSetCmd) +} + +func valueOrDefault(v, def string) string { + if v == "" { + return def + } + return v +} diff --git a/cli/cmd/doctor.go b/cli/cmd/doctor.go new file mode 100644 index 00000000..fa0f6645 --- /dev/null +++ b/cli/cmd/doctor.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "github.com/dreadnode/dreadgoad/internal/config" + "github.com/dreadnode/dreadgoad/internal/doctor" + "github.com/spf13/cobra" +) + +var doctorCmd = &cobra.Command{ + Use: "doctor", + Short: "Run pre-flight system checks", + Long: `Verifies that all required tools and configurations are in place: +ansible-core version, AWS CLI, Python, Ansible collections, credentials, and inventory.`, + RunE: func(cmd *cobra.Command, args []string) error { + cfg := config.Get() + results := doctor.RunChecks(cfg.InventoryPath(), cfg.ProjectRoot) + doctor.PrintResults(results) + + for _, r := range results { + if r.Status == "fail" { + return nil // non-zero shown by print, but don't error on doctor + } + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(doctorCmd) +} diff --git a/cli/cmd/inventory.go b/cli/cmd/inventory.go new file mode 100644 index 00000000..ce8461dd --- /dev/null +++ b/cli/cmd/inventory.go @@ -0,0 +1,245 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + daws "github.com/dreadnode/dreadgoad/internal/aws" + "github.com/dreadnode/dreadgoad/internal/config" + inv "github.com/dreadnode/dreadgoad/internal/inventory" + "github.com/spf13/cobra" +) + +var inventoryCmd = &cobra.Command{ + Use: "inventory", + Short: "Manage Ansible inventory", +} + +var inventorySyncCmd = &cobra.Command{ + Use: "sync", + Short: "Synchronize inventory with AWS instance IDs", + RunE: runInventorySync, +} + +var inventoryShowCmd = &cobra.Command{ + Use: "show", + Short: "Display current inventory", + RunE: runInventoryShow, +} + +var inventoryMappingCmd = &cobra.Command{ + Use: "mapping", + Short: "Generate instance-to-IP mapping for Ansible optimization", + RunE: runInventoryMapping, +} + +func init() { + rootCmd.AddCommand(inventoryCmd) + inventoryCmd.AddCommand(inventorySyncCmd) + inventoryCmd.AddCommand(inventoryShowCmd) + inventoryCmd.AddCommand(inventoryMappingCmd) + + inventorySyncCmd.Flags().Bool("backup", false, "Create backup before modifying") + inventorySyncCmd.Flags().String("json", "", "Path to JSON file with instance data") + inventoryMappingCmd.Flags().StringP("output", "o", "", "Output file path") +} + +type instanceInfo struct { + InstanceID string `json:"InstanceId"` + Name string `json:"Name"` +} + +func runInventorySync(cmd *cobra.Command, args []string) error { + cfg := config.Get() + invPath := cfg.InventoryPath() + + if _, err := os.Stat(invPath); os.IsNotExist(err) { + return fmt.Errorf("inventory file not found: %s", invPath) + } + + backup, _ := cmd.Flags().GetBool("backup") + if backup { + if err := backupInventory(invPath); err != nil { + return err + } + } + + if err := updateEnvField(invPath, cfg.Env); err != nil { + return err + } + + jsonFile, _ := cmd.Flags().GetString("json") + instances, err := loadInstances(context.Background(), jsonFile, invPath, cfg.Env) + if err != nil { + return err + } + + return applyInstanceUpdates(invPath, instances) +} + +func backupInventory(invPath string) error { + backupPath := invPath + ".bak." + time.Now().Format("20060102150405") + data, err := os.ReadFile(invPath) + if err != nil { + return fmt.Errorf("read inventory for backup: %w", err) + } + if err := os.WriteFile(backupPath, data, 0o644); err != nil { + return fmt.Errorf("write backup: %w", err) + } + fmt.Printf("Created backup: %s\n", backupPath) + return nil +} + +func updateEnvField(invPath, env string) error { + data, err := os.ReadFile(invPath) + if err != nil { + return err + } + re := regexp.MustCompile(`(?m)^(\s*env=).*$`) + updated := re.ReplaceAllString(string(data), "${1}"+env) + if err := os.WriteFile(invPath, []byte(updated), 0o644); err != nil { + return fmt.Errorf("write inventory: %w", err) + } + return nil +} + +func loadInstances(ctx context.Context, jsonFile, invPath, env string) ([]instanceInfo, error) { + if jsonFile != "" { + raw, err := os.ReadFile(jsonFile) + if err != nil { + return nil, fmt.Errorf("read JSON: %w", err) + } + var instances []instanceInfo + if err := json.Unmarshal(raw, &instances); err != nil { + return nil, fmt.Errorf("parse instance JSON: %w", err) + } + return instances, nil + } + + parsed, err := inv.Parse(invPath) + if err != nil { + return nil, err + } + client, err := daws.NewClient(ctx, parsed.Region()) + if err != nil { + return nil, err + } + awsInstances, err := client.DiscoverInstances(ctx, env) + if err != nil { + return nil, fmt.Errorf("discover instances: %w", err) + } + var instances []instanceInfo + for _, i := range awsInstances { + instances = append(instances, instanceInfo{InstanceID: i.InstanceID, Name: i.Name}) + } + return instances, nil +} + +func applyInstanceUpdates(invPath string, instances []instanceInfo) error { + content, err := os.ReadFile(invPath) + if err != nil { + return fmt.Errorf("read inventory: %w", err) + } + lines := string(content) + updates := 0 + + for _, inst := range instances { + if !strings.Contains(inst.Name, "dreadgoad-") { + continue + } + parts := strings.SplitN(inst.Name, "dreadgoad-", 2) + if len(parts) < 2 { + continue + } + hostname := strings.ToLower(parts[1]) + re := regexp.MustCompile(`(?mi)^(` + regexp.QuoteMeta(hostname) + `\s+ansible_host=)\S+`) + if re.MatchString(lines) { + newLines := re.ReplaceAllString(lines, "${1}"+inst.InstanceID) + if newLines != lines { + lines = newLines + fmt.Printf("Updated %s with instance ID: %s\n", hostname, inst.InstanceID) + updates++ + } + } + } + + if err := os.WriteFile(invPath, []byte(lines), 0o644); err != nil { + return fmt.Errorf("write updated inventory: %w", err) + } + + if updates == 0 { + fmt.Println("No instance ID updates needed. All IDs are current.") + } else { + fmt.Printf("Updated %d instance IDs in %s\n", updates, invPath) + } + return nil +} + +func runInventoryShow(cmd *cobra.Command, args []string) error { + cfg := config.Get() + + parsed, err := inv.Parse(cfg.InventoryPath()) + if err != nil { + return err + } + + fmt.Printf("Inventory: %s (env=%s, region=%s)\n\n", parsed.FilePath, cfg.Env, parsed.Region()) + fmt.Printf("%-8s %-24s %-10s %-10s %s\n", "HOST", "INSTANCE ID", "DICT_KEY", "DNS_DOMAIN", "GROUPS") + fmt.Println(strings.Repeat("-", 80)) + + for _, host := range parsed.Hosts { + groups := strings.Join(host.Groups, ",") + fmt.Printf("%-8s %-24s %-10s %-10s %s\n", + host.Name, host.InstanceID, host.DictKey, host.DNSDomain, groups) + } + return nil +} + +func runInventoryMapping(cmd *cobra.Command, args []string) error { + cfg := config.Get() + ctx := context.Background() + + parsed, err := inv.Parse(cfg.InventoryPath()) + if err != nil { + return err + } + + outputPath, _ := cmd.Flags().GetString("output") + if outputPath == "" { + outputPath = filepath.Join(os.TempDir(), fmt.Sprintf("aws_instance_mapping_%s.json", cfg.Env)) + } + + client, err := daws.NewClient(ctx, parsed.Region()) + if err != nil { + return err + } + + instanceIDs := parsed.InstanceIDs() + fmt.Printf("Querying AWS for %d instance IPs...\n", len(instanceIDs)) + + mapping, err := client.GetInstancePrivateIPs(ctx, instanceIDs) + if err != nil { + return err + } + + output := map[string]interface{}{ + "instance_to_ip": mapping, + } + data, err := json.MarshalIndent(output, "", " ") + if err != nil { + return fmt.Errorf("marshal mapping: %w", err) + } + if err := os.WriteFile(outputPath, data, 0o644); err != nil { + return fmt.Errorf("write mapping: %w", err) + } + + fmt.Printf("Mapping generated: %s\n", outputPath) + fmt.Printf("Mapped %d instances\n", len(mapping)) + return nil +} diff --git a/cli/cmd/lab.go b/cli/cmd/lab.go new file mode 100644 index 00000000..8e4a13bd --- /dev/null +++ b/cli/cmd/lab.go @@ -0,0 +1,121 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + + daws "github.com/dreadnode/dreadgoad/internal/aws" + "github.com/dreadnode/dreadgoad/internal/config" + "github.com/spf13/cobra" +) + +var labCmd = &cobra.Command{ + Use: "lab", + Short: "Manage GOAD lab lifecycle", +} + +var labStatusCmd = &cobra.Command{ + Use: "status", + Short: "Show lab instance states", + RunE: runLabStatus, +} + +var labStartCmd = &cobra.Command{ + Use: "start", + Short: "Start stopped lab instances", + RunE: runLabAction("start"), +} + +var labStopCmd = &cobra.Command{ + Use: "stop", + Short: "Stop running lab instances", + RunE: runLabAction("stop"), +} + +func init() { + rootCmd.AddCommand(labCmd) + labCmd.AddCommand(labStatusCmd) + labCmd.AddCommand(labStartCmd) + labCmd.AddCommand(labStopCmd) +} + +func runLabStatus(cmd *cobra.Command, args []string) error { + cfg := config.Get() + ctx := context.Background() + + region := cfg.Region + if region == "" { + region = "us-west-1" + } + + client, err := daws.NewClient(ctx, region) + if err != nil { + return err + } + + instances, err := client.DiscoverInstances(ctx, cfg.Env) + if err != nil { + return err + } + + if len(instances) == 0 { + fmt.Printf("No GOAD instances found for env=%s\n", cfg.Env) + return nil + } + + fmt.Printf("GOAD Lab Status (%s)\n", cfg.Env) + fmt.Printf("%-40s %-24s %-15s %s\n", "NAME", "INSTANCE ID", "STATE", "PRIVATE IP") + fmt.Println(strings.Repeat("-", 95)) + + for _, inst := range instances { + fmt.Printf("%-40s %-24s %-15s %s\n", + inst.Name, inst.InstanceID, inst.State, inst.PrivateIP) + } + return nil +} + +func runLabAction(action string) func(*cobra.Command, []string) error { + return func(cmd *cobra.Command, args []string) error { + cfg := config.Get() + ctx := context.Background() + + region := cfg.Region + if region == "" { + region = "us-west-1" + } + + client, err := daws.NewClient(ctx, region) + if err != nil { + return err + } + + instances, err := client.DiscoverInstances(ctx, cfg.Env) + if err != nil { + return err + } + + if len(instances) == 0 { + return fmt.Errorf("no GOAD instances found for env=%s", cfg.Env) + } + + var ids []string + for _, inst := range instances { + ids = append(ids, inst.InstanceID) + fmt.Printf(" %s %s (%s)\n", action, inst.Name, inst.InstanceID) + } + + switch action { + case "start": + err = client.StartInstances(ctx, ids) + case "stop": + err = client.StopInstances(ctx, ids) + } + if err != nil { + return fmt.Errorf("%s instances: %w", action, err) + } + + fmt.Printf("\nSuccessfully initiated %s for %d instances\n", action, len(ids)) + return nil + } +} diff --git a/cli/cmd/provision.go b/cli/cmd/provision.go new file mode 100644 index 00000000..db0dfc8d --- /dev/null +++ b/cli/cmd/provision.go @@ -0,0 +1,134 @@ +package cmd + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "time" + + "github.com/dreadnode/dreadgoad/internal/ansible" + "github.com/dreadnode/dreadgoad/internal/config" + "github.com/dreadnode/dreadgoad/internal/doctor" + "github.com/spf13/cobra" +) + +var provisionCmd = &cobra.Command{ + Use: "provision", + Short: "Run GOAD provisioning playbooks with retry logic", + Long: `Runs Ansible playbooks to provision Active Directory infrastructure. + +Executes the full playbook sequence (or a subset) with error-specific +retry strategies, SSM session management, and idle timeout monitoring.`, + Example: ` dreadgoad provision + dreadgoad provision --plays build.yml,ad-servers.yml + dreadgoad provision --env staging --debug + dreadgoad provision --plays ad-data.yml --limit dc01 + dreadgoad provision --max-retries 5 --retry-delay 60`, + RunE: runProvision, +} + +var adUsersCmd = &cobra.Command{ + Use: "ad-users", + Short: "Ensure AD users exist (runs ad-data.yml)", + RunE: func(cmd *cobra.Command, args []string) error { + plays, _ := cmd.Flags().GetString("plays") + if plays == "" { + _ = cmd.Flags().Set("plays", "ad-data.yml") + } + return runProvision(cmd, args) + }, +} + +func init() { + rootCmd.AddCommand(provisionCmd) + rootCmd.AddCommand(adUsersCmd) + + provisionCmd.Flags().String("plays", "", "Comma-separated playbooks to run (default: all)") + provisionCmd.Flags().String("limit", "", "Limit execution to specific hosts") + provisionCmd.Flags().Int("max-retries", 0, "Max retry attempts (default: from config)") + provisionCmd.Flags().Int("retry-delay", 0, "Delay between retries in seconds (default: from config)") + + // ad-users inherits provision flags + adUsersCmd.Flags().String("plays", "ad-data.yml", "Playbooks to run") + adUsersCmd.Flags().String("limit", "", "Limit execution to specific hosts") + adUsersCmd.Flags().Int("max-retries", 0, "Max retry attempts") + adUsersCmd.Flags().Int("retry-delay", 0, "Delay between retries in seconds") +} + +func runProvision(cmd *cobra.Command, args []string) error { + cfg := config.Get() + ctx := context.Background() + + // Determine playbooks + playsFlag, _ := cmd.Flags().GetString("plays") + var playbooks []string + if playsFlag != "" { + playbooks = strings.Split(playsFlag, ",") + } else { + playbooks = cfg.Playbooks + } + + limit, _ := cmd.Flags().GetString("limit") + maxRetries, _ := cmd.Flags().GetInt("max-retries") + retryDelay, _ := cmd.Flags().GetInt("retry-delay") + + // Ensure log directory + _ = os.MkdirAll(cfg.LogDir, 0o755) + logFile := filepath.Join(cfg.LogDir, fmt.Sprintf("%s-dreadgoad-%s.log", + cfg.Env, time.Now().Format("20060102_150405"))) + + // Pre-flight: verify ansible-core version compatibility + if err := doctor.CheckAnsibleCoreVersion(); err != nil { + return fmt.Errorf("ansible-core version check failed: %w", err) + } + + // Pre-flight: prepare ADCS zips + if err := ansible.PrepareADCSZips(cfg.ProjectRoot); err != nil { + slog.Warn("ADCS zip preparation failed", "error", err) + } + + // Log header + fmt.Println("===============================================") + fmt.Printf("DreadGOAD provisioning started at %s\n", time.Now().Format(time.RFC3339)) + fmt.Printf("Environment: %s\n", cfg.Env) + fmt.Printf("Log file: %s\n", logFile) + if limit != "" { + fmt.Printf("Limited to hosts: %s\n", limit) + } + fmt.Println("===============================================") + fmt.Println("\nPlaybooks to execute:") + for _, p := range playbooks { + fmt.Printf(" - ansible/playbooks/%s\n", p) + } + fmt.Println("-----------------------------------------------") + + // Run each playbook + for _, playbook := range playbooks { + opts := ansible.RetryOptions{ + Playbook: playbook, + Env: cfg.Env, + Limit: limit, + Debug: cfg.Debug, + LogFile: logFile, + } + if maxRetries > 0 { + opts.MaxRetries = maxRetries + } + if retryDelay > 0 { + opts.RetryDelay = time.Duration(retryDelay) * time.Second + } + + if err := ansible.RunPlaybookWithRetry(ctx, opts); err != nil { + return fmt.Errorf("provisioning failed at %s: %w", playbook, err) + } + } + + fmt.Println("===============================================") + fmt.Printf("All playbooks completed successfully at %s\n", time.Now().Format(time.RFC3339)) + fmt.Printf("Full log: %s\n", logFile) + fmt.Println("===============================================") + return nil +} diff --git a/cli/cmd/root.go b/cli/cmd/root.go new file mode 100644 index 00000000..8f9af0f7 --- /dev/null +++ b/cli/cmd/root.go @@ -0,0 +1,50 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/dreadnode/dreadgoad/internal/config" + "github.com/dreadnode/dreadgoad/internal/logging" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var rootCmd = &cobra.Command{ + Use: "dreadgoad", + Short: "DreadGOAD - Active Directory lab management CLI", + Long: `DreadGOAD orchestrates the deployment and management of intentionally +vulnerable Active Directory environments for security research and testing. + +It manages the full lifecycle: infrastructure provisioning via Terraform, +configuration via Ansible, validation of vulnerability configurations, +and operational tasks like SSM session management.`, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + cfg := config.Get() + logging.Init(cfg.Debug, cfg.LogDir, cfg.Env) + return nil + }, + SilenceUsage: true, + SilenceErrors: true, +} + +func Execute() error { + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + return err + } + return nil +} + +func init() { + cobra.OnInitialize(config.Init) + + rootCmd.PersistentFlags().StringP("env", "e", "dev", "Target environment (dev, staging, prod)") + rootCmd.PersistentFlags().String("region", "", "AWS region (default: from inventory)") + rootCmd.PersistentFlags().Bool("debug", false, "Enable debug/verbose output") + rootCmd.PersistentFlags().String("config", "", "Config file path") + + _ = viper.BindPFlag("env", rootCmd.PersistentFlags().Lookup("env")) + _ = viper.BindPFlag("region", rootCmd.PersistentFlags().Lookup("region")) + _ = viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug")) +} diff --git a/cli/cmd/ssm.go b/cli/cmd/ssm.go new file mode 100644 index 00000000..3ac14b4d --- /dev/null +++ b/cli/cmd/ssm.go @@ -0,0 +1,245 @@ +package cmd + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/exec" + "strings" + "time" + + daws "github.com/dreadnode/dreadgoad/internal/aws" + "github.com/dreadnode/dreadgoad/internal/config" + "github.com/dreadnode/dreadgoad/internal/inventory" + "github.com/spf13/cobra" +) + +var ssmCmd = &cobra.Command{ + Use: "ssm", + Short: "Manage AWS SSM sessions", +} + +var ssmStatusCmd = &cobra.Command{ + Use: "status", + Short: "Show active SSM sessions for environment instances", + RunE: runSSMStatus, +} + +var ssmCleanupCmd = &cobra.Command{ + Use: "cleanup", + Short: "Terminate stale SSM sessions", + RunE: runSSMCleanup, +} + +var ssmConnectCmd = &cobra.Command{ + Use: "connect ", + Short: "Start interactive SSM session to a host", + Args: cobra.ExactArgs(1), + RunE: runSSMConnect, +} + +var ssmRunCmd = &cobra.Command{ + Use: "run", + Short: "Run PowerShell commands across GOAD instances via SSM", + RunE: runSSMRun, +} + +func init() { + rootCmd.AddCommand(ssmCmd) + ssmCmd.AddCommand(ssmStatusCmd) + ssmCmd.AddCommand(ssmCleanupCmd) + ssmCmd.AddCommand(ssmConnectCmd) + ssmCmd.AddCommand(ssmRunCmd) + + ssmCleanupCmd.Flags().Int("max-age", 30, "Sessions older than this (minutes) are stale") + ssmCleanupCmd.Flags().Bool("dry-run", false, "Show what would be terminated") + + ssmRunCmd.Flags().String("hosts", "all", "Comma-separated host names or 'all'") + ssmRunCmd.Flags().StringP("cmd", "c", "", "PowerShell command to execute") + _ = ssmRunCmd.MarkFlagRequired("cmd") +} + +func runSSMStatus(cmd *cobra.Command, args []string) error { + cfg := config.Get() + ctx := context.Background() + + inv, err := inventory.Parse(cfg.InventoryPath()) + if err != nil { + return fmt.Errorf("parse inventory: %w", err) + } + + client, err := daws.NewClient(ctx, inv.Region()) + if err != nil { + return err + } + + fmt.Printf("Active SSM sessions for %s environment:\n\n", cfg.Env) + + for _, host := range inv.Hosts { + if host.InstanceID == "" { + continue + } + + sessions, err := client.DescribeActiveSessions(ctx, host.InstanceID) + if err != nil { + fmt.Printf("[%s] %s: error: %v\n", host.Name, host.InstanceID, err) + continue + } + + if len(sessions) == 0 { + fmt.Printf("[%s] %s: No active sessions\n", host.Name, host.InstanceID) + } else { + fmt.Printf("[%s] %s: %d active session(s)\n", host.Name, host.InstanceID, len(sessions)) + for _, s := range sessions { + fmt.Printf(" - %s (%s, started: %s)\n", s.SessionID, s.Status, s.StartDate.Format(time.RFC3339)) + } + } + } + return nil +} + +func runSSMCleanup(cmd *cobra.Command, args []string) error { + cfg := config.Get() + ctx := context.Background() + + maxAge, _ := cmd.Flags().GetInt("max-age") + dryRun, _ := cmd.Flags().GetBool("dry-run") + + inv, err := inventory.Parse(cfg.InventoryPath()) + if err != nil { + return fmt.Errorf("parse inventory: %w", err) + } + + client, err := daws.NewClient(ctx, inv.Region()) + if err != nil { + return err + } + + fmt.Printf("Checking for stale SSM sessions (older than %d minutes)...\n", maxAge) + + terminated, err := client.CleanupStaleSessions(ctx, inv.InstanceIDs(), + time.Duration(maxAge)*time.Minute, dryRun, slog.Default()) + if err != nil { + return err + } + + if dryRun { + fmt.Println("\nDry run complete. Use --dry-run=false to actually terminate.") + } else { + fmt.Printf("\nCleanup complete. Terminated %d stale session(s).\n", terminated) + } + return nil +} + +func runSSMConnect(cmd *cobra.Command, args []string) error { + cfg := config.Get() + + inv, err := inventory.Parse(cfg.InventoryPath()) + if err != nil { + return fmt.Errorf("parse inventory: %w", err) + } + + host := inv.HostByName(args[0]) + if host == nil || host.InstanceID == "" { + return fmt.Errorf("host %q not found in inventory", args[0]) + } + + region := inv.Region() + fmt.Printf("Starting SSM session to %s (%s) in %s...\n", host.Name, host.InstanceID, region) + + // Exec into aws ssm start-session + ssmCmd := exec.Command("aws", "ssm", "start-session", + "--target", host.InstanceID, + "--region", region) + ssmCmd.Stdin = os.Stdin + ssmCmd.Stdout = os.Stdout + ssmCmd.Stderr = os.Stderr + return ssmCmd.Run() +} + +func runSSMRun(cmd *cobra.Command, args []string) error { + cfg := config.Get() + ctx := context.Background() + + hostsFlag, _ := cmd.Flags().GetString("hosts") + psCmd, _ := cmd.Flags().GetString("cmd") + + // Determine region - prefer flag, then inventory + region := cfg.Region + if region == "" { + region = "us-west-1" + } + + client, err := daws.NewClient(ctx, region) + if err != nil { + return err + } + + // Discover instances + instances, err := client.DiscoverInstances(ctx, cfg.Env) + if err != nil { + return fmt.Errorf("discover instances: %w", err) + } + + if len(instances) == 0 { + return fmt.Errorf("no running GOAD instances found for env=%s", cfg.Env) + } + + targetIDs, targetNames := filterInstances(instances, hostsFlag) + if len(targetIDs) == 0 { + return fmt.Errorf("no matching instances found") + } + + fmt.Printf("Running command on: %s\n", strings.Join(targetNames, ", ")) + fmt.Printf("Command: %s\n\n", psCmd) + + results, err := client.RunPowerShellOnMultiple(ctx, targetIDs, psCmd, 60*time.Second) + if err != nil { + return err + } + + for i, id := range targetIDs { + name := targetNames[i] + result := results[id] + fmt.Printf("=== %s (%s) ===\n", name, id) + if result != nil { + fmt.Printf("Status: %s\n", result.Status) + if result.Stdout != "" { + fmt.Println(result.Stdout) + } + if result.Stderr != "" { + fmt.Printf("STDERR: %s\n", result.Stderr) + } + } + fmt.Println() + } + return nil +} + +func filterInstances(instances []daws.Instance, hostsFlag string) ([]string, []string) { + var ids, names []string + if hostsFlag == "all" { + for _, inst := range instances { + ids = append(ids, inst.InstanceID) + names = append(names, inst.Name) + } + return ids, names + } + for _, hostName := range strings.Split(hostsFlag, ",") { + hostName = strings.TrimSpace(hostName) + found := false + for _, inst := range instances { + if strings.Contains(strings.ToUpper(inst.Name), strings.ToUpper(hostName)) { + ids = append(ids, inst.InstanceID) + names = append(names, inst.Name) + found = true + break + } + } + if !found { + fmt.Printf("WARNING: Host %q not found\n", hostName) + } + } + return ids, names +} diff --git a/cli/cmd/validate.go b/cli/cmd/validate.go new file mode 100644 index 00000000..05580501 --- /dev/null +++ b/cli/cmd/validate.go @@ -0,0 +1,102 @@ +package cmd + +import ( + "context" + "fmt" + "log/slog" + "time" + + daws "github.com/dreadnode/dreadgoad/internal/aws" + "github.com/dreadnode/dreadgoad/internal/config" + "github.com/dreadnode/dreadgoad/internal/validate" + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +var validateCmd = &cobra.Command{ + Use: "validate", + Short: "Validate GOAD vulnerability configurations", + Long: `Validates that all GOAD vulnerabilities are properly configured by +running checks via SSM PowerShell commands against live instances. + +Checks credentials, Kerberos, SMB, delegation, MSSQL, ADCS, ACLs, trusts, and services.`, + Example: ` dreadgoad validate + dreadgoad validate --env staging --verbose + dreadgoad validate --format json --output /tmp/results.json + dreadgoad validate --no-fail`, + RunE: runValidate, +} + +func init() { + rootCmd.AddCommand(validateCmd) + + validateCmd.Flags().String("format", "table", "Output format: table or json") + validateCmd.Flags().String("output", "", "JSON report output path") + validateCmd.Flags().Bool("verbose", false, "Enable verbose output") + validateCmd.Flags().Bool("no-fail", false, "Don't exit with error on failed checks") +} + +func runValidate(cmd *cobra.Command, args []string) error { + cfg := config.Get() + ctx := context.Background() + + verbose, _ := cmd.Flags().GetBool("verbose") + outputPath, _ := cmd.Flags().GetString("output") + noFail, _ := cmd.Flags().GetBool("no-fail") + + // Determine region + region := cfg.Region + if region == "" { + region = "us-west-1" // validate default matches Taskfile + } + + client, err := daws.NewClient(ctx, region) + if err != nil { + return fmt.Errorf("create AWS client: %w", err) + } + + fmt.Println("==========================================") + fmt.Println("GOAD Vulnerability Validation") + fmt.Println("==========================================") + fmt.Printf("Environment: %s\n", cfg.Env) + fmt.Printf("Region: %s\n", region) + + v := validate.NewValidator(client, cfg.Env, verbose, slog.Default()) + + if err := v.DiscoverHosts(ctx); err != nil { + return fmt.Errorf("discover hosts: %w", err) + } + + v.RunAllChecks(ctx) + + report := v.GetReport() + + // Save JSON report + if outputPath == "" { + outputPath = fmt.Sprintf("/tmp/goad-validation-%s.json", time.Now().Format("20060102-150405")) + } + if err := v.SaveReport(outputPath); err != nil { + fmt.Printf("Warning: could not save report: %v\n", err) + } + + // Print summary + fmt.Println("\n==========================================") + fmt.Println("Validation Summary") + fmt.Println("==========================================") + fmt.Printf("Total Checks: %d\n", report.Total) + color.Green("Passed: %d", report.Passed) + color.Red("Failed: %d", report.Failed) + color.Yellow("Warnings: %d", report.Warnings) + + if report.Total > 0 { + pct := report.Passed * 100 / report.Total + fmt.Printf("\nSuccess Rate: %d%%\n", pct) + } + + fmt.Printf("\nResults saved to: %s\n", outputPath) + + if !noFail && report.Failed > 0 { + return fmt.Errorf("validation failed with %d errors", report.Failed) + } + return nil +} diff --git a/cli/go.mod b/cli/go.mod new file mode 100644 index 00000000..2a79391b --- /dev/null +++ b/cli/go.mod @@ -0,0 +1,43 @@ +module github.com/dreadnode/dreadgoad + +go 1.26.1 + +require ( + github.com/aws/aws-sdk-go-v2 v1.41.5 + github.com/aws/aws-sdk-go-v2/config v1.32.13 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2 + github.com/aws/aws-sdk-go-v2/service/ssm v1.68.4 + github.com/fatih/color v1.19.0 + github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.19.13 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect + github.com/aws/smithy-go v1.24.2 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.28.0 // indirect +) diff --git a/cli/go.sum b/cli/go.sum new file mode 100644 index 00000000..f2323420 --- /dev/null +++ b/cli/go.sum @@ -0,0 +1,93 @@ +github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= +github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/config v1.32.13 h1:5KgbxMaS2coSWRrx9TX/QtWbqzgQkOdEa3sZPhBhCSg= +github.com/aws/aws-sdk-go-v2/config v1.32.13/go.mod h1:8zz7wedqtCbw5e9Mi2doEwDyEgHcEE9YOJp6a8jdSMY= +github.com/aws/aws-sdk-go-v2/credentials v1.19.13 h1:mA59E3fokBvyEGHKFdnpNNrvaR351cqiHgRg+JzOSRI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.13/go.mod h1:yoTXOQKea18nrM69wGF9jBdG4WocSZA1h38A+t/MAsk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2 h1:Ytu50ChAxCiDsOlBcBq8jbczXy6+QLb07T65DBJASRs= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2/go.mod h1:R+2BNtUfTfhPY0RH18oL02q116bakeBWjanrbnVBqkM= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI= +github.com/aws/aws-sdk-go-v2/service/ssm v1.68.4 h1:5Wg8AAAnIWM2LE/0KFGqllZff96bm4dBs+uerYFfReE= +github.com/aws/aws-sdk-go-v2/service/ssm v1.68.4/go.mod h1:nph0ypDLWm9D9iA9zOX39W/N+A4GqwzlxA13jzXVD4k= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 h1:GcLE9ba5ehAQma6wlopUesYg/hbcOhFNWTjELkiWkh4= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.14/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 h1:mP49nTpfKtpXLt5SLn8Uv8z6W+03jYVoOSAl/c02nog= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w= +github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/cli/internal/ansible/errors.go b/cli/internal/ansible/errors.go new file mode 100644 index 00000000..4d7204c4 --- /dev/null +++ b/cli/internal/ansible/errors.go @@ -0,0 +1,105 @@ +package ansible + +import ( + "regexp" + "strings" +) + +// ErrorType classifies Ansible failures for error-specific retry strategies. +type ErrorType string + +const ( + ErrFactGathering ErrorType = "fact_gathering" + ErrNetworkAdapter ErrorType = "network_adapter" + ErrSSMTransfer ErrorType = "ssm_transfer_error" + ErrSSMReconnection ErrorType = "ssm_reconnection_needed" + ErrPowerShell ErrorType = "powershell_interactive" + ErrSSMUserAccount ErrorType = "ssm_user_account_issue" + ErrMSIInstaller ErrorType = "msi_installer_error" + ErrUnclassified ErrorType = "unclassified" +) + +var fatalMsgRe = regexp.MustCompile(`(?m)msg:|rc:|stderr:`) + +// DetectErrorType analyzes Ansible output and classifies the failure. +func DetectErrorType(output string) (ErrorType, string) { + switch { + case containsAny(output, + "FAILED! => .* setup", + "Invalid control character", + "modules failed to execute: ansible.legacy.setup", + "Module result deserialization failed"): + return ErrFactGathering, "fact gathering/module deserialization failure" + + case strings.Contains(output, "No MSFT_NetAdapter objects found with property 'Name' equal to 'Ethernet3'"): + return ErrNetworkAdapter, "network adapter Ethernet3 not found" + + case strings.Contains(output, "failed to transfer file"): + return ErrSSMTransfer, "SSM file transfer error" + + case containsAny(output, "TargetNotConnected", "is not connected", + "Timed out waiting for last boot time", "timeout waiting for system to reboot"): + return ErrSSMReconnection, "SSM target not connected / reboot timeout" + + case strings.Contains(output, "Windows PowerShell is in NonInteractive mode"): + return ErrPowerShell, "PowerShell interactive mode issue" + + case containsAny(output, "ssm-user.*disabled", "SSM.*account.*issue", "Windows Local SAM"): + return ErrSSMUserAccount, "SSM user account disabled/destroyed" + + case containsAny(output, "rc: 1603", "rc: 3010"): + return ErrMSIInstaller, "MSI installer error (rc 1603/3010)" + + default: + detail := extractFatalContext(output) + return ErrUnclassified, detail + } +} + +func containsAny(s string, patterns ...string) bool { + for _, p := range patterns { + if strings.Contains(p, ".*") || strings.Contains(p, "[") { + if re, err := regexp.Compile(p); err == nil && re.MatchString(s) { + return true + } + } else if strings.Contains(s, p) { + return true + } + } + return false +} + +func extractFatalContext(output string) string { + lines := strings.Split(output, "\n") + for i, line := range lines { + if strings.HasPrefix(line, "fatal:") { + // Grab up to 5 lines after for context + end := i + 6 + if end > len(lines) { + end = len(lines) + } + context := strings.Join(lines[i:end], "\n") + // Extract msg/rc/stderr lines + matches := fatalMsgRe.FindAllString(context, -1) + if len(matches) > 0 { + return strings.TrimSpace(context) + } + // Truncate to 120 chars + if len(line) > 120 { + return line[:120] + } + return line + } + } + // Fallback: last FAILED/fatal line + for i := len(lines) - 1; i >= 0; i-- { + if strings.Contains(lines[i], "FAILED") || strings.Contains(lines[i], "fatal") { + line := lines[i] + if len(line) > 120 { + return line[:120] + } + return line + } + } + return "unknown error" +} diff --git a/cli/internal/ansible/errors_test.go b/cli/internal/ansible/errors_test.go new file mode 100644 index 00000000..a9372e43 --- /dev/null +++ b/cli/internal/ansible/errors_test.go @@ -0,0 +1,229 @@ +package ansible + +import ( + "testing" +) + +func TestDetectErrorType(t *testing.T) { + tests := []struct { + name string + output string + wantType ErrorType + wantMsg string + }{ + { + name: "fact gathering with setup failure", + output: `FAILED! => {"msg": "MODULE FAILURE", "module_stdout": ""} setup`, + wantType: ErrFactGathering, + wantMsg: "fact gathering", + }, + { + name: "fact gathering with invalid control character", + output: "Invalid control character at: line 1 column 2", + wantType: ErrFactGathering, + wantMsg: "fact gathering", + }, + { + name: "fact gathering with module deserialization", + output: "Module result deserialization failed for some task", + wantType: ErrFactGathering, + wantMsg: "fact gathering", + }, + { + name: "fact gathering with legacy setup failure", + output: "modules failed to execute: ansible.legacy.setup", + wantType: ErrFactGathering, + wantMsg: "fact gathering", + }, + { + name: "network adapter not found", + output: "No MSFT_NetAdapter objects found with property 'Name' equal to 'Ethernet3'", + wantType: ErrNetworkAdapter, + wantMsg: "network adapter", + }, + { + name: "SSM file transfer error", + output: "failed to transfer file to remote host", + wantType: ErrSSMTransfer, + wantMsg: "SSM file transfer", + }, + { + name: "SSM target not connected", + output: "An error occurred (TargetNotConnected) when calling the SendCommand operation", + wantType: ErrSSMReconnection, + wantMsg: "SSM target not connected", + }, + { + name: "SSM host not connected", + output: "host is not connected to SSM", + wantType: ErrSSMReconnection, + wantMsg: "not connected", + }, + { + name: "reboot timeout", + output: "Timed out waiting for last boot time", + wantType: ErrSSMReconnection, + wantMsg: "reboot timeout", + }, + { + name: "timeout waiting for reboot", + output: "timeout waiting for system to reboot", + wantType: ErrSSMReconnection, + wantMsg: "reboot timeout", + }, + { + name: "PowerShell non-interactive mode", + output: "Windows PowerShell is in NonInteractive mode. Read and Prompt", + wantType: ErrPowerShell, + wantMsg: "PowerShell interactive", + }, + { + name: "SSM user disabled", + output: "The ssm-user account is disabled on this instance", + wantType: ErrSSMUserAccount, + wantMsg: "SSM user account", + }, + { + name: "MSI installer rc 1603", + output: "fatal: [DC01]: FAILED! => {\"changed\": true, \"rc: 1603\"}", + wantType: ErrMSIInstaller, + wantMsg: "MSI installer", + }, + { + name: "MSI installer rc 3010", + output: "fatal: [DC01]: FAILED! => {\"changed\": true, \"rc: 3010\"}", + wantType: ErrMSIInstaller, + wantMsg: "MSI installer", + }, + { + name: "unclassified error with fatal line", + output: "fatal: [DC01]: FAILED! => {\"msg\": \"some unknown error\"}", + wantType: ErrUnclassified, + }, + { + name: "unclassified error no fatal", + output: "something went wrong", + wantType: ErrUnclassified, + wantMsg: "unknown error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotType, gotMsg := DetectErrorType(tt.output) + if gotType != tt.wantType { + t.Errorf("DetectErrorType() type = %q, want %q", gotType, tt.wantType) + } + if tt.wantMsg != "" { + if len(gotMsg) == 0 { + t.Errorf("DetectErrorType() msg is empty, want to contain %q", tt.wantMsg) + } + } + }) + } +} + +func TestContainsAny(t *testing.T) { + tests := []struct { + name string + s string + patterns []string + want bool + }{ + { + name: "plain string match", + s: "hello world", + patterns: []string{"world"}, + want: true, + }, + { + name: "no match", + s: "hello world", + patterns: []string{"foo", "bar"}, + want: false, + }, + { + name: "regex pattern match", + s: "ssm-user is disabled", + patterns: []string{"ssm-user.*disabled"}, + want: true, + }, + { + name: "regex pattern no match", + s: "ssm-user is active", + patterns: []string{"ssm-user.*disabled"}, + want: false, + }, + { + name: "multiple patterns first matches", + s: "error: TargetNotConnected", + patterns: []string{"TargetNotConnected", "is not connected"}, + want: true, + }, + { + name: "multiple patterns second matches", + s: "host is not connected", + patterns: []string{"TargetNotConnected", "is not connected"}, + want: true, + }, + { + name: "empty input", + s: "", + patterns: []string{"foo"}, + want: false, + }, + { + name: "empty patterns", + s: "hello", + patterns: []string{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := containsAny(tt.s, tt.patterns...) + if got != tt.want { + t.Errorf("containsAny() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestExtractFatalContext(t *testing.T) { + tests := []struct { + name string + output string + want string + }{ + { + name: "no fatal line", + output: "TASK [some task]\nok: [DC01]\n", + want: "unknown error", + }, + { + name: "fatal line with msg", + output: "TASK [failing]\nfatal: [DC01]: FAILED! => {\"msg\": \"broken\"}\nrc: 1", + want: "fatal: [DC01]: FAILED! => {\"msg\": \"broken\"}\nrc: 1", + }, + { + name: "fatal line truncated to 120 chars", + output: "fatal: " + string(make([]byte, 200)), + want: "fatal: " + string(make([]byte, 113)), + }, + { + name: "FAILED in last line as fallback", + output: "some output\nTask FAILED with error", + want: "Task FAILED with error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFatalContext(tt.output) + if got != tt.want { + t.Errorf("extractFatalContext() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/cli/internal/ansible/logparser.go b/cli/internal/ansible/logparser.go new file mode 100644 index 00000000..4e4b433f --- /dev/null +++ b/cli/internal/ansible/logparser.go @@ -0,0 +1,62 @@ +package ansible + +import ( + "regexp" + "strings" +) + +var ( + failedRe = regexp.MustCompile(`failed=[1-9][0-9]*`) + unreachableRe = regexp.MustCompile(`unreachable=[1-9][0-9]*`) + failedHostRe = regexp.MustCompile(`(?m)^([a-zA-Z0-9_-]+)\s+:.*failed=[1-9]`) +) + +// CheckAnsibleSuccess analyzes Ansible output to determine if the run succeeded. +// Returns true if no failures detected. +func CheckAnsibleSuccess(output string) bool { + // Primary: check PLAY RECAP for failures + if idx := strings.Index(output, "PLAY RECAP"); idx >= 0 { + recap := output[idx:] + if failedRe.MatchString(recap) || unreachableRe.MatchString(recap) { + return false + } + } + + // Secondary: check for fatal errors not followed by "...ignoring" + lines := strings.Split(output, "\n") + for i, line := range lines { + if strings.HasPrefix(line, "fatal:") { + // Check next 10 lines for "...ignoring" + end := i + 11 + if end > len(lines) { + end = len(lines) + } + context := strings.Join(lines[i:end], "\n") + if !strings.Contains(context, "...ignoring") { + return false + } + } + } + + // Check for retry indicator + if strings.Contains(output, "to retry, use:") { + return false + } + + return true +} + +// ExtractFailedHosts parses PLAY RECAP to find hosts with failures. +func ExtractFailedHosts(output string) []string { + matches := failedHostRe.FindAllStringSubmatch(output, -1) + var hosts []string + seen := make(map[string]bool) + for _, m := range matches { + host := m[1] + if !seen[host] { + seen[host] = true + hosts = append(hosts, host) + } + } + return hosts +} diff --git a/cli/internal/ansible/logparser_test.go b/cli/internal/ansible/logparser_test.go new file mode 100644 index 00000000..22caf3ac --- /dev/null +++ b/cli/internal/ansible/logparser_test.go @@ -0,0 +1,142 @@ +package ansible + +import ( + "testing" +) + +func TestCheckAnsibleSuccess(t *testing.T) { + tests := []struct { + name string + output string + want bool + }{ + { + name: "all hosts ok", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=15 changed=3 unreachable=0 failed=0 skipped=2 rescued=0 ignored=0 +DC02 : ok=12 changed=1 unreachable=0 failed=0 skipped=1 rescued=0 ignored=0`, + want: true, + }, + { + name: "host with failures in recap", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=10 changed=2 unreachable=0 failed=3 skipped=1 rescued=0 ignored=0`, + want: false, + }, + { + name: "host unreachable in recap", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=0 changed=0 unreachable=1 failed=0 skipped=0 rescued=0 ignored=0`, + want: false, + }, + { + name: "fatal error followed by ignoring", + output: `TASK [some task] +fatal: [DC01]: FAILED! => {"msg": "non-critical error"} +...ignoring +PLAY RECAP ********************************************************************* +DC01 : ok=10 changed=2 unreachable=0 failed=0 skipped=1 rescued=0 ignored=1`, + want: true, + }, + { + name: "fatal error not ignored", + output: `TASK [some task] +fatal: [DC01]: FAILED! => {"msg": "critical error"} +NO MORE HOSTS LEFT *************************************************************`, + want: false, + }, + { + name: "retry indicator present", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=5 changed=0 unreachable=0 failed=0 skipped=0 rescued=0 ignored=0 +to retry, use: --limit @/path/to/retry.yml`, + want: false, + }, + { + name: "empty output", + output: "", + want: true, + }, + { + name: "no recap section", + output: "TASK [Gathering Facts]\nok: [DC01]", + want: true, + }, + { + name: "failed=10 double digits", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=5 changed=0 unreachable=0 failed=10 skipped=0 rescued=0 ignored=0`, + want: false, + }, + { + name: "failed=0 should pass", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=5 changed=0 unreachable=0 failed=0 skipped=0 rescued=0 ignored=0`, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CheckAnsibleSuccess(tt.output) + if got != tt.want { + t.Errorf("CheckAnsibleSuccess() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestExtractFailedHosts(t *testing.T) { + tests := []struct { + name string + output string + want []string + }{ + { + name: "single failed host", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=10 changed=2 unreachable=0 failed=3 skipped=1 rescued=0 ignored=0 +DC02 : ok=15 changed=3 unreachable=0 failed=0 skipped=2 rescued=0 ignored=0`, + want: []string{"DC01"}, + }, + { + name: "multiple failed hosts", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=10 changed=2 unreachable=0 failed=3 skipped=1 rescued=0 ignored=0 +DC02 : ok=5 changed=1 unreachable=0 failed=1 skipped=0 rescued=0 ignored=0 +SRV01 : ok=15 changed=3 unreachable=0 failed=0 skipped=2 rescued=0 ignored=0`, + want: []string{"DC01", "DC02"}, + }, + { + name: "no failed hosts", + output: `PLAY RECAP ********************************************************************* +DC01 : ok=10 changed=2 unreachable=0 failed=0 skipped=1 rescued=0 ignored=0`, + want: nil, + }, + { + name: "empty output", + output: "", + want: nil, + }, + { + name: "deduplicated hosts", + output: `DC01 : ok=10 changed=2 unreachable=0 failed=3 skipped=1 rescued=0 ignored=0 +DC01 : ok=5 changed=0 unreachable=0 failed=1 skipped=0 rescued=0 ignored=0`, + want: []string{"DC01"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractFailedHosts(tt.output) + if len(got) != len(tt.want) { + t.Fatalf("ExtractFailedHosts() returned %d hosts %v, want %d hosts %v", len(got), got, len(tt.want), tt.want) + } + for i, host := range got { + if host != tt.want[i] { + t.Errorf("ExtractFailedHosts()[%d] = %q, want %q", i, host, tt.want[i]) + } + } + }) + } +} diff --git a/cli/internal/ansible/prepare.go b/cli/internal/ansible/prepare.go new file mode 100644 index 00000000..3c4373c2 --- /dev/null +++ b/cli/internal/ansible/prepare.go @@ -0,0 +1,40 @@ +package ansible + +import ( + "log/slog" + "os" + "os/exec" + "path/filepath" +) + +// PrepareADCSZips creates the ADCSTemplate.zip files needed by ADCS roles. +func PrepareADCSZips(projectRoot string) error { + dirs := []string{ + filepath.Join(projectRoot, "ansible", "roles", "adcs_templates", "files"), + filepath.Join(projectRoot, "ansible", "roles", "vulns_adcs_templates", "files"), + } + + for _, dir := range dirs { + zipPath := filepath.Join(dir, "ADCSTemplate.zip") + templateDir := filepath.Join(dir, "ADCSTemplate") + + // Skip if zip already exists + if _, err := os.Stat(zipPath); err == nil { + continue + } + + // Skip if template dir doesn't exist + if _, err := os.Stat(templateDir); os.IsNotExist(err) { + continue + } + + slog.Info("creating ADCS template zip", "dir", dir) + cmd := exec.Command("zip", "-r", "ADCSTemplate.zip", "ADCSTemplate/") + cmd.Dir = dir + if output, err := cmd.CombinedOutput(); err != nil { + slog.Warn("failed to create ADCS zip", "dir", dir, "error", err, "output", string(output)) + return err + } + } + return nil +} diff --git a/cli/internal/ansible/retry.go b/cli/internal/ansible/retry.go new file mode 100644 index 00000000..dbc5c3ba --- /dev/null +++ b/cli/internal/ansible/retry.go @@ -0,0 +1,300 @@ +package ansible + +import ( + "context" + "fmt" + "log/slog" + "os/exec" + "path/filepath" + "strings" + "time" + + daws "github.com/dreadnode/dreadgoad/internal/aws" + "github.com/dreadnode/dreadgoad/internal/config" + "github.com/dreadnode/dreadgoad/internal/inventory" +) + +// RetryOptions configures the retry behavior for playbook execution. +type RetryOptions struct { + Playbook string + Env string + Limit string + Debug bool + MaxRetries int + RetryDelay time.Duration + LogFile string + Log *slog.Logger // optional; falls back to slog.Default() +} + +func (o *RetryOptions) logger() *slog.Logger { + if o.Log != nil { + return o.Log + } + return slog.Default() +} + +// RunPlaybookWithRetry runs a playbook with error-specific retry logic. +func RunPlaybookWithRetry(ctx context.Context, opts RetryOptions) error { + cfg := config.Get() + log := opts.logger() + + if opts.MaxRetries == 0 { + opts.MaxRetries = cfg.MaxRetries + } + if opts.RetryDelay == 0 { + opts.RetryDelay = time.Duration(cfg.RetryDelay) * time.Second + } + + for attempt := range opts.MaxRetries { + if attempt > 0 { + log.Info("retry attempt", "attempt", attempt, "playbook", opts.Playbook) + log.Info("waiting before retry", "delay", opts.RetryDelay) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(opts.RetryDelay): + } + } + + log.Info("starting playbook", "playbook", opts.Playbook, "attempt", attempt+1, "max", opts.MaxRetries) + + result := RunPlaybook(ctx, RunOptions{ + Playbook: opts.Playbook, + Env: opts.Env, + Limit: opts.Limit, + Debug: opts.Debug, + LogFile: opts.LogFile, + }) + + if result.TimedOut { + log.Error("playbook timed out (idle timeout)", "playbook", opts.Playbook) + cleanupSSMSessions(ctx, opts.Env, log) + continue + } + + if result.Success { + log.Info("playbook completed successfully", "playbook", opts.Playbook) + return nil + } + + log.Warn("playbook failed", "playbook", opts.Playbook, + "error_type", result.ErrorType, "detail", result.ErrorDetail, + "failed_hosts", result.FailedHosts) + + retryResult := retryWithErrorStrategy(ctx, opts, result, log) + if retryResult != nil && retryResult.Success { + log.Info("playbook succeeded after error-specific retry", "playbook", opts.Playbook) + return nil + } + } + + return fmt.Errorf("playbook %s failed after %d attempts", opts.Playbook, opts.MaxRetries) +} + +func retryWithErrorStrategy(ctx context.Context, opts RetryOptions, failResult *RunResult, log *slog.Logger) *RunResult { + failedHostsStr := strings.Join(failResult.FailedHosts, ",") + limit := buildRetryLimit(opts.Limit, failedHostsStr) + + baseOpts := RunOptions{ + Playbook: opts.Playbook, + Env: opts.Env, + Limit: limit, + Debug: opts.Debug, + LogFile: opts.LogFile, + } + + switch failResult.ErrorType { + case ErrFactGathering: + log.Info("retrying with modified fact gathering settings") + baseOpts.Forks = 1 + baseOpts.ExtraVars = map[string]string{ + "ansible_facts_gathering_timeout": "60", + "gather_timeout": "60", + } + baseOpts.ExtraEnv = map[string]string{ + "ANSIBLE_GATHERING": "explicit", + } + return RunPlaybook(ctx, baseOpts) + + case ErrNetworkAdapter: + log.Info("retrying with network adapter fix") + baseOpts.ExtraVars = map[string]string{ + "skip_network_adapter_config": "true", + "bypass_ethernet3_check": "true", + } + return RunPlaybook(ctx, baseOpts) + + case ErrSSMTransfer: + log.Info("SSM transfer error - fixing ssm-user accounts") + cleanupSSMSessions(ctx, opts.Env, log) + fixSSMUsers(ctx, opts.Env, failResult.FailedHosts, log) + log.Info("waiting for SSM Agent to stabilize", "delay", "30s") + time.Sleep(30 * time.Second) + + baseOpts.Forks = 1 + baseOpts.ExtraVars = map[string]string{ + "ansible_aws_ssm_retries": "10", + "ansible_aws_ssm_retry_delay": "30", + "ansible_connection_timeout": "300", + "ansible_command_timeout": "300", + "ansible_aws_ssm_timeout": "300", + } + baseOpts.ExtraEnv = map[string]string{"ANSIBLE_TIMEOUT": "300"} + return RunPlaybook(ctx, baseOpts) + + case ErrSSMReconnection: + log.Info("SSM reconnection needed - waiting for systems to reboot") + cleanupSSMSessions(ctx, opts.Env, log) + log.Info("waiting for Windows reboot and SSM reconnection", "delay", "120s") + time.Sleep(120 * time.Second) + + fixSSMUsers(ctx, opts.Env, failResult.FailedHosts, log) + time.Sleep(10 * time.Second) + + baseOpts.Forks = 1 + baseOpts.ExtraVars = map[string]string{ + "ansible_connection_timeout": "180", + "ansible_timeout": "180", + "ansible_facts_gathering_timeout": "60", + } + baseOpts.ExtraEnv = map[string]string{"ANSIBLE_TIMEOUT": "180"} + return RunPlaybook(ctx, baseOpts) + + case ErrPowerShell: + log.Info("retrying with PowerShell interactive mode fix") + baseOpts.ExtraVars = map[string]string{ + "ansible_shell_type": "powershell", + "force_ps_module": "true", + "ansible_ps_version": "5.1", + } + return RunPlaybook(ctx, baseOpts) + + case ErrSSMUserAccount: + log.Info("SSM user account issue - recreating as domain account") + fixSSMUsers(ctx, opts.Env, failResult.FailedHosts, log) + log.Info("waiting for SSM Agent to stabilize", "delay", "30s") + time.Sleep(30 * time.Second) + + baseOpts.Forks = 1 + baseOpts.ExtraVars = map[string]string{ + "ansible_connection_timeout": "180", + "ansible_timeout": "180", + "ansible_aws_ssm_timeout": "300", + } + baseOpts.ExtraEnv = map[string]string{"ANSIBLE_TIMEOUT": "180"} + return RunPlaybook(ctx, baseOpts) + + case ErrMSIInstaller: + log.Info("MSI installer error - rebooting failed hosts before retry") + rebootFailedHosts(ctx, opts, log) + time.Sleep(30 * time.Second) + + baseOpts.Forks = 1 + return RunPlaybook(ctx, baseOpts) + + default: + log.Info("retrying with general robust settings") + baseOpts.Forks = 1 + baseOpts.ExtraEnv = map[string]string{ + "ANSIBLE_SSH_RETRIES": "5", + "ANSIBLE_TIMEOUT": "120", + } + return RunPlaybook(ctx, baseOpts) + } +} + +func buildRetryLimit(userLimit, failedHosts string) string { + switch { + case userLimit != "" && failedHosts != "": + return userLimit + "," + failedHosts + case userLimit != "": + return userLimit + default: + return failedHosts + } +} + +func cleanupSSMSessions(ctx context.Context, env string, log *slog.Logger) { + cfg := config.Get() + inv, err := inventory.Parse(cfg.InventoryPath()) + if err != nil { + log.Warn("could not parse inventory for SSM cleanup", "error", err) + return + } + + client, err := daws.NewClient(ctx, inv.Region()) + if err != nil { + log.Warn("could not create AWS client for SSM cleanup", "error", err) + return + } + + terminated, err := client.CleanupStaleSessions(ctx, inv.InstanceIDs(), 15*time.Minute, false, log) + if err != nil { + log.Warn("SSM cleanup error", "error", err) + } + if terminated > 0 { + log.Info("terminated stale SSM sessions", "count", terminated) + time.Sleep(5 * time.Second) + } +} + +func fixSSMUsers(ctx context.Context, env string, failedHosts []string, log *slog.Logger) { + if len(failedHosts) == 0 { + return + } + + cfg := config.Get() + inv, err := inventory.Parse(cfg.InventoryPath()) + if err != nil { + log.Warn("could not parse inventory for ssm-user fix", "error", err) + return + } + + client, err := daws.NewClient(ctx, inv.Region()) + if err != nil { + log.Warn("could not create AWS client for ssm-user fix", "error", err) + return + } + + for _, hostName := range failedHosts { + host := inv.HostByName(hostName) + if host == nil || host.InstanceID == "" { + log.Warn("host not found in inventory", "host", hostName) + continue + } + + log.Info("fixing ssm-user", "host", hostName, "instance", host.InstanceID) + + if err := client.EnableSSMUserLocal(ctx, host.InstanceID); err != nil { + log.Info("local enable failed, trying domain account fix", "host", hostName) + if err := client.FixSSMUserViaDomainAccount(ctx, host.InstanceID); err != nil { + log.Warn("ssm-user fix failed", "host", hostName, "error", err) + } + } + } +} + +func rebootFailedHosts(ctx context.Context, opts RetryOptions, log *slog.Logger) { + cfg := config.Get() + for _, host := range strings.Split(opts.Limit, ",") { + if host == "" { + continue + } + log.Info("rebooting host before retry", "host", host) + args := []string{ + host, "-i", filepath.Join(cfg.ProjectRoot, opts.Env+"-inventory"), + "-m", "ansible.windows.win_reboot", + "-a", "reboot_timeout=600 post_reboot_delay=60", + } + rebootCmd := execCommand(ctx, "ansible", args...) + rebootCmd.Dir = cfg.ProjectRoot + rebootCmd.Env = buildEnv(RunOptions{Env: opts.Env}, cfg) + if output, err := rebootCmd.CombinedOutput(); err != nil { + log.Warn("reboot failed", "host", host, "error", err, "output", string(output)) + } + } +} + +// execCommand is a variable for testability. +var execCommand = exec.CommandContext diff --git a/cli/internal/ansible/runner.go b/cli/internal/ansible/runner.go new file mode 100644 index 00000000..27dae0b8 --- /dev/null +++ b/cli/internal/ansible/runner.go @@ -0,0 +1,228 @@ +package ansible + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" + "sync/atomic" + "syscall" + "time" + + "github.com/dreadnode/dreadgoad/internal/config" +) + +// RunOptions configures a single ansible-playbook execution. +type RunOptions struct { + Playbook string + Env string + Limit string + Forks int + ExtraVars map[string]string + ExtraEnv map[string]string + Debug bool + IdleTimeout time.Duration + LogFile string +} + +// RunResult holds the outcome of an ansible-playbook execution. +type RunResult struct { + ExitCode int + Output string + Success bool + FailedHosts []string + ErrorType ErrorType + ErrorDetail string + TimedOut bool +} + +// RunPlaybook executes ansible-playbook with idle timeout monitoring. +func RunPlaybook(ctx context.Context, opts RunOptions) *RunResult { + cfg := config.Get() + result := &RunResult{} + + args := buildArgs(opts, cfg) + env := buildEnv(opts, cfg) + + slog.Info("running playbook", "playbook", opts.Playbook, "args", strings.Join(args, " ")) + + cmd := exec.CommandContext(ctx, "ansible-playbook", args...) + cmd.Env = env + cmd.Dir = cfg.ProjectRoot + // Set process group so we can kill the entire tree + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + // Capture output while streaming to stdout and log file + var outputBuf bytes.Buffer + writers := []io.Writer{&outputBuf, os.Stdout} + + if opts.LogFile != "" { + if f, err := os.OpenFile(opts.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644); err == nil { + writers = append(writers, f) + defer func() { _ = f.Close() }() + } + } + + multiW := io.MultiWriter(writers...) + + stdout, err := cmd.StdoutPipe() + if err != nil { + result.ExitCode = 1 + result.Output = fmt.Sprintf("failed to create stdout pipe: %v", err) + return result + } + cmd.Stderr = cmd.Stdout // merge stderr into stdout + + if err := cmd.Start(); err != nil { + result.ExitCode = 1 + result.Output = fmt.Sprintf("failed to start ansible-playbook: %v", err) + return result + } + + // Monitor output with idle timeout + var bytesWritten atomic.Int64 + idleTimeout := opts.IdleTimeout + if idleTimeout == 0 { + idleTimeout = time.Duration(cfg.IdleTimeout) * time.Second + } + + // Stream output in a goroutine + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + scanner := bufio.NewScanner(stdout) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + for scanner.Scan() { + line := scanner.Text() + _, _ = fmt.Fprintln(multiW, line) + bytesWritten.Add(int64(len(line))) + } + }() + + timedOut := monitorIdleTimeout(ctx, &bytesWritten, idleTimeout, cmd.Process.Pid, doneCh) + + <-doneCh + err = cmd.Wait() + + output := outputBuf.String() + result.Output = output + result.TimedOut = *timedOut + + if *timedOut { + result.ExitCode = 124 + return result + } + + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + result.ExitCode = exitErr.ExitCode() + } else { + result.ExitCode = 1 + } + } + + result.Success = result.ExitCode == 0 && CheckAnsibleSuccess(output) + + if !result.Success { + result.FailedHosts = ExtractFailedHosts(output) + result.ErrorType, result.ErrorDetail = DetectErrorType(output) + } + + return result +} + +// monitorIdleTimeout watches for output stalls and kills the process if idle too long. +// Returns a pointer to a bool that is set to true if the process was killed. +func monitorIdleTimeout(ctx context.Context, bytesWritten *atomic.Int64, timeout time.Duration, pid int, doneCh <-chan struct{}) *bool { + timedOut := new(bool) + go func() { + lastBytes := bytesWritten.Load() + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + lastActivity := time.Now() + + for { + select { + case <-doneCh: + return + case <-ctx.Done(): + return + case <-ticker.C: + current := bytesWritten.Load() + if current > lastBytes { + lastBytes = current + lastActivity = time.Now() + } else if time.Since(lastActivity) > timeout { + slog.Error("idle timeout reached, killing playbook", + "timeout", timeout, "pid", pid) + *timedOut = true + killProcessGroup(pid) + return + } + } + } + }() + return timedOut +} + +func buildArgs(opts RunOptions, cfg *config.Config) []string { + inventoryPath := filepath.Join(cfg.ProjectRoot, opts.Env+"-inventory") + playbookPath := filepath.Join(cfg.ProjectRoot, "ansible", "playbooks", opts.Playbook) + + args := []string{ + "-i", inventoryPath, + "-e", "ansible_facts_gathering_timeout=60", + playbookPath, + } + + if opts.Debug { + args = append([]string{"-vvv"}, args...) + } + + if opts.Limit != "" { + args = append(args, "--limit", opts.Limit) + } + + if opts.Forks > 0 { + args = append(args, "--forks", fmt.Sprintf("%d", opts.Forks)) + } + + for k, v := range opts.ExtraVars { + args = append(args, "-e", k+"="+v) + } + + return args +} + +func buildEnv(opts RunOptions, cfg *config.Config) []string { + env := os.Environ() + + for k, v := range cfg.AnsibleEnv() { + env = append(env, k+"="+v) + } + + for k, v := range opts.ExtraEnv { + env = append(env, k+"="+v) + } + + return env +} + +func killProcessGroup(pid int) { + pgid, err := syscall.Getpgid(pid) + if err == nil { + _ = syscall.Kill(-pgid, syscall.SIGTERM) + time.Sleep(2 * time.Second) + _ = syscall.Kill(-pgid, syscall.SIGKILL) + } else { + _ = syscall.Kill(pid, syscall.SIGKILL) + } +} diff --git a/cli/internal/aws/client.go b/cli/internal/aws/client.go new file mode 100644 index 00000000..35c981ef --- /dev/null +++ b/cli/internal/aws/client.go @@ -0,0 +1,52 @@ +package aws + +import ( + "context" + "fmt" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ssm" +) + +// Client wraps AWS SDK clients for EC2 and SSM. +type Client struct { + EC2 *ec2.Client + SSM *ssm.Client + Region string +} + +var ( + clients = make(map[string]*Client) + mu sync.Mutex +) + +// NewClient creates or returns a cached AWS client for the given region. +func NewClient(ctx context.Context, region string) (*Client, error) { + mu.Lock() + defer mu.Unlock() + + if c, ok := clients[region]; ok { + return c, nil + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) + if err != nil { + return nil, fmt.Errorf("load AWS config for %s: %w", region, err) + } + + c := &Client{ + EC2: ec2.NewFromConfig(cfg), + SSM: ssm.NewFromConfig(cfg), + Region: region, + } + clients[region] = c + return c, nil +} + +// Ptr returns a pointer to the given string (helper for AWS SDK). +func Ptr(s string) *string { + return aws.String(s) +} diff --git a/cli/internal/aws/ec2.go b/cli/internal/aws/ec2.go new file mode 100644 index 00000000..5cf5f577 --- /dev/null +++ b/cli/internal/aws/ec2.go @@ -0,0 +1,106 @@ +package aws + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +// Instance represents a discovered EC2 instance. +type Instance struct { + InstanceID string + Name string + PrivateIP string + State string +} + +// DiscoverInstances finds running GOAD instances by tag pattern. +func (c *Client) DiscoverInstances(ctx context.Context, env string) ([]Instance, error) { + pattern := fmt.Sprintf("*%s*dreadgoad*", env) + out, err := c.EC2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + Filters: []types.Filter{ + {Name: Ptr("tag:Name"), Values: []string{pattern}}, + {Name: Ptr("instance-state-name"), Values: []string{"running"}}, + }, + }) + if err != nil { + return nil, fmt.Errorf("describe instances: %w", err) + } + + var instances []Instance + for _, r := range out.Reservations { + for _, i := range r.Instances { + inst := Instance{ + InstanceID: deref(i.InstanceId), + PrivateIP: deref(i.PrivateIpAddress), + State: string(i.State.Name), + } + for _, t := range i.Tags { + if deref(t.Key) == "Name" { + inst.Name = deref(t.Value) + } + } + instances = append(instances, inst) + } + } + return instances, nil +} + +// GetInstancePrivateIPs queries EC2 for private IPs of the given instance IDs. +func (c *Client) GetInstancePrivateIPs(ctx context.Context, instanceIDs []string) (map[string]string, error) { + out, err := c.EC2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + InstanceIds: instanceIDs, + }) + if err != nil { + return nil, fmt.Errorf("describe instances: %w", err) + } + + mapping := make(map[string]string) + for _, r := range out.Reservations { + for _, i := range r.Instances { + mapping[deref(i.InstanceId)] = deref(i.PrivateIpAddress) + } + } + return mapping, nil +} + +// StartInstances starts the given EC2 instances. +func (c *Client) StartInstances(ctx context.Context, instanceIDs []string) error { + _, err := c.EC2.StartInstances(ctx, &ec2.StartInstancesInput{ + InstanceIds: instanceIDs, + }) + return err +} + +// StopInstances stops the given EC2 instances. +func (c *Client) StopInstances(ctx context.Context, instanceIDs []string) error { + _, err := c.EC2.StopInstances(ctx, &ec2.StopInstancesInput{ + InstanceIds: instanceIDs, + }) + return err +} + +// FindInstanceByHostname finds an instance whose Name tag contains the hostname. +func (c *Client) FindInstanceByHostname(ctx context.Context, env, hostname string) (*Instance, error) { + instances, err := c.DiscoverInstances(ctx, env) + if err != nil { + return nil, err + } + hostname = strings.ToUpper(hostname) + for _, inst := range instances { + if strings.Contains(strings.ToUpper(inst.Name), hostname) { + return &inst, nil + } + } + return nil, fmt.Errorf("instance not found for hostname %s", hostname) +} + +func deref(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/cli/internal/aws/ssm.go b/cli/internal/aws/ssm.go new file mode 100644 index 00000000..c06a6589 --- /dev/null +++ b/cli/internal/aws/ssm.go @@ -0,0 +1,264 @@ +package aws + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" +) + +// Session represents an active SSM session. +type Session struct { + SessionID string + InstanceID string + StartDate time.Time + Status string +} + +// CommandResult holds the output of an SSM command. +type CommandResult struct { + Status string + Stdout string + Stderr string +} + +// DescribeActiveSessions returns active SSM sessions for an instance. +func (c *Client) DescribeActiveSessions(ctx context.Context, instanceID string) ([]Session, error) { + out, err := c.SSM.DescribeSessions(ctx, &ssm.DescribeSessionsInput{ + State: ssmtypes.SessionStateActive, + Filters: []ssmtypes.SessionFilter{ + {Key: ssmtypes.SessionFilterKeyTargetId, Value: Ptr(instanceID)}, + }, + }) + if err != nil { + return nil, fmt.Errorf("describe sessions for %s: %w", instanceID, err) + } + + var sessions []Session + for _, s := range out.Sessions { + sess := Session{ + SessionID: deref(s.SessionId), + InstanceID: deref(s.Target), + Status: string(s.Status), + } + if s.StartDate != nil { + sess.StartDate = *s.StartDate + } + sessions = append(sessions, sess) + } + return sessions, nil +} + +// CleanupStaleSessions terminates SSM sessions older than maxAge for the given instances. +func (c *Client) CleanupStaleSessions(ctx context.Context, instanceIDs []string, maxAge time.Duration, dryRun bool, log *slog.Logger) (int, error) { + cutoff := time.Now().UTC().Add(-maxAge) + terminated := 0 + + for _, instanceID := range instanceIDs { + sessions, err := c.DescribeActiveSessions(ctx, instanceID) + if err != nil { + log.Warn("failed to describe sessions", "instance", instanceID, "error", err) + continue + } + + for _, s := range sessions { + if s.StartDate.Before(cutoff) { + if dryRun { + log.Info("would terminate session", "session", s.SessionID, "instance", instanceID, "started", s.StartDate) + } else { + if err := c.TerminateSession(ctx, s.SessionID); err != nil { + log.Warn("failed to terminate session", "session", s.SessionID, "error", err) + } else { + log.Info("terminated session", "session", s.SessionID, "instance", instanceID) + terminated++ + } + } + } + } + } + return terminated, nil +} + +// TerminateSession ends an SSM session. +func (c *Client) TerminateSession(ctx context.Context, sessionID string) error { + _, err := c.SSM.TerminateSession(ctx, &ssm.TerminateSessionInput{ + SessionId: Ptr(sessionID), + }) + return err +} + +// RunPowerShellCommand executes a PowerShell command on an instance via SSM and returns the result. +func (c *Client) RunPowerShellCommand(ctx context.Context, instanceID, command string, timeout time.Duration) (*CommandResult, error) { + timeoutSecs := int32(timeout.Seconds()) + if timeoutSecs == 0 { + timeoutSecs = 60 + } + + out, err := c.SSM.SendCommand(ctx, &ssm.SendCommandInput{ + InstanceIds: []string{instanceID}, + DocumentName: Ptr("AWS-RunPowerShellScript"), + Parameters: map[string][]string{"commands": {command}}, + TimeoutSeconds: aws.Int32(timeoutSecs), + }) + if err != nil { + return nil, fmt.Errorf("send command to %s: %w", instanceID, err) + } + + commandID := deref(out.Command.CommandId) + return c.waitForCommand(ctx, commandID, instanceID, timeout) +} + +// RunPowerShellOnMultiple executes a PowerShell command on multiple instances. +func (c *Client) RunPowerShellOnMultiple(ctx context.Context, instanceIDs []string, command string, timeout time.Duration) (map[string]*CommandResult, error) { + timeoutSecs := int32(timeout.Seconds()) + if timeoutSecs == 0 { + timeoutSecs = 60 + } + + out, err := c.SSM.SendCommand(ctx, &ssm.SendCommandInput{ + InstanceIds: instanceIDs, + DocumentName: Ptr("AWS-RunPowerShellScript"), + Parameters: map[string][]string{"commands": {command}}, + TimeoutSeconds: aws.Int32(timeoutSecs), + }) + if err != nil { + return nil, fmt.Errorf("send command: %w", err) + } + + commandID := deref(out.Command.CommandId) + results := make(map[string]*CommandResult, len(instanceIDs)) + + for _, id := range instanceIDs { + result, err := c.waitForCommand(ctx, commandID, id, timeout) + if err != nil { + results[id] = &CommandResult{Status: "Error", Stderr: err.Error()} + } else { + results[id] = result + } + } + return results, nil +} + +// EnableSSMUserLocal re-enables the local ssm-user account. +func (c *Client) EnableSSMUserLocal(ctx context.Context, instanceID string) error { + cmd := `try { Enable-LocalUser -Name ssm-user -ErrorAction Stop; Write-Output "ssm-user enabled" } catch { Write-Output "Failed: $_"; exit 1 }` + result, err := c.RunPowerShellCommand(ctx, instanceID, cmd, 60*time.Second) + if err != nil { + return err + } + if result.Status != "Success" { + return fmt.Errorf("enable ssm-user failed: %s", result.Stderr) + } + return nil +} + +// FixSSMUserViaDomainAccount creates ssm-user as a domain account on DCs. +func (c *Client) FixSSMUserViaDomainAccount(ctx context.Context, instanceID string) error { + script := `$ErrorActionPreference = "Continue" +$maxWait = 30 +$attempt = 0 + +$cs = Get-WmiObject Win32_ComputerSystem +if ($cs.DomainRole -lt 4) { + Write-Output "Not a DC (role=$($cs.DomainRole)), skipping domain ssm-user creation" + exit 0 +} + +Write-Output "Waiting for AD Web Services..." +while ($attempt -lt $maxWait) { + $adws = Get-Service ADWS -ErrorAction SilentlyContinue + if ($adws.Status -eq "Running") { + Write-Output "ADWS is running" + break + } + if ($adws.Status -eq "Stopped") { + Start-Service ADWS -ErrorAction SilentlyContinue + } + Start-Sleep -Seconds 10 + $attempt++ +} + +try { + Get-ADDomain -ErrorAction Stop | Out-Null + Write-Output "AD is accessible" +} catch { + Write-Output "ERROR: AD not accessible: $_" + exit 1 +} + +try { + $user = Get-ADUser -Identity ssm-user -ErrorAction Stop + Write-Output "ssm-user exists, ensuring enabled..." + Enable-ADAccount -Identity ssm-user + Set-ADUser -Identity ssm-user -PasswordNeverExpires $true +} catch { + Write-Output "Creating ssm-user domain account..." + $pwd = ConvertTo-SecureString "TempP@ss$(Get-Random)!" -AsPlainText -Force + New-ADUser -Name ssm-user -AccountPassword $pwd -Enabled $true -PasswordNeverExpires $true +} + +try { + Add-ADGroupMember -Identity "Domain Admins" -Members ssm-user -ErrorAction SilentlyContinue + Write-Output "ssm-user added to Domain Admins" +} catch { + Write-Output "ssm-user already in Domain Admins or error: $_" +} + +Restart-Service AmazonSSMAgent -Force +Write-Output "SSM Agent restarted - ssm-user fix complete"` + + result, err := c.RunPowerShellCommand(ctx, instanceID, script, 10*time.Minute) + if err != nil { + return err + } + if result.Status != "Success" { + return fmt.Errorf("fix ssm-user failed: %s %s", result.Stdout, result.Stderr) + } + return nil +} + +func (c *Client) waitForCommand(ctx context.Context, commandID, instanceID string, timeout time.Duration) (*CommandResult, error) { + deadline := time.Now().Add(timeout + 30*time.Second) + backoff := 2 * time.Second + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(backoff): + } + if backoff < 10*time.Second { + backoff = backoff * 3 / 2 + } + + out, err := c.SSM.GetCommandInvocation(ctx, &ssm.GetCommandInvocationInput{ + CommandId: Ptr(commandID), + InstanceId: Ptr(instanceID), + }) + if err != nil { + if strings.Contains(err.Error(), "InvocationDoesNotExist") { + continue + } + return nil, fmt.Errorf("get command invocation: %w", err) + } + + status := string(out.Status) + switch out.Status { + case ssmtypes.CommandInvocationStatusSuccess, + ssmtypes.CommandInvocationStatusFailed, + ssmtypes.CommandInvocationStatusTimedOut, + ssmtypes.CommandInvocationStatusCancelled: + return &CommandResult{ + Status: status, + Stdout: deref(out.StandardOutputContent), + Stderr: deref(out.StandardErrorContent), + }, nil + } + } + return nil, fmt.Errorf("command %s timed out waiting for result", commandID) +} diff --git a/cli/internal/config/config.go b/cli/internal/config/config.go new file mode 100644 index 00000000..16ec99d1 --- /dev/null +++ b/cli/internal/config/config.go @@ -0,0 +1,114 @@ +package config + +import ( + "os" + "path/filepath" + "sync" + + "github.com/spf13/viper" +) + +// Config holds all CLI configuration. +type Config struct { + Env string `mapstructure:"env"` + Region string `mapstructure:"region"` + Debug bool `mapstructure:"debug"` + MaxRetries int `mapstructure:"max_retries"` + RetryDelay int `mapstructure:"retry_delay"` + IdleTimeout int `mapstructure:"idle_timeout"` + LogDir string `mapstructure:"log_dir"` + Playbooks []string `mapstructure:"playbooks"` + ProjectRoot string `mapstructure:"project_root"` +} + +var ( + cfg *Config + once sync.Once +) + +// Init initializes Viper configuration. Called by cobra.OnInitialize. +func Init() { + if cfgFile := viper.GetString("config"); cfgFile != "" { + viper.SetConfigFile(cfgFile) + } else { + home, _ := os.UserHomeDir() + viper.AddConfigPath(filepath.Join(home, ".config", "dreadgoad")) + viper.AddConfigPath(".") + viper.SetConfigName("dreadgoad") + viper.SetConfigType("yaml") + } + + viper.SetEnvPrefix("DREADGOAD") + viper.AutomaticEnv() + + setDefaults() + + // Config file is optional + _ = viper.ReadInConfig() +} + +// Get returns the current configuration, loading it once. +func Get() *Config { + once.Do(func() { + cfg = &Config{} + _ = viper.Unmarshal(cfg) + + // Resolve project root (directory containing ansible/) + if cfg.ProjectRoot == "" { + cfg.ProjectRoot = findProjectRoot() + } + + // Expand log dir + if cfg.LogDir == "" { + home, _ := os.UserHomeDir() + cfg.LogDir = filepath.Join(home, ".ansible", "logs", "goad") + } + }) + return cfg +} + +// Reset clears the cached config (for testing). +func Reset() { + once = sync.Once{} + cfg = nil +} + +// InventoryPath returns the path to the inventory file for the current env. +func (c *Config) InventoryPath() string { + return filepath.Join(c.ProjectRoot, c.Env+"-inventory") +} + +// AnsibleCfgPath returns the path to the ansible.cfg file. +func (c *Config) AnsibleCfgPath() string { + return filepath.Join(c.ProjectRoot, "ansible", "ansible.cfg") +} + +// AnsibleEnv returns environment variables needed for ansible-playbook execution. +func (c *Config) AnsibleEnv() map[string]string { + home, _ := os.UserHomeDir() + return map[string]string{ + "ANSIBLE_CONFIG": c.AnsibleCfgPath(), + "ANSIBLE_CACHE_PLUGIN_CONNECTION": filepath.Join(home, ".ansible", "cache", c.Env+"_dreadgoad_facts"), + "ANSIBLE_HOST_KEY_CHECKING": "False", + "ANSIBLE_RETRY_FILES_ENABLED": "True", + "ANSIBLE_GATHER_TIMEOUT": "60", + } +} + +func findProjectRoot() string { + // Walk up from cwd looking for ansible/ directory + dir, _ := os.Getwd() + for { + if _, err := os.Stat(filepath.Join(dir, "ansible")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + // Fallback to cwd + cwd, _ := os.Getwd() + return cwd +} diff --git a/cli/internal/config/config_test.go b/cli/internal/config/config_test.go new file mode 100644 index 00000000..57b2c884 --- /dev/null +++ b/cli/internal/config/config_test.go @@ -0,0 +1,164 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestConfigInventoryPath(t *testing.T) { + c := &Config{ProjectRoot: "/opt/goad", Env: "dev"} + got := c.InventoryPath() + want := filepath.Join("/opt/goad", "dev-inventory") + if got != want { + t.Errorf("InventoryPath() = %q, want %q", got, want) + } +} + +func TestConfigAnsibleCfgPath(t *testing.T) { + c := &Config{ProjectRoot: "/opt/goad"} + got := c.AnsibleCfgPath() + want := filepath.Join("/opt/goad", "ansible", "ansible.cfg") + if got != want { + t.Errorf("AnsibleCfgPath() = %q, want %q", got, want) + } +} + +func TestConfigAnsibleEnv(t *testing.T) { + c := &Config{ProjectRoot: "/opt/goad", Env: "staging"} + + env := c.AnsibleEnv() + + if env["ANSIBLE_CONFIG"] != c.AnsibleCfgPath() { + t.Errorf("ANSIBLE_CONFIG = %q, want %q", env["ANSIBLE_CONFIG"], c.AnsibleCfgPath()) + } + if env["ANSIBLE_HOST_KEY_CHECKING"] != "False" { + t.Errorf("ANSIBLE_HOST_KEY_CHECKING = %q, want %q", env["ANSIBLE_HOST_KEY_CHECKING"], "False") + } + if env["ANSIBLE_RETRY_FILES_ENABLED"] != "True" { + t.Errorf("ANSIBLE_RETRY_FILES_ENABLED = %q, want %q", env["ANSIBLE_RETRY_FILES_ENABLED"], "True") + } + if env["ANSIBLE_GATHER_TIMEOUT"] != "60" { + t.Errorf("ANSIBLE_GATHER_TIMEOUT = %q, want %q", env["ANSIBLE_GATHER_TIMEOUT"], "60") + } + + cacheConn := env["ANSIBLE_CACHE_PLUGIN_CONNECTION"] + if !strings.Contains(cacheConn, "staging_dreadgoad_facts") { + t.Errorf("ANSIBLE_CACHE_PLUGIN_CONNECTION = %q, want to contain %q", cacheConn, "staging_dreadgoad_facts") + } +} + +func TestConfigInventoryPathDifferentEnvs(t *testing.T) { + tests := []struct { + env string + wantSufx string + }{ + {"dev", "dev-inventory"}, + {"staging", "staging-inventory"}, + {"prod", "prod-inventory"}, + } + + for _, tt := range tests { + t.Run(tt.env, func(t *testing.T) { + c := &Config{ProjectRoot: "/opt/goad", Env: tt.env} + got := c.InventoryPath() + if !strings.HasSuffix(got, tt.wantSufx) { + t.Errorf("InventoryPath() = %q, want suffix %q", got, tt.wantSufx) + } + }) + } +} + +func TestDefaultPlaybooks(t *testing.T) { + if len(DefaultPlaybooks) == 0 { + t.Fatal("DefaultPlaybooks is empty") + } + + // First playbook should be build.yml + if DefaultPlaybooks[0] != "build.yml" { + t.Errorf("first playbook = %q, want %q", DefaultPlaybooks[0], "build.yml") + } + + // Last playbook should be vulnerabilities.yml + last := DefaultPlaybooks[len(DefaultPlaybooks)-1] + if last != "vulnerabilities.yml" { + t.Errorf("last playbook = %q, want %q", last, "vulnerabilities.yml") + } + + // All playbooks should end in .yml + for _, p := range DefaultPlaybooks { + if !strings.HasSuffix(p, ".yml") { + t.Errorf("playbook %q does not end in .yml", p) + } + } +} + +func TestRebootPlaybooks(t *testing.T) { + if len(RebootPlaybooks) == 0 { + t.Fatal("RebootPlaybooks is empty") + } + + // All reboot playbooks should be in DefaultPlaybooks + defaultSet := make(map[string]bool) + for _, p := range DefaultPlaybooks { + defaultSet[p] = true + } + for _, p := range RebootPlaybooks { + if !defaultSet[p] { + t.Errorf("RebootPlaybook %q not in DefaultPlaybooks", p) + } + } +} + +// resolveSymlinks resolves symlinks so paths are comparable on macOS +// where TempDir returns /var/... but os.Getwd returns /private/var/... +func resolveSymlinks(t *testing.T, path string) string { + t.Helper() + resolved, err := filepath.EvalSymlinks(path) + if err != nil { + t.Fatal(err) + } + return resolved +} + +func TestFindProjectRoot(t *testing.T) { + t.Run("finds ansible directory", func(t *testing.T) { + dir := resolveSymlinks(t, t.TempDir()) + ansibleDir := filepath.Join(dir, "ansible") + if err := os.Mkdir(ansibleDir, 0o755); err != nil { + t.Fatal(err) + } + + // Change to a subdirectory + subDir := filepath.Join(dir, "sub", "deep") + if err := os.MkdirAll(subDir, 0o755); err != nil { + t.Fatal(err) + } + + origDir, _ := os.Getwd() + if err := os.Chdir(subDir); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = os.Chdir(origDir) }) + + got := findProjectRoot() + if got != dir { + t.Errorf("findProjectRoot() = %q, want %q", got, dir) + } + }) + + t.Run("falls back to cwd when no ansible dir", func(t *testing.T) { + dir := resolveSymlinks(t, t.TempDir()) + origDir, _ := os.Getwd() + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = os.Chdir(origDir) }) + + got := findProjectRoot() + if got != dir { + t.Errorf("findProjectRoot() = %q, want %q", got, dir) + } + }) +} diff --git a/cli/internal/config/defaults.go b/cli/internal/config/defaults.go new file mode 100644 index 00000000..75a843f6 --- /dev/null +++ b/cli/internal/config/defaults.go @@ -0,0 +1,41 @@ +package config + +import "github.com/spf13/viper" + +// DefaultPlaybooks is the ordered list of all GOAD playbooks. +var DefaultPlaybooks = []string{ + "build.yml", + "ad-servers.yml", + "ad-parent_domain.yml", + "ad-child_domain.yml", + "ad-members.yml", + "ad-trusts.yml", + "ad-data.yml", + "ad-gmsa.yml", + "laps.yml", + "ad-relations.yml", + "adcs.yml", + "ad-acl.yml", + "servers.yml", + "security.yml", + "vulnerabilities.yml", +} + +// RebootPlaybooks are playbooks that may trigger Windows reboots. +var RebootPlaybooks = []string{ + "ad-parent_domain.yml", + "ad-child_domain.yml", + "ad-members.yml", + "ad-trusts.yml", +} + +func setDefaults() { + viper.SetDefault("env", "dev") + viper.SetDefault("region", "") + viper.SetDefault("debug", false) + viper.SetDefault("max_retries", 3) + viper.SetDefault("retry_delay", 30) + viper.SetDefault("idle_timeout", 1200) + viper.SetDefault("log_dir", "") + viper.SetDefault("playbooks", DefaultPlaybooks) +} diff --git a/cli/internal/doctor/checks.go b/cli/internal/doctor/checks.go new file mode 100644 index 00000000..9c3cfa2a --- /dev/null +++ b/cli/internal/doctor/checks.go @@ -0,0 +1,183 @@ +package doctor + +import ( + "fmt" + "os" + "os/exec" + "regexp" + "strconv" + "strings" + + "github.com/fatih/color" +) + +// CheckResult holds a single doctor check result. +type CheckResult struct { + Name string + Status string // pass, fail, warn + Message string +} + +// RunChecks executes all pre-flight checks and returns results. +func RunChecks(inventoryPath, projectRoot string) []CheckResult { + var results []CheckResult + + results = append(results, checkAnsibleVersion()) + results = append(results, checkCommand("aws", "AWS CLI")) + results = append(results, checkCommand("python3", "Python 3")) + results = append(results, checkCommand("jq", "jq")) + results = append(results, checkCommand("zip", "zip")) + results = append(results, checkAWSCredentials()) + results = append(results, checkInventoryFile(inventoryPath)) + results = append(results, checkAnsibleCollections()...) + + return results +} + +// PrintResults displays check results with color. +func PrintResults(results []CheckResult) { + passed, failed, warned := 0, 0, 0 + + fmt.Println("DreadGOAD Pre-flight Checks") + fmt.Println(strings.Repeat("=", 40)) + + for _, r := range results { + switch r.Status { + case "pass": + color.Green(" [pass] %s: %s", r.Name, r.Message) + passed++ + case "fail": + color.Red(" [fail] %s: %s", r.Name, r.Message) + failed++ + case "warn": + color.Yellow(" [warn] %s: %s", r.Name, r.Message) + warned++ + } + } + + fmt.Println(strings.Repeat("=", 40)) + fmt.Printf("Results: %d passed, %d failed, %d warnings\n", passed, failed, warned) +} + +// CheckAnsibleCoreVersion verifies ansible-core is installed and within the +// compatible version range (<2.19). Returns an error if the version is +// incompatible or ansible-core is not found. This is used as a pre-flight +// gate before running playbooks. +func CheckAnsibleCoreVersion() error { + result := checkAnsibleVersion() + if result.Status == "fail" { + return fmt.Errorf("%s", result.Message) + } + return nil +} + +func checkAnsibleVersion() CheckResult { + out, err := exec.Command("ansible", "--version").CombinedOutput() + if err != nil { + return CheckResult{ + Name: "ansible-core", + Status: "fail", + Message: "ansible-core not found. Install: pip install 'ansible-core>=2.17.0,<2.18.0'", + } + } + + re := regexp.MustCompile(`(\d+)\.(\d+)\.(\d+)`) + m := re.FindStringSubmatch(string(out)) + if m == nil { + return CheckResult{Name: "ansible-core", Status: "fail", Message: "could not parse version"} + } + + major, _ := strconv.Atoi(m[1]) + minor, _ := strconv.Atoi(m[2]) + version := fmt.Sprintf("%s.%s.%s", m[1], m[2], m[3]) + + if major > 2 || (major == 2 && minor >= 19) { + return CheckResult{ + Name: "ansible-core", + Status: "fail", + Message: fmt.Sprintf("v%s detected. Versions >=2.19 break Windows SSM. "+ + "Fix: pip install 'ansible-core>=2.17.0,<2.18.0'", version), + } + } + + if major == 2 && minor >= 17 && minor < 19 { + return CheckResult{ + Name: "ansible-core", + Status: "pass", + Message: fmt.Sprintf("v%s (compatible)", version), + } + } + + return CheckResult{ + Name: "ansible-core", + Status: "warn", + Message: fmt.Sprintf("v%s (untested, recommend 2.17.x)", version), + } +} + +func checkCommand(name, label string) CheckResult { + path, err := exec.LookPath(name) + if err != nil { + return CheckResult{Name: label, Status: "fail", Message: "not found in PATH"} + } + return CheckResult{Name: label, Status: "pass", Message: path} +} + +func checkAWSCredentials() CheckResult { + out, err := exec.Command("aws", "sts", "get-caller-identity", "--query", "Account", "--output", "text").CombinedOutput() + if err != nil { + return CheckResult{ + Name: "AWS Credentials", + Status: "fail", + Message: "invalid or not configured. Run: aws configure", + } + } + return CheckResult{ + Name: "AWS Credentials", + Status: "pass", + Message: fmt.Sprintf("account %s", strings.TrimSpace(string(out))), + } +} + +func checkInventoryFile(path string) CheckResult { + if _, err := os.Stat(path); os.IsNotExist(err) { + return CheckResult{ + Name: "Inventory", + Status: "fail", + Message: fmt.Sprintf("%s not found", path), + } + } + return CheckResult{Name: "Inventory", Status: "pass", Message: path} +} + +func checkAnsibleCollections() []CheckResult { + required := []string{ + "ansible.windows", + "community.general", + "community.windows", + "amazon.aws", + "microsoft.ad", + "chocolatey.chocolatey", + } + + out, _ := exec.Command("ansible-galaxy", "collection", "list", "--format", "yaml").CombinedOutput() + output := string(out) + + var results []CheckResult + for _, col := range required { + if strings.Contains(output, col) { + results = append(results, CheckResult{ + Name: fmt.Sprintf("Collection: %s", col), + Status: "pass", + Message: "installed", + }) + } else { + results = append(results, CheckResult{ + Name: fmt.Sprintf("Collection: %s", col), + Status: "fail", + Message: "not installed. Run: ansible-galaxy install -r ansible/requirements.yml", + }) + } + } + return results +} diff --git a/cli/internal/inventory/parser.go b/cli/internal/inventory/parser.go new file mode 100644 index 00000000..000dae8b --- /dev/null +++ b/cli/internal/inventory/parser.go @@ -0,0 +1,158 @@ +package inventory + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" +) + +// Host represents a single host in the Ansible inventory. +type Host struct { + Name string + InstanceID string // ansible_host value (e.g. i-0e428dfc02f5007dd) + DictKey string + DNSDomain string + User string + Groups []string +} + +// Inventory represents a parsed Ansible inventory file. +type Inventory struct { + Hosts map[string]*Host + Groups map[string][]string // group name -> host names + Vars map[string]string // [all:vars] section + FilePath string +} + +var ( + sectionRe = regexp.MustCompile(`^\[([^\]]+)\]`) + hostLineRe = regexp.MustCompile(`^(\w[\w.-]+)\s+(.+)`) + varRe = regexp.MustCompile(`(\w+)=(\S+)`) +) + +// Parse reads and parses an Ansible INI-style inventory file. +func Parse(path string) (*Inventory, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open inventory %s: %w", path, err) + } + defer func() { _ = f.Close() }() + + inv := &Inventory{ + Hosts: make(map[string]*Host), + Groups: make(map[string][]string), + Vars: make(map[string]string), + FilePath: path, + } + + scanner := bufio.NewScanner(f) + currentSection := "" + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if line == "" || strings.HasPrefix(line, ";") || strings.HasPrefix(line, "#") { + continue + } + + if m := sectionRe.FindStringSubmatch(line); m != nil { + currentSection = m[1] + continue + } + + inv.parseLine(line, currentSection) + } + + return inv, scanner.Err() +} + +func (inv *Inventory) parseLine(line, section string) { + switch section { + case "all:vars": + if parts := strings.SplitN(line, "=", 2); len(parts) == 2 { + inv.Vars[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } + case "default": + inv.parseHostDef(line) + case "": + // no section yet + default: + inv.parseGroupMembership(line, section) + } +} + +func (inv *Inventory) parseHostDef(line string) { + m := hostLineRe.FindStringSubmatch(line) + if m == nil { + return + } + host := &Host{Name: m[1]} + for _, vm := range varRe.FindAllStringSubmatch(m[2], -1) { + switch vm[1] { + case "ansible_host": + host.InstanceID = vm[2] + case "dict_key": + host.DictKey = vm[2] + case "dns_domain": + host.DNSDomain = vm[2] + case "ansible_user": + host.User = vm[2] + } + } + inv.Hosts[host.Name] = host + inv.Groups["default"] = append(inv.Groups["default"], host.Name) +} + +func (inv *Inventory) parseGroupMembership(line, section string) { + name := strings.Fields(line)[0] + if _, exists := inv.Hosts[name]; exists { + inv.Groups[section] = append(inv.Groups[section], name) + inv.Hosts[name].Groups = append(inv.Hosts[name].Groups, section) + } +} + +// InstanceIDs returns all unique instance IDs from the inventory. +func (inv *Inventory) InstanceIDs() []string { + seen := make(map[string]struct{}) + var ids []string + for _, h := range inv.Hosts { + if h.InstanceID != "" { + if _, exists := seen[h.InstanceID]; !exists { + seen[h.InstanceID] = struct{}{} + ids = append(ids, h.InstanceID) + } + } + } + return ids +} + +// Region returns the AWS SSM region from inventory vars. +func (inv *Inventory) Region() string { + if r, ok := inv.Vars["ansible_aws_ssm_region"]; ok { + return r + } + return "us-west-2" +} + +// HostByName returns a host by its name (case-insensitive). +func (inv *Inventory) HostByName(name string) *Host { + name = strings.ToLower(name) + for k, h := range inv.Hosts { + if strings.ToLower(k) == name { + return h + } + } + return nil +} + +// HostByInstanceID returns a host by its instance ID. +func (inv *Inventory) HostByInstanceID(id string) *Host { + for _, h := range inv.Hosts { + if h.InstanceID == id { + return h + } + } + return nil +} diff --git a/cli/internal/inventory/parser_test.go b/cli/internal/inventory/parser_test.go new file mode 100644 index 00000000..dacde940 --- /dev/null +++ b/cli/internal/inventory/parser_test.go @@ -0,0 +1,238 @@ +package inventory + +import ( + "os" + "path/filepath" + "testing" +) + +const testInventory = `; GOAD inventory - auto-generated +[default] +DC01 ansible_host=i-0e428dfc02f5007dd dict_key=dc01 dns_domain=sevenkingdoms.local ansible_user=vagrant +DC02 ansible_host=i-0abc123def456789a dict_key=dc02 dns_domain=north.sevenkingdoms.local ansible_user=vagrant +SRV01 ansible_host=i-0fff999888777666a dict_key=srv01 dns_domain=sevenkingdoms.local ansible_user=vagrant + +[all:vars] +ansible_aws_ssm_region=us-east-1 +ansible_connection=aws_ssm + +[dc] +DC01 +DC02 + +[server] +SRV01 + +[north] +DC02 +` + +func writeTestInventory(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "test-inventory") + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + return path +} + +func parseTestInventory(t *testing.T) *Inventory { + t.Helper() + path := writeTestInventory(t, testInventory) + inv, err := Parse(path) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + return inv +} + +func TestParse_HostCount(t *testing.T) { + inv := parseTestInventory(t) + if len(inv.Hosts) != 3 { + t.Errorf("got %d hosts, want 3", len(inv.Hosts)) + } +} + +func TestParse_HostAttributes(t *testing.T) { + inv := parseTestInventory(t) + dc01 := inv.Hosts["DC01"] + if dc01 == nil { + t.Fatal("DC01 not found") + } + if dc01.InstanceID != "i-0e428dfc02f5007dd" { + t.Errorf("InstanceID = %q, want %q", dc01.InstanceID, "i-0e428dfc02f5007dd") + } + if dc01.DictKey != "dc01" { + t.Errorf("DictKey = %q, want %q", dc01.DictKey, "dc01") + } + if dc01.DNSDomain != "sevenkingdoms.local" { + t.Errorf("DNSDomain = %q, want %q", dc01.DNSDomain, "sevenkingdoms.local") + } + if dc01.User != "vagrant" { + t.Errorf("User = %q, want %q", dc01.User, "vagrant") + } +} + +func TestParse_Vars(t *testing.T) { + inv := parseTestInventory(t) + if inv.Vars["ansible_aws_ssm_region"] != "us-east-1" { + t.Errorf("region = %q, want %q", inv.Vars["ansible_aws_ssm_region"], "us-east-1") + } + if inv.Vars["ansible_connection"] != "aws_ssm" { + t.Errorf("connection = %q, want %q", inv.Vars["ansible_connection"], "aws_ssm") + } +} + +func TestParse_Groups(t *testing.T) { + inv := parseTestInventory(t) + if len(inv.Groups["dc"]) != 2 { + t.Errorf("dc group has %d members, want 2", len(inv.Groups["dc"])) + } + if len(inv.Groups["server"]) != 1 { + t.Errorf("server group has %d members, want 1", len(inv.Groups["server"])) + } +} + +func TestParse_GroupMembership(t *testing.T) { + inv := parseTestInventory(t) + dc02 := inv.Hosts["DC02"] + if dc02 == nil { + t.Fatal("DC02 not found") + } + wantGroups := map[string]bool{"dc": false, "north": false} + for _, g := range dc02.Groups { + if _, ok := wantGroups[g]; ok { + wantGroups[g] = true + } + } + for g, found := range wantGroups { + if !found { + t.Errorf("DC02 missing group %q", g) + } + } +} + +func TestParse_FileNotFound(t *testing.T) { + _, err := Parse("/nonexistent/inventory") + if err == nil { + t.Fatal("expected error for missing file, got nil") + } +} + +func TestParse_EmptyFile(t *testing.T) { + path := writeTestInventory(t, "") + inv, err := Parse(path) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if len(inv.Hosts) != 0 { + t.Errorf("got %d hosts, want 0", len(inv.Hosts)) + } +} + +func TestParse_CommentsOnly(t *testing.T) { + path := writeTestInventory(t, "# comment\n; another comment\n") + inv, err := Parse(path) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if len(inv.Hosts) != 0 { + t.Errorf("got %d hosts, want 0", len(inv.Hosts)) + } +} + +func TestInstanceIDs(t *testing.T) { + path := writeTestInventory(t, testInventory) + inv, err := Parse(path) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + ids := inv.InstanceIDs() + if len(ids) != 3 { + t.Errorf("InstanceIDs() returned %d, want 3", len(ids)) + } + + // Verify all IDs are present (order is map-dependent) + idSet := make(map[string]bool) + for _, id := range ids { + idSet[id] = true + } + for _, want := range []string{"i-0e428dfc02f5007dd", "i-0abc123def456789a", "i-0fff999888777666a"} { + if !idSet[want] { + t.Errorf("InstanceIDs() missing %q", want) + } + } +} + +func TestRegion(t *testing.T) { + t.Run("from vars", func(t *testing.T) { + path := writeTestInventory(t, testInventory) + inv, _ := Parse(path) + if got := inv.Region(); got != "us-east-1" { + t.Errorf("Region() = %q, want %q", got, "us-east-1") + } + }) + + t.Run("default fallback", func(t *testing.T) { + path := writeTestInventory(t, "[default]\n") + inv, _ := Parse(path) + if got := inv.Region(); got != "us-west-2" { + t.Errorf("Region() = %q, want %q", got, "us-west-2") + } + }) +} + +func TestHostByName(t *testing.T) { + path := writeTestInventory(t, testInventory) + inv, _ := Parse(path) + + tests := []struct { + name string + query string + want string + }{ + {"exact match", "DC01", "DC01"}, + {"case insensitive", "dc01", "DC01"}, + {"mixed case", "Dc02", "DC02"}, + {"not found", "NONEXISTENT", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := inv.HostByName(tt.query) + if tt.want == "" { + if got != nil { + t.Errorf("HostByName(%q) = %v, want nil", tt.query, got) + } + return + } + if got == nil { + t.Fatalf("HostByName(%q) = nil, want %q", tt.query, tt.want) + } + if got.Name != tt.want { + t.Errorf("HostByName(%q).Name = %q, want %q", tt.query, got.Name, tt.want) + } + }) + } +} + +func TestHostByInstanceID(t *testing.T) { + path := writeTestInventory(t, testInventory) + inv, _ := Parse(path) + + t.Run("found", func(t *testing.T) { + got := inv.HostByInstanceID("i-0e428dfc02f5007dd") + if got == nil || got.Name != "DC01" { + t.Errorf("HostByInstanceID() = %v, want DC01", got) + } + }) + + t.Run("not found", func(t *testing.T) { + got := inv.HostByInstanceID("i-nonexistent") + if got != nil { + t.Errorf("HostByInstanceID() = %v, want nil", got) + } + }) +} diff --git a/cli/internal/logging/logger.go b/cli/internal/logging/logger.go new file mode 100644 index 00000000..51545b0a --- /dev/null +++ b/cli/internal/logging/logger.go @@ -0,0 +1,44 @@ +package logging + +import ( + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "time" +) + +var logger *slog.Logger + +// Init sets up the structured logger with console and optional file output. +func Init(debug bool, logDir, env string) { + level := slog.LevelInfo + if debug { + level = slog.LevelDebug + } + + var writers []io.Writer + writers = append(writers, os.Stdout) + + if logDir != "" { + _ = os.MkdirAll(logDir, 0o755) + logFile := filepath.Join(logDir, fmt.Sprintf("%s-dreadgoad-%s.log", + env, time.Now().Format("20060102_150405"))) + if f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644); err == nil { + writers = append(writers, f) + } + } + + w := io.MultiWriter(writers...) + logger = slog.New(slog.NewTextHandler(w, &slog.HandlerOptions{Level: level})) + slog.SetDefault(logger) +} + +// Get returns the configured logger. +func Get() *slog.Logger { + if logger == nil { + logger = slog.Default() + } + return logger +} diff --git a/cli/internal/validate/checks.go b/cli/internal/validate/checks.go new file mode 100644 index 00000000..c3e7958e --- /dev/null +++ b/cli/internal/validate/checks.go @@ -0,0 +1,244 @@ +package validate + +import ( + "context" + "fmt" + "strings" +) + +func (v *Validator) checkCredentialDiscovery(ctx context.Context) { + fmt.Println("\n== 1. Credential Discovery Vulnerabilities ==") + + output := v.runPS(ctx, "DC02", `Get-ADUser -Filter * -Properties Description | Where-Object {$_.Description -match 'password|heartsbane'} | Select-Object SamAccountName,Description | Format-Table -AutoSize | Out-String -Width 200`) + if strings.Contains(strings.ToLower(output), "samwell.tarly") { + v.addResult("PASS", "Credentials", "samwell.tarly has password in description", "") + } else { + v.addResult("FAIL", "Credentials", "samwell.tarly does NOT have password in description", "") + } +} + +func (v *Validator) checkKerberosAttacks(ctx context.Context) { + fmt.Println("\n== 2. Kerberos Attack Vectors ==") + + // AS-REP Roasting + output := v.runPS(ctx, "DC02", `Get-ADUser brandon.stark -Properties DoesNotRequirePreAuth | Select-Object SamAccountName,DoesNotRequirePreAuth | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "true") { + v.addResult("PASS", "Kerberos", "brandon.stark has DoesNotRequirePreAuth (AS-REP roastable)", "") + } else { + v.addResult("FAIL", "Kerberos", "brandon.stark does NOT have PreAuth disabled", "") + } + + output = v.runPS(ctx, "DC03", `Get-ADUser missandei -Properties DoesNotRequirePreAuth | Select-Object SamAccountName,DoesNotRequirePreAuth | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "true") { + v.addResult("PASS", "Kerberos", "missandei has DoesNotRequirePreAuth (AS-REP roastable)", "") + } else { + v.addResult("FAIL", "Kerberos", "missandei does NOT have PreAuth disabled", "") + } + + // Kerberoasting + output = v.runPS(ctx, "DC02", `Get-ADUser jon.snow -Properties ServicePrincipalName | Select-Object SamAccountName,ServicePrincipalName | Format-List | Out-String`) + if strings.Contains(strings.ToLower(output), "serviceprincipalname") { + v.addResult("PASS", "Kerberos", "jon.snow has SPNs configured (Kerberoastable)", "") + } else { + v.addResult("FAIL", "Kerberos", "jon.snow does NOT have SPNs configured", "") + } + + output = v.runPS(ctx, "DC02", `Get-ADUser sql_svc -Properties ServicePrincipalName | Select-Object SamAccountName,ServicePrincipalName | Format-List | Out-String`) + if strings.Contains(strings.ToLower(output), "serviceprincipalname") { + v.addResult("PASS", "Kerberos", "sql_svc has SPNs configured (Kerberoastable)", "") + } else { + v.addResult("FAIL", "Kerberos", "sql_svc does NOT have SPNs configured", "") + } +} + +func (v *Validator) checkNetworkMisconfigs(ctx context.Context) { + fmt.Println("\n== 3. Network-Level Misconfigurations ==") + + for _, host := range []string{"SRV02", "SRV03"} { + if !v.hasHost(host) { + continue + } + output := v.runPS(ctx, host, `Get-SmbServerConfiguration | Select-Object RequireSecuritySignature,EnableSecuritySignature | Format-Table -AutoSize | Out-String`) + lower := strings.ToLower(output) + hostLabel := map[string]string{"SRV02": "CASTELBLACK", "SRV03": "BRAAVOS"}[host] + + switch { + case strings.Contains(lower, "false") && strings.Count(lower, "false") >= 2: + v.addResult("PASS", "Network", fmt.Sprintf("%s has SMB signing disabled", hostLabel), "") + case strings.Contains(lower, "false"): + v.addResult("WARN", "Network", fmt.Sprintf("%s has SMB signing enabled but not required", hostLabel), "") + default: + v.addResult("FAIL", "Network", fmt.Sprintf("%s has SMB signing enforced", hostLabel), "") + } + } +} + +func (v *Validator) checkAnonymousSMB(ctx context.Context) { + fmt.Println("\n== 4. Anonymous/Guest SMB Enumeration ==") + + // RestrictAnonymous on DC02 + output := v.runPS(ctx, "DC02", `Get-ItemProperty -Path 'HKLM:\System\CurrentControlSet\Control\Lsa' -Name RestrictAnonymous -ErrorAction SilentlyContinue | Select-Object -ExpandProperty RestrictAnonymous`) + val := strings.TrimSpace(output) + if val == "0" { + v.addResult("PASS", "SMB", "RestrictAnonymous is 0 on WINTERFELL (NULL sessions enabled)", "") + } else { + v.addResult("FAIL", "SMB", fmt.Sprintf("RestrictAnonymous is %s on WINTERFELL (expected 0)", val), "") + } + + // RestrictAnonymousSAM + output = v.runPS(ctx, "DC02", `Get-ItemProperty -Path 'HKLM:\System\CurrentControlSet\Control\Lsa' -Name RestrictAnonymousSAM -ErrorAction SilentlyContinue | Select-Object -ExpandProperty RestrictAnonymousSAM`) + val = strings.TrimSpace(output) + if val == "0" { + v.addResult("PASS", "SMB", "RestrictAnonymousSAM is 0 on WINTERFELL (SAM enum enabled)", "") + } else { + v.addResult("FAIL", "SMB", fmt.Sprintf("RestrictAnonymousSAM is %s on WINTERFELL (expected 0)", val), "") + } + + // Guest accounts on member servers + for _, host := range []string{"SRV02", "SRV03"} { + if !v.hasHost(host) { + continue + } + hostLabel := map[string]string{"SRV02": "CASTELBLACK", "SRV03": "BRAAVOS"}[host] + output = v.runPS(ctx, host, `Get-LocalUser -Name Guest | Select-Object Name,Enabled | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "true") { + v.addResult("PASS", "SMB", fmt.Sprintf("Guest account enabled on %s", hostLabel), "") + } else { + v.addResult("FAIL", "SMB", fmt.Sprintf("Guest account NOT enabled on %s", hostLabel), "") + } + } + + // LmCompatibilityLevel on DC03 + output = v.runPS(ctx, "DC03", `Get-ItemProperty -Path 'HKLM:\System\CurrentControlSet\Control\Lsa' -Name LmCompatibilityLevel -ErrorAction SilentlyContinue | Select-Object -ExpandProperty LmCompatibilityLevel`) + val = strings.TrimSpace(output) + if val == "0" || val == "1" || val == "2" { + v.addResult("PASS", "SMB", fmt.Sprintf("LmCompatibilityLevel is %s on MEEREEN (NTLM downgrade vulnerable)", val), "") + } else { + v.addResult("FAIL", "SMB", fmt.Sprintf("LmCompatibilityLevel is %s on MEEREEN (expected 0-2)", val), "") + } +} + +func (v *Validator) checkDelegation(ctx context.Context) { + fmt.Println("\n== 5. Delegation Configurations ==") + + output := v.runPS(ctx, "DC02", `Get-ADUser sansa.stark -Properties TrustedForDelegation | Select-Object SamAccountName,TrustedForDelegation | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "true") { + v.addResult("PASS", "Delegation", "sansa.stark has unconstrained delegation", "") + } else { + v.addResult("FAIL", "Delegation", "sansa.stark does NOT have unconstrained delegation", "") + } + + output = v.runPS(ctx, "DC02", `Get-ADUser jon.snow -Properties msDS-AllowedToDelegateTo | Select-Object SamAccountName,msDS-AllowedToDelegateTo | Format-List | Out-String`) + if strings.Contains(strings.ToLower(output), "msds-allowedtodelegateto") { + v.addResult("PASS", "Delegation", "jon.snow has constrained delegation configured", "") + } else { + v.addResult("FAIL", "Delegation", "jon.snow does NOT have constrained delegation", "") + } +} + +func (v *Validator) checkMachineAccountQuota(ctx context.Context) { + fmt.Println("\n== 6. Machine Account Quota ==") + + output := v.runPS(ctx, "DC01", `Get-ADObject -Identity ((Get-ADDomain).distinguishedname) -Properties ms-DS-MachineAccountQuota | Select-Object -ExpandProperty ms-DS-MachineAccountQuota`) + val := strings.TrimSpace(output) + if val == "10" { + v.addResult("PASS", "MachineQuota", "Machine Account Quota is 10 (allows RBCD)", "") + } else { + v.addResult("WARN", "MachineQuota", fmt.Sprintf("Machine Account Quota is %s (expected 10)", val), "") + } +} + +func (v *Validator) checkMSSQL(ctx context.Context) { + fmt.Println("\n== 7. MSSQL Configurations ==") + + for _, host := range []string{"SRV02", "SRV03"} { + if !v.hasHost(host) { + continue + } + hostLabel := map[string]string{"SRV02": "CASTELBLACK", "SRV03": "BRAAVOS"}[host] + output := v.runPS(ctx, host, `Get-Service 'MSSQL$SQLEXPRESS' -ErrorAction SilentlyContinue | Select-Object Name,Status,StartType | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "running") { + v.addResult("PASS", "MSSQL", fmt.Sprintf("MSSQL running on %s", hostLabel), "") + } else { + v.addResult("FAIL", "MSSQL", fmt.Sprintf("MSSQL NOT running on %s", hostLabel), "") + } + } +} + +func (v *Validator) checkADCS(ctx context.Context) { + fmt.Println("\n== 8. ADCS Configuration ==") + + if !v.hasHost("SRV03") { + return + } + + output := v.runPS(ctx, "SRV03", `Get-WindowsFeature ADCS-Cert-Authority | Select-Object Name,InstallState | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "installed") { + v.addResult("PASS", "ADCS", "ADCS installed on BRAAVOS", "") + } else { + v.addResult("FAIL", "ADCS", "ADCS NOT installed on BRAAVOS", "") + } + + output = v.runPS(ctx, "SRV03", `Get-WindowsFeature ADCS-Web-Enrollment | Select-Object Name,InstallState | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "installed") { + v.addResult("PASS", "ADCS", "ADCS Web Enrollment installed (ESC8 possible)", "") + } else { + v.addResult("WARN", "ADCS", "ADCS Web Enrollment not installed", "") + } +} + +func (v *Validator) checkACLPermissions(ctx context.Context) { + fmt.Println("\n== 9. ACL Permissions ==") + + output := v.runPS(ctx, "DC01", `$user = Get-ADUser jaime.lannister -Properties nTSecurityDescriptor; $acl = $user.nTSecurityDescriptor.Access | Where-Object { $_.IdentityReference -like '*tywin*' }; if ($acl) { Write-Output 'ACL_FOUND' } else { Write-Output 'ACL_NOT_FOUND' }`) + switch { + case strings.Contains(output, "ACL_FOUND"): + v.addResult("PASS", "ACL", "tywin.lannister has ACL rights on jaime.lannister", "") + case strings.Contains(output, "ACL_NOT_FOUND"): + v.addResult("FAIL", "ACL", "tywin.lannister does NOT have ACL rights on jaime.lannister", "") + default: + v.addResult("WARN", "ACL", "Could not verify ACL: tywin -> jaime", "") + } +} + +func (v *Validator) checkDomainTrusts(ctx context.Context) { + fmt.Println("\n== 10. Domain Trusts ==") + + output := v.runPS(ctx, "DC02", `Get-ADTrust -Filter * | Select-Object Name,Direction,TrustType | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "sevenkingdoms") { + v.addResult("PASS", "Trusts", "Parent-child trust configured (north -> sevenkingdoms)", "") + } else { + v.addResult("FAIL", "Trusts", "Parent-child trust NOT found", "") + } + + output = v.runPS(ctx, "DC01", `Get-ADTrust -Filter * | Select-Object Name,Direction,TrustType | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "essos") { + v.addResult("PASS", "Trusts", "Forest trust configured (sevenkingdoms <-> essos)", "") + } else { + v.addResult("FAIL", "Trusts", "Forest trust NOT found", "") + } +} + +func (v *Validator) checkServices(ctx context.Context) { + fmt.Println("\n== 11. Additional Services ==") + + // Print Spooler on all DCs + for _, host := range []string{"DC01", "DC02", "DC03"} { + output := v.runPS(ctx, host, `Get-Service Spooler | Select-Object Status | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "running") { + v.addResult("PASS", "Services", fmt.Sprintf("Print Spooler running on %s (coercion possible)", host), "") + } else { + v.addResult("WARN", "Services", fmt.Sprintf("Print Spooler not running on %s", host), "") + } + } + + // IIS on SRV02 + if v.hasHost("SRV02") { + output := v.runPS(ctx, "SRV02", `Get-Service W3SVC -ErrorAction SilentlyContinue | Select-Object Name,Status | Format-Table -AutoSize | Out-String`) + if strings.Contains(strings.ToLower(output), "running") { + v.addResult("PASS", "Services", "IIS running on CASTELBLACK", "") + } else { + v.addResult("FAIL", "Services", "IIS NOT running on CASTELBLACK", "") + } + } +} diff --git a/cli/internal/validate/validator.go b/cli/internal/validate/validator.go new file mode 100644 index 00000000..5f024556 --- /dev/null +++ b/cli/internal/validate/validator.go @@ -0,0 +1,157 @@ +package validate + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "strings" + "time" + + daws "github.com/dreadnode/dreadgoad/internal/aws" + "github.com/fatih/color" +) + +// Result represents a single check result. +type Result struct { + Status string `json:"status"` // PASS, FAIL, WARN + Category string `json:"category"` + Name string `json:"name"` + Detail string `json:"detail,omitzero"` +} + +// Report holds all validation results. +type Report struct { + Date string `json:"validation_date"` + Env string `json:"environment"` + Total int `json:"total_checks"` + Passed int `json:"passed"` + Failed int `json:"failed"` + Warnings int `json:"warnings"` + Results []Result `json:"checks"` +} + +// Validator runs vulnerability checks against GOAD instances. +type Validator struct { + client *daws.Client + log *slog.Logger + env string + verbose bool + report Report + hosts map[string]string // hostname -> instance ID +} + +// NewValidator creates a new Validator. +func NewValidator(client *daws.Client, env string, verbose bool, log *slog.Logger) *Validator { + if log == nil { + log = slog.Default() + } + return &Validator{ + client: client, + log: log, + env: env, + verbose: verbose, + hosts: make(map[string]string), + report: Report{ + Date: time.Now().UTC().Format(time.RFC3339), + Env: env, + }, + } +} + +// DiscoverHosts finds GOAD instances and maps hostnames to instance IDs. +func (v *Validator) DiscoverHosts(ctx context.Context) error { + instances, err := v.client.DiscoverInstances(ctx, v.env) + if err != nil { + return fmt.Errorf("discover instances: %w", err) + } + + for _, inst := range instances { + name := strings.ToUpper(inst.Name) + for _, host := range []string{"DC01", "DC02", "DC03", "SRV02", "SRV03"} { + if strings.Contains(name, host) { + v.hosts[host] = inst.InstanceID + v.addResult("PASS", "Discovery", fmt.Sprintf("Found %s", host), inst.InstanceID) + } + } + } + + // Verify required hosts + for _, required := range []string{"DC01", "DC02", "DC03"} { + if _, ok := v.hosts[required]; !ok { + v.addResult("FAIL", "Discovery", fmt.Sprintf("Missing %s", required), "not found") + return fmt.Errorf("required host %s not found", required) + } + } + return nil +} + +// RunAllChecks executes all vulnerability validation checks. +func (v *Validator) RunAllChecks(ctx context.Context) { + v.checkCredentialDiscovery(ctx) + v.checkKerberosAttacks(ctx) + v.checkNetworkMisconfigs(ctx) + v.checkAnonymousSMB(ctx) + v.checkDelegation(ctx) + v.checkMachineAccountQuota(ctx) + v.checkMSSQL(ctx) + v.checkADCS(ctx) + v.checkACLPermissions(ctx) + v.checkDomainTrusts(ctx) + v.checkServices(ctx) +} + +// GetReport returns the current report. +func (v *Validator) GetReport() *Report { + v.report.Total = len(v.report.Results) + return &v.report +} + +// SaveReport writes the report to a JSON file. +func (v *Validator) SaveReport(path string) error { + data, err := json.MarshalIndent(v.GetReport(), "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + +func (v *Validator) runPS(ctx context.Context, host, command string) string { + instanceID, ok := v.hosts[host] + if !ok { + v.log.Warn("host not found", "host", host) + return "" + } + if v.verbose { + v.log.Debug("running PS command", "host", host, "command", command) + } + result, err := v.client.RunPowerShellCommand(ctx, instanceID, command, 60*time.Second) + if err != nil { + v.log.Warn("PS command failed", "host", host, "error", err) + return "" + } + return result.Stdout +} + +func (v *Validator) addResult(status, category, name, detail string) { + r := Result{Status: status, Category: category, Name: name, Detail: detail} + v.report.Results = append(v.report.Results, r) + + switch status { + case "PASS": + v.report.Passed++ + color.Green(" ✓ %s", name) + case "FAIL": + v.report.Failed++ + color.Red(" ✗ %s", name) + case "WARN": + v.report.Warnings++ + color.Yellow(" ⚠ %s", name) + } +} + +func (v *Validator) hasHost(host string) bool { + _, ok := v.hosts[host] + return ok +} diff --git a/cli/main.go b/cli/main.go new file mode 100644 index 00000000..03778f8a --- /dev/null +++ b/cli/main.go @@ -0,0 +1,13 @@ +package main + +import ( + "os" + + "github.com/dreadnode/dreadgoad/cmd" +) + +func main() { + if err := cmd.Execute(); err != nil { + os.Exit(1) + } +}