@@ -129,6 +129,83 @@ void DestroyQSL(void* qsl) {
129129 delete qsl_cast;
130130}
131131
132+ namespace {
133+
134+ //
135+ class GroupedQuerySampleLibraryTrampoline : public QuerySampleLibrary {
136+ public:
137+ GroupedQuerySampleLibraryTrampoline (
138+ ClientData client_data,
139+ std::string name,
140+ size_t performance_sample_count,
141+ LoadSamplesToRamCallback load_samples_to_ram_cb,
142+ UnloadSamplesFromRamCallback unload_samples_from_ram_cb,
143+ std::vector<size_t >& group_sizes)
144+ : name_(std::move(name)),
145+ performance_sample_count_ (performance_sample_count),
146+ load_samples_to_ram_cb_(load_samples_to_ram_cb),
147+ unload_samples_from_ram_cb_(unload_samples_from_ram_cb) {
148+
149+ total_sample_count_ = 0 ;
150+
151+ for (ssize_t i = 0 ; i < group_sizes.size (); i++){
152+ group_sizes_.push_back (group_sizes[i]);
153+ total_sample_count_ += group_sizes[i];
154+ for (size_t j = 0 ; j < group_sizes[i]; j++){
155+ group_idx_.push_back (i);
156+ }
157+ }
158+ }
159+ ~GroupedQuerySampleLibraryTrampoline () override = default ;
160+
161+ const std::string& Name () override { return name_; }
162+ size_t TotalSampleCount () override { return total_sample_count_; }
163+ size_t PerformanceSampleCount () override { return performance_sample_count_; }
164+ size_t GroupSize (size_t i) override { return group_sizes_[i]; }
165+ size_t GroupOf (size_t i) override { return group_idx_[i]; }
166+ size_t NumberOfGroups () override { return group_sizes_.size (); }
167+
168+ void LoadSamplesToRam (const std::vector<QuerySampleIndex>& samples) override {
169+ (*load_samples_to_ram_cb_)(client_data_, samples.data (), samples.size ());
170+ }
171+ void UnloadSamplesFromRam (
172+ const std::vector<QuerySampleIndex>& samples) override {
173+ (*unload_samples_from_ram_cb_)(client_data_, samples.data (),
174+ samples.size ());
175+ }
176+
177+ private:
178+ std::string name_;
179+ ClientData client_data_;
180+ std::vector<size_t > group_sizes_;
181+ std::vector<size_t > group_idx_;
182+ size_t total_sample_count_;
183+ size_t performance_sample_count_;
184+ LoadSamplesToRamCallback load_samples_to_ram_cb_;
185+ UnloadSamplesFromRamCallback unload_samples_from_ram_cb_;
186+ };
187+
188+ } // namespace
189+
190+ void * ConstructGroupedQSL (ClientData client_data, const char * name, size_t name_length,
191+ size_t total_sample_count, size_t performance_sample_count,
192+ LoadSamplesToRamCallback load_samples_to_ram_cb,
193+ UnloadSamplesFromRamCallback unload_samples_from_ram_cb,
194+ std::vector<size_t >& group_sizes) {
195+ GroupedQuerySampleLibraryTrampoline* qsl = new GroupedQuerySampleLibraryTrampoline (
196+ client_data, std::string (name, name_length), total_sample_count,
197+ performance_sample_count, load_samples_to_ram_cb,
198+ unload_samples_from_ram_cb, group_sizes);
199+ return reinterpret_cast <void *>(qsl);
200+ }
201+
202+ void DestroyGroupedQSL (void * qsl) {
203+ GroupedQuerySampleLibraryTrampoline* qsl_cast =
204+ reinterpret_cast <GroupedQuerySampleLibraryTrampoline*>(qsl);
205+ delete qsl_cast;
206+ }
207+
208+
132209// mlperf::c::StartTest just forwards to mlperf::StartTest after doing the
133210// proper cast.
134211void StartTest (void * sut, void * qsl, const TestSettings& settings,
@@ -142,6 +219,18 @@ void StartTest(void* sut, void* qsl, const TestSettings& settings,
142219 audit_config_filename);
143220}
144221
222+ void StartTestWithGroupedQSL (void * sut, void * qsl, const TestSettings& settings,
223+ const std::string& audit_config_filename = " audit.config" ) {
224+ SystemUnderTestTrampoline* sut_cast =
225+ reinterpret_cast <SystemUnderTestTrampoline*>(sut);
226+ GroupedQuerySampleLibraryTrampoline* qsl_cast =
227+ reinterpret_cast <GroupedQuerySampleLibraryTrampoline*>(qsl);
228+ assert (settings.use_grouped_qsl );
229+ LogSettings default_log_settings;
230+ mlperf::StartTest (sut_cast, qsl_cast, settings, default_log_settings,
231+ audit_config_filename);
232+ }
233+
145234void QuerySamplesComplete (QuerySampleResponse* responses,
146235 size_t response_count) {
147236 mlperf::QuerySamplesComplete (responses, response_count);
0 commit comments