diff --git a/src/Tests/ApprovalFiles/NoPublicApiChanges.Run.approved.cs b/src/Tests/ApprovalFiles/NoPublicApiChanges.Run.approved.cs index b3131dc..5b70f6b 100644 --- a/src/Tests/ApprovalFiles/NoPublicApiChanges.Run.approved.cs +++ b/src/Tests/ApprovalFiles/NoPublicApiChanges.Run.approved.cs @@ -25,6 +25,7 @@ public static void PostgresqlDatabase(this DbUp.SupportedDatabasesForDropDatabas public static void PostgresqlDatabase(this DbUp.SupportedDatabasesForDropDatabase supported, string connectionString, DbUp.Engine.Output.IUpgradeLog logger, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate) { } public static void PostgresqlDatabase(this DbUp.SupportedDatabasesForEnsureDatabase supported, string connectionString, DbUp.Engine.Output.IUpgradeLog logger, DbUp.Postgresql.PostgresqlConnectionOptions connectionOptions) { } public static void PostgresqlDatabase(this DbUp.SupportedDatabasesForEnsureDatabase supported, string connectionString, DbUp.Engine.Output.IUpgradeLog logger, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate) { } + public static void PostgresqlDatabase(this DbUp.SupportedDatabasesForEnsureDatabase supported, string connectionString, DbUp.Engine.Output.IUpgradeLog logger, DbUp.Postgresql.PostgresqlConnectionOptions connectionOptions, string owner) { } } namespace DbUp.Postgresql { diff --git a/src/dbup-postgresql/PostgresqlExtensions.cs b/src/dbup-postgresql/PostgresqlExtensions.cs index 8c965ab..a20b10c 100644 --- a/src/dbup-postgresql/PostgresqlExtensions.cs +++ b/src/dbup-postgresql/PostgresqlExtensions.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Data; using System.Security.Cryptography.X509Certificates; using System.Text.RegularExpressions; @@ -16,7 +16,7 @@ /// public static class PostgresqlExtensions { - private static readonly string pattern= @"(?i)Search\s?Path=([^;]+)"; + private static readonly string pattern = @"(?i)Search\s?Path=([^;]+)"; /// /// Creates an upgrader for PostgreSQL databases. /// @@ -175,7 +175,7 @@ public static void PostgresqlDatabase(this SupportedDatabasesForEnsureDatabase s public static void PostgresqlDatabase(this SupportedDatabasesForEnsureDatabase supported, string connectionString, IUpgradeLog logger, X509Certificate2 certificate) { var options = new PostgresqlConnectionOptions - { + { ClientCertificate = certificate }; PostgresqlDatabase(supported, connectionString, logger, options); @@ -189,11 +189,30 @@ public static void PostgresqlDatabase(this SupportedDatabasesForEnsureDatabase s /// The used to record actions. /// Connection options to set SSL parameters public static void PostgresqlDatabase( - this SupportedDatabasesForEnsureDatabase supported, - string connectionString, - IUpgradeLog logger, + this SupportedDatabasesForEnsureDatabase supported, + string connectionString, + IUpgradeLog logger, PostgresqlConnectionOptions connectionOptions ) + { + PostgresqlDatabase(supported, connectionString, logger, connectionOptions, null); + } + + /// + /// Ensures that the database specified in the connection string exists, assigning an owner at creation time. + /// + /// Fluent helper type. + /// The connection string. + /// The used to record actions. + /// Connection options to set SSL parameters + /// Role to own the new database during creation (adds 'WITH OWNER = "role"'). + public static void PostgresqlDatabase( + this SupportedDatabasesForEnsureDatabase supported, + string connectionString, + IUpgradeLog logger, + PostgresqlConnectionOptions connectionOptions, + string owner + ) { if (supported == null) throw new ArgumentNullException("supported"); @@ -205,7 +224,7 @@ PostgresqlConnectionOptions connectionOptions if (logger == null) throw new ArgumentNullException("logger"); var masterConnectionStringBuilder = new NpgsqlConnectionStringBuilder(connectionString); - + var databaseName = masterConnectionStringBuilder.Database; if (string.IsNullOrEmpty(databaseName) || databaseName.Trim() == string.Empty) @@ -232,9 +251,9 @@ PostgresqlConnectionOptions connectionOptions // check to see if the database already exists.. using (var command = new NpgsqlCommand(sqlCommandText, connection) - { - CommandType = CommandType.Text - }) + { + CommandType = CommandType.Text + }) { var results = Convert.ToInt32(command.ExecuteScalar()); @@ -245,18 +264,56 @@ PostgresqlConnectionOptions connectionOptions } } - sqlCommandText = $"create database \"{databaseName}\";"; - - // Create the database... - using (var command = new NpgsqlCommand(sqlCommandText, connection) - { - CommandType = CommandType.Text - }) + if (string.IsNullOrEmpty(owner)) { - command.ExecuteNonQuery(); + sqlCommandText = $"create database \"{databaseName}\";"; + + // Create the database... + using (var command = new NpgsqlCommand(sqlCommandText, connection) + { + CommandType = CommandType.Text + }) + { + command.ExecuteNonQuery(); + } + + logger.LogInformation(@"Created database {0}", databaseName); } + else + { + sqlCommandText = "select exists (select 1 from pg_roles where rolname = @owner);"; + // check to see if the owner exists.. + using (var command = new NpgsqlCommand(sqlCommandText, connection) + { + CommandType = CommandType.Text + }) + { + command.Parameters.AddWithValue("@owner", owner); + + var roleExists = (bool)command.ExecuteScalar(); + // if the owner role does not exist, we throw an exception. + if (!roleExists) + { + throw new InvalidOperationException($"PostgreSQL role '{owner}' does not exist."); + } + } - logger.LogInformation(@"Created database {0}", databaseName); + using var formattedSql = new NpgsqlCommand("select format('create database %I with owner = %I', @databaseName, @owner);", connection); + formattedSql.Parameters.AddWithValue("databaseName", databaseName); + formattedSql.Parameters.AddWithValue("owner", owner); + sqlCommandText = (string)formattedSql.ExecuteScalar(); + + // Create the database.. + using (var command = new NpgsqlCommand(sqlCommandText, connection) + { + CommandType = CommandType.Text, + }) + { + command.ExecuteNonQuery(); + } + + logger.LogInformation(@"Created database {0} with owner {1}", databaseName, owner); + } } /// @@ -347,7 +404,7 @@ PostgresqlConnectionOptions connectionOptions var masterConnectionStringBuilder = new NpgsqlConnectionStringBuilder(connectionString); - var databaseName = masterConnectionStringBuilder.Database; + var databaseName = masterConnectionStringBuilder.Database; if (string.IsNullOrEmpty(databaseName) || databaseName.Trim() == string.Empty) { @@ -379,9 +436,9 @@ PostgresqlConnectionOptions connectionOptions // check to see if the database already exists.. using (var command = new NpgsqlCommand(sqlCommandText, connection) - { - CommandType = CommandType.Text - }) + { + CommandType = CommandType.Text + }) { var results = Convert.ToInt32(command.ExecuteScalar()); @@ -396,9 +453,9 @@ PostgresqlConnectionOptions connectionOptions // prevent new connections to the database sqlCommandText = $"alter database \"{databaseName}\" with ALLOW_CONNECTIONS false;"; using (var command = new NpgsqlCommand(sqlCommandText, connection) - { - CommandType = CommandType.Text - }) + { + CommandType = CommandType.Text + }) { command.ExecuteNonQuery(); } @@ -408,9 +465,9 @@ PostgresqlConnectionOptions connectionOptions // terminate all existing connections to the database sqlCommandText = $"select pg_terminate_backend(pg_stat_activity.pid) from pg_stat_activity where pg_stat_activity.datname = \'{databaseName}\';"; using (var command = new NpgsqlCommand(sqlCommandText, connection) - { - CommandType = CommandType.Text - }) + { + CommandType = CommandType.Text + }) { command.ExecuteNonQuery(); } @@ -421,9 +478,9 @@ PostgresqlConnectionOptions connectionOptions // drop the database using (var command = new NpgsqlCommand(sqlCommandText, connection) - { - CommandType = CommandType.Text - }) + { + CommandType = CommandType.Text + }) { command.ExecuteNonQuery(); }