diff --git a/Package.swift b/Package.swift index b8c1e5f..abe5f54 100644 --- a/Package.swift +++ b/Package.swift @@ -95,7 +95,7 @@ let package = Package( .library(name: "CassandraClient", targets: ["CassandraClient"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio", .upToNextMajor(from: "2.41.1")), + .package(url: "https://github.com/apple/swift-nio", .upToNextMajor(from: "2.42.0")), .package(url: "https://github.com/apple/swift-nio-ssl", .upToNextMajor(from: "2.21.0")), .package(url: "https://github.com/apple/swift-atomics", from: "1.0.2"), .package(url: "https://github.com/apple/swift-log", .upToNextMajor(from: "1.0.0")), diff --git a/Sources/CassandraClient/CassandraClient.swift b/Sources/CassandraClient/CassandraClient.swift index eb43e26..ab2c6fa 100644 --- a/Sources/CassandraClient/CassandraClient.swift +++ b/Sources/CassandraClient/CassandraClient.swift @@ -21,7 +21,7 @@ import NIOConcurrencyHelpers /// `CassandraClient` is a wrapper around the [Datastax Cassandra C++ Driver](https://github.com/datastax/cpp-driver) /// and can be used to run queries against a Cassandra database. public class CassandraClient: CassandraSession { - private let eventLoopGroupContainer: EventLoopGroupConainer + private let eventLoopGroupContainer: EventLoopGroupContainer public var eventLoopGroup: EventLoopGroup { self.eventLoopGroupContainer.value } @@ -247,4 +247,4 @@ extension CassandraClient { } #endif -internal typealias EventLoopGroupConainer = (value: EventLoopGroup, managed: Bool) +internal typealias EventLoopGroupContainer = (value: EventLoopGroup, managed: Bool) diff --git a/Sources/CassandraClient/Session.swift b/Sources/CassandraClient/Session.swift index 51fa7cf..1842b99 100644 --- a/Sources/CassandraClient/Session.swift +++ b/Sources/CassandraClient/Session.swift @@ -190,15 +190,14 @@ extension CassandraSession { extension CassandraClient { internal final class Session: CassandraSession { - private let eventLoopGroupContainer: EventLoopGroupConainer + private let eventLoopGroupContainer: EventLoopGroupContainer public var eventLoopGroup: EventLoopGroup { self.eventLoopGroupContainer.value } private let configuration: Configuration private let logger: Logger - private var state = State.idle - private let lock = Lock() + private let stateStore = NIOLockedValueBox(State.idle) private let rawPointer: OpaquePointer @@ -212,7 +211,7 @@ extension CassandraClient { case disconnected } - internal init(configuration: Configuration, logger: Logger, eventLoopGroupContainer: EventLoopGroupConainer) { + internal init(configuration: Configuration, logger: Logger, eventLoopGroupContainer: EventLoopGroupContainer) { self.configuration = configuration self.logger = logger self.eventLoopGroupContainer = eventLoopGroupContainer @@ -220,23 +219,29 @@ extension CassandraClient { } deinit { - guard case .disconnected = (self.lock.withLock { self.state }) else { + let isDisconnected = self.stateStore.withLockedValue { state in + if case .disconnected = state { + return true + } + return false + } + guard isDisconnected else { preconditionFailure("Session not shut down before the deinit. Please call session.shutdown() when no longer needed.") } cass_session_free(self.rawPointer) } func shutdown() throws { - self.lock.lock() - defer { - self.state = .disconnected - self.lock.unlock() - } - switch self.state { - case .connected: - try self.disconect() - default: - break + try self.stateStore.withLockedValue { (state: inout State) in + defer { + state = .disconnected + } + switch state { + case .connected: + try self.disconect() + default: + break + } } } @@ -244,26 +249,30 @@ extension CassandraClient { let eventLoop = eventLoop ?? self.eventLoopGroup.next() let logger = logger ?? self.logger - self.lock.lock() - switch self.state { + let (startingState, future) = self.stateStore.withLockedValue { (state: inout State) -> (State, EventLoopFuture?) in + if case .idle = state { + let future = self.connect(on: eventLoop, logger: logger) + state = .connectingFuture(future) + return (.idle, future) + } else { + return (state, nil) + } + } + + switch startingState { case .idle: - let future = self.connect(on: eventLoop, logger: logger) - self.state = .connectingFuture(future) - self.lock.unlock() - return future.flatMap { _ in - self.lock.withLock { - self.state = .connected + return future!.flatMap { _ in + self.stateStore.withLockedValue { (state: inout State) in + state = .connected } return self.execute(statement: statement, on: eventLoop, logger: logger) } case .connectingFuture(let future): - self.lock.unlock() return future.flatMap { _ in self.execute(statement: statement, on: eventLoop, logger: logger) } #if compiler(>=5.5) && canImport(_Concurrency) case .connecting(let task): - self.lock.unlock() let promise = eventLoop.makePromise(of: Rows.self) if #available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) { promise.completeWithTask { @@ -274,7 +283,6 @@ extension CassandraClient { return promise.futureResult #endif case .connected: - self.lock.unlock() logger.debug("executing: \(statement.query)") logger.trace("\(statement.parameters)") let promise = eventLoop.makePromise(of: Rows.self) @@ -284,7 +292,6 @@ extension CassandraClient { } return promise.futureResult case .disconnected: - self.lock.unlock() if self.eventLoopGroupContainer.managed { // eventloop *is* shutdown now preconditionFailure("client is disconnected") @@ -444,28 +451,30 @@ extension CassandraClient.Session { func execute(statement: CassandraClient.Statement, logger: Logger? = .none) async throws -> CassandraClient.Rows { let logger = logger ?? self.logger - lock.lock() - switch state { - case .idle: - let task = self.connect(logger: logger) - state = .connecting(ConnectionTask(task)) - lock.unlock() + let (startingState, task) = self.stateStore.withLockedValue { (state: inout State) -> (State, Task?) in + if case .idle = state { + let task = self.connect(logger: logger) + state = .connecting(ConnectionTask(task)) + return (.idle, task) + } else { + return (state, nil) + } + } - try await task.value - lock.withLock { - self.state = .connected + switch startingState { + case .idle: + try await task!.value + self.stateStore.withLockedValue { (state: inout State) in + state = .connected } return try await self.execute(statement: statement, logger: logger) case .connectingFuture(let future): - lock.unlock() try await future.get() return try await self.execute(statement: statement, logger: logger) case .connecting(let task): - lock.unlock() try await task.task.value return try await self.execute(statement: statement, logger: logger) case .connected: - lock.unlock() logger.debug("executing: \(statement.query)") logger.trace("\(statement.parameters)") let future = cass_session_execute(rawPointer, statement.rawPointer) @@ -475,7 +484,6 @@ extension CassandraClient.Session { } } case .disconnected: - lock.unlock() if eventLoopGroupContainer.managed { // eventloop *is* shutdown now preconditionFailure("client is disconnected")