Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/infinicore/ops/random_sample.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

#include "infinicore/tensor.hpp"

namespace infinicore::op {

class RandomSample {
public:
using schema = void (*)(Tensor, Tensor, float, float, int, float);
static void execute(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);
static common::OpDispatcher<schema> &dispatcher();
};

// Out-of-place API
Tensor random_sample(Tensor logits, float random_val, float topp, int topk, float temperature);
// In-place API
void random_sample_(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);

} // namespace infinicore::op


22 changes: 22 additions & 0 deletions python/infinicore/ops/random_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def random_sample(logits, random_val, topp, topk, temperature, *, out=None):
if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying, random_val, topp, topk, temperature
)
)

_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)


38 changes: 38 additions & 0 deletions src/infinicore/ops/random_sample/random_sample.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "infinicore/ops/random_sample.hpp"

namespace infinicore::op {

common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() {
static common::OpDispatcher<RandomSample::schema> dispatcher_;
return dispatcher_;
};

void RandomSample::execute(
Tensor indices, Tensor logits,
float random_val, float topp, int topk, float temperature) {
dispatcher().lookup(context::getDevice().getType())(
indices, logits, random_val, topp, topk, temperature);
}

Tensor random_sample(
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
auto indices = Tensor::empty({}, DataType::I32, logits->device());
random_sample_(indices, logits, random_val, topp, topk, temperature);
return indices;
}

void random_sample_(
Tensor indices,
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
RandomSample::execute(indices, logits, random_val, topp, topk, temperature);
}

} // namespace infinicore::op
66 changes: 66 additions & 0 deletions src/infinicore/ops/random_sample/random_sample_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/random_sample.hpp"
#include <infiniop.h>

namespace infinicore::op::random_sample_impl::infiniop_backend {

thread_local common::OpCache<size_t, infiniopRandomSampleDescriptor_t> caches(
100, // capacity
[](infiniopRandomSampleDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyRandomSampleDescriptor(desc));
desc = nullptr;
}
});

static void calculate(
Tensor indices,
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
// cache per (result desc + logits desc) on device
size_t seed = hash_combine(indices, logits);

auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();

auto &cache = caches.getCache(device_type, device_index);

auto desc_opt = cache.get(seed);
infiniopRandomSampleDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
context::getInfiniopHandle(), &desc,
indices->desc(), logits->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetRandomSampleWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopRandomSample(
desc,
workspace->data(), workspace_size,
indices->data(), logits->data(),
random_val, topp, topk, temperature,
context::getStream()));
}

} // namespace infinicore::op::random_sample_impl::infiniop_backend

namespace infinicore::op {

static bool registered = []() {
RandomSample::dispatcher().registerAll(&random_sample_impl::infiniop_backend::calculate, false);
return true;
}();

} // namespace infinicore::op
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/silu.hpp"
Expand All @@ -19,6 +20,7 @@ inline void bind(py::module &m) {
bind_add(m);
bind_attention(m);
bind_causal_softmax(m);
bind_random_sample(m);
bind_matmul(m);
bind_rearrange(m);
bind_rms_norm(m);
Expand Down
32 changes: 32 additions & 0 deletions src/infinicore/pybind11/ops/random_sample.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/random_sample.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_random_sample(py::module &m) {
m.def("random_sample",
&op::random_sample,
py::arg("logits"),
py::arg("random_val"),
py::arg("topp"),
py::arg("topk"),
py::arg("temperature"),
R"doc(Random sampling: returns an int32 scalar index.)doc");

m.def("random_sample_",
&op::random_sample_,
py::arg("indices"),
py::arg("logits"),
py::arg("random_val"),
py::arg("topp"),
py::arg("topk"),
py::arg("temperature"),
R"doc(In-place random sampling into provided int32 scalar tensor.)doc");
}

} // namespace infinicore::ops
Loading