Skip to content

Commit 2664da2

Browse files
Add grouped QSL c API
1 parent 1f4ae66 commit 2664da2

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

loadgen/bindings/c_api.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,83 @@ void DestroyQSL(void* qsl) {
129129
delete qsl_cast;
130130
}
131131

132+
namespace {
133+
134+
//
135+
class GroupedQuerySampleLibraryTrampoline : public QuerySampleLibrary {
136+
public:
137+
GroupedQuerySampleLibraryTrampoline(
138+
ClientData client_data,
139+
std::string name,
140+
size_t performance_sample_count,
141+
LoadSamplesToRamCallback load_samples_to_ram_cb,
142+
UnloadSamplesFromRamCallback unload_samples_from_ram_cb,
143+
std::vector<size_t>& group_sizes)
144+
: name_(std::move(name)),
145+
performance_sample_count_(performance_sample_count),
146+
load_samples_to_ram_cb_(load_samples_to_ram_cb),
147+
unload_samples_from_ram_cb_(unload_samples_from_ram_cb) {
148+
149+
total_sample_count_ = 0;
150+
151+
for(ssize_t i = 0; i < group_sizes.size(); i++){
152+
group_sizes_.push_back(group_sizes[i]);
153+
total_sample_count_ += group_sizes[i];
154+
for(size_t j = 0; j < group_sizes[i]; j++){
155+
group_idx_.push_back(i);
156+
}
157+
}
158+
}
159+
~GroupedQuerySampleLibraryTrampoline() override = default;
160+
161+
const std::string& Name() override { return name_; }
162+
size_t TotalSampleCount() override { return total_sample_count_; }
163+
size_t PerformanceSampleCount() override { return performance_sample_count_; }
164+
size_t GroupSize(size_t i) override { return group_sizes_[i]; }
165+
size_t GroupOf(size_t i) override { return group_idx_[i]; }
166+
size_t NumberOfGroups() override { return group_sizes_.size(); }
167+
168+
void LoadSamplesToRam(const std::vector<QuerySampleIndex>& samples) override {
169+
(*load_samples_to_ram_cb_)(client_data_, samples.data(), samples.size());
170+
}
171+
void UnloadSamplesFromRam(
172+
const std::vector<QuerySampleIndex>& samples) override {
173+
(*unload_samples_from_ram_cb_)(client_data_, samples.data(),
174+
samples.size());
175+
}
176+
177+
private:
178+
std::string name_;
179+
ClientData client_data_;
180+
std::vector<size_t> group_sizes_;
181+
std::vector<size_t> group_idx_;
182+
size_t total_sample_count_;
183+
size_t performance_sample_count_;
184+
LoadSamplesToRamCallback load_samples_to_ram_cb_;
185+
UnloadSamplesFromRamCallback unload_samples_from_ram_cb_;
186+
};
187+
188+
} // namespace
189+
190+
void* ConstructGroupedQSL(ClientData client_data, const char* name, size_t name_length,
191+
size_t total_sample_count, size_t performance_sample_count,
192+
LoadSamplesToRamCallback load_samples_to_ram_cb,
193+
UnloadSamplesFromRamCallback unload_samples_from_ram_cb,
194+
std::vector<size_t>& group_sizes) {
195+
GroupedQuerySampleLibraryTrampoline* qsl = new GroupedQuerySampleLibraryTrampoline(
196+
client_data, std::string(name, name_length), total_sample_count,
197+
performance_sample_count, load_samples_to_ram_cb,
198+
unload_samples_from_ram_cb, group_sizes);
199+
return reinterpret_cast<void*>(qsl);
200+
}
201+
202+
void DestroyGroupedQSL(void* qsl) {
203+
GroupedQuerySampleLibraryTrampoline* qsl_cast =
204+
reinterpret_cast<GroupedQuerySampleLibraryTrampoline*>(qsl);
205+
delete qsl_cast;
206+
}
207+
208+
132209
// mlperf::c::StartTest just forwards to mlperf::StartTest after doing the
133210
// proper cast.
134211
void StartTest(void* sut, void* qsl, const TestSettings& settings,
@@ -142,6 +219,18 @@ void StartTest(void* sut, void* qsl, const TestSettings& settings,
142219
audit_config_filename);
143220
}
144221

222+
void StartTestWithGroupedQSL(void* sut, void* qsl, const TestSettings& settings,
223+
const std::string& audit_config_filename = "audit.config") {
224+
SystemUnderTestTrampoline* sut_cast =
225+
reinterpret_cast<SystemUnderTestTrampoline*>(sut);
226+
GroupedQuerySampleLibraryTrampoline* qsl_cast =
227+
reinterpret_cast<GroupedQuerySampleLibraryTrampoline*>(qsl);
228+
assert(settings.use_grouped_qsl);
229+
LogSettings default_log_settings;
230+
mlperf::StartTest(sut_cast, qsl_cast, settings, default_log_settings,
231+
audit_config_filename);
232+
}
233+
145234
void QuerySamplesComplete(QuerySampleResponse* responses,
146235
size_t response_count) {
147236
mlperf::QuerySamplesComplete(responses, response_count);

loadgen/bindings/python_api.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ void StartTestWithGroupedQSL(
340340
GroupedQuerySampleLibraryTrampoline* qsl_cast =
341341
reinterpret_cast<GroupedQuerySampleLibraryTrampoline*>(qsl);
342342
LogSettings default_log_settings;
343-
assert(TestSettings.use_grouped_qsl);
343+
assert(test_settings.use_grouped_qsl);
344344
mlperf::StartTest(sut_cast, qsl_cast, test_settings, default_log_settings,
345345
audit_config_filename);
346346
}

0 commit comments

Comments
 (0)