diff --git a/src/utils/tensor.js b/src/utils/tensor.js index ea822b6c6..2e7ddf2cf 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -212,7 +212,7 @@ export class Tensor { * @returns {Array} */ tolist() { - return reshape(this.data, this.dims) + return reshape(this.data, this.dims); } /** @@ -885,31 +885,54 @@ export class Tensor { */ function reshape(data, dimensions) { + const ndim = dimensions.length; + + if (ndim === 0) { + // Scalar + return data[0]; + } + const totalElements = data.length; - const dimensionSize = dimensions.reduce((a, b) => a * b); + const dimensionSize = dimensions.reduce((a, b) => a * b, 1); if (totalElements !== dimensionSize) { throw Error(`cannot reshape array of size ${totalElements} into shape (${dimensions})`); } - /** @type {any} */ - let reshapedArray = data; + if (ndim === 1) { + return Array.from(data); + } - for (let i = dimensions.length - 1; i >= 0; i--) { - reshapedArray = reshapedArray.reduce((acc, val) => { - let lastArray = acc[acc.length - 1]; + // Pre-compute strides for each dimension + const strides = new Array(ndim); + strides[ndim - 1] = 1; + for (let i = ndim - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * dimensions[i + 1]; + } - if (lastArray.length < dimensions[i]) { - lastArray.push(val); - } else { - acc.push([val]); + /** + * Recursively construct the nested array. + * @param {number} offset - Current offset in `data`. + * @param {number} axis - Current axis being processed. + * @returns {Array} + */ + function build(offset, axis) { + const size = dimensions[axis]; + const result = new Array(size); + if (axis === ndim - 1) { + for (let i = 0; i < size; i++) { + result[i] = data[offset + i]; } - - return acc; - }, [[]]); + } else { + const step = strides[axis]; + for (let i = 0; i < size; i++) { + result[i] = build(offset + i * step, axis + 1); + } + } + return result; } - return reshapedArray[0]; + return build(0, 0); } /** diff --git a/tests/utils/tensor.test.js b/tests/utils/tensor.test.js index 008f46abb..f4cd7c7aa 100644 --- a/tests/utils/tensor.test.js +++ b/tests/utils/tensor.test.js @@ -378,6 +378,56 @@ describe("Tensor operations", () => { [3, 4], ]); }); + + it("should return nested arrays for a 3D tensor", () => { + const t1 = new Tensor( + "float32", + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 2, 2], + ); + const arr = t1.tolist(); + compare(arr, [ + [ + [1, 2], + [3, 4], + ], + [ + [5, 6], + [7, 8], + ], + ]); + }); + + it("should return nested arrays for a 4D tensor", () => { + const t1 = new Tensor( + "float32", + Array.from({ length: 16 }, (_, i) => i + 1), + [2, 2, 2, 2], + ); + const arr = t1.tolist(); + compare(arr, [ + [ + [ + [1, 2], + [3, 4], + ], + [ + [5, 6], + [7, 8], + ], + ], + [ + [ + [9, 10], + [11, 12], + ], + [ + [13, 14], + [15, 16], + ], + ], + ]); + }); }); describe("mul", () => {