The layer_norm, RMS norm and softmax designs use the same parameter for the number of elements to normalize over (batch size) and the buffer sizes on the cores (tile size).
We should rename tile_size to matrix_columns, batch_size or something similar to avoid confusion.
Other kernels use tile_size as a parameter that affects only the data movement, not the output. For those other kernels, you can tune tile_size to maximize L1 memory usage while still performing the same calculation. For the normalization kernels, on the other hand, changing tile_size changes the output.
Additionally, we might want to add support for batch_size < tile_size, as this should be relatively simple. (Each kernel call processes N batches and maintains N means, variances, ...). batch_size > tile_size might be harder to implement, as it would require passing means, variances, ... from kernel call to the next kernel call, so we could just error in that case for now.
The layer_norm, RMS norm and softmax designs use the same parameter for the number of elements to normalize over (batch size) and the buffer sizes on the cores (tile size).
We should rename
tile_sizetomatrix_columns,batch_sizeor something similar to avoid confusion.Other kernels use
tile_sizeas a parameter that affects only the data movement, not the output. For those other kernels, you can tunetile_sizeto maximize L1 memory usage while still performing the same calculation. For the normalization kernels, on the other hand, changingtile_sizechanges the output.Additionally, we might want to add support for
batch_size < tile_size, as this should be relatively simple. (Each kernel call processesNbatches and maintainsNmeans, variances, ...).batch_size > tile_sizemight be harder to implement, as it would require passing means, variances, ... from kernel call to the next kernel call, so we could just error in that case for now.