From 1bfdc964ac243c3490fcc86b2de2efea4899a14d Mon Sep 17 00:00:00 2001 From: rogmann Date: Sun, 11 Aug 2024 00:04:44 +0200 Subject: [PATCH] Implementation of Q6_KFloatTensor --- Llama3.java | 475 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 472 insertions(+), 3 deletions(-) diff --git a/Llama3.java b/Llama3.java index 25a495f..7e019f6 100755 --- a/Llama3.java +++ b/Llama3.java @@ -254,6 +254,8 @@ public static void main(String[] args) throws IOException { Options options = Options.parseOptions(args); Llama model = ModelLoader.loadModel(options.modelPath(), options.maxTokens()); Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed()); + System.out.println(String.format("Llama4J: species_bits=%d, species_array=%s, species=%s", + FloatTensor.F_SPECIES_BITS, FloatTensor.F_SPECIES_ARRAY, FloatTensor.F_SPECIES)); if (options.interactive()) { runInteractive(model, sampler, options); } else { @@ -747,6 +749,7 @@ public static FloatTensor loadQuantized(GGMLTensorEntry entry) { return switch (ggmlType) { //case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case Q6_K -> new Q6_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); }; @@ -1432,9 +1435,16 @@ abstract class FloatTensor { static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); - // Preferred vector size for the fast multiplication routines. - // (Apple Silicon) NEON only supports up-to 128bit vectors. - static final VectorSpecies F_SPECIES = FloatVector.SPECIES_PREFERRED.vectorBitSize() == 128 ? FloatVector.SPECIES_128 : FloatVector.SPECIES_256; + /** true if an array might be used in the dot product computation (256 bits only) */ + static final boolean F_SPECIES_ARRAY = Boolean.getBoolean("llama.species.array"); + /** number of species bits to be used in VectorAPI, default is the preferred size (will be truncated to 256) */ + static final int F_SPECIES_BITS = Integer.parseInt(System.getProperty("llama.species.bits", Integer.toString(FloatVector.SPECIES_PREFERRED.vectorBitSize()))); + + /** + * Preferred vector size for the fast multiplication routines. + * (Apple Silicon) NEON only supports up-to 128bit vectors. + */ + static final VectorSpecies F_SPECIES = (F_SPECIES_BITS == 128) ? FloatVector.SPECIES_128 : FloatVector.SPECIES_256; abstract int size(); @@ -1701,6 +1711,465 @@ private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, FloatTensor } } +/** + * {@link FloatTensor} quantized in the {@link GGMLType#Q6_K} format. + *

+ * This tensor implementation is not compatible with {@link FloatTensor}, but + * {@link #dot(int, FloatTensor, int, int)} has a vectorized implementation that is used when + * the second argument implements {@link FloatTensor}. + */ +final class Q6_KFloatTensor extends FloatTensor { + + final int size; + final MemorySegment memorySegment; + + public Q6_KFloatTensor(int size, MemorySegment memorySegment) { + this.size = size; + this.memorySegment = memorySegment; + } + + @Override + int size() { + return size; + } + + @Override + public void setFloat(int index, float value) { + throw new UnsupportedOperationException("setFloat"); + } + + @Override + FloatVector getFloatVector(VectorSpecies species, int index) { + throw new UnsupportedOperationException("getFloatVector"); + } + + @Override + public GGMLType type() { + return GGMLType.Q6_K; + } + + @Override + public float getFloat(int index) { + assert 0 <= index && index < size; + int blockSize = GGMLType.Q6_K.getBlockSize(); + int blockIndex = index / blockSize; + int indexInBlock = index % blockSize; + int blockOffset = blockIndex * GGMLType.Q6_K.getTypeSize(); + + int superblockSize = GGMLType.QK_K / 16; + // Layout: + // uint8 ql[QK_K/2]: quants, lower 4 bits + // uint8 qh[QK_K/4]: quants, upper 2 bits + // int8 scales[QK_K/16]: scales, quantized with 8 bits + // fp16 d: super-block scale + int offsetQl = blockOffset; + int offsetQh = offsetQl + GGMLType.QK_K / 2; + int offsetScales = offsetQh + GGMLType.QK_K / 4; + int offsetD = offsetScales + superblockSize; + float scale = Float.float16ToFloat(memorySegment.get(JAVA_SHORT_LE, offsetD)); + int superblockIndex = indexInBlock / superblockSize; + int scale8 = memorySegment.get(ValueLayout.JAVA_BYTE, offsetScales + superblockIndex); + int blk128 = indexInBlock / 128; // blk128 in {0, 1} + int blk32 = (indexInBlock / 32) % 4; // blk32 in {0, 1, 2, 3} + int idx32 = indexInBlock % 32; + int ql; + int qh; + if (blk32 == 0) { + ql = memorySegment.get(ValueLayout.JAVA_BYTE, offsetQl + 64 * blk128 + idx32) & 0x0f; + qh = memorySegment.get(ValueLayout.JAVA_BYTE, offsetQh + 32 * blk128 + idx32) & 0b11; + } else if (blk32 == 1) { + ql = memorySegment.get(ValueLayout.JAVA_BYTE, offsetQl + 64 * blk128 + 32 + idx32) & 0x0f; + qh = (memorySegment.get(ValueLayout.JAVA_BYTE, offsetQh + 32 * blk128 + idx32) >> 2) & 0b11; + } else if (blk32 == 2) { + ql = (memorySegment.get(ValueLayout.JAVA_BYTE, offsetQl + 64 * blk128 + idx32) & 0xf0) >>> 4; + qh = (memorySegment.get(ValueLayout.JAVA_BYTE, offsetQh + 32 * blk128 + idx32) >> 4) & 0b11; + } else { + ql = (memorySegment.get(ValueLayout.JAVA_BYTE, offsetQl + 64 * blk128 + 32 + idx32) & 0xf0) >>> 4; + qh = (memorySegment.get(ValueLayout.JAVA_BYTE, offsetQh + 32 * blk128 + idx32) >> 6) & 0b11; + } + int quant = (qh << 4) + ql - 32; + return quant * scale8 * scale; + } + + @Override + public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { + if (FloatTensor.USE_VECTOR_API) { + return vectorDot(this, thisOffset, that, thatOffset, size); + } else { + return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); + } + } + + private static float vectorDot(Q6_KFloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { + if (FloatVector.SPECIES_128.equals(F_SPECIES)) { + if (F_SPECIES_ARRAY) { + throw new UnsupportedOperationException("no array implementation for 128 bits"); + } + return vectorDot128Array(thiz, thisOffset, that, thatOffset, size); + } else if (F_SPECIES_BITS == 512) { + if (F_SPECIES_ARRAY) { + throw new UnsupportedOperationException("no array implementation for 512 bits"); + } + return vectorDot512(thiz, thisOffset, that, thatOffset, size); + } else if (FloatVector.SPECIES_256.equals(F_SPECIES) && F_SPECIES_ARRAY) { + return vectorDot256Array(thiz, thisOffset, that, thatOffset, size); + } else if (FloatVector.SPECIES_256.equals(F_SPECIES)) { + return vectorDot256(thiz, thisOffset, that, thatOffset, size); + } else { + throw new UnsupportedOperationException("Unexpected species " + F_SPECIES); + } + } + + private static float vectorDot512(Q6_KFloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { + float result = 0f; + int j = 0; + + // Align thisOffset + j to type().getBlockSize(). + int blockSize = GGMLType.Q6_K.getBlockSize(); + assert Integer.bitCount(blockSize) == 1 : "power of 2"; + int alignmentBound = Math.min(size, -thisOffset & (blockSize - 1)); + if (alignmentBound > 0) { + result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); + j += alignmentBound; + } + assert (thisOffset + j) % blockSize == 0; + + FloatVector val = FloatVector.zero(F_SPECIES); + int typeSize = GGMLType.Q6_K.getTypeSize(); + int superblockSize = GGMLType.QK_K / 16; + int blockOffset = (thisOffset + j) / blockSize * typeSize; + int upperBound = size / blockSize * blockSize; + var bSpecies128 = ByteVector.SPECIES_128; // 16 bytes + var bSpecies256 = ByteVector.SPECIES_256; // 32 bytes + var bSpecies512 = ByteVector.SPECIES_512; // 64 bytes + // Layout: + // uint8 ql[QK_K/2]: quants, lower 4 bits + // uint8 qh[QK_K/4]: quants, upper 2 bits + // int8 scales[QK_K/16]: scales, quantized with 8 bits + // fp16 d: super-block scale + for (; j < upperBound; j += blockSize, blockOffset += typeSize) { + int offsetQl = blockOffset; // size GK_K/2 = 128 bytes = 256 quants + int offsetQh = offsetQl + GGMLType.QK_K / 2; // size GK_K/4 = 64 bytes = 256 quants + int offsetScales = offsetQh + GGMLType.QK_K / 4; // size GK_K/16 = 16 bytes = 16 block-scales + int offsetD = offsetScales + superblockSize; + float wScaleValue = Float.float16ToFloat(thiz.memorySegment.get(JAVA_SHORT_LE, offsetD)); + var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); + var wScale8 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetScales, ByteOrder.LITTLE_ENDIAN); + for (int blk128 = 0; blk128 < 2; blk128++) { + final ByteVector bytesQl = ByteVector.fromMemorySegment(bSpecies512, thiz.memorySegment, offsetQl + 64 * blk128, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQh = ByteVector.fromMemorySegment(bSpecies256, thiz.memorySegment, offsetQh + 32 * blk128, ByteOrder.LITTLE_ENDIAN); + + var ql0 = bytesQl.and((byte) 0x0f); + var ql1 = bytesQl.lanewise(VectorOperators.LSHR, 4).and((byte) 0x0f); // index(ql1) = index(ql0) + QK_K/4 + var qh0 = bytesQh.lanewise(VectorOperators.LSHL, 4).and((byte) 0b110000); + var qh1 = bytesQh.lanewise(VectorOperators.LSHL, 2).and((byte) 0b110000); // index(qh1) = index(qh0) + QK_K/8 + var qh2 = bytesQh.and((byte) 0b110000); + var qh3 = bytesQh.lanewise(VectorOperators.LSHR, 2).and((byte) 0b110000); + var q0 = ql0.castShape(bSpecies256, 0).reinterpretAsBytes().or(qh0).sub((byte) 32); + var q1 = ql0.castShape(bSpecies256, 1).reinterpretAsBytes().or(qh1).sub((byte) 32); + var q2 = ql1.castShape(bSpecies256, 0).reinterpretAsBytes().or(qh2).sub((byte) 32); + var q3 = ql1.castShape(bSpecies256, 1).reinterpretAsBytes().or(qh3).sub((byte) 32); + final int blk128Idx = 8 * blk128; + if (F_SPECIES.vectorBitSize() == 512) { + // not yet implemented in the other classes + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (blk128Idx + 0) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (blk128Idx + 1) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 1)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (blk128Idx + 2) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 2)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (blk128Idx + 3) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 3)); + var sum4 = that.getFloatVector(F_SPECIES, thatOffset + j + (blk128Idx + 4) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 4)); + var sum5 = that.getFloatVector(F_SPECIES, thatOffset + j + (blk128Idx + 5) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 5)); + var sum6 = that.getFloatVector(F_SPECIES, thatOffset + j + (blk128Idx + 6) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 6)); + var sum7 = that.getFloatVector(F_SPECIES, thatOffset + j + (blk128Idx + 7) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 7)); + val = sum0.add(sum1).add(sum2).add(sum3).add(sum4).add(sum5).add(sum6).add(sum7).fma(wScale, val); + } else if (F_SPECIES.vectorBitSize() == 256) { + var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 0) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 0)); + var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 1) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 0)); + var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 2) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 2)).mul(wScale8.lane(blk128Idx + 1)); + var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 3) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 3)).mul(wScale8.lane(blk128Idx + 1)); + var sum4 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 4) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 2)); + var sum5 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 5) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 2)); + var sum6 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 6) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 2)).mul(wScale8.lane(blk128Idx + 3)); + var sum7 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 7) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 3)).mul(wScale8.lane(blk128Idx + 3)); + var sum8 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 8) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 4)); + var sum9 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 9) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 4)); + var sum10 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 10) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 2)).mul(wScale8.lane(blk128Idx + 5)); + var sum11 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 11) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 3)).mul(wScale8.lane(blk128Idx + 5)); + var sum12 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 12) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 6)); + var sum13 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 13) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 6)); + var sum14 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 14) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 2)).mul(wScale8.lane(blk128Idx + 7)); + var sum15 = that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 15) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 3)).mul(wScale8.lane(blk128Idx + 7)); + val = sum0.add(sum1).add(sum2).add(sum3).add(sum4).add(sum5).add(sum6).add(sum7) + .add(sum8).add(sum9).add(sum10).add(sum11).add(sum12).add(sum13).add(sum14).add(sum15) + .fma(wScale, val); + } else { + throw new UnsupportedOperationException(String.format("Invalid call of vectorDot512 with %s", F_SPECIES)); + } + } + } + result += val.reduceLanes(VectorOperators.ADD); + + // Remaining entries. + if (j < size) { + result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); + } + + return result; + } + + private static float vectorDot256(Q6_KFloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { + float result = 0f; + int j = 0; + + // Align thisOffset + j to type().getBlockSize(). + int blockSize = GGMLType.Q6_K.getBlockSize(); + assert Integer.bitCount(blockSize) == 1 : "power of 2"; + int alignmentBound = Math.min(size, -thisOffset & (blockSize - 1)); + if (alignmentBound > 0) { + result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); + j += alignmentBound; + } + assert (thisOffset + j) % blockSize == 0; + + FloatVector val = FloatVector.zero(F_SPECIES); // FloatVector.SPECIES_256 + int typeSize = GGMLType.Q6_K.getTypeSize(); + int superblockSize = GGMLType.QK_K / 16; + int blockOffset = (thisOffset + j) / blockSize * typeSize; + int upperBound = size / blockSize * blockSize; + var bSpecies128 = ByteVector.SPECIES_128; // 16 bytes + var bSpecies256 = ByteVector.SPECIES_256; // 32 bytes + var m4 = ByteVector.broadcast(ByteVector.SPECIES_256, (byte) 0x0f); + var m32s = ByteVector.broadcast(ByteVector.SPECIES_256, 32); + // Layout: + // uint8 ql[QK_K/2]: quants, lower 4 bits + // uint8 qh[QK_K/4]: quants, upper 2 bits + // int8 scales[QK_K/16]: scales, quantized with 8 bits + // fp16 d: super-block scale + for (; j < upperBound; j += blockSize, blockOffset += typeSize) { + int offsetQl = blockOffset; // size GK_K/2 = 128 bytes = 256 quants + int offsetQh = offsetQl + GGMLType.QK_K / 2; // size GK_K/4 = 64 bytes = 256 quants + int offsetScales = offsetQh + GGMLType.QK_K / 4; // size GK_K/16 = 16 bytes = 16 block-scales + int offsetD = offsetScales + superblockSize; + float wScaleValue = Float.float16ToFloat(thiz.memorySegment.get(JAVA_SHORT_LE, offsetD)); + var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); + var wScale8 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetScales, ByteOrder.LITTLE_ENDIAN); + for (int blk128 = 0; blk128 < 2; blk128++) { + final ByteVector bytesQl0 = ByteVector.fromMemorySegment(bSpecies256, thiz.memorySegment, offsetQl + 64 * blk128, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQl1 = ByteVector.fromMemorySegment(bSpecies256, thiz.memorySegment, offsetQl + 64 * blk128 + 32, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQh0 = ByteVector.fromMemorySegment(bSpecies256, thiz.memorySegment, offsetQh + 32 * blk128, ByteOrder.LITTLE_ENDIAN); + + var ql00 = bytesQl0.and((byte) 0x0f); + var ql01 = bytesQl1.and((byte) 0x0f); + //var ql00 = bytesQl0.and(m4); + //var ql01 = bytesQl1.and(m4); + var ql10 = bytesQl0.lanewise(VectorOperators.LSHR, 4); // index(ql1) = index(ql0) + QK_K/4 + var ql11 = bytesQl1.lanewise(VectorOperators.LSHR, 4); // index(ql1) = index(ql0) + QK_K/4 + var qh00 = bytesQh0.lanewise(VectorOperators.LSHL, 4).and((byte) 0b110000); + var qh10 = bytesQh0.lanewise(VectorOperators.LSHL, 2).and((byte) 0b110000); // index(qh1) = index(qh0) + QK_K/8 + var qh20 = bytesQh0.and((byte) 0b110000); + var qh30 = bytesQh0.lanewise(VectorOperators.LSHR, 2).and((byte) 0b110000); + var q0 = ql00.add(qh00).sub((byte) 32); + var q1 = ql01.add(qh10).sub((byte) 32); + var q2 = ql10.add(qh20).sub((byte) 32); + var q3 = ql11.add(qh30).sub((byte) 32); + final int blk128Idx = 8 * blk128; + FloatVector sum = FloatVector.broadcast(F_SPECIES, 0f); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 0) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 0))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 1) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 0))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 2) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 2)).mul(wScale8.lane(blk128Idx + 1))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 3) * F_SPECIES.length()).mul(q0.castShape(F_SPECIES, 3)).mul(wScale8.lane(blk128Idx + 1))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 4) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 2))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 5) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 2))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 6) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 2)).mul(wScale8.lane(blk128Idx + 3))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 7) * F_SPECIES.length()).mul(q1.castShape(F_SPECIES, 3)).mul(wScale8.lane(blk128Idx + 3))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 8) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 4))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 9) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 4))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 10) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 2)).mul(wScale8.lane(blk128Idx + 5))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 11) * F_SPECIES.length()).mul(q2.castShape(F_SPECIES, 3)).mul(wScale8.lane(blk128Idx + 5))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 12) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 0)).mul(wScale8.lane(blk128Idx + 6))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 13) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 1)).mul(wScale8.lane(blk128Idx + 6))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 14) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 2)).mul(wScale8.lane(blk128Idx + 7))); + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + 15) * F_SPECIES.length()).mul(q3.castShape(F_SPECIES, 3)).mul(wScale8.lane(blk128Idx + 7))); + val = sum.fma(wScale, val); + } + } + result += val.reduceLanes(VectorOperators.ADD); + + // Remaining entries. + if (j < size) { + result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); + } + + return result; + } + + private static float vectorDot256Array(Q6_KFloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { + float result = 0f; + int j = 0; + + // Align thisOffset + j to type().getBlockSize(). + int blockSize = GGMLType.Q6_K.getBlockSize(); + assert Integer.bitCount(blockSize) == 1 : "power of 2"; + int alignmentBound = Math.min(size, -thisOffset & (blockSize - 1)); + if (alignmentBound > 0) { + result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); + j += alignmentBound; + } + assert (thisOffset + j) % blockSize == 0; + + FloatVector val = FloatVector.zero(F_SPECIES); // FloatVector.SPECIES_256 + int typeSize = GGMLType.Q6_K.getTypeSize(); + int superblockSize = GGMLType.QK_K / 16; + int blockOffset = (thisOffset + j) / blockSize * typeSize; + int upperBound = size / blockSize * blockSize; + var bSpecies128 = ByteVector.SPECIES_128; // 16 bytes + var bSpecies256 = ByteVector.SPECIES_256; // 32 bytes + var m4 = ByteVector.broadcast(ByteVector.SPECIES_256, (byte) 0x0f); + var m32s = ByteVector.broadcast(ByteVector.SPECIES_256, 32); + // Layout: + // uint8 ql[QK_K/2]: quants, lower 4 bits + // uint8 qh[QK_K/4]: quants, upper 2 bits + // int8 scales[QK_K/16]: scales, quantized with 8 bits + // fp16 d: super-block scale + for (; j < upperBound; j += blockSize, blockOffset += typeSize) { + int offsetQl = blockOffset; // size GK_K/2 = 128 bytes = 256 quants + int offsetQh = offsetQl + GGMLType.QK_K / 2; // size GK_K/4 = 64 bytes = 256 quants + int offsetScales = offsetQh + GGMLType.QK_K / 4; // size GK_K/16 = 16 bytes = 16 block-scales + int offsetD = offsetScales + superblockSize; + float wScaleValue = Float.float16ToFloat(thiz.memorySegment.get(JAVA_SHORT_LE, offsetD)); + var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); + var wScale8 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetScales, ByteOrder.LITTLE_ENDIAN); + for (int blk128 = 0; blk128 < 2; blk128++) { + final ByteVector bytesQl0 = ByteVector.fromMemorySegment(bSpecies256, thiz.memorySegment, offsetQl + 64 * blk128, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQl1 = ByteVector.fromMemorySegment(bSpecies256, thiz.memorySegment, offsetQl + 64 * blk128 + 32, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQh0 = ByteVector.fromMemorySegment(bSpecies256, thiz.memorySegment, offsetQh + 32 * blk128, ByteOrder.LITTLE_ENDIAN); + + var ql00 = bytesQl0.and((byte) 0x0f); + var ql01 = bytesQl1.and((byte) 0x0f); + //var ql00 = bytesQl0.and(m4); + //var ql01 = bytesQl1.and(m4); + var ql10 = bytesQl0.lanewise(VectorOperators.LSHR, 4); // index(ql1) = index(ql0) + QK_K/4 + var ql11 = bytesQl1.lanewise(VectorOperators.LSHR, 4); // index(ql1) = index(ql0) + QK_K/4 + var qh00 = bytesQh0.lanewise(VectorOperators.LSHL, 4).and((byte) 0b110000); + var qh10 = bytesQh0.lanewise(VectorOperators.LSHL, 2).and((byte) 0b110000); // index(qh1) = index(qh0) + QK_K/8 + var qh20 = bytesQh0.and((byte) 0b110000); + var qh30 = bytesQh0.lanewise(VectorOperators.LSHR, 2).and((byte) 0b110000); + ByteVector[] q = new ByteVector[] { + //ql00.add(qh00).sub(m32s), + //ql01.add(qh10).sub(m32s), + //ql10.add(qh20).sub(m32s), + //ql11.add(qh30).sub(m32s), + ql00.add(qh00).sub((byte) 32), + ql01.add(qh10).sub((byte) 32), + ql10.add(qh20).sub((byte) 32), + ql11.add(qh30).sub((byte) 32), + }; + final int blk128Idx = 8 * blk128; + FloatVector sum = FloatVector.broadcast(F_SPECIES, 0f); + for (int i = 0; i < 16; i++) { + final int idx2 = i / 2; + final int idx4 = i / 4; + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (2 * blk128Idx + i) * F_SPECIES.length()).mul(q[idx4].castShape(F_SPECIES, i % 4)).mul(wScale8.lane(blk128Idx + idx2))); + } + val = sum.fma(wScale, val); + } + } + result += val.reduceLanes(VectorOperators.ADD); + + // Remaining entries. + if (j < size) { + result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); + } + + return result; + } + + private static float vectorDot128Array(Q6_KFloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { + float result = 0f; + int j = 0; + + // Align thisOffset + j to type().getBlockSize(). + int blockSize = GGMLType.Q6_K.getBlockSize(); + assert Integer.bitCount(blockSize) == 1 : "power of 2"; + int alignmentBound = Math.min(size, -thisOffset & (blockSize - 1)); + if (alignmentBound > 0) { + result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); + j += alignmentBound; + } + assert (thisOffset + j) % blockSize == 0; + + FloatVector val = FloatVector.zero(F_SPECIES); // FloatVector.SPECIES_128 + int typeSize = GGMLType.Q6_K.getTypeSize(); + int superblockSize = GGMLType.QK_K / 16; + int blockOffset = (thisOffset + j) / blockSize * typeSize; + int upperBound = size / blockSize * blockSize; + var bSpecies128 = ByteVector.SPECIES_128; // 16 bytes + // Layout: + // uint8 ql[QK_K/2]: quants, lower 4 bits + // uint8 qh[QK_K/4]: quants, upper 2 bits + // int8 scales[QK_K/16]: scales, quantized with 8 bits + // fp16 d: super-block scale + for (; j < upperBound; j += blockSize, blockOffset += typeSize) { + int offsetQl = blockOffset; // size GK_K/2 = 128 bytes = 256 quants + int offsetQh = offsetQl + GGMLType.QK_K / 2; // size GK_K/4 = 64 bytes = 256 quants + int offsetScales = offsetQh + GGMLType.QK_K / 4; // size GK_K/16 = 16 bytes = 16 block-scales + int offsetD = offsetScales + superblockSize; + float wScaleValue = Float.float16ToFloat(thiz.memorySegment.get(JAVA_SHORT_LE, offsetD)); + var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); + var wScale8 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetScales, ByteOrder.LITTLE_ENDIAN); + for (int blk128 = 0; blk128 < 2; blk128++) { + final ByteVector bytesQl0 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetQl + 64 * blk128, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQl1 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetQl + 64 * blk128 + 16, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQl2 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetQl + 64 * blk128 + 32, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQl3 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetQl + 64 * blk128 + 48, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQh0 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetQh + 32 * blk128, ByteOrder.LITTLE_ENDIAN); + final ByteVector bytesQh1 = ByteVector.fromMemorySegment(bSpecies128, thiz.memorySegment, offsetQh + 32 * blk128 + 16, ByteOrder.LITTLE_ENDIAN); + + var ql00 = bytesQl0.and((byte) 0x0f); + var ql01 = bytesQl1.and((byte) 0x0f); + var ql02 = bytesQl2.and((byte) 0x0f); + var ql03 = bytesQl3.and((byte) 0x0f); + var ql10 = bytesQl0.lanewise(VectorOperators.LSHR, 4).and((byte) 0x0f); // index(ql1) = index(ql0) + QK_K/4 + var ql11 = bytesQl1.lanewise(VectorOperators.LSHR, 4).and((byte) 0x0f); // index(ql1) = index(ql0) + QK_K/4 + var ql12 = bytesQl2.lanewise(VectorOperators.LSHR, 4).and((byte) 0x0f); // index(ql1) = index(ql0) + QK_K/4 + var ql13 = bytesQl3.lanewise(VectorOperators.LSHR, 4).and((byte) 0x0f); // index(ql1) = index(ql0) + QK_K/4 + var qh00 = bytesQh0.lanewise(VectorOperators.LSHL, 4).and((byte) 0b110000); + var qh01 = bytesQh1.lanewise(VectorOperators.LSHL, 4).and((byte) 0b110000); + var qh10 = bytesQh0.lanewise(VectorOperators.LSHL, 2).and((byte) 0b110000); // index(qh1) = index(qh0) + QK_K/8 + var qh11 = bytesQh1.lanewise(VectorOperators.LSHL, 2).and((byte) 0b110000); // index(qh1) = index(qh0) + QK_K/8 + var qh20 = bytesQh0.and((byte) 0b110000); + var qh21 = bytesQh1.and((byte) 0b110000); + var qh30 = bytesQh0.lanewise(VectorOperators.LSHR, 2).and((byte) 0b110000); + var qh31 = bytesQh1.lanewise(VectorOperators.LSHR, 2).and((byte) 0b110000); + ByteVector[] q = new ByteVector[] { + ql00.add(qh00).sub((byte) 32), + ql01.add(qh01).sub((byte) 32), + ql02.add(qh10).sub((byte) 32), + ql03.add(qh11).sub((byte) 32), + ql10.add(qh20).sub((byte) 32), + ql11.add(qh21).sub((byte) 32), + ql12.add(qh30).sub((byte) 32), + ql13.add(qh31).sub((byte) 32) + }; + final int blk128Idx = 8 * blk128; + FloatVector sum = FloatVector.broadcast(F_SPECIES, 0f); + for (int i = 0; i < 32; i++) { + final int idx4 = i / 4; + sum = sum.add(that.getFloatVector(F_SPECIES, thatOffset + j + (4 * blk128Idx + i) * F_SPECIES.length()).mul(q[idx4].castShape(F_SPECIES, i % 4)).mul(wScale8.lane(blk128Idx + idx4))); + } + val = sum.fma(wScale, val); + } + } + result += val.reduceLanes(VectorOperators.ADD); + + // Remaining entries. + if (j < size) { + result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); + } + + return result; + } +} + final class Q8_0FloatTensor extends FloatTensor { final int size;