diff --git a/src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java b/src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java index 2f4df1eba03b4..3d901881a5977 100644 --- a/src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java +++ b/src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java @@ -1600,13 +1600,13 @@ public int doFinal(ByteBuffer src, ByteBuffer dst) Arrays.fill(dst.array(), ofs, ofs + len, (byte) 0); } else { - NIO_ACCESS.acquireSession(dst); + int ticket = NIO_ACCESS.acquireSession(dst); try { Unsafe.getUnsafe().setMemory( NIO_ACCESS.getBufferAddress(dst), len + dst.position(), (byte) 0); } finally { - NIO_ACCESS.releaseSession(dst); + NIO_ACCESS.releaseSession(dst, ticket); } } } diff --git a/src/java.base/share/classes/java/lang/ClassLoader.java b/src/java.base/share/classes/java/lang/ClassLoader.java index fdb5a80e2a2a1..efad002af4827 100644 --- a/src/java.base/share/classes/java/lang/ClassLoader.java +++ b/src/java.base/share/classes/java/lang/ClassLoader.java @@ -1075,13 +1075,13 @@ protected final Class defineClass(String name, ByteBuffer b, private Class defineClass(String name, ByteBuffer b, int len, ProtectionDomain pd) { pd = preDefineClass(name, pd); String source = defineClassSourceLocation(pd); - SharedSecrets.getJavaNioAccess().acquireSession(b); + int ticket = SharedSecrets.getJavaNioAccess().acquireSession(b); try { Class c = defineClass2(this, name, b, b.position(), len, pd, source); postDefineClass(c, pd); return c; } finally { - SharedSecrets.getJavaNioAccess().releaseSession(b); + SharedSecrets.getJavaNioAccess().releaseSession(b, ticket); } } diff --git a/src/java.base/share/classes/java/nio/Buffer.java b/src/java.base/share/classes/java/nio/Buffer.java index c0c2cfe80e7f0..b0555708eed87 100644 --- a/src/java.base/share/classes/java/nio/Buffer.java +++ b/src/java.base/share/classes/java/nio/Buffer.java @@ -878,19 +878,20 @@ public MemorySegment bufferSegment(Buffer buffer) { } @Override - public void acquireSession(Buffer buffer) { + public int acquireSession(Buffer buffer) { var scope = buffer.session(); if (scope != null) { - scope.acquire0(); + return scope.acquire0(); } + return 0; } @Override - public void releaseSession(Buffer buffer) { + public void releaseSession(Buffer buffer, int ticket) { try { var scope = buffer.session(); if (scope != null) { - scope.release0(); + scope.release0(ticket); } } finally { Reference.reachabilityFence(buffer); diff --git a/src/java.base/share/classes/java/util/zip/Adler32.java b/src/java.base/share/classes/java/util/zip/Adler32.java index fa2fdb1f4c3a0..5c3525e2c5ed5 100644 --- a/src/java.base/share/classes/java/util/zip/Adler32.java +++ b/src/java.base/share/classes/java/util/zip/Adler32.java @@ -97,11 +97,11 @@ public void update(ByteBuffer buffer) { if (rem <= 0) return; if (buffer.isDirect()) { - NIO_ACCESS.acquireSession(buffer); + int ticket = NIO_ACCESS.acquireSession(buffer); try { adler = updateByteBuffer(adler, NIO_ACCESS.getBufferAddress(buffer), pos, rem); } finally { - NIO_ACCESS.releaseSession(buffer); + NIO_ACCESS.releaseSession(buffer, ticket); } } else if (buffer.hasArray()) { adler = updateBytes(adler, buffer.array(), pos + buffer.arrayOffset(), rem); diff --git a/src/java.base/share/classes/java/util/zip/CRC32.java b/src/java.base/share/classes/java/util/zip/CRC32.java index 80b4f05a3c1c6..bcf7c98fda481 100644 --- a/src/java.base/share/classes/java/util/zip/CRC32.java +++ b/src/java.base/share/classes/java/util/zip/CRC32.java @@ -96,11 +96,11 @@ public void update(ByteBuffer buffer) { if (rem <= 0) return; if (buffer.isDirect()) { - NIO_ACCESS.acquireSession(buffer); + int ticket = NIO_ACCESS.acquireSession(buffer); try { crc = updateByteBuffer(crc, NIO_ACCESS.getBufferAddress(buffer), pos, rem); } finally { - NIO_ACCESS.releaseSession(buffer); + NIO_ACCESS.releaseSession(buffer, ticket); } } else if (buffer.hasArray()) { crc = updateBytes(crc, buffer.array(), pos + buffer.arrayOffset(), rem); diff --git a/src/java.base/share/classes/java/util/zip/CRC32C.java b/src/java.base/share/classes/java/util/zip/CRC32C.java index 5b4b3597e9c39..d711a3f19c6b1 100644 --- a/src/java.base/share/classes/java/util/zip/CRC32C.java +++ b/src/java.base/share/classes/java/util/zip/CRC32C.java @@ -174,12 +174,12 @@ public void update(ByteBuffer buffer) { } if (buffer.isDirect()) { - NIO_ACCESS.acquireSession(buffer); + int ticket = NIO_ACCESS.acquireSession(buffer); try { crc = updateDirectByteBuffer(crc, NIO_ACCESS.getBufferAddress(buffer), pos, limit); } finally { - NIO_ACCESS.releaseSession(buffer); + NIO_ACCESS.releaseSession(buffer, ticket); } } else if (buffer.hasArray()) { crc = updateBytes(crc, buffer.array(), pos + buffer.arrayOffset(), diff --git a/src/java.base/share/classes/java/util/zip/Deflater.java b/src/java.base/share/classes/java/util/zip/Deflater.java index c3ce84263b29b..c0b578c74ebd1 100644 --- a/src/java.base/share/classes/java/util/zip/Deflater.java +++ b/src/java.base/share/classes/java/util/zip/Deflater.java @@ -319,12 +319,12 @@ public void setDictionary(ByteBuffer dictionary) { int remaining = Math.max(dictionary.limit() - position, 0); ensureOpen(); if (dictionary.isDirect()) { - NIO_ACCESS.acquireSession(dictionary); + int ticket = NIO_ACCESS.acquireSession(dictionary); try { long address = NIO_ACCESS.getBufferAddress(dictionary); setDictionaryBuffer(zsRef.address(), address + position, remaining); } finally { - NIO_ACCESS.releaseSession(dictionary); + NIO_ACCESS.releaseSession(dictionary, ticket); } } else { byte[] array = ZipUtils.getBufferArray(dictionary); @@ -574,7 +574,7 @@ public int deflate(byte[] output, int off, int len, int flush) { inputPos = input.position(); int inputRem = Math.max(input.limit() - inputPos, 0); if (input.isDirect()) { - NIO_ACCESS.acquireSession(input); + int ticket = NIO_ACCESS.acquireSession(input); try { long inputAddress = NIO_ACCESS.getBufferAddress(input); result = deflateBufferBytes(zsRef.address(), @@ -582,7 +582,7 @@ public int deflate(byte[] output, int off, int len, int flush) { output, off, len, flush, params); } finally { - NIO_ACCESS.releaseSession(input); + NIO_ACCESS.releaseSession(input, ticket); } } else { byte[] inputArray = ZipUtils.getBufferArray(input); @@ -698,7 +698,7 @@ public int deflate(ByteBuffer output, int flush) { if (input == null) { inputPos = this.inputPos; if (output.isDirect()) { - NIO_ACCESS.acquireSession(output); + int ticket = NIO_ACCESS.acquireSession(output); try { long outputAddress = NIO_ACCESS.getBufferAddress(output); result = deflateBytesBuffer(zsRef.address(), @@ -706,7 +706,7 @@ public int deflate(ByteBuffer output, int flush) { outputAddress + outputPos, outputRem, flush, params); } finally { - NIO_ACCESS.releaseSession(output); + NIO_ACCESS.releaseSession(output, ticket); } } else { byte[] outputArray = ZipUtils.getBufferArray(output); @@ -720,11 +720,11 @@ public int deflate(ByteBuffer output, int flush) { inputPos = input.position(); int inputRem = Math.max(input.limit() - inputPos, 0); if (input.isDirect()) { - NIO_ACCESS.acquireSession(input); + int ticket = NIO_ACCESS.acquireSession(input); try { long inputAddress = NIO_ACCESS.getBufferAddress(input); if (output.isDirect()) { - NIO_ACCESS.acquireSession(output); + int ticket2 = NIO_ACCESS.acquireSession(output); try { long outputAddress = outputPos + NIO_ACCESS.getBufferAddress(output); result = deflateBufferBuffer(zsRef.address(), @@ -732,7 +732,7 @@ public int deflate(ByteBuffer output, int flush) { outputAddress, outputRem, flush, params); } finally { - NIO_ACCESS.releaseSession(output); + NIO_ACCESS.releaseSession(output, ticket2); } } else { byte[] outputArray = ZipUtils.getBufferArray(output); @@ -743,13 +743,13 @@ public int deflate(ByteBuffer output, int flush) { flush, params); } } finally { - NIO_ACCESS.releaseSession(input); + NIO_ACCESS.releaseSession(input, ticket); } } else { byte[] inputArray = ZipUtils.getBufferArray(input); int inputOffset = ZipUtils.getBufferOffset(input); if (output.isDirect()) { - NIO_ACCESS.acquireSession(output); + int ticket = NIO_ACCESS.acquireSession(output); try { long outputAddress = NIO_ACCESS.getBufferAddress(output); result = deflateBytesBuffer(zsRef.address(), @@ -757,7 +757,7 @@ public int deflate(ByteBuffer output, int flush) { outputAddress + outputPos, outputRem, flush, params); } finally { - NIO_ACCESS.releaseSession(output); + NIO_ACCESS.releaseSession(output, ticket); } } else { byte[] outputArray = ZipUtils.getBufferArray(output); diff --git a/src/java.base/share/classes/java/util/zip/Inflater.java b/src/java.base/share/classes/java/util/zip/Inflater.java index 5ced5a4879a82..77aa07e9805a4 100644 --- a/src/java.base/share/classes/java/util/zip/Inflater.java +++ b/src/java.base/share/classes/java/util/zip/Inflater.java @@ -240,12 +240,12 @@ public void setDictionary(ByteBuffer dictionary) { int remaining = Math.max(dictionary.limit() - position, 0); ensureOpen(); if (dictionary.isDirect()) { - NIO_ACCESS.acquireSession(dictionary); + int ticket = NIO_ACCESS.acquireSession(dictionary); try { long address = NIO_ACCESS.getBufferAddress(dictionary); setDictionaryBuffer(zsRef.address(), address + position, remaining); } finally { - NIO_ACCESS.releaseSession(dictionary); + NIO_ACCESS.releaseSession(dictionary, ticket); } } else { byte[] array = ZipUtils.getBufferArray(dictionary); @@ -363,14 +363,14 @@ public int inflate(byte[] output, int off, int len) try { int inputRem = Math.max(input.limit() - inputPos, 0); if (input.isDirect()) { - NIO_ACCESS.acquireSession(input); + int ticket = NIO_ACCESS.acquireSession(input); try { long inputAddress = NIO_ACCESS.getBufferAddress(input); result = inflateBufferBytes(zsRef.address(), inputAddress + inputPos, inputRem, output, off, len); } finally { - NIO_ACCESS.releaseSession(input); + NIO_ACCESS.releaseSession(input, ticket); } } else { byte[] inputArray = ZipUtils.getBufferArray(input); @@ -500,14 +500,14 @@ public int inflate(ByteBuffer output) throws DataFormatException { inputPos = this.inputPos; try { if (output.isDirect()) { - NIO_ACCESS.acquireSession(output); + int ticket = NIO_ACCESS.acquireSession(output); try { long outputAddress = NIO_ACCESS.getBufferAddress(output); result = inflateBytesBuffer(zsRef.address(), inputArray, inputPos, inputLim - inputPos, outputAddress + outputPos, outputRem); } finally { - NIO_ACCESS.releaseSession(output); + NIO_ACCESS.releaseSession(output, ticket); } } else { byte[] outputArray = ZipUtils.getBufferArray(output); @@ -525,18 +525,18 @@ public int inflate(ByteBuffer output) throws DataFormatException { int inputRem = Math.max(input.limit() - inputPos, 0); try { if (input.isDirect()) { - NIO_ACCESS.acquireSession(input); + int ticket = NIO_ACCESS.acquireSession(input); try { long inputAddress = NIO_ACCESS.getBufferAddress(input); if (output.isDirect()) { - NIO_ACCESS.acquireSession(output); + int ticket2 = NIO_ACCESS.acquireSession(output); try { long outputAddress = NIO_ACCESS.getBufferAddress(output); result = inflateBufferBuffer(zsRef.address(), inputAddress + inputPos, inputRem, outputAddress + outputPos, outputRem); } finally { - NIO_ACCESS.releaseSession(output); + NIO_ACCESS.releaseSession(output, ticket2); } } else { byte[] outputArray = ZipUtils.getBufferArray(output); @@ -546,20 +546,20 @@ public int inflate(ByteBuffer output) throws DataFormatException { outputArray, outputOffset + outputPos, outputRem); } } finally { - NIO_ACCESS.releaseSession(input); + NIO_ACCESS.releaseSession(input, ticket); } } else { byte[] inputArray = ZipUtils.getBufferArray(input); int inputOffset = ZipUtils.getBufferOffset(input); if (output.isDirect()) { - NIO_ACCESS.acquireSession(output); + int ticket = NIO_ACCESS.acquireSession(output); try { long outputAddress = NIO_ACCESS.getBufferAddress(output); result = inflateBytesBuffer(zsRef.address(), inputArray, inputOffset + inputPos, inputRem, outputAddress + outputPos, outputRem); } finally { - NIO_ACCESS.releaseSession(output); + NIO_ACCESS.releaseSession(output, ticket); } } else { byte[] outputArray = ZipUtils.getBufferArray(output); diff --git a/src/java.base/share/classes/jdk/internal/access/JavaNioAccess.java b/src/java.base/share/classes/jdk/internal/access/JavaNioAccess.java index 431ffbdab5378..c4180ab5712e2 100644 --- a/src/java.base/share/classes/jdk/internal/access/JavaNioAccess.java +++ b/src/java.base/share/classes/jdk/internal/access/JavaNioAccess.java @@ -96,19 +96,19 @@ public interface JavaNioAccess { * Used by operations to make a buffer's session non-closeable * (for the duration of the operation) by acquiring the session. * {@snippet lang = java: - * acquireSession(buffer); + * int ticket = acquireSession(buffer); * try { * performOperation(buffer); * } finally { - * releaseSession(buffer); + * releaseSession(buffer, ticket); * } *} * - * @see #releaseSession(Buffer) + * @see #releaseSession(Buffer, int) */ - void acquireSession(Buffer buffer); + int acquireSession(Buffer buffer); - void releaseSession(Buffer buffer); + void releaseSession(Buffer buffer, int ticket); boolean isThreadConfined(Buffer buffer); diff --git a/src/java.base/share/classes/jdk/internal/foreign/ConfinedSession.java b/src/java.base/share/classes/jdk/internal/foreign/ConfinedSession.java index 47dfc69e887dc..f83cfd39ffb80 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/ConfinedSession.java +++ b/src/java.base/share/classes/jdk/internal/foreign/ConfinedSession.java @@ -49,17 +49,18 @@ public ConfinedSession(Thread owner) { @Override @ForceInline - public void acquire0() { + public int acquire0() { checkValidState(); if (acquireCount == MAX_FORKS) { throw tooManyAcquires(); } acquireCount++; + return 0; } @Override @ForceInline - public void release0() { + public void release0(int ticket) { if (Thread.currentThread() == owner) { acquireCount--; } else { diff --git a/src/java.base/share/classes/jdk/internal/foreign/GlobalSession.java b/src/java.base/share/classes/jdk/internal/foreign/GlobalSession.java index 8d834b23eb2e7..b355e1cc81950 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/GlobalSession.java +++ b/src/java.base/share/classes/jdk/internal/foreign/GlobalSession.java @@ -43,14 +43,15 @@ public GlobalSession() { @Override @ForceInline - public void release0() { + public void release0(int ticket) { // do nothing } @Override @ForceInline - public void acquire0() { + public int acquire0() { // do nothing + return 0; } @Override diff --git a/src/java.base/share/classes/jdk/internal/foreign/ImplicitSession.java b/src/java.base/share/classes/jdk/internal/foreign/ImplicitSession.java index 93a899329b7af..f47aa3314d39c 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/ImplicitSession.java +++ b/src/java.base/share/classes/jdk/internal/foreign/ImplicitSession.java @@ -38,22 +38,23 @@ * {@link DirectBuffer#address()}, where obtaining an address of a buffer instance associated * with a potentially closeable session is forbidden. */ -final class ImplicitSession extends SharedSession { +final class ImplicitSession extends MemorySessionImpl { public ImplicitSession(Cleaner cleaner) { - super(); + super(null, new SharedResourceList()); this.state = NONCLOSEABLE; cleaner.register(this, resourceList); } @Override - public void release0() { + public void release0(int ticket) { Reference.reachabilityFence(this); } @Override - public void acquire0() { + public int acquire0() { // do nothing + return 0; } @Override diff --git a/src/java.base/share/classes/jdk/internal/foreign/MemorySessionImpl.java b/src/java.base/share/classes/jdk/internal/foreign/MemorySessionImpl.java index 2163146f1a93e..ecea59d42635c 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/MemorySessionImpl.java +++ b/src/java.base/share/classes/jdk/internal/foreign/MemorySessionImpl.java @@ -55,7 +55,7 @@ */ public abstract sealed class MemorySessionImpl implements Scope - permits ConfinedSession, GlobalSession, SharedSession { + permits ConfinedSession, GlobalSession, SharedSession, ImplicitSession { /** * The value of the {@code state} of a {@code MemorySessionImpl}. The only possible transition @@ -156,17 +156,17 @@ public static MemorySessionImpl createHeap(Object ref) { return new HeapSession(ref); } - public abstract void release0(); + public abstract void release0(int ticket); - public abstract void acquire0(); + public abstract int acquire0(); public void whileAlive(Runnable action) { Objects.requireNonNull(action); - acquire0(); + int ticket = acquire0(); try { action.run(); } finally { - release0(); + release0(ticket); } } @@ -227,7 +227,7 @@ protected Object clone() throws CloneNotSupportedException { throw new CloneNotSupportedException(); } - public final boolean isCloseable() { + public boolean isCloseable() { return state <= OPEN; } diff --git a/src/java.base/share/classes/jdk/internal/foreign/SharedResourceList.java b/src/java.base/share/classes/jdk/internal/foreign/SharedResourceList.java new file mode 100644 index 0000000000000..a0f8043285be2 --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/foreign/SharedResourceList.java @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.foreign; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; + +import jdk.internal.foreign.MemorySessionImpl.ResourceList; +import jdk.internal.foreign.MemorySessionImpl.ResourceList.ResourceCleanup; +import jdk.internal.invoke.MhUtil; + + +/** + * A shared resource list; this implementation has to handle add vs. add races, as well as add vs. cleanup races. + */ +class SharedResourceList extends ResourceList { + + static final VarHandle FST = MhUtil.findVarHandle( + MethodHandles.lookup(), ResourceList.class, "fst", ResourceCleanup.class); + + @Override + void add(ResourceCleanup cleanup) { + while (true) { + ResourceCleanup prev = (ResourceCleanup) FST.getVolatile(this); + if (prev == ResourceCleanup.CLOSED_LIST) { + // too late + throw MemorySessionImpl.alreadyClosed(); + } + cleanup.next = prev; + if (FST.compareAndSet(this, prev, cleanup)) { + return; //victory + } + // keep trying + } + } + + void cleanup() { + // At this point we are only interested about add vs. close races - not close vs. close + // (because MemorySessionImpl::justClose ensured that this thread won the race to close the session). + // So, the only "bad" thing that could happen is that some other thread adds to this list + // while we're closing it. + if (FST.getAcquire(this) != ResourceCleanup.CLOSED_LIST) { + //ok now we're really closing down + ResourceCleanup prev = null; + while (true) { + prev = (ResourceCleanup) FST.getVolatile(this); + // no need to check for DUMMY, since only one thread can get here! + if (FST.compareAndSet(this, prev, ResourceCleanup.CLOSED_LIST)) { + break; + } + } + cleanup(prev); + } else { + throw MemorySessionImpl.alreadyClosed(); + } + } +} diff --git a/src/java.base/share/classes/jdk/internal/foreign/SharedSession.java b/src/java.base/share/classes/jdk/internal/foreign/SharedSession.java index b86bb7daee3e9..5e453bb0caeea 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/SharedSession.java +++ b/src/java.base/share/classes/jdk/internal/foreign/SharedSession.java @@ -1,5 +1,6 @@ /* * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -25,125 +26,229 @@ package jdk.internal.foreign; -import jdk.internal.invoke.MhUtil; import jdk.internal.misc.ScopedMemoryAccess; import jdk.internal.vm.annotation.ForceInline; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.VarHandle; +import java.util.concurrent.atomic.AtomicIntegerArray; /** * A shared session, which can be shared across multiple threads. Closing a shared session has to ensure that * (i) only one thread can successfully close a session (e.g. in a close vs. close race) and that * (ii) no other thread is accessing the memory associated with this session while the segment is being - * closed. To ensure the former condition, a CAS is performed on the liveness bit. Ensuring the latter - * is trickier, and require a complex synchronization protocol (see {@link jdk.internal.misc.ScopedMemoryAccess}). + * closed. To ensure the former condition, the method {@link #justClose() justClose} is synchronized. Ensuring + * the latter is trickier, using a number of counters to track how many threads are accessing the memory and + * requires a complex synchronization protocol (see {@link jdk.internal.misc.ScopedMemoryAccess}). * Since it is the responsibility of the closing thread to make sure that no concurrent access is possible, * checking the liveness bit upon access can be performed in plain mode, as in the confined case. */ -sealed class SharedSession extends MemorySessionImpl permits ImplicitSession { +final class SharedSession extends MemorySessionImpl { private static final ScopedMemoryAccess SCOPED_MEMORY_ACCESS = ScopedMemoryAccess.getScopedMemoryAccess(); - private static final int CLOSED_ACQUIRE_COUNT = -1; + final private AtomicIntegerArray counters; + + final static private int numCounters; + final static private int mask; + + // The number of ints per cacheline. + final static private int multiplier; + final static int CNT_CLOSING = -1; + final static int CNT_CLOSED = -2; + + private static final jdk.internal.misc.Unsafe UNSAFE = jdk.internal.misc.Unsafe.getUnsafe(); + + static { + int cpus = Runtime.getRuntime().availableProcessors(); + + if (cpus < 2) { + // Single CPU case. + cpus = 1; + mask = 0; + } else { + // Round up to next power of 2 CPUs. + // Cap at 1024 to avoid excessive size. + cpus = Integer.min(Integer.highestOneBit(cpus) << 1, 1024); + mask = cpus - 1; + } + numCounters = cpus; + + int cacheLineSize = UNSAFE.dataCacheLineFlushSize(); + + // Each counter is an integer on its own cacheline. + multiplier = ((cacheLineSize < Integer.BYTES) ? 64 : cacheLineSize) / Integer.BYTES; + } + SharedSession() { super(null, new SharedResourceList()); + counters = new AtomicIntegerArray(numCounters * multiplier); + } + + @ForceInline + private int getCounter() { + return Thread.currentThread().hashCode() & mask; + } + + @ForceInline + private int getAcquire(int index) { + assert numCounters > index; + return counters.getAcquire(index * multiplier); + } + + @ForceInline + private int compareAndExchange(int index, int expected, int value) { + assert numCounters > index; + return counters.compareAndExchange(index * multiplier, expected, value); } @Override @ForceInline - public void acquire0() { - int value; + public int acquire0() { + int ticket = getCounter(); + int value = 0; + int old = getAcquire(ticket); do { - value = (int) ACQUIRE_COUNT.getVolatile(this); - if (value < 0) { - //segment is not open! + value = old; + + if (value >= 0) { + if (value == MAX_FORKS) { + throw tooManyAcquires(); + } + old = compareAndExchange(ticket, value, value + 1); + } else if (value == CNT_CLOSED) { + // The following method will wait for the justClose() method to + // set STATE variable to CLOSED, after all counters have been set + // to CNT_CLOSED. throw sharedSessionAlreadyClosed(); - } else if (value == MAX_FORKS) { - //overflow - throw tooManyAcquires(); + } else if (value == CNT_CLOSING) { + // The closing thread will either succeed, changing this counter + // to CNT_CLOSED or fail and backout the counter state to "0". + do { + old = getAcquire(ticket); + Thread.onSpinWait(); + } while (old == CNT_CLOSING); + // On exit value is CNT_CLOSING and old is >=0 or CNT_CLOSED. + assert (old == CNT_CLOSED) || (old >= 0); } - } while (!ACQUIRE_COUNT.compareAndSet(this, value, value + 1)); + } while (old != value); + + return ticket; } @Override @ForceInline - public void release0() { - int value; + public void release0(int ticket) { + assert (ticket >= 0 && ticket < numCounters) : "Invalid ticket."; + + int value = 0; + int old = getAcquire(ticket); do { - value = (int) ACQUIRE_COUNT.getVolatile(this); - if (value <= 0) { - //cannot get here - we can't close segment twice - throw sharedSessionAlreadyClosed(); + value = old; + + if (value > 0) { + old = compareAndExchange(ticket, value, value - 1); + } else { + throw alreadyClosed(); } - } while (!ACQUIRE_COUNT.compareAndSet(this, value, value - 1)); + } while (old != value); } - void justClose() { - int acquireCount = (int) ACQUIRE_COUNT.compareAndExchange(this, 0, CLOSED_ACQUIRE_COUNT); - if (acquireCount < 0) { - throw sharedSessionAlreadyClosed(); - } else if (acquireCount > 0) { - throw alreadyAcquired(acquireCount); + synchronized void justClose() { + int value; + + if (state == CLOSED) { + throw alreadyClosed(); + } + + // Attempt to transition all counters to CNT_CLOSING state. + // Normally each counter should be 0. This method atomically changes them + // to CNT_CLOSING (-1) and if that succeeds, then changes them all to + // CNT_CLOSED (-2) then updates STATE and the SCOPED_MEMORY_ACCESS to + // match. + // Threads calling acquire0 will spin if CNT_CLOSING is acquired, and will + // either fail if this method succeeds, or pass if this method fails to close. + // If this method encounters a counter >0, counters that were set to + // CNT_CLOSING are set to 0 and this method fails. + for (int i = 0; i < numCounters; i++) { + value = compareAndExchange(i, 0, CNT_CLOSING); + + assert value != CNT_CLOSING; + + if (value == CNT_CLOSED) { + // It is already closed - throw an exception. + throw alreadyClosed(); + } + + if (value != 0) { + // Total the counters we haven't set to CNT_CLOSING. + // This might be inaccurate, but won't be zero. + int total = value; + for (int j = i + 1; j < numCounters; j++) { + int counter = counters.get(j * multiplier); + assert counter >= 0; + + total += counter; + } + + // Swapping from 0 to CNT_CLOSING failed, set back to 0. + // We can't set the current one, that's the one that failed. + for (int j = 0; j < i; j++) { + assert counters.getAcquire(j * multiplier) == CNT_CLOSING; + counters.setRelease(j * multiplier, 0); + } + + throw alreadyAcquired(total); + } + } + // Success, any threads acquiring will spin on CNT_CLOSING now, for this counter. + + // This causes threads that were spinning on CNT_CLOSING to throw alreadyClosed(). + for (int i = 0; i < numCounters; i++) { + assert counters.getAcquire(i * multiplier) == CNT_CLOSING; + + counters.setRelease(i * multiplier, CNT_CLOSED); } + // Set MemorySessionImpl.state to match the counters closed status. STATE.setVolatile(this, CLOSED); SCOPED_MEMORY_ACCESS.closeScope(this, ALREADY_CLOSED); } - private IllegalStateException sharedSessionAlreadyClosed() { - // To avoid the situation where a scope fails to be acquired or closed but still reports as - // alive afterward, we wait for the state to change before throwing the exception - while ((int) STATE.getVolatile(this) == OPEN) { - Thread.onSpinWait(); + @Override + public boolean isCloseable() { + if (state == CLOSED) { + return true; } - return alreadyClosed(); - } - /** - * A shared resource list; this implementation has to handle add vs. add races, as well as add vs. cleanup races. - */ - static class SharedResourceList extends ResourceList { - - static final VarHandle FST = MhUtil.findVarHandle( - MethodHandles.lookup(), ResourceList.class, "fst", ResourceCleanup.class); - - @Override - void add(ResourceCleanup cleanup) { - while (true) { - ResourceCleanup prev = (ResourceCleanup) FST.getVolatile(this); - if (prev == ResourceCleanup.CLOSED_LIST) { - // too late - throw alreadyClosed(); - } - cleanup.next = prev; - if (FST.compareAndSet(this, prev, cleanup)) { - return; //victory + for (int i = 0; i < numCounters; i++) { + int value = getAcquire(i); + + if (value == CNT_CLOSING) { + while ((value = getAcquire(i)) == CNT_CLOSING) { + Thread.onSpinWait(); } - // keep trying + + // Restart from first counter. + i = -1; + continue; } - } - void cleanup() { - // At this point we are only interested about add vs. close races - not close vs. close - // (because MemorySessionImpl::justClose ensured that this thread won the race to close the session). - // So, the only "bad" thing that could happen is that some other thread adds to this list - // while we're closing it. - if (FST.getAcquire(this) != ResourceCleanup.CLOSED_LIST) { - //ok now we're really closing down - ResourceCleanup prev = null; - while (true) { - prev = (ResourceCleanup) FST.getVolatile(this); - // no need to check for DUMMY, since only one thread can get here! - if (FST.compareAndSet(this, prev, ResourceCleanup.CLOSED_LIST)) { - break; - } - } - cleanup(prev); - } else { - throw alreadyClosed(); + if (value == CNT_CLOSED) { + return false; + } else if (value > 0) { + return false; } } + + return true; + } + + private IllegalStateException sharedSessionAlreadyClosed() { + // To avoid the situation where a scope fails to be acquired or closed but still reports as + // alive afterward, we wait for the state to change before throwing the exception + while ((int) STATE.getVolatile(this) == OPEN) { + Thread.onSpinWait(); + } + return alreadyClosed(); } } diff --git a/src/java.base/share/classes/jdk/internal/foreign/abi/BindingSpecializer.java b/src/java.base/share/classes/jdk/internal/foreign/abi/BindingSpecializer.java index 20ccec61fd2ed..f2c4684c7b184 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/abi/BindingSpecializer.java +++ b/src/java.base/share/classes/jdk/internal/foreign/abi/BindingSpecializer.java @@ -74,6 +74,7 @@ import java.util.List; import static java.lang.classfile.ClassFile.*; +import static java.lang.classfile.TypeKind.INT; import static java.lang.classfile.TypeKind.LONG; import static java.lang.classfile.TypeKind.REFERENCE; import static java.lang.constant.ConstantDescs.*; @@ -123,8 +124,8 @@ public class BindingSpecializer { private static final MethodTypeDesc MTD_LONG_TO_ADDRESS_SCOPE = MethodTypeDesc.of(CD_MemorySegment, CD_long, CD_long, CD_long, CD_MemorySessionImpl); private static final MethodTypeDesc MTD_ALLOCATE = MethodTypeDesc.of(CD_MemorySegment, CD_long, CD_long); private static final MethodTypeDesc MTD_HANDLE_UNCAUGHT_EXCEPTION = MethodTypeDesc.of(CD_void, CD_Throwable); - private static final MethodTypeDesc MTD_RELEASE0 = MTD_void; - private static final MethodTypeDesc MTD_ACQUIRE0 = MTD_void; + private static final MethodTypeDesc MTD_RELEASE0 = MethodTypeDesc.of(CD_void, CD_int); + private static final MethodTypeDesc MTD_ACQUIRE0 = MethodTypeDesc.of(CD_int); private static final MethodTypeDesc MTD_INTEGER_TO_UNSIGNED_LONG = MethodTypeDesc.of(CD_long, CD_int); private static final MethodTypeDesc MTD_SHORT_TO_UNSIGNED_LONG = MethodTypeDesc.of(CD_long, CD_short); private static final MethodTypeDesc MTD_BYTE_TO_UNSIGNED_LONG = MethodTypeDesc.of(CD_long, CD_byte); @@ -145,6 +146,7 @@ public class BindingSpecializer { private int[] leafArgSlots; private int[] scopeSlots; + private int[] sessionTickets; private int curScopeLocalIdx = -1; private int returnAllocatorIdx = -1; private int contextIdx = -1; @@ -290,6 +292,15 @@ private void specialize() { } } scopeSlots = Arrays.copyOf(initialScopeSlots, numScopes); // fit to size + + sessionTickets = new int[numScopes]; + for (int i = 0; i < numScopes; i++) { + int ticketLocal = cb.allocateLocal(INT); + sessionTickets[i] = ticketLocal; + cb.loadConstant(0) + .istore(ticketLocal); + } + curScopeLocalIdx = 0; // used from emitGetInput } @@ -530,9 +541,12 @@ private void emitAcquireScope() { // 1 scope to acquire on the stack cb.dup(); - int nextScopeLocal = scopeSlots[curScopeLocalIdx++]; + int nextScopeLocal = scopeSlots[curScopeLocalIdx]; + int nextSessionTicketLocal = sessionTickets[curScopeLocalIdx]; + curScopeLocalIdx++; // call acquire first here. So that if it fails, we don't call release cb.invokevirtual(CD_MemorySessionImpl, "acquire0", MTD_ACQUIRE0) // call acquire on the other + .istore(nextSessionTicketLocal) .astore(nextScopeLocal); // store off one to release later if (hasLookup) { // avoid ASM generating a bunch of nops for the dead code @@ -545,11 +559,14 @@ private void emitAcquireScope() { } private void emitReleaseScopes() { - for (int scopeLocal : scopeSlots) { + for (int i = 0; i < scopeSlots.length; i++) { + int scopeLocal = scopeSlots[i]; + int ticketLocal = sessionTickets[i]; cb.aload(scopeLocal) - .ifThen(Opcode.IFNONNULL, ifCb -> { + .ifThen(Opcode.IFNONNULL, ifCb -> { ifCb.aload(scopeLocal) - .invokevirtual(CD_MemorySessionImpl, "release0", MTD_RELEASE0); + .iload(ticketLocal) + .invokevirtual(CD_MemorySessionImpl, "release0", MTD_RELEASE0); }); } } diff --git a/src/java.base/share/classes/jdk/internal/foreign/abi/DowncallLinker.java b/src/java.base/share/classes/jdk/internal/foreign/abi/DowncallLinker.java index acdbef5822d59..ded94751e5cd6 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/abi/DowncallLinker.java +++ b/src/java.base/share/classes/jdk/internal/foreign/abi/DowncallLinker.java @@ -40,6 +40,7 @@ import java.lang.invoke.MethodType; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.stream.Stream; @@ -145,7 +146,9 @@ Object invokeInterpBindings(SegmentAllocator allocator, Object[] args, Invocatio Arena unboxArena = callingSequence.allocationSize() != 0 ? SharedUtils.newBoundedArena(callingSequence.allocationSize()) : SharedUtils.DUMMY_ARENA; - List acquiredScopes = new ArrayList<>(); + record ScopeAndTicket(MemorySessionImpl session, int ticket) {} + List acquiredScopes = new ArrayList<>(); + try (unboxArena) { MemorySegment returnBuffer = null; @@ -172,10 +175,10 @@ public void store(VMStorage storage, Object o) { if (callingSequence.functionDesc().argumentLayouts().get(i) instanceof AddressLayout) { MemorySessionImpl sessionImpl = ((AbstractMemorySegmentImpl) arg).sessionImpl(); if (!(callingSequence.needsReturnBuffer() && i == 0)) { // don't acquire unboxArena's scope - sessionImpl.acquire0(); + int ticket = sessionImpl.acquire0(); // add this scope _after_ we acquire, so we only release scopes we actually acquired // in case an exception occurs - acquiredScopes.add(sessionImpl); + acquiredScopes.add(new ScopeAndTicket(sessionImpl, ticket)); } } BindingInterpreter.unbox(arg, callingSequence.argumentBindings(i), storeFunc, unboxArena); @@ -205,8 +208,10 @@ public Object load(VMStorage storage, Class type) { allocator); } } finally { - for (MemorySessionImpl sessionImpl : acquiredScopes) { - sessionImpl.release0(); + for (ScopeAndTicket scope : acquiredScopes) { + int ticket = scope.ticket; + MemorySessionImpl sessionImpl = scope.session; + sessionImpl.release0(ticket); } } } diff --git a/src/java.base/share/classes/jdk/internal/foreign/abi/fallback/FallbackLinker.java b/src/java.base/share/classes/jdk/internal/foreign/abi/fallback/FallbackLinker.java index 7343a23436d95..cd72070d02ac2 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/abi/fallback/FallbackLinker.java +++ b/src/java.base/share/classes/jdk/internal/foreign/abi/fallback/FallbackLinker.java @@ -46,6 +46,7 @@ import java.lang.invoke.MethodType; import java.lang.ref.Reference; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Map; @@ -145,15 +146,15 @@ private record DowncallData(MemorySegment cif, MemoryLayout returnLayout, List acquiredSessions = new ArrayList<>(); + record ScopeAndTicket(MemorySessionImpl session, int ticket) {} + List acquiredSessions = new ArrayList<>(); try (Arena arena = Arena.ofConfined()) { int argStart = 0; Object[] heapBases = invData.allowsHeapAccess() ? new Object[args.length] : null; MemorySegment target = (MemorySegment) args[argStart++]; MemorySessionImpl targetImpl = ((AbstractMemorySegmentImpl) target).sessionImpl(); - targetImpl.acquire0(); - acquiredSessions.add(targetImpl); + acquiredSessions.add(new ScopeAndTicket(targetImpl, targetImpl.acquire0())); MemorySegment capturedState = null; Object captureStateHeapBase = null; @@ -165,8 +166,7 @@ private static Object doDowncall(SegmentAllocator returnAllocator, Object[] args captureStateHeapBase = capturedState.heapBase().orElse(null); } MemorySessionImpl capturedStateImpl = ((AbstractMemorySegmentImpl) capturedState).sessionImpl(); - capturedStateImpl.acquire0(); - acquiredSessions.add(capturedStateImpl); + acquiredSessions.add(new ScopeAndTicket(capturedStateImpl, capturedStateImpl.acquire0())); } List argLayouts = invData.argLayouts(); @@ -178,8 +178,7 @@ private static Object doDowncall(SegmentAllocator returnAllocator, Object[] args if (layout instanceof AddressLayout) { AbstractMemorySegmentImpl ms = (AbstractMemorySegmentImpl) arg; MemorySessionImpl sessionImpl = ms.sessionImpl(); - sessionImpl.acquire0(); - acquiredSessions.add(sessionImpl); + acquiredSessions.add(new ScopeAndTicket(sessionImpl, sessionImpl.acquire0())); if (invData.allowsHeapAccess() && !ms.isNative()) { heapBases[i] = ms.unsafeGetBase(); // write the offset to the arg segment, add array ptr to it in native code @@ -206,8 +205,8 @@ private static Object doDowncall(SegmentAllocator returnAllocator, Object[] args return readValue(retSeg, invData.returnLayout()); } finally { - for (MemorySessionImpl session : acquiredSessions) { - session.release0(); + for (ScopeAndTicket scopeAndTicket : acquiredSessions) { + scopeAndTicket.session.release0(scopeAndTicket.ticket); } } } diff --git a/src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java b/src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java index afb312ed722c4..e78f96647b56d 100644 --- a/src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java +++ b/src/java.base/share/classes/sun/nio/ch/DatagramChannelImpl.java @@ -724,7 +724,7 @@ private int receiveIntoNativeBuffer(ByteBuffer bb, int rem, int pos, boolean connected) throws IOException { - NIO_ACCESS.acquireSession(bb); + int ticket = NIO_ACCESS.acquireSession(bb); try { long bufAddress = NIO_ACCESS.getBufferAddress(bb); int n = receive0(fd, @@ -736,7 +736,7 @@ private int receiveIntoNativeBuffer(ByteBuffer bb, int rem, int pos, bb.position(pos + n); return n; } finally { - NIO_ACCESS.releaseSession(bb); + NIO_ACCESS.releaseSession(bb, ticket); } } @@ -907,7 +907,7 @@ private int sendFromNativeBuffer(FileDescriptor fd, ByteBuffer bb, int rem = (pos <= lim ? lim - pos : 0); int written; - NIO_ACCESS.acquireSession(bb); + int ticket = NIO_ACCESS.acquireSession(bb); try { long bufAddress = NIO_ACCESS.getBufferAddress(bb); int addressLen = targetSocketAddress(target); @@ -921,7 +921,7 @@ private int sendFromNativeBuffer(FileDescriptor fd, ByteBuffer bb, throw pue; written = rem; } finally { - NIO_ACCESS.releaseSession(bb); + NIO_ACCESS.releaseSession(bb, ticket); } if (written > 0) bb.position(pos + written); diff --git a/src/java.base/share/classes/sun/nio/ch/IOUtil.java b/src/java.base/share/classes/sun/nio/ch/IOUtil.java index 45f8cb2e588c3..8c565ca93618f 100644 --- a/src/java.base/share/classes/sun/nio/ch/IOUtil.java +++ b/src/java.base/share/classes/sun/nio/ch/IOUtil.java @@ -129,7 +129,7 @@ private static int writeFromNativeBuffer(FileDescriptor fd, ByteBuffer bb, int written = 0; if (rem == 0) return 0; - acquireScope(bb, async); + int ticket = acquireScope(bb, async); try { if (position != -1) { written = nd.pwrite(fd, bufferAddress(bb) + pos, rem, position); @@ -137,7 +137,7 @@ private static int writeFromNativeBuffer(FileDescriptor fd, ByteBuffer bb, written = nd.write(fd, bufferAddress(bb) + pos, rem); } } finally { - releaseScope(bb); + releaseScope(bb, ticket); } if (written > 0) bb.position(pos + written); @@ -182,9 +182,9 @@ static long write(FileDescriptor fd, ByteBuffer[] bufs, int offset, int length, int i = offset; while (i < count && iov_len < IOV_MAX && writevLen < WRITEV_MAX) { ByteBuffer buf = bufs[i]; - acquireScope(buf, async); + int ticket = acquireScope(buf, async); if (NIO_ACCESS.hasSession(buf)) { - handleReleasers = LinkedRunnable.of(Releaser.of(buf), handleReleasers); + handleReleasers = LinkedRunnable.of(Releaser.of(buf, ticket), handleReleasers); } int pos = buf.position(); int lim = buf.limit(); @@ -333,7 +333,7 @@ private static int readIntoNativeBuffer(FileDescriptor fd, ByteBuffer bb, if (rem == 0) return 0; int n = 0; - acquireScope(bb, async); + int ticket = acquireScope(bb, async); try { if (position != -1) { n = nd.pread(fd, bufferAddress(bb) + pos, rem, position); @@ -341,7 +341,7 @@ private static int readIntoNativeBuffer(FileDescriptor fd, ByteBuffer bb, n = nd.read(fd, bufferAddress(bb) + pos, rem); } } finally { - releaseScope(bb); + releaseScope(bb, ticket); } if (n > 0) bb.position(pos + n); @@ -395,9 +395,9 @@ static long read(FileDescriptor fd, ByteBuffer[] bufs, int offset, int length, ByteBuffer buf = bufs[i]; if (buf.isReadOnly()) throw new IllegalArgumentException("Read-only buffer"); - acquireScope(buf, async); + int ticket = acquireScope(buf, async); if (NIO_ACCESS.hasSession(buf)) { - handleReleasers = LinkedRunnable.of(Releaser.of(buf), handleReleasers); + handleReleasers = LinkedRunnable.of(Releaser.of(buf, ticket), handleReleasers); } int pos = buf.position(); int lim = buf.limit(); @@ -477,16 +477,16 @@ static long read(FileDescriptor fd, ByteBuffer[] bufs, int offset, int length, private static final JavaNioAccess NIO_ACCESS = SharedSecrets.getJavaNioAccess(); - static void acquireScope(ByteBuffer bb, boolean async) { + static int acquireScope(ByteBuffer bb, boolean async) { if (async && NIO_ACCESS.isThreadConfined(bb)) { throw new IllegalArgumentException("Buffer is thread confined"); } - NIO_ACCESS.acquireSession(bb); + return NIO_ACCESS.acquireSession(bb); } - static void releaseScope(ByteBuffer bb) { + static void releaseScope(ByteBuffer bb, int ticket) { try { - NIO_ACCESS.releaseSession(bb); + NIO_ACCESS.releaseSession(bb, ticket); } catch (Exception e) { throw new IllegalStateException(e); } @@ -499,14 +499,14 @@ static Runnable acquireScopes(ByteBuffer[] buffers) { static Runnable acquireScopes(ByteBuffer buf, ByteBuffer[] buffers) { if (buffers == null) { assert buf != null; - IOUtil.acquireScope(buf, true); - return IOUtil.Releaser.of(buf); + int ticket = IOUtil.acquireScope(buf, true); + return IOUtil.Releaser.of(buf, ticket); } else { assert buf == null; Runnable handleReleasers = null; for (var b : buffers) { - IOUtil.acquireScope(b, true); - handleReleasers = IOUtil.LinkedRunnable.of(IOUtil.Releaser.of(b), handleReleasers); + int ticket = IOUtil.acquireScope(b, true); + handleReleasers = IOUtil.LinkedRunnable.of(IOUtil.Releaser.of(b, ticket), handleReleasers); } return handleReleasers; } @@ -537,22 +537,21 @@ static LinkedRunnable of(Runnable first, Runnable second) { } } - record Releaser(ByteBuffer bb) implements Runnable { + record Releaser(ByteBuffer bb, int ticket) implements Runnable { Releaser { Objects.requireNonNull(bb); } @Override public void run() { - releaseScope(bb); + releaseScope(bb, ticket); } - static Runnable of(ByteBuffer bb) { + static Runnable of(ByteBuffer bb, int ticket) { return NIO_ACCESS.hasSession(bb) - ? new Releaser(bb) + ? new Releaser(bb, ticket) : () -> {}; } - } static long bufferAddress(ByteBuffer buf) { diff --git a/src/java.base/share/classes/sun/nio/ch/SimpleAsynchronousFileChannelImpl.java b/src/java.base/share/classes/sun/nio/ch/SimpleAsynchronousFileChannelImpl.java index 13de40e9a2477..a69e09f18cf92 100644 --- a/src/java.base/share/classes/sun/nio/ch/SimpleAsynchronousFileChannelImpl.java +++ b/src/java.base/share/classes/sun/nio/ch/SimpleAsynchronousFileChannelImpl.java @@ -327,7 +327,7 @@ Future implRead(final ByteBuffer dst, return null; } - IOUtil.acquireScope(dst, true); + int ticket = IOUtil.acquireScope(dst, true); final PendingFuture result = (handler == null) ? new PendingFuture(this) : null; @@ -351,7 +351,7 @@ public void run() { } finally { end(); threads.remove(ti); - IOUtil.releaseScope(dst); + IOUtil.releaseScope(dst, ticket); } if (handler == null) { result.setResult(n, exc); @@ -384,7 +384,7 @@ Future implWrite(final ByteBuffer src, return null; } - IOUtil.acquireScope(src, true); + int ticket = IOUtil.acquireScope(src, true); final PendingFuture result = (handler == null) ? new PendingFuture(this) : null; @@ -408,7 +408,7 @@ public void run() { } finally { end(); threads.remove(ti); - IOUtil.releaseScope(src); + IOUtil.releaseScope(src, ticket); } if (handler == null) { result.setResult(n, exc); diff --git a/src/java.base/unix/classes/sun/nio/fs/UnixUserDefinedFileAttributeView.java b/src/java.base/unix/classes/sun/nio/fs/UnixUserDefinedFileAttributeView.java index 1214ba855684e..fd219afaaa59d 100644 --- a/src/java.base/unix/classes/sun/nio/fs/UnixUserDefinedFileAttributeView.java +++ b/src/java.base/unix/classes/sun/nio/fs/UnixUserDefinedFileAttributeView.java @@ -167,14 +167,14 @@ public int read(String name, ByteBuffer dst) throws IOException { int rem = (pos <= lim ? lim - pos : 0); if (dst.isDirect()) { - NIO_ACCESS.acquireSession(dst); + int ticket = NIO_ACCESS.acquireSession(dst); try { long address = NIO_ACCESS.getBufferAddress(dst) + pos; int n = read(name, address, rem); dst.position(pos + n); return n; } finally { - NIO_ACCESS.releaseSession(dst); + NIO_ACCESS.releaseSession(dst, ticket); } } else { try (NativeBuffer nb = NativeBuffers.getNativeBuffer(rem)) { @@ -226,14 +226,14 @@ public int write(String name, ByteBuffer src) throws IOException { int rem = (pos <= lim ? lim - pos : 0); if (src.isDirect()) { - NIO_ACCESS.acquireSession(src); + int ticket = NIO_ACCESS.acquireSession(src); try { long address = NIO_ACCESS.getBufferAddress(src) + pos; write(name, address, rem); src.position(pos + rem); return rem; } finally { - NIO_ACCESS.releaseSession(src); + NIO_ACCESS.releaseSession(src, ticket); } } else { try (NativeBuffer nb = NativeBuffers.getNativeBuffer(rem)) { diff --git a/src/java.base/windows/classes/sun/nio/ch/WindowsAsynchronousFileChannelImpl.java b/src/java.base/windows/classes/sun/nio/ch/WindowsAsynchronousFileChannelImpl.java index 5df8aef743126..8befb9c35e205 100644 --- a/src/java.base/windows/classes/sun/nio/ch/WindowsAsynchronousFileChannelImpl.java +++ b/src/java.base/windows/classes/sun/nio/ch/WindowsAsynchronousFileChannelImpl.java @@ -393,6 +393,7 @@ private class ReadTask implements Runnable, Iocp.ResultHandler { private final long position; // file position private final PendingFuture result; private volatile boolean released; + private int ticket; // to release buffer scope // set to dst if direct; otherwise set to substituted direct buffer private volatile ByteBuffer buf; @@ -412,7 +413,7 @@ private class ReadTask implements Runnable, Iocp.ResultHandler { void releaseScopeOrCacheSubstitute() { if (buf == dst) { - IOUtil.releaseScope(dst); + IOUtil.releaseScope(dst, ticket); } else if (RELEASED.compareAndSet(this, false, true)) { Util.releaseTemporaryDirectBuffer(buf); } @@ -448,7 +449,7 @@ public void run() { // Substitute a native buffer if not direct if (dst.isDirect()) { buf = dst; - IOUtil.acquireScope(dst, true); + ticket = IOUtil.acquireScope(dst, true); address = IOUtil.bufferAddress(dst) + pos; } else { buf = Util.getTemporaryDirectBuffer(rem); @@ -586,6 +587,7 @@ private class WriteTask implements Runnable, Iocp.ResultHandler { private final long position; // file position private final PendingFuture result; private volatile boolean released; + private int ticket; // to release buffer scope // set to src if direct; otherwise set to substituted direct buffer private volatile ByteBuffer buf; @@ -605,7 +607,7 @@ private class WriteTask implements Runnable, Iocp.ResultHandler { void releaseScopeOrCacheSubstitute() { if (buf == src) { - IOUtil.releaseScope(src); + IOUtil.releaseScope(src, ticket); } else if (RELEASED.compareAndSet(this, false, true)) { Util.releaseTemporaryDirectBuffer(buf); } @@ -631,7 +633,7 @@ public void run() { // Substitute a native buffer if not direct if (src.isDirect()) { buf = src; - IOUtil.acquireScope(src, true); + ticket = IOUtil.acquireScope(src, true); address = IOUtil.bufferAddress(src) + pos; } else { buf = Util.getTemporaryDirectBuffer(rem); diff --git a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11AEADCipher.java b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11AEADCipher.java index b41e939e9062f..48f91d7730672 100644 --- a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11AEADCipher.java +++ b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11AEADCipher.java @@ -724,9 +724,9 @@ private int implDoFinal(ByteBuffer inBuffer, ByteBuffer outBuffer) } boolean doCancel = true; - NIO_ACCESS.acquireSession(inBuffer); + int ticket = NIO_ACCESS.acquireSession(inBuffer); try { - NIO_ACCESS.acquireSession(outBuffer); + int ticket2 = NIO_ACCESS.acquireSession(outBuffer); try { try { ensureInitialized(); @@ -808,10 +808,10 @@ private int implDoFinal(ByteBuffer inBuffer, ByteBuffer outBuffer) reset(doCancel); } } finally { - NIO_ACCESS.releaseSession(outBuffer); + NIO_ACCESS.releaseSession(outBuffer, ticket2); } } finally { - NIO_ACCESS.releaseSession(inBuffer); + NIO_ACCESS.releaseSession(inBuffer, ticket); } } diff --git a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Cipher.java b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Cipher.java index 9c923b7621508..c077fc85fc180 100644 --- a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Cipher.java +++ b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Cipher.java @@ -733,9 +733,9 @@ private int implUpdate(ByteBuffer inBuffer, ByteBuffer outBuffer) throw new ShortBufferException(); } int origPos = inBuffer.position(); - NIO_ACCESS.acquireSession(inBuffer); + int ticket = NIO_ACCESS.acquireSession(inBuffer); try { - NIO_ACCESS.acquireSession(outBuffer); + int ticket2 = NIO_ACCESS.acquireSession(outBuffer); try { ensureInitialized(); @@ -896,10 +896,10 @@ private int implUpdate(ByteBuffer inBuffer, ByteBuffer outBuffer) reset(true); throw new ProviderException("update() failed", e); } finally { - NIO_ACCESS.releaseSession(outBuffer); + NIO_ACCESS.releaseSession(outBuffer, ticket2); } } finally { - NIO_ACCESS.releaseSession(inBuffer); + NIO_ACCESS.releaseSession(inBuffer, ticket); } } @@ -1005,7 +1005,7 @@ private int implDoFinal(ByteBuffer outBuffer) } boolean doCancel = true; - NIO_ACCESS.acquireSession(outBuffer); + int ticket = NIO_ACCESS.acquireSession(outBuffer); try { try { ensureInitialized(); @@ -1116,7 +1116,7 @@ private int implDoFinal(ByteBuffer outBuffer) reset(doCancel); } } finally { - NIO_ACCESS.releaseSession(outBuffer); + NIO_ACCESS.releaseSession(outBuffer, ticket); } } diff --git a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Digest.java b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Digest.java index 7f6a62e936389..3ad1a88acca01 100644 --- a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Digest.java +++ b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Digest.java @@ -287,12 +287,12 @@ protected void engineUpdate(ByteBuffer byteBuffer) { token.p11.C_DigestUpdate(session.id(), 0, buffer, 0, bufOfs); bufOfs = 0; } - NIO_ACCESS.acquireSession(byteBuffer); + int ticket = NIO_ACCESS.acquireSession(byteBuffer); try { final long address = NIO_ACCESS.getBufferAddress(byteBuffer); token.p11.C_DigestUpdate(session.id(), address + ofs, null, 0, len); } finally { - NIO_ACCESS.releaseSession(byteBuffer); + NIO_ACCESS.releaseSession(byteBuffer, ticket); } byteBuffer.position(ofs + len); } catch (PKCS11Exception e) { diff --git a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11KeyWrapCipher.java b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11KeyWrapCipher.java index d761a0637ade4..97ba344bf1a6e 100644 --- a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11KeyWrapCipher.java +++ b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11KeyWrapCipher.java @@ -556,9 +556,9 @@ private int implDoFinal(ByteBuffer inBuffer, ByteBuffer outBuffer) boolean doCancel = true; int k = 0; - NIO_ACCESS.acquireSession(inBuffer); + int ticket = NIO_ACCESS.acquireSession(inBuffer); try { - NIO_ACCESS.acquireSession(outBuffer); + int ticket2 = NIO_ACCESS.acquireSession(outBuffer); try { try { ensureInitialized(); @@ -634,10 +634,10 @@ private int implDoFinal(ByteBuffer inBuffer, ByteBuffer outBuffer) reset(doCancel); } } finally { - NIO_ACCESS.releaseSession(outBuffer); + NIO_ACCESS.releaseSession(outBuffer, ticket2); } } finally { - NIO_ACCESS.releaseSession(inBuffer); + NIO_ACCESS.releaseSession(inBuffer, ticket); } return k; } diff --git a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Mac.java b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Mac.java index 0af129f0ba41c..6e46cc63c3eee 100644 --- a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Mac.java +++ b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Mac.java @@ -284,12 +284,12 @@ protected void engineUpdate(ByteBuffer byteBuffer) { return; } int ofs = byteBuffer.position(); - NIO_ACCESS.acquireSession(byteBuffer); + int ticket = NIO_ACCESS.acquireSession(byteBuffer); try { final long address = NIO_ACCESS.getBufferAddress(byteBuffer); token.p11.C_SignUpdate(session.id(), address + ofs, null, 0, len); } finally { - NIO_ACCESS.releaseSession(byteBuffer); + NIO_ACCESS.releaseSession(byteBuffer, ticket); } byteBuffer.position(ofs + len); } catch (PKCS11Exception e) { diff --git a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11PSSSignature.java b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11PSSSignature.java index 52d6ef72d3861..67ec87f2a0bdf 100644 --- a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11PSSSignature.java +++ b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11PSSSignature.java @@ -620,7 +620,7 @@ protected void engineUpdate(ByteBuffer byteBuffer) { return; } int ofs = byteBuffer.position(); - NIO_ACCESS.acquireSession(byteBuffer); + int ticket = NIO_ACCESS.acquireSession(byteBuffer); try { long addr = NIO_ACCESS.getBufferAddress(byteBuffer); if (mode == M_SIGN) { @@ -638,7 +638,7 @@ protected void engineUpdate(ByteBuffer byteBuffer) { reset(false); throw new ProviderException("Update failed", e); } finally { - NIO_ACCESS.releaseSession(byteBuffer); + NIO_ACCESS.releaseSession(byteBuffer, ticket); } } case T_DIGEST -> { diff --git a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Signature.java b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Signature.java index ed7cd525283d3..acda313622d59 100644 --- a/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Signature.java +++ b/src/jdk.crypto.cryptoki/share/classes/sun/security/pkcs11/P11Signature.java @@ -585,7 +585,7 @@ protected void engineUpdate(ByteBuffer byteBuffer) { return; } int ofs = byteBuffer.position(); - NIO_ACCESS.acquireSession(byteBuffer); + int ticket = NIO_ACCESS.acquireSession(byteBuffer); try { long addr = NIO_ACCESS.getBufferAddress(byteBuffer); if (mode == M_SIGN) { @@ -601,7 +601,7 @@ protected void engineUpdate(ByteBuffer byteBuffer) { reset(false); throw new ProviderException("Update failed", e); } finally { - NIO_ACCESS.releaseSession(byteBuffer); + NIO_ACCESS.releaseSession(byteBuffer, ticket); } } case T_DIGEST -> { diff --git a/src/jdk.sctp/unix/classes/sun/nio/ch/sctp/SctpChannelImpl.java b/src/jdk.sctp/unix/classes/sun/nio/ch/sctp/SctpChannelImpl.java index 485ebf13f2c23..8afcae77a8b2e 100644 --- a/src/jdk.sctp/unix/classes/sun/nio/ch/sctp/SctpChannelImpl.java +++ b/src/jdk.sctp/unix/classes/sun/nio/ch/sctp/SctpChannelImpl.java @@ -828,7 +828,7 @@ private int receiveIntoNativeBuffer(int fd, boolean peek) throws IOException { - NIO_ACCESS.acquireSession(bb); + int ticket = NIO_ACCESS.acquireSession(bb); try { int n = receive0(fd, resultContainer, NIO_ACCESS.getBufferAddress(bb) + pos, rem, peek); @@ -836,7 +836,7 @@ private int receiveIntoNativeBuffer(int fd, bb.position(pos + n); return n; } finally { - NIO_ACCESS.releaseSession(bb); + NIO_ACCESS.releaseSession(bb, ticket); } } @@ -1011,7 +1011,7 @@ private int sendFromNativeBuffer(int fd, assert (pos <= lim); int rem = (pos <= lim ? lim - pos : 0); - NIO_ACCESS.acquireSession(bb); + int ticket = NIO_ACCESS.acquireSession(bb); try { int written = send0(fd, NIO_ACCESS.getBufferAddress(bb) + pos, rem, addr, port, -1 /*121*/, streamNumber, unordered, ppid); @@ -1019,7 +1019,7 @@ private int sendFromNativeBuffer(int fd, bb.position(pos + written); return written; } finally { - NIO_ACCESS.releaseSession(bb); + NIO_ACCESS.releaseSession(bb, ticket); } } diff --git a/src/jdk.sctp/unix/classes/sun/nio/ch/sctp/SctpMultiChannelImpl.java b/src/jdk.sctp/unix/classes/sun/nio/ch/sctp/SctpMultiChannelImpl.java index 31e83d72f96df..cd5807cc4f02e 100644 --- a/src/jdk.sctp/unix/classes/sun/nio/ch/sctp/SctpMultiChannelImpl.java +++ b/src/jdk.sctp/unix/classes/sun/nio/ch/sctp/SctpMultiChannelImpl.java @@ -560,14 +560,14 @@ private int receiveIntoNativeBuffer(int fd, int rem, int pos) throws IOException { - NIO_ACCESS.acquireSession(bb); + int ticket = NIO_ACCESS.acquireSession(bb); try { int n = receive0(fd, resultContainer, NIO_ACCESS.getBufferAddress(bb) + pos, rem); if (n > 0) bb.position(pos + n); return n; } finally { - NIO_ACCESS.releaseSession(bb); + NIO_ACCESS.releaseSession(bb, ticket); } } @@ -869,7 +869,7 @@ private int sendFromNativeBuffer(int fd, assert (pos <= lim); int rem = (pos <= lim ? lim - pos : 0); - NIO_ACCESS.acquireSession(bb); + int ticket = NIO_ACCESS.acquireSession(bb); try { int written = send0(fd, NIO_ACCESS.getBufferAddress(bb) + pos, rem, addr, port, assocId, streamNumber, unordered, ppid); @@ -877,7 +877,7 @@ private int sendFromNativeBuffer(int fd, bb.position(pos + written); return written; } finally { - NIO_ACCESS.releaseSession(bb); + NIO_ACCESS.releaseSession(bb, ticket); } } diff --git a/test/jdk/java/foreign/TestMemorySession.java b/test/jdk/java/foreign/TestMemorySession.java index b06e2707c399d..de7dd3008eeed 100644 --- a/test/jdk/java/foreign/TestMemorySession.java +++ b/test/jdk/java/foreign/TestMemorySession.java @@ -1,5 +1,6 @@ /* * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -318,9 +319,8 @@ public void testIsCloseableBy(ArenaSupplier arenaSupplier) { assertEquals(sessionImpl.isCloseableBy(otherThread), isCloseableByOther); } - /** - * Test that a thread failing to acquire a scope will not observe it as alive afterwards. - */ + + // Test that a thread failing to acquire a scope will not observe it as alive afterwards. @Test public void testAcquireCloseRace() throws InterruptedException { int iteration = 1000; @@ -362,8 +362,9 @@ public void testAcquireCloseRace() throws InterruptedException { for (int i = 0; i < iteration;) { MemorySessionImpl scope = scopes[i]; while (true) { + int ticket = 0; try { - scope.acquire0(); + ticket = scope.acquire0(); } catch (IllegalStateException e) { // The scope has been closed, proceed to the next iteration if (scope.isAlive()) { @@ -372,7 +373,7 @@ public void testAcquireCloseRace() throws InterruptedException { break; } // Release and try again - scope.release0(); + scope.release0(ticket); } // Proceed to the next iteration i = lock.getAndAdd(1) + 1; @@ -386,6 +387,159 @@ public void testAcquireCloseRace() throws InterruptedException { assertFalse(result[0]); } + @Test + public void testTickets() { + Arena arena = Arena.ofShared(); + var sessionImpl = ((MemorySessionImpl) arena.scope()); + int ticket = sessionImpl.acquire0(); + + assertFalse(sessionImpl.isCloseable()); + + sessionImpl.release0(ticket); + assertTrue(sessionImpl.isCloseable()); + assertThrows(IllegalStateException.class, () -> sessionImpl.release0(ticket)); + + sessionImpl.close(); + assertThrows(IllegalStateException.class, () -> sessionImpl.acquire0()); + assertTrue(sessionImpl.isCloseable()); + } + + @Test + public void testTicketsCrossThreads() throws InterruptedException { + Arena arena = Arena.ofShared(); + var sessionImpl = ((MemorySessionImpl) arena.scope()); + int[] tickets = new int[N_THREADS]; + Thread[] threads = new Thread[N_THREADS]; + + for (int counter = 0; counter < N_THREADS; counter++) { + threads[counter] = new Thread(new AcquireWork(counter, sessionImpl, tickets)); + threads[counter].start(); + } + + for (int i = 0; i < N_THREADS; i++) { + threads[i].join(); + } + + assertFalse(sessionImpl.isCloseable()); + + try { + for (int i = 0; i < N_THREADS; i++) { + sessionImpl.release0(tickets[i]); + } + } catch (IllegalStateException e) { + fail(); + } + + assertTrue(sessionImpl.isCloseable()); + + try { + sessionImpl.close(); + } catch (IllegalStateException e) { + fail(); + } + } + + static class AcquireWork implements Runnable { + int counter; + MemorySessionImpl sessionImpl; + int tickets[]; + + AcquireWork(int counter, MemorySessionImpl sessionImpl, int[] tickets) { + this.counter = counter; + this.sessionImpl = sessionImpl; + this.tickets = tickets; + } + + public void run() { + tickets[counter] = sessionImpl.acquire0(); + assertFalse(sessionImpl.isCloseable()); + } + } + + // Check that a scope is either closed and an acquire fails, or the + // close fails and the acquire and release are successful. + @Test + public void testAcquireReleaseCloseRace() throws InterruptedException { + boolean acquireLose = false; + boolean closeLose = false; + + while (true) { + boolean closeLost = false; + + Arena arena = Arena.ofShared(); + MemorySessionImpl sessionImpl = ((MemorySessionImpl) arena.scope()); + + int nthreads = 4; + AcquireReleaseLoop[] arls = new AcquireReleaseLoop[nthreads]; + Thread[] threads = new Thread[nthreads]; + + for (int i = 0; i < nthreads; i++) { + arls[i] = new AcquireReleaseLoop(sessionImpl); + threads[i] = new Thread(arls[i]); + threads[i].start(); + } + Thread.sleep(100); + + try { + sessionImpl.close(); + assertFalse(sessionImpl.isAlive()); + } catch (IllegalStateException e) { + assertTrue(sessionImpl.isAlive()); + closeLost = true; + } + + closeLose |= closeLost; + for (int i = 0; i < nthreads; i++) { + arls[i].loop = false; + threads[i].join(); + + // Both cannot fail simultaneously. + if (arls[i].acquireLose) { + assertFalse(closeLost); + } else if (closeLost) { + assertFalse(arls[i].acquireLose); + } + + acquireLose |= arls[i].acquireLose; + } + + if (acquireLose && closeLose) { + break; + } + } + } + + static class AcquireReleaseLoop implements Runnable { + MemorySessionImpl sessionImpl; + boolean acquireLose; + volatile boolean loop = true; + String message = null; + + AcquireReleaseLoop(MemorySessionImpl sessionImpl) { + this.sessionImpl = sessionImpl; + } + + public void run() { + while (loop) { + int ticket = 0; + try { + ticket = sessionImpl.acquire0(); + } catch (IllegalStateException e) { + assertFalse(sessionImpl.isAlive()); + acquireLose = true; + break; + } + + try { + sessionImpl.release0(ticket); + } catch (IllegalStateException e) { + fail(); + break; + } + } + } + } + private void waitSomeTime() { try { Thread.sleep(10); @@ -412,8 +566,8 @@ static Object[][] drops() { private void keepAlive(Arena child, Arena parent) { MemorySessionImpl parentImpl = MemorySessionImpl.toMemorySession(parent); - parentImpl.acquire0(); - addCloseAction(child, parentImpl::release0); + final int ticket = parentImpl.acquire0(); + addCloseAction(child, () -> {parentImpl.release0(ticket); }); } private void addCloseAction(Arena session, Runnable action) {