Skip to content

Commit 0e0ac3b

Browse files
Script use one connection (#44)
# Context We want to ensure that all commands are run on a single connection # Change Concatenete the commands to have them run at once in a connection
1 parent ab2f4fb commit 0e0ac3b

File tree

2 files changed

+101
-10
lines changed

2 files changed

+101
-10
lines changed

postgresql/resource_postgresql_script.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,27 @@ func executeCommands(ctx context.Context, db *DBConnection, commands []string, t
168168
func executeBatch(ctx context.Context, db *DBConnection, commands []string, timeout int) error {
169169
timeoutContext, timeoutCancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
170170
defer timeoutCancel()
171-
for _, command := range commands {
172-
log.Printf("[DEBUG] Executing %s", command)
173-
_, err := db.ExecContext(timeoutContext, command)
174-
log.Printf("[DEBUG] Result %s: %v", command, err)
175-
if err != nil {
176-
log.Println("[DEBUG] Error catched:", err)
177-
if _, rollbackError := db.Query("ROLLBACK"); rollbackError != nil {
178-
log.Println("[DEBUG] Rollback raised an error:", rollbackError)
179-
}
180-
return err
171+
172+
// Concatenate all commands into a single SQL statement to ensure they run on one connection
173+
concatenatedSQL := ""
174+
for i, command := range commands {
175+
if i > 0 {
176+
concatenatedSQL += "; "
181177
}
178+
concatenatedSQL += command
182179
}
180+
181+
log.Printf("[DEBUG] Executing concatenated SQL: %s", concatenatedSQL)
182+
_, err := db.ExecContext(timeoutContext, concatenatedSQL)
183+
log.Printf("[DEBUG] Result: %v", err)
184+
if err != nil {
185+
log.Println("[DEBUG] Error catched:", err)
186+
if _, rollbackError := db.Query("ROLLBACK"); rollbackError != nil {
187+
log.Println("[DEBUG] Rollback raised an error:", rollbackError)
188+
}
189+
return err
190+
}
191+
183192
return nil
184193
}
185194

postgresql/resource_postgresql_script_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,85 @@ func testAccCheckTableHasRecords(dbName, tableName string, expectedCount int) re
352352
return nil
353353
}
354354
}
355+
356+
func TestAccPostgresqlScript_setLocalRoleWorks(t *testing.T) {
357+
// This test demonstrates that SET LOCAL ROLE works across separate commands
358+
// because they are concatenated into a single SQL statement to be executed on a single connection
359+
config := `
360+
resource "postgresql_role" "test_role" {
361+
name = "test_owner_role"
362+
}
363+
364+
resource "postgresql_script" "grant_role" {
365+
commands = [
366+
"GRANT test_owner_role TO CURRENT_USER",
367+
"GRANT CREATE ON SCHEMA public TO test_owner_role"
368+
]
369+
depends_on = [postgresql_role.test_role]
370+
}
371+
372+
resource "postgresql_script" "test_with_set_local_separate" {
373+
commands = [
374+
"BEGIN",
375+
"SET LOCAL ROLE test_owner_role",
376+
"CREATE TABLE test_set_locals (id INT)",
377+
"COMMIT"
378+
]
379+
depends_on = [postgresql_script.grant_role]
380+
}
381+
`
382+
383+
resource.Test(t, resource.TestCase{
384+
PreCheck: func() { testAccPreCheck(t) },
385+
Providers: testAccProviders,
386+
CheckDestroy: testAccCheckSetLocalRoleTablesDestroyed,
387+
Steps: []resource.TestStep{
388+
{
389+
Config: config,
390+
Check: resource.ComposeTestCheckFunc(
391+
testAccCheckTableExistsInDatabase("postgres", "test_set_locals"),
392+
// Both commands should now work with SET LOCAL ROLE since commands are concatenated
393+
testAccCheckTableOwner("postgres", "test_set_locals", "test_owner_role"),
394+
),
395+
},
396+
},
397+
})
398+
}
399+
400+
func testAccCheckTableOwner(dbName, tableName, expectedOwner string) resource.TestCheckFunc {
401+
return func(s *terraform.State) error {
402+
client := testAccProvider.Meta().(*Client)
403+
dbClient := client.config.NewClient(dbName)
404+
db, err := dbClient.Connect()
405+
if err != nil {
406+
return fmt.Errorf("Error connecting to database %s: %s", dbName, err)
407+
}
408+
409+
var owner string
410+
query := `SELECT tableowner FROM pg_tables WHERE schemaname = 'public' AND tablename = $1`
411+
err = db.QueryRow(query, tableName).Scan(&owner)
412+
if err != nil {
413+
return fmt.Errorf("Error checking owner of table %s: %s", tableName, err)
414+
}
415+
416+
if owner != expectedOwner {
417+
return fmt.Errorf("Expected table %s to be owned by %s but got %s", tableName, expectedOwner, owner)
418+
}
419+
420+
return nil
421+
}
422+
}
423+
424+
func testAccCheckSetLocalRoleTablesDestroyed(s *terraform.State) error {
425+
client := testAccProvider.Meta().(*Client)
426+
db, err := client.Connect()
427+
if err != nil {
428+
return nil // Skip if we can't connect
429+
}
430+
431+
_, _ = db.Exec("DROP TABLE IF EXISTS test_set_local_separate")
432+
_, _ = db.Exec("DROP TABLE IF EXISTS test_set_local_single")
433+
_, _ = db.Exec("DROP ROLE IF EXISTS test_owner_role")
434+
435+
return nil
436+
}

0 commit comments

Comments
 (0)