diff --git a/larray/src/main/scala/xerial/larray/LArray.scala b/larray/src/main/scala/xerial/larray/LArray.scala index bc4d1e7..affd1dc 100644 --- a/larray/src/main/scala/xerial/larray/LArray.scala +++ b/larray/src/main/scala/xerial/larray/LArray.scala @@ -274,6 +274,9 @@ trait LArray[A] extends LSeq[A] with WritableByteChannel { @inline def putLong(offset: Long, v: Long) = { unsafe.putLong(address + offset, v); v} @inline def putDouble(offset: Long, v: Double) = { unsafe.putDouble(address + offset, v); v } + @inline def casInt(offset: Long, exp: Int, v: Int) = { unsafe.compareAndSwapInt(null, address + offset, exp, v) } + @inline def casLong(offset: Long, exp: Long, v: Long) = { unsafe.compareAndSwapLong(null, address + offset, exp, v)} + } @@ -813,6 +816,10 @@ class LIntArray(val size: Long, private[larray] val m: Memory)(implicit val allo v } + def cas(i: Long, exp: Int, v: Int) = { + unsafe.compareAndSwapInt(null, m.address + (i << 2), exp, v) + } + /** * Byte size of an element. For example, if A is Int, its elementByteSize is 4 */ @@ -849,6 +856,10 @@ class LLongArray(val size: Long, private[larray] val m: Memory)(implicit val all v } + def cas(i: Long, exp: Long, v: Long) = { + unsafe.compareAndSwapLong(null, m.address + (i << 2), exp, v) + } + def view(from: Long, to: Long) = new LArrayView.LLongArrayView(this, from, to - from) } diff --git a/larray/src/test/scala/xerial/larray/LArrayTest.scala b/larray/src/test/scala/xerial/larray/LArrayTest.scala index 87e3806..401925b 100644 --- a/larray/src/test/scala/xerial/larray/LArrayTest.scala +++ b/larray/src/test/scala/xerial/larray/LArrayTest.scala @@ -98,6 +98,58 @@ class LArrayTest extends LArraySpec { } } + "implement cas correctly" in { + info("cas test") + + val l = new LIntArray(3) + try { + l(0) = 1 + l(1) = 2 + l(2) = 3 + l(0) should be (1) + l(1) should be (2) + l(2) should be (3) + + l.cas(0, 1, 2) should be (true) + l(0) should be (2) + + l.casInt(0, 2, 3) should be (true) + l(0) should be (3) + + l.cas(0, 1, 1) should be (false) + l(0) should be (3) + } + finally { + l.free + } + } + + "implement cas correctly for LLongArray" in { + info("cas test") + + val l = new LLongArray(3) + try { + l(0) = 1 + l(1) = 2 + l(2) = 3 + l(0) should be (1L) + l(1) should be (2L) + l(2) should be (3L) + + l.cas(0, 1L, 2L) should be (true) + l(0) should be (2L) + + l.casLong(0, 2, 3) should be (true) + l(0) should be (3L) + + l.cas(0, 1L, 1L) should be (false) + l(0) should be (3L) + } + finally { + l.free + } + } + "read/write data to Array[Byte]" taggedAs ("rw") in { val l = LArray(1, 3) val b = new Array[Byte](l.byteLength.toInt)