diff --git a/CHANGELOG.md b/CHANGELOG.md index 341ea9d58..14383e621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## [Unreleased] - Fix regression restricting the characters in an `elasticstack_elasticsearch_role_mapping` `name`. ([#1373](https://github.com/elastic/terraform-provider-elasticstack/pull/1373)) +- Add schema validations to require either (but not both) `index` and `data_view_id` is set for relevant Security Detection Rules ## [0.12.0] - 2025-10-15 diff --git a/internal/fleet/output/schema.go b/internal/fleet/output/schema.go index 590374c21..b0c25d986 100644 --- a/internal/fleet/output/schema.go +++ b/internal/fleet/output/schema.go @@ -152,10 +152,7 @@ func getSchema() schema.Schema { int64planmodifier.UseStateForUnknown(), }, Validators: []validator.Int64{ - validators.Int64ConditionalRequirement( - path.Root("kafka").AtName("compression"), - []string{"gzip"}, - ), + validators.AllowedIfDependentPathEquals(path.Root("kafka").AtName("compression"), "gzip"), }, }, "connection_type": schema.StringAttribute{ @@ -163,7 +160,7 @@ func getSchema() schema.Schema { Optional: true, Validators: []validator.String{ stringvalidator.OneOf("plaintext", "encryption"), - validators.StringConditionalRequirementSingle( + validators.AllowedIfDependentPathEquals( path.Root("kafka").AtName("auth_type"), "none", ), diff --git a/internal/kibana/security_detection_rule/acc_test.go b/internal/kibana/security_detection_rule/acc_test.go index 3f4a21f56..546fb8726 100644 --- a/internal/kibana/security_detection_rule/acc_test.go +++ b/internal/kibana/security_detection_rule/acc_test.go @@ -3,6 +3,7 @@ package security_detection_rule_test import ( "context" "fmt" + "regexp" "strings" "testing" @@ -73,7 +74,6 @@ func TestAccResourceSecurityDetectionRule_Query(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "severity", "medium"), resource.TestCheckResourceAttr(resourceName, "risk_score", "50"), resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), - resource.TestCheckResourceAttr(resourceName, "data_view_id", "test-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "test-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom Query Rule Name"), resource.TestCheckResourceAttr(resourceName, "timestamp_override", "@timestamp"), @@ -151,7 +151,6 @@ func TestAccResourceSecurityDetectionRule_Query(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "description", "Updated test query security detection rule"), resource.TestCheckResourceAttr(resourceName, "severity", "high"), resource.TestCheckResourceAttr(resourceName, "risk_score", "75"), - resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "updated-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom Query Rule Name"), resource.TestCheckResourceAttr(resourceName, "timestamp_override", "event.ingested"), @@ -268,7 +267,6 @@ func TestAccResourceSecurityDetectionRule_EQL(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "description", "Test EQL security detection rule"), resource.TestCheckResourceAttr(resourceName, "severity", "high"), resource.TestCheckResourceAttr(resourceName, "risk_score", "70"), - resource.TestCheckResourceAttr(resourceName, "index.0", "winlogbeat-*"), resource.TestCheckResourceAttr(resourceName, "tiebreaker_field", "@timestamp"), resource.TestCheckResourceAttr(resourceName, "data_view_id", "eql-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "eql-namespace"), @@ -728,7 +726,6 @@ func TestAccResourceSecurityDetectionRule_NewTerms(t *testing.T) { checkResourceJSONAttr(resourceName, "filters", `[{"bool": {"should": [{"wildcard": {"user.domain": "*.internal"}}, {"term": {"user.type": "service_account"}}]}}]`), resource.TestCheckResourceAttr(resourceName, "history_window_start", "now-14d"), - resource.TestCheckResourceAttr(resourceName, "data_view_id", "new-terms-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "new-terms-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom New Terms Rule Name"), resource.TestCheckResourceAttr(resourceName, "timestamp_override", "user.created"), @@ -868,7 +865,6 @@ func TestAccResourceSecurityDetectionRule_SavedQuery(t *testing.T) { // Check filters field checkResourceJSONAttr(resourceName, "filters", `[{"prefix": {"event.action": "user_"}}]`), - resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), resource.TestCheckResourceAttr(resourceName, "data_view_id", "saved-query-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "saved-query-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom Saved Query Rule Name"), @@ -943,8 +939,6 @@ func TestAccResourceSecurityDetectionRule_SavedQuery(t *testing.T) { // Check filters field (updated values) checkResourceJSONAttr(resourceName, "filters", `[{"script": {"script": {"source": "doc['event.severity'].value > 2"}}}]`), - resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), - resource.TestCheckResourceAttr(resourceName, "index.1", "audit-*"), resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-saved-query-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "updated-saved-query-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom Saved Query Rule Name"), @@ -1033,7 +1027,6 @@ func TestAccResourceSecurityDetectionRule_ThreatMatch(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "severity", "high"), resource.TestCheckResourceAttr(resourceName, "risk_score", "80"), resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), - resource.TestCheckResourceAttr(resourceName, "data_view_id", "threat-match-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "threat-match-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom Threat Match Rule Name"), resource.TestCheckResourceAttr(resourceName, "timestamp_override", "threat.indicator.first_seen"), @@ -1115,7 +1108,6 @@ func TestAccResourceSecurityDetectionRule_ThreatMatch(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "risk_score", "95"), resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), resource.TestCheckResourceAttr(resourceName, "index.1", "network-*"), - resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-threat-match-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "updated-threat-match-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom Threat Match Rule Name"), resource.TestCheckResourceAttr(resourceName, "timestamp_override", "threat.indicator.last_seen"), @@ -1206,7 +1198,6 @@ func TestAccResourceSecurityDetectionRule_Threshold(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "description", "Test threshold security detection rule"), resource.TestCheckResourceAttr(resourceName, "severity", "medium"), resource.TestCheckResourceAttr(resourceName, "risk_score", "55"), - resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), resource.TestCheckResourceAttr(resourceName, "data_view_id", "threshold-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "threshold-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Custom Threshold Rule Name"), @@ -1277,8 +1268,6 @@ func TestAccResourceSecurityDetectionRule_Threshold(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "description", "Updated test threshold security detection rule"), resource.TestCheckResourceAttr(resourceName, "severity", "high"), resource.TestCheckResourceAttr(resourceName, "risk_score", "75"), - resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), - resource.TestCheckResourceAttr(resourceName, "index.1", "audit-*"), resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-threshold-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "updated-threshold-namespace"), resource.TestCheckResourceAttr(resourceName, "rule_name_override", "Updated Custom Threshold Rule Name"), @@ -1437,7 +1426,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { to = "now" interval = "5m" index = ["logs-*"] - data_view_id = "test-data-view-id" namespace = "test-namespace" rule_name_override = "Custom Query Rule Name" timestamp_override = "@timestamp" @@ -1555,7 +1543,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { author = ["Test Author"] tags = ["test", "automation"] license = "Elastic License v2" - data_view_id = "updated-data-view-id" namespace = "updated-namespace" rule_name_override = "Updated Custom Query Rule Name" timestamp_override = "event.ingested" @@ -1695,7 +1682,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { from = "now-6m" to = "now" interval = "5m" - index = ["winlogbeat-*"] tiebreaker_field = "@timestamp" data_view_id = "eql-data-view-id" namespace = "eql-namespace" @@ -2273,7 +2259,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { index = ["logs-*"] new_terms_fields = ["user.name"] history_window_start = "now-14d" - data_view_id = "new-terms-data-view-id" namespace = "new-terms-namespace" rule_name_override = "Custom New Terms Rule Name" timestamp_override = "user.created" @@ -2492,7 +2477,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { from = "now-6m" to = "now" interval = "5m" - index = ["logs-*"] saved_id = "test-saved-query-id" data_view_id = "saved-query-data-view-id" namespace = "saved-query-namespace" @@ -2588,7 +2572,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { from = "now-6m" to = "now" interval = "5m" - index = ["logs-*", "audit-*"] saved_id = "test-saved-query-id-updated" data_view_id = "updated-saved-query-data-view-id" namespace = "updated-saved-query-namespace" @@ -2705,7 +2688,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { to = "now" interval = "5m" index = ["logs-*"] - data_view_id = "threat-match-data-view-id" namespace = "threat-match-namespace" rule_name_override = "Custom Threat Match Rule Name" timestamp_override = "threat.indicator.first_seen" @@ -2828,7 +2810,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { to = "now" interval = "5m" index = ["logs-*", "network-*"] - data_view_id = "updated-threat-match-data-view-id" namespace = "updated-threat-match-namespace" threat_index = ["threat-intel-*", "ioc-*"] threat_query = "threat.indicator.type:(ip OR domain)" @@ -2958,7 +2939,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { from = "now-6m" to = "now" interval = "5m" - index = ["logs-*"] data_view_id = "threshold-data-view-id" namespace = "threshold-namespace" rule_name_override = "Custom Threshold Rule Name" @@ -3065,7 +3045,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { from = "now-6m" to = "now" interval = "5m" - index = ["logs-*", "audit-*"] data_view_id = "updated-threshold-data-view-id" namespace = "updated-threshold-namespace" author = ["Test Author"] @@ -3206,7 +3185,6 @@ func TestAccResourceSecurityDetectionRule_WithConnectorAction(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "severity", "medium"), resource.TestCheckResourceAttr(resourceName, "risk_score", "50"), resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), - resource.TestCheckResourceAttr(resourceName, "data_view_id", "connector-action-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "connector-action-namespace"), // Check risk score mapping @@ -3241,7 +3219,6 @@ func TestAccResourceSecurityDetectionRule_WithConnectorAction(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "description", "Updated test security detection rule with connector action"), resource.TestCheckResourceAttr(resourceName, "severity", "high"), resource.TestCheckResourceAttr(resourceName, "risk_score", "75"), - resource.TestCheckResourceAttr(resourceName, "data_view_id", "updated-connector-action-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "updated-connector-action-namespace"), resource.TestCheckResourceAttr(resourceName, "tags.#", "2"), resource.TestCheckResourceAttr(resourceName, "tags.0", "test"), @@ -3310,7 +3287,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { to = "now" interval = "5m" index = ["logs-*"] - data_view_id = "connector-action-data-view-id" namespace = "connector-action-namespace" risk_score_mapping = [ @@ -3381,7 +3357,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { to = "now" interval = "5m" index = ["logs-*"] - data_view_id = "updated-connector-action-data-view-id" namespace = "updated-connector-action-namespace" tags = ["test", "terraform"] @@ -3444,7 +3419,6 @@ func TestAccResourceSecurityDetectionRule_BuildingBlockType(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "severity", "low"), resource.TestCheckResourceAttr(resourceName, "risk_score", "21"), resource.TestCheckResourceAttr(resourceName, "index.0", "logs-*"), - resource.TestCheckResourceAttr(resourceName, "data_view_id", "building-block-data-view-id"), resource.TestCheckResourceAttr(resourceName, "namespace", "building-block-namespace"), resource.TestCheckResourceAttr(resourceName, "building_block_type", "default"), @@ -3503,7 +3477,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { to = "now" interval = "5m" index = ["logs-*"] - data_view_id = "building-block-data-view-id" namespace = "building-block-namespace" building_block_type = "default" } @@ -3528,7 +3501,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { from = "now-6m" to = "now" interval = "5m" - index = ["logs-*"] data_view_id = "updated-building-block-data-view-id" namespace = "updated-building-block-namespace" building_block_type = "default" @@ -3557,7 +3529,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { from = "now-6m" to = "now" interval = "5m" - index = ["logs-*"] data_view_id = "no-building-block-data-view-id" namespace = "no-building-block-namespace" } @@ -3716,7 +3687,6 @@ resource "elasticstack_kibana_security_detection_rule" "test" { to = "now" interval = "5m" index = ["logs-*"] - data_view_id = "no-filters-data-view-id" namespace = "no-filters-namespace" # Note: No filters field specified - this tests removing filters from a rule @@ -4771,3 +4741,190 @@ resource "elasticstack_kibana_security_detection_rule" "test" { } `, name) } + +// TestAccResourceSecurityDetectionRule_ValidateConfig tests the ValidateConfig method +// to ensure proper validation of index vs data_view_id configuration +func TestAccResourceSecurityDetectionRule_ValidateConfig(t *testing.T) { + resource.Test(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ProtoV6ProviderFactories: acctest.Providers, + Steps: []resource.TestStep{ + // Test 1: Valid config with only index (should succeed) + { + Config: testAccSecurityDetectionRuleConfig_validationIndexOnly("test-validation-index-only"), + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("elasticstack_kibana_security_detection_rule.test", "name", "test-validation-index-only"), + resource.TestCheckResourceAttr("elasticstack_kibana_security_detection_rule.test", "index.0", "logs-*"), + resource.TestCheckNoResourceAttr("elasticstack_kibana_security_detection_rule.test", "data_view_id"), + ), + }, + // Test 2: Valid config with only data_view_id (should succeed) + { + Config: testAccSecurityDetectionRuleConfig_validationDataViewOnly("test-validation-dataview-only"), + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("elasticstack_kibana_security_detection_rule.test", "name", "test-validation-dataview-only"), + resource.TestCheckResourceAttr("elasticstack_kibana_security_detection_rule.test", "data_view_id", "test-data-view-id"), + resource.TestCheckNoResourceAttr("elasticstack_kibana_security_detection_rule.test", "index.0"), + ), + }, + // Test 3: Invalid config with both index and data_view_id (should fail) + { + Config: testAccSecurityDetectionRuleConfig_validationBothIndexAndDataView("test-validation-both"), + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + ExpectError: regexp.MustCompile("Both 'index' and 'data_view_id' cannot be set at the same time"), + PlanOnly: true, + }, + // Test 4: Invalid config with neither index nor data_view_id (should fail) + { + Config: testAccSecurityDetectionRuleConfig_validationNeither("test-validation-neither"), + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + ExpectError: regexp.MustCompile("One of 'index' or 'data_view_id' must be set"), + PlanOnly: true, + }, + // Test 5: ESQL rule type should skip validation (both index and data_view_id allowed to be unset) + { + Config: testAccSecurityDetectionRuleConfig_validationESQLType("test-validation-esql"), + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("elasticstack_kibana_security_detection_rule.test", "name", "test-validation-esql"), + resource.TestCheckResourceAttr("elasticstack_kibana_security_detection_rule.test", "type", "esql"), + resource.TestCheckNoResourceAttr("elasticstack_kibana_security_detection_rule.test", "index.0"), + resource.TestCheckNoResourceAttr("elasticstack_kibana_security_detection_rule.test", "data_view_id"), + ), + }, + // Test 6: Machine learning rule type should skip validation (both index and data_view_id allowed to be unset) + { + Config: testAccSecurityDetectionRuleConfig_validationMLType("test-validation-ml"), + SkipFunc: versionutils.CheckIfVersionIsUnsupported(minVersionSupport), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("elasticstack_kibana_security_detection_rule.test", "name", "test-validation-ml"), + resource.TestCheckResourceAttr("elasticstack_kibana_security_detection_rule.test", "type", "machine_learning"), + resource.TestCheckNoResourceAttr("elasticstack_kibana_security_detection_rule.test", "index.0"), + resource.TestCheckNoResourceAttr("elasticstack_kibana_security_detection_rule.test", "data_view_id"), + ), + }, + }, + }) +} + +// Helper function configurations for validation tests + +func testAccSecurityDetectionRuleConfig_validationIndexOnly(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "*:*" + language = "kuery" + enabled = true + description = "Test validation with index only" + severity = "medium" + risk_score = 50 + index = ["logs-*"] +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_validationDataViewOnly(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "*:*" + language = "kuery" + enabled = true + description = "Test validation with data_view_id only" + severity = "medium" + risk_score = 50 + data_view_id = "test-data-view-id" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_validationBothIndexAndDataView(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "*:*" + language = "kuery" + enabled = true + description = "Test validation with both index and data_view_id (should fail)" + severity = "medium" + risk_score = 50 + index = ["logs-*"] + data_view_id = "test-data-view-id" +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_validationNeither(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "query" + query = "*:*" + language = "kuery" + enabled = true + description = "Test validation with neither index nor data_view_id (should fail)" + severity = "medium" + risk_score = 50 +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_validationESQLType(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "esql" + query = "FROM logs-* | WHERE event.action == \"login\" | STATS count(*) BY user.name" + language = "esql" + enabled = true + description = "Test ESQL validation bypass - neither index nor data_view_id required" + severity = "medium" + risk_score = 50 +} +`, name) +} + +func testAccSecurityDetectionRuleConfig_validationMLType(name string) string { + return fmt.Sprintf(` +provider "elasticstack" { + kibana {} +} + +resource "elasticstack_kibana_security_detection_rule" "test" { + name = "%s" + type = "machine_learning" + enabled = true + description = "Test ML validation bypass - neither index nor data_view_id required" + severity = "medium" + risk_score = 50 + anomaly_threshold = 75 + machine_learning_job_id = ["test-ml-job"] +} +`, name) +} diff --git a/internal/kibana/security_detection_rule/models_esql.go b/internal/kibana/security_detection_rule/models_esql.go index a55eaeb0f..aa51a7997 100644 --- a/internal/kibana/security_detection_rule/models_esql.go +++ b/internal/kibana/security_detection_rule/models_esql.go @@ -9,7 +9,6 @@ import ( "github.com/google/uuid" "github.com/hashicorp/terraform-plugin-framework/attr" "github.com/hashicorp/terraform-plugin-framework/diag" - "github.com/hashicorp/terraform-plugin-framework/path" "github.com/hashicorp/terraform-plugin-framework/types" ) @@ -59,35 +58,10 @@ func (e EsqlRuleProcessor) ExtractId(response any) (string, diag.Diagnostics) { return value.Id.String(), diags } -// applyEsqlValidations validates that ESQL-specific constraints are met -func (d SecurityDetectionRuleData) applyEsqlValidations(diags *diag.Diagnostics) { - if utils.IsKnown(d.Index) { - diags.AddAttributeError( - path.Root("index"), - "Invalid attribute 'index'", - "ESQL rules do not use index patterns. Please remove the 'index' attribute.", - ) - } - - if utils.IsKnown(d.Filters) { - diags.AddAttributeError( - path.Root("filters"), - "Invalid attribute 'filters'", - "ESQL rules do not support filters. Please remove the 'filters' attribute.", - ) - } -} - func (d SecurityDetectionRuleData) toEsqlRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { var diags diag.Diagnostics var createProps kbapi.SecurityDetectionsAPIRuleCreateProps - // Apply ESQL-specific validations - d.applyEsqlValidations(&diags) - if diags.HasError() { - return createProps, diags - } - esqlRule := kbapi.SecurityDetectionsAPIEsqlRuleCreateProps{ Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), @@ -153,12 +127,6 @@ func (d SecurityDetectionRuleData) toEsqlRuleUpdateProps(ctx context.Context, cl var diags diag.Diagnostics var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps - // Apply ESQL-specific validations - d.applyEsqlValidations(&diags) - if diags.HasError() { - return updateProps, diags - } - // Parse ID to get space_id and rule_id compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) diags.Append(resourceIdDiags...) diff --git a/internal/kibana/security_detection_rule/models_machine_learning.go b/internal/kibana/security_detection_rule/models_machine_learning.go index 783adbbbe..7b3cf5c21 100644 --- a/internal/kibana/security_detection_rule/models_machine_learning.go +++ b/internal/kibana/security_detection_rule/models_machine_learning.go @@ -59,26 +59,10 @@ func (m MachineLearningRuleProcessor) ExtractId(response any) (string, diag.Diag return value.Id.String(), diags } -// applyMachineLearningValidations validates that Machine learning-specific constraints are met -func (d SecurityDetectionRuleData) applyMachineLearningValidations(diags *diag.Diagnostics) { - if !utils.IsKnown(d.AnomalyThreshold) { - diags.AddAttributeError( - path.Root("anomaly_threshold"), - "Missing attribute 'anomaly_threshold'", - "Machine learning rules require an 'anomaly_threshold' attribute.", - ) - } -} - func (d SecurityDetectionRuleData) toMachineLearningRuleCreateProps(ctx context.Context, client clients.MinVersionEnforceable) (kbapi.SecurityDetectionsAPIRuleCreateProps, diag.Diagnostics) { var diags diag.Diagnostics var createProps kbapi.SecurityDetectionsAPIRuleCreateProps - d.applyMachineLearningValidations(&diags) - if diags.HasError() { - return createProps, diags - } - mlRule := kbapi.SecurityDetectionsAPIMachineLearningRuleCreateProps{ Name: kbapi.SecurityDetectionsAPIRuleName(d.Name.ValueString()), Description: kbapi.SecurityDetectionsAPIRuleDescription(d.Description.ValueString()), @@ -156,11 +140,6 @@ func (d SecurityDetectionRuleData) toMachineLearningRuleUpdateProps(ctx context. var diags diag.Diagnostics var updateProps kbapi.SecurityDetectionsAPIRuleUpdateProps - d.applyMachineLearningValidations(&diags) - if diags.HasError() { - return updateProps, diags - } - // Parse ID to get space_id and rule_id compId, resourceIdDiags := clients.CompositeIdFromStrFw(d.Id.ValueString()) diags.Append(resourceIdDiags...) diff --git a/internal/kibana/security_detection_rule/schema.go b/internal/kibana/security_detection_rule/schema.go index c131f78fc..0dab25047 100644 --- a/internal/kibana/security_detection_rule/schema.go +++ b/internal/kibana/security_detection_rule/schema.go @@ -4,11 +4,14 @@ import ( "context" "regexp" + "github.com/elastic/terraform-provider-elasticstack/internal/utils" "github.com/elastic/terraform-provider-elasticstack/internal/utils/customtypes" + "github.com/elastic/terraform-provider-elasticstack/internal/utils/validators" "github.com/hashicorp/terraform-plugin-framework-jsontypes/jsontypes" "github.com/hashicorp/terraform-plugin-framework-validators/int64validator" "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/path" "github.com/hashicorp/terraform-plugin-framework/resource" "github.com/hashicorp/terraform-plugin-framework/resource/schema" "github.com/hashicorp/terraform-plugin-framework/resource/schema/booldefault" @@ -73,6 +76,13 @@ func GetSchema() schema.Schema { "data_view_id": schema.StringAttribute{ MarkdownDescription: "Data view ID for the rule. Not supported for esql and machine_learning rule types.", Optional: true, + Validators: []validator.String{ + // Enforce that data_view_id is not set if the rule type is ml or esql + validators.ForbiddenIfDependentPathOneOf( + path.Root("type"), + []string{"machine_learning", "esql"}, + ), + }, }, "namespace": schema.StringAttribute{ MarkdownDescription: "Alerts index namespace. Available for all rule types.", @@ -108,6 +118,13 @@ func GetSchema() schema.Schema { MarkdownDescription: "Indices on which the rule functions.", Optional: true, Computed: true, + Validators: []validator.List{ + // Enforce that index is not set if the rule type is ml or esql + validators.ForbiddenIfDependentPathOneOf( + path.Root("type"), + []string{"machine_learning", "esql"}, + ), + }, }, "enabled": schema.BoolAttribute{ MarkdownDescription: "Determines whether the rule is enabled.", @@ -302,6 +319,12 @@ func GetSchema() schema.Schema { MarkdownDescription: "Query and filter context array to define alert conditions as JSON. Supports complex filter structures including bool queries, term filters, range filters, etc. Available for all rule types.", Optional: true, CustomType: jsontypes.NormalizedType{}, + Validators: []validator.String{ + validators.ForbiddenIfDependentPathOneOf( + path.Root("type"), + []string{"machine_learning", "esql"}, + ), + }, }, "note": schema.StringAttribute{ MarkdownDescription: "Notes to help investigate alerts produced by the rule.", @@ -602,6 +625,10 @@ func GetSchema() schema.Schema { MarkdownDescription: "Anomaly score threshold above which the rule creates an alert. Valid values are from 0 to 100. Required for machine_learning rules.", Optional: true, Validators: []validator.Int64{ + validators.RequiredIfDependentPathEquals( + path.Root("type"), + "machine_learning", + ), int64validator.Between(0, 100), }, }, @@ -913,3 +940,41 @@ func getThreatSubtechniqueElementType() attr.Type { techniqueType := threatType.AttributeTypes()["technique"].(attr.TypeWithElementType).ElementType().(attr.TypeWithAttributeTypes) return techniqueType.AttributeTypes()["subtechnique"].(attr.TypeWithElementType).ElementType() } + +// ValidateConfig validates the configuration for a security detection rule resource. +// It ensures that the configuration meets the following requirements: +// +// - For rule types "esql" and "machine_learning", no additional validation is performed +// - For other rule types, exactly one of 'index' or 'data_view_id' must be specified +// - Both 'index' and 'data_view_id' cannot be set simultaneously +// +// The function adds appropriate error diagnostics if validation fails. +func (r securityDetectionRuleResource) ValidateConfig(ctx context.Context, req resource.ValidateConfigRequest, resp *resource.ValidateConfigResponse) { + var data SecurityDetectionRuleData + + resp.Diagnostics.Append(req.Config.Get(ctx, &data)...) + + if resp.Diagnostics.HasError() { + return + } + + if data.Type.ValueString() == "esql" || data.Type.ValueString() == "machine_learning" { + return + } + + if utils.IsKnown(data.Index) && utils.IsKnown(data.DataViewId) { + resp.Diagnostics.AddError( + "Invalid Configuration", + "Both 'index' and 'data_view_id' cannot be set at the same time.", + ) + + } + + if !utils.IsKnown(data.Index) && !utils.IsKnown(data.DataViewId) { + resp.Diagnostics.AddError( + "Invalid Configuration", + "One of 'index' or 'data_view_id' must be set.", + ) + + } +} diff --git a/internal/utils/validators/conditional.go b/internal/utils/validators/conditional.go index a13d400a5..df011533c 100644 --- a/internal/utils/validators/conditional.go +++ b/internal/utils/validators/conditional.go @@ -12,133 +12,322 @@ import ( "github.com/hashicorp/terraform-plugin-framework/types" ) -// conditionalRequirement represents a validator which ensures that an attribute -// can only be set if another attribute at a specified path equals one of the specified values. -// This is a shared implementation that can be used for both string and float64 validators. -type conditionalRequirement struct { +type valueValidator func(dependentFieldHasAllowedValue bool, dependentValueStr string, val attr.Value, p path.Path) diag.Diagnostics + +// condition represents a validation rule that enforces conditional requirements +// based on the value of a dependent field. It contains the path to the field +// that this condition depends on and a list of allowed values for that field. +// When the dependent field matches one of the allowed values, additional +// validation logic can be applied to the current field. +type condition struct { + description func() string dependentPath path.Path allowedValues []string + validateValue valueValidator } // Description describes the validation in plain text formatting. -func (v conditionalRequirement) Description(_ context.Context) string { - if len(v.allowedValues) == 1 { - return fmt.Sprintf("value can only be set when %s equals %q", v.dependentPath, v.allowedValues[0]) - } - return fmt.Sprintf("value can only be set when %s is one of %v", v.dependentPath, v.allowedValues) +func (v condition) Description(ctx context.Context) string { + return v.description() } // MarkdownDescription describes the validation in Markdown formatting. -func (v conditionalRequirement) MarkdownDescription(ctx context.Context) string { - return v.Description(ctx) +func (v condition) MarkdownDescription(ctx context.Context) string { + return v.description() } -func (v conditionalRequirement) validate(ctx context.Context, config tfsdk.Config, val attr.Value, p path.Path) diag.Diagnostics { - if val.IsNull() || val.IsUnknown() { - return nil - } - - // Get the value at the dependent path +// dependentFieldHasAllowedValue checks if the dependent field specified by the condition's +// dependentPath has a value that matches one of the allowed values defined in the condition. +// It retrieves the dependent field's value from the provided configuration context and +// compares it against the condition's allowedValues slice. +// +// The method returns three values: +// - bool: true if the dependent field has a non-null, non-unknown value that matches +// one of the allowed values; false otherwise +// - string: the string representation of the dependent field's current value +// - diag.Diagnostics: any diagnostics encountered while retrieving the field value +// +// If the dependent field is null, unknown, or its value doesn't match any of the +// allowed values, the condition is considered not met and the method returns false. +func (v condition) dependentFieldHasAllowedValue(ctx context.Context, config tfsdk.Config) (bool, string, diag.Diagnostics) { var dependentValue types.String diags := config.GetAttribute(ctx, v.dependentPath, &dependentValue) + if diags.HasError() { - return diags + return false, "", diags } - // If dependent value is null, unknown, or doesn't match any allowed values, - // then the current attribute should not be set dependentValueStr := dependentValue.ValueString() - isAllowed := false + dependentFieldHasAllowedValue := false if !dependentValue.IsNull() && !dependentValue.IsUnknown() { for _, allowedValue := range v.allowedValues { if dependentValueStr == allowedValue { - isAllowed = true + dependentFieldHasAllowedValue = true break } } } - if !isAllowed { - if len(v.allowedValues) == 1 { - diags.AddAttributeError(p, "Invalid Configuration", - fmt.Sprintf("Attribute %s can only be set when %s equals %q, but %s is %q", - p, - v.dependentPath, - v.allowedValues[0], - v.dependentPath, - dependentValueStr, - ), - ) - return diags - } else { - diags.AddAttributeError(p, "Invalid Configuration", - fmt.Sprintf("Attribute %s can only be set when %s is one of %v, but %s is %q", - p, - v.dependentPath, - v.allowedValues, - v.dependentPath, - dependentValueStr, - ), - ) - return diags - } + return dependentFieldHasAllowedValue, dependentValueStr, nil +} + +func (v condition) validate(ctx context.Context, config tfsdk.Config, val attr.Value, p path.Path) diag.Diagnostics { + dependentFieldHasAllowedValue, dependentValueStr, diags := v.dependentFieldHasAllowedValue(ctx, config) + if diags.HasError() { + return diags } - return nil + return v.validateValue(dependentFieldHasAllowedValue, dependentValueStr, val, p) } -// validateConditionalRequirement was an attempt at shared logic but is not used -// The validation logic is implemented directly in ValidateString and ValidateFloat64 methods +func (v condition) ValidateString(ctx context.Context, request validator.StringRequest, response *validator.StringResponse) { + response.Diagnostics.Append(v.validate(ctx, request.Config, request.ConfigValue, request.Path)...) +} -// ValidateString performs the validation for string attributes. -func (v conditionalRequirement) ValidateString(ctx context.Context, request validator.StringRequest, response *validator.StringResponse) { +func (v condition) ValidateList(ctx context.Context, request validator.ListRequest, response *validator.ListResponse) { response.Diagnostics.Append(v.validate(ctx, request.Config, request.ConfigValue, request.Path)...) } -// ValidateInt64 performs the validation for int64 attributes. -func (v conditionalRequirement) ValidateInt64(ctx context.Context, request validator.Int64Request, response *validator.Int64Response) { +func (v condition) ValidateInt64(ctx context.Context, request validator.Int64Request, response *validator.Int64Response) { response.Diagnostics.Append(v.validate(ctx, request.Config, request.ConfigValue, request.Path)...) } -// StringConditionalRequirement returns a validator which ensures that a string attribute -// can only be set if another attribute at the specified path equals one of the specified values. +// DependantPathOneOf creates a condition that validates a dependent path's value is one of the allowed values. +// It returns a condition that checks if the value at dependentPath matches any of the provided allowedValues. +// If the dependent field does not have an allowed value, it generates a diagnostic error indicating +// which values are permitted and what the current value is. // -// The dependentPath parameter should use path.Root() to specify the attribute path. -// For example: path.Root("auth_type") +// Parameters: +// - dependentPath: The path to the attribute that must have one of the allowed values +// - allowedValues: A slice of strings representing the valid values for the dependent path // -// Example usage: +// Returns: +// - condition: A condition struct that can be used for validation +func DependantPathOneOf(dependentPath path.Path, allowedValues []string) condition { + return condition{ + dependentPath: dependentPath, + allowedValues: allowedValues, + description: func() string { + return fmt.Sprintf("Attribute '%v' is not one of %s", + dependentPath, + allowedValues, + ) + }, + validateValue: func(dependentFieldHasAllowedValue bool, dependentValueStr string, val attr.Value, p path.Path) diag.Diagnostics { + if !dependentFieldHasAllowedValue { + var diags diag.Diagnostics + diags.AddAttributeError(p, "Invalid Configuration", fmt.Sprintf("Attribute '%s' is not one of %v, %s is %q", + dependentPath, + allowedValues, + dependentPath, + dependentValueStr, + )) + + return diags + } + + return nil + }, + } +} + +// AllowedIfDependentPathOneOf creates a validation condition that allows the current attribute +// to be set only when a dependent attribute at the specified path has one of the allowed values. +// +// Parameters: +// - dependentPath: The path to the attribute that this validation depends on +// - allowedValues: A slice of string values that the dependent attribute must match +// +// Returns: +// - condition: A validation condition that can be used with conditional validators // -// "connection_type": schema.StringAttribute{ -// Optional: true, -// Validators: []validator.String{ -// validators.StringConditionalRequirement( -// path.Root("auth_type"), -// []string{"none"}, -// "connection_type can only be set when auth_type is 'none'", -// ), -// }, -// }, -func StringConditionalRequirement(dependentPath path.Path, allowedValues []string) validator.String { - return conditionalRequirement{ +// Example: +// +// // Only allow "ssl_cert" to be set when "protocol" is "https" +// AllowedIfDependentPathOneOf(path.Root("protocol"), []string{"https"}) +func AllowedIfDependentPathOneOf(dependentPath path.Path, allowedValues []string) condition { + return condition{ dependentPath: dependentPath, allowedValues: allowedValues, + description: func() string { + if len(allowedValues) == 1 { + return fmt.Sprintf("value can only be set when %s equals %q", dependentPath, allowedValues[0]) + } + return fmt.Sprintf("value can only be set when %s is one of %v", dependentPath, allowedValues) + }, + validateValue: func(dependentFieldHasAllowedValue bool, dependentValueStr string, val attr.Value, p path.Path) diag.Diagnostics { + var diags diag.Diagnostics + isEmpty := val.IsNull() || val.IsUnknown() + isSet := !isEmpty + + if dependentFieldHasAllowedValue { + return diags + } + + if isSet { + if len(allowedValues) == 1 { + diags.AddAttributeError(p, "Invalid Configuration", + fmt.Sprintf("Attribute %s can only be set when %s equals %q, but %s is %q", + p, + dependentPath, + allowedValues[0], + dependentPath, + dependentValueStr, + ), + ) + } else { + diags.AddAttributeError(p, "Invalid Configuration", + fmt.Sprintf("Attribute %s can only be set when %s is one of %v, but %s is %q", + p, + dependentPath, + allowedValues, + dependentPath, + dependentValueStr, + ), + ) + } + } + + return diags + }, } } -// StringConditionalRequirementSingle is a convenience function for when there's only one allowed value. -func StringConditionalRequirementSingle(dependentPath path.Path, requiredValue string) validator.String { - return StringConditionalRequirement(dependentPath, []string{requiredValue}) +// AllowedIfDependentPathEquals returns a condition that allows a field to be set +// only if the value at the specified dependent path equals the required value. +// This is a convenience function that wraps AllowedIfDependentPathOneOf with a +// single value slice. +// +// Parameters: +// - dependentPath: The path to the field whose value determines if this field is allowed +// - requiredValue: The exact string value that the dependent field must equal +// +// Returns: +// - condition: A validation condition that enforces the dependency rule +func AllowedIfDependentPathEquals(dependentPath path.Path, requiredValue string) condition { + return AllowedIfDependentPathOneOf(dependentPath, []string{requiredValue}) +} + +// RequiredIfDependentPathEquals returns a condition that makes a field required +// when the value at the specified dependent path equals the given required value. +// This is a convenience function that wraps RequiredIfDependentPathOneOf with +// a single value slice. +// +// Parameters: +// - dependentPath: The path to the field whose value will be checked +// - requiredValue: The value that, when present at dependentPath, makes this field required +// +// Returns: +// - condition: A validation condition function +func RequiredIfDependentPathEquals(dependentPath path.Path, requiredValue string) condition { + return RequiredIfDependentPathOneOf(dependentPath, []string{requiredValue}) } -func Int64ConditionalRequirement(dependentPath path.Path, allowedValues []string) validator.Int64 { - return conditionalRequirement{ +// RequiredIfDependentPathOneOf returns a condition that validates an attribute is required +// when a dependent attribute's value matches one of the specified allowed values. +// +// The condition checks if the dependent attribute (specified by dependentPath) has a value +// that is present in the allowedValues slice. If the dependent attribute matches any of +// the allowed values, then the attribute being validated must not be null or unknown. +// +// Parameters: +// - dependentPath: The path to the attribute whose value determines the requirement +// - allowedValues: A slice of string values that trigger the requirement when matched +// +// Returns: +// - condition: A validation condition that enforces the requirement rule +// +// Example usage: +// +// validator := RequiredIfDependentPathOneOf( +// path.Root("type"), +// []string{"custom", "advanced"}, +// ) +// // This would require the current attribute when "type" equals "custom" or "advanced" +func RequiredIfDependentPathOneOf(dependentPath path.Path, allowedValues []string) condition { + return condition{ dependentPath: dependentPath, allowedValues: allowedValues, + description: func() string { + if len(allowedValues) == 1 { + return fmt.Sprintf("value required when %s equals %q", dependentPath, allowedValues[0]) + } + return fmt.Sprintf("value required when %s is one of %v", dependentPath, allowedValues) + }, + validateValue: func(dependentFieldHasAllowedValue bool, dependentValueStr string, val attr.Value, p path.Path) diag.Diagnostics { + var diags diag.Diagnostics + isEmpty := val.IsNull() || val.IsUnknown() + + if !dependentFieldHasAllowedValue { + return diags + } + + if isEmpty { + diags.AddAttributeError(p, "Invalid Configuration", + fmt.Sprintf("Attribute %s must be set when %s equals %q", + p, + dependentPath, + allowedValues[0], + ), + ) + } + return diags + }, } } -// Int64ConditionalRequirementSingle is a convenience function for when there's only one allowed value. -func Int64ConditionalRequirementSingle(dependentPath path.Path, requiredValue string) validator.Int64 { - return Int64ConditionalRequirement(dependentPath, []string{requiredValue}) +// ForbiddenIfDependentPathOneOf creates a validation condition that forbids setting a value +// when a dependent field matches one of the specified allowed values. +// +// This validator is useful for creating mutually exclusive configuration scenarios where +// certain attributes should not be set when another attribute has specific values. +// +// Parameters: +// - dependentPath: The path to the field whose value determines the validation behavior +// - allowedValues: A slice of string values that, when matched by the dependent field, +// will trigger the forbidden condition +// +// Returns: +// - condition: A validation condition that will generate an error if the current field +// is set while the dependent field matches any of the allowed values +// +// Example usage: +// +// validator := ForbiddenIfDependentPathOneOf( +// path.Root("type"), +// []string{"basic", "simple"}, +// ) +// // This will prevent setting the current attribute when "type" equals "basic" or "simple" +func ForbiddenIfDependentPathOneOf(dependentPath path.Path, allowedValues []string) condition { + return condition{ + dependentPath: dependentPath, + allowedValues: allowedValues, + description: func() string { + if len(allowedValues) == 1 { + return fmt.Sprintf("value cannot be set when %s equals %q", dependentPath, allowedValues[0]) + } + return fmt.Sprintf("value cannot be set when %s is one of %v", dependentPath, allowedValues) + }, + validateValue: func(dependentFieldHasAllowedValue bool, dependentValueStr string, val attr.Value, p path.Path) diag.Diagnostics { + var diags diag.Diagnostics + + if !dependentFieldHasAllowedValue { + return diags + } + + isEmpty := val.IsNull() || val.IsUnknown() + isSet := !isEmpty + if isSet { + diags.AddAttributeError(p, "Invalid Configuration", + fmt.Sprintf("Attribute %s cannot be set when %s equals %q", + p, + dependentPath, + allowedValues[0], + ), + ) + } + return diags + }, + } } diff --git a/internal/utils/validators/conditional_test.go b/internal/utils/validators/conditional_test.go index faca2f8e5..f9e0c363e 100644 --- a/internal/utils/validators/conditional_test.go +++ b/internal/utils/validators/conditional_test.go @@ -10,9 +10,10 @@ import ( "github.com/hashicorp/terraform-plugin-framework/tfsdk" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-go/tftypes" + "github.com/stretchr/testify/require" ) -func TestStringConditionalRequirement(t *testing.T) { +func TestAllowedIfDependentPathOneOf(t *testing.T) { t.Parallel() type testCase struct { @@ -109,7 +110,7 @@ func TestStringConditionalRequirement(t *testing.T) { } // Create validator - v := StringConditionalRequirement( + v := AllowedIfDependentPathOneOf( path.Root("auth_type"), []string{"none"}, ) @@ -139,8 +140,8 @@ func TestStringConditionalRequirement(t *testing.T) { } } -func TestStringConditionalRequirement_Description(t *testing.T) { - v := StringConditionalRequirement( +func TestAllowedIfDependentPathOneOf_Description(t *testing.T) { + v := AllowedIfDependentPathOneOf( path.Root("auth_type"), []string{"none"}, ) @@ -153,50 +154,376 @@ func TestStringConditionalRequirement_Description(t *testing.T) { } } -func TestInt64ConditionalRequirement(t *testing.T) { +func TestForbiddenIfDependentPathOneOf(t *testing.T) { t.Parallel() type testCase struct { name string - currentValue types.Int64 + currentValue types.String dependentValue types.String expectedError bool } testCases := []testCase{ { - name: "valid - current null, dependent any value", - currentValue: types.Int64Null(), - dependentValue: types.StringValue("none"), + name: "valid - current null, dependent matches forbidden value", + currentValue: types.StringNull(), + dependentValue: types.StringValue("https"), expectedError: false, }, { - name: "valid - current unknown, dependent any value", - currentValue: types.Int64Unknown(), + name: "valid - current unknown, dependent matches forbidden value", + currentValue: types.StringUnknown(), + dependentValue: types.StringValue("https"), + expectedError: false, + }, + { + name: "valid - current set, dependent doesn't match forbidden value", + currentValue: types.StringValue("custom_cert"), + dependentValue: types.StringValue("http"), + expectedError: false, + }, + { + name: "invalid - current set, dependent matches forbidden value", + currentValue: types.StringValue("custom_cert"), + dependentValue: types.StringValue("https"), + expectedError: true, + }, + { + name: "invalid - current set, dependent matches one of multiple forbidden values", + currentValue: types.StringValue("custom_cert"), + dependentValue: types.StringValue("tls"), + expectedError: true, + }, + { + name: "valid - current set, dependent is null", + currentValue: types.StringValue("custom_cert"), + dependentValue: types.StringNull(), + expectedError: false, + }, + { + name: "valid - current set, dependent is unknown", + currentValue: types.StringValue("custom_cert"), + dependentValue: types.StringUnknown(), + expectedError: false, + }, + { + name: "valid - current null, dependent is null", + currentValue: types.StringNull(), + dependentValue: types.StringNull(), + expectedError: false, + }, + { + name: "valid - current null, dependent is unknown", + currentValue: types.StringNull(), + dependentValue: types.StringUnknown(), + expectedError: false, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + // Create a simple schema for testing + testSchema := schema.Schema{ + Attributes: map[string]schema.Attribute{ + "custom_cert": schema.StringAttribute{ + Optional: true, + }, + "protocol": schema.StringAttribute{ + Optional: true, + }, + }, + } + + // Create raw config values + currentTfValue, err := testCase.currentValue.ToTerraformValue(context.Background()) + if err != nil { + t.Fatalf("Error converting current value: %v", err) + } + dependentTfValue, err := testCase.dependentValue.ToTerraformValue(context.Background()) + if err != nil { + t.Fatalf("Error converting dependent value: %v", err) + } + + rawConfigValues := map[string]tftypes.Value{ + "custom_cert": currentTfValue, + "protocol": dependentTfValue, + } + + rawConfig := tftypes.NewValue( + tftypes.Object{ + AttributeTypes: map[string]tftypes.Type{ + "custom_cert": tftypes.String, + "protocol": tftypes.String, + }, + }, + rawConfigValues, + ) + + config := tfsdk.Config{ + Raw: rawConfig, + Schema: testSchema, + } + + // Create validator - StringConditionalForbidden forbids the field when dependent matches forbidden values + v := ForbiddenIfDependentPathOneOf( + path.Root("protocol"), + []string{"https", "tls"}, + ) + + // Create validation request + request := validator.StringRequest{ + Path: path.Root("custom_cert"), + ConfigValue: testCase.currentValue, + Config: config, + } + + // Run validation + response := &validator.StringResponse{} + v.ValidateString(context.Background(), request, response) + + // Check result + if testCase.expectedError { + if !response.Diagnostics.HasError() { + t.Errorf("Expected validation error but got none") + } + } else { + if response.Diagnostics.HasError() { + t.Errorf("Expected no validation error but got: %v", response.Diagnostics.Errors()) + } + } + }) + } +} + +func TestForbiddenIfDependentPathOneOf_Description(t *testing.T) { + v := ForbiddenIfDependentPathOneOf( + path.Root("protocol"), + []string{"https", "tls"}, + ) + + description := v.Description(context.Background()) + // Note: Currently the Description method doesn't differentiate between allowed and forbidden + // This matches the current implementation behavior + require.Equal(t, "value cannot be set when protocol is one of [https tls]", description) +} + +func TestRequiredIfDependentPathOneOf(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + currentValue types.String + dependentValue types.String + expectedError bool + } + + testCases := []testCase{ + { + name: "valid - current set, dependent matches required value", + currentValue: types.StringValue("some_value"), + dependentValue: types.StringValue("ssl"), + expectedError: false, + }, + { + name: "valid - current null, dependent doesn't match required value", + currentValue: types.StringNull(), dependentValue: types.StringValue("none"), expectedError: false, }, + { + name: "valid - current unknown, dependent doesn't match required value", + currentValue: types.StringUnknown(), + dependentValue: types.StringValue("basic"), + expectedError: false, + }, + { + name: "valid - current set, dependent matches one of multiple allowed values", + currentValue: types.StringValue("certificate_path"), + dependentValue: types.StringValue("tls"), + expectedError: false, + }, + { + name: "invalid - current null, dependent matches required value", + currentValue: types.StringNull(), + dependentValue: types.StringValue("ssl"), + expectedError: true, + }, + { + name: "invalid - current unknown, dependent matches required value", + currentValue: types.StringUnknown(), + dependentValue: types.StringValue("tls"), + expectedError: true, + }, + { + name: "valid - current null, dependent is null", + currentValue: types.StringNull(), + dependentValue: types.StringNull(), + expectedError: false, + }, + { + name: "valid - current null, dependent is unknown", + currentValue: types.StringNull(), + dependentValue: types.StringUnknown(), + expectedError: false, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + // Create a simple schema for testing + testSchema := schema.Schema{ + Attributes: map[string]schema.Attribute{ + "ssl_cert": schema.StringAttribute{ + Optional: true, + }, + "security_mode": schema.StringAttribute{ + Optional: true, + }, + }, + } + + // Create raw config values + currentTfValue, err := testCase.currentValue.ToTerraformValue(context.Background()) + if err != nil { + t.Fatalf("Error converting current value: %v", err) + } + dependentTfValue, err := testCase.dependentValue.ToTerraformValue(context.Background()) + if err != nil { + t.Fatalf("Error converting dependent value: %v", err) + } + + rawConfigValues := map[string]tftypes.Value{ + "ssl_cert": currentTfValue, + "security_mode": dependentTfValue, + } + + rawConfig := tftypes.NewValue( + tftypes.Object{ + AttributeTypes: map[string]tftypes.Type{ + "ssl_cert": tftypes.String, + "security_mode": tftypes.String, + }, + }, + rawConfigValues, + ) + + config := tfsdk.Config{ + Raw: rawConfig, + Schema: testSchema, + } + + // Create validator - RequiredIfDependentPathOneOf requires the field when dependent matches allowed values + v := RequiredIfDependentPathOneOf( + path.Root("security_mode"), + []string{"ssl", "tls"}, + ) + + // Create validation request + request := validator.StringRequest{ + Path: path.Root("ssl_cert"), + ConfigValue: testCase.currentValue, + Config: config, + } + + // Run validation + response := &validator.StringResponse{} + v.ValidateString(context.Background(), request, response) + + // Check result + if testCase.expectedError { + if !response.Diagnostics.HasError() { + t.Errorf("Expected validation error but got none") + } + } else { + if response.Diagnostics.HasError() { + t.Errorf("Expected no validation error but got: %v", response.Diagnostics.Errors()) + } + } + }) + } +} + +func TestRequiredIfDependentPathOneOf_Description(t *testing.T) { + v := RequiredIfDependentPathOneOf( + path.Root("security_mode"), + []string{"ssl", "tls"}, + ) + + description := v.Description(context.Background()) + require.Equal(t, "value required when security_mode is one of [ssl tls]", description) +} + +func TestRequiredIfDependentPathEquals_Description(t *testing.T) { + v := RequiredIfDependentPathEquals( + path.Root("auth_type"), + "oauth", + ) + + description := v.Description(context.Background()) + expected := "value required when auth_type equals \"oauth\"" + + if description != expected { + t.Errorf("Expected description %q, got %q", expected, description) + } +} + +func TestDependantPathOneOf(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + currentValue types.String + dependentValue types.String + expectedError bool + } + + testCases := []testCase{ + { + name: "valid - current null, dependent matches allowed value", + currentValue: types.StringNull(), + dependentValue: types.StringValue("machine_learning"), + expectedError: false, + }, + { + name: "valid - current unknown, dependent matches allowed value", + currentValue: types.StringUnknown(), + dependentValue: types.StringValue("esql"), + expectedError: false, + }, { name: "valid - current set, dependent matches required value", - currentValue: types.Int64Value(6), - dependentValue: types.StringValue("gzip"), + currentValue: types.StringValue("some_value"), + dependentValue: types.StringValue("machine_learning"), expectedError: false, }, + { + name: "invalid - current null, dependent doesn't match required value", + currentValue: types.StringNull(), + dependentValue: types.StringValue("other_type"), + expectedError: true, + }, { name: "invalid - current set, dependent doesn't match required value", - currentValue: types.Int64Value(6), - dependentValue: types.StringValue("none"), + currentValue: types.StringValue("some_value"), + dependentValue: types.StringValue("other_type"), expectedError: true, }, { name: "invalid - current set, dependent is null", - currentValue: types.Int64Value(6), + currentValue: types.StringValue("some_value"), dependentValue: types.StringNull(), expectedError: true, }, { name: "invalid - current set, dependent is unknown", - currentValue: types.Int64Value(6), + currentValue: types.StringValue("some_value"), dependentValue: types.StringUnknown(), expectedError: true, }, @@ -210,10 +537,10 @@ func TestInt64ConditionalRequirement(t *testing.T) { // Create a simple schema for testing testSchema := schema.Schema{ Attributes: map[string]schema.Attribute{ - "compression_level": schema.Float64Attribute{ + "some_field": schema.StringAttribute{ Optional: true, }, - "compression": schema.StringAttribute{ + "type": schema.StringAttribute{ Optional: true, }, }, @@ -230,15 +557,15 @@ func TestInt64ConditionalRequirement(t *testing.T) { } rawConfigValues := map[string]tftypes.Value{ - "compression_level": currentTfValue, - "compression": dependentTfValue, + "some_field": currentTfValue, + "type": dependentTfValue, } rawConfig := tftypes.NewValue( tftypes.Object{ AttributeTypes: map[string]tftypes.Type{ - "compression_level": tftypes.Number, - "compression": tftypes.String, + "some_field": tftypes.String, + "type": tftypes.String, }, }, rawConfigValues, @@ -249,22 +576,22 @@ func TestInt64ConditionalRequirement(t *testing.T) { Schema: testSchema, } - // Create validator - v := Int64ConditionalRequirement( - path.Root("compression"), - []string{"gzip"}, + // Create validator - StringAssert validates that the dependent field matches allowed values + v := DependantPathOneOf( + path.Root("type"), + []string{"machine_learning", "esql"}, ) // Create validation request - request := validator.Int64Request{ - Path: path.Root("compression_level"), + request := validator.StringRequest{ + Path: path.Root("some_field"), ConfigValue: testCase.currentValue, Config: config, } // Run validation - response := &validator.Int64Response{} - v.ValidateInt64(context.Background(), request, response) + response := &validator.StringResponse{} + v.ValidateString(context.Background(), request, response) // Check result if testCase.expectedError { @@ -279,3 +606,17 @@ func TestInt64ConditionalRequirement(t *testing.T) { }) } } + +func TestDependantPathOneOf_Description(t *testing.T) { + v := DependantPathOneOf( + path.Root("type"), + []string{"machine_learning", "esql"}, + ) + + description := v.Description(context.Background()) + expected := "Attribute 'type' is not one of [machine_learning esql]" + + if description != expected { + t.Errorf("Expected description %q, got %q", expected, description) + } +}