diff --git a/src/pufferlib.cu b/src/pufferlib.cu index df6faa57f4..a6c8c5957a 100644 --- a/src/pufferlib.cu +++ b/src/pufferlib.cu @@ -2091,6 +2091,11 @@ std::unique_ptr create_pufferl_impl(HypersT& hypers, } cudaMemset(pufferl->rng_offset_puf.data, 0, numel(pufferl->rng_offset_puf.shape) * sizeof(long)); + for (int bank = 0; bank < 1 + pufferl->num_frozen_banks; bank++) { + PrecisionTensor* bs = (bank == 0) ? pufferl->buffer_states + : pufferl->frozen_banks[bank - 1].buffer_states; + for (int b = 0; b < num_buffers; b++) puf_zero(&bs[b], pufferl->default_stream); + } cudaDeviceSynchronize(); pufferl->epoch = 0;