-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsfss.h
More file actions
2246 lines (1846 loc) · 71 KB
/
Copy pathsfss.h
File metadata and controls
2246 lines (1846 loc) · 71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#ifndef SFSS_HEADER
#define SFSS_HEADER
#include "util.h"
#include "field.h"
// the key class for randomized DPF
template <size_t N = 64>
class rdpf_key_class
{
public:
block seed;
bool party;
std::vector<block> s_CW;
std::vector<bool> t_L_CW;
std::vector<bool> t_R_CW;
rdpf_key_class()
{
s_CW.resize(N);
t_L_CW.resize(N);
t_R_CW.resize(N);
}
// void set_depth(size_t d) { s_CW.resize(d); t_L_CW.resize(d); t_R_CW.resize(d); }
// size_t get_depth() const { return D; }
block get_seed() const { return seed; }
void set_seed(const block v) { seed = v; }
bool get_party() const { return party; }
void set_party(bool v) { party = v; }
block get_s_CW(uint32_t i) const { return s_CW.at(i - 1); }
void set_s_CW(uint32_t i, block w) { s_CW[i - 1] = w; }
bool get_t_L_CW(uint32_t i) const { return t_L_CW.at(i - 1); }
void set_t_L_CW(uint32_t i, bool v) { t_L_CW[i - 1] = v; }
bool get_t_R_CW(uint32_t i) const { return t_R_CW.at(i - 1); }
void set_t_R_CW(uint32_t i, bool v) { t_R_CW[i - 1] = v; }
// 计算序列化所需字节数
size_t get_serialized_size() const
{
size_t size = 0;
// size += sizeof(size_t); // depth
size += sizeof(block); // seed
size += sizeof(bool); // party
// size += sizeof(size_t); // s_CW size
size += N * sizeof(block);
// size += sizeof(size_t); // t_L_CW size
size += N * sizeof(bool);
// size += sizeof(size_t); // t_R_CW size
size += N * sizeof(bool);
return size;
}
// 序列化到buf
size_t serialize(char *buf) const
{
// std::cout<< " depth: " << depth << ", s_CW.size(): " << s_CW.size() << std::endl;
char *p = buf;
// std::memcpy(p, &depth, sizeof(size_t));
// p += sizeof(size_t);
std::memcpy(p, &seed, sizeof(block));
p += sizeof(block);
std::memcpy(p, &party, sizeof(bool));
p += sizeof(bool);
// s_CW
// std::memcpy(p, &depth, sizeof(size_t));
// p += sizeof(size_t);
std::memcpy(p, s_CW.data(), N * sizeof(block));
p += N * sizeof(block);
assert(s_CW.size() == N);
// t_L_CW
// size_t t_L_CW_size = t_L_CW.size();
// std::memcpy(p, &t_L_CW_size, sizeof(size_t));
// p += sizeof(size_t);
for (size_t i = 0; i < N; ++i)
{
*p = t_L_CW[i] ? 1 : 0;
++p;
}
assert(t_L_CW.size() == N);
// t_R_CW
// size_t t_R_CW_size = t_R_CW.size();
// std::memcpy(p, &t_R_CW_size, sizeof(size_t));
// p += sizeof(size_t);
for (size_t i = 0; i < N; ++i)
{
*p = t_R_CW[i] ? 1 : 0;
++p;
}
assert(t_R_CW.size() == N);
return (p - buf);
}
size_t deserialize(const char *buf)
{
const char *p = buf; // const denote that we cannot modify memory pointed by p.
// std::memcpy(&depth, p, sizeof(size_t));
// p += sizeof(size_t);
std::memcpy(&seed, p, sizeof(block));
p += sizeof(block);
std::memcpy(&party, p, sizeof(bool));
p += sizeof(bool);
// s_CW
// std::memcpy(&depth, p, sizeof(size_t));
// p += sizeof(size_t);
s_CW.resize(N);
std::memcpy(s_CW.data(), p, N * sizeof(block));
p += N * sizeof(block);
// t_L_CW
// size_t t_L_CW_size;
// std::memcpy(&t_L_CW_size, p, sizeof(size_t));
// p += sizeof(size_t);
t_L_CW.resize(N);
for (size_t i = 0; i < N; ++i)
{
t_L_CW[i] = (*p != 0);
++p;
}
// t_R_CW
// size_t t_R_CW_size;
// std::memcpy(&t_R_CW_size, p, sizeof(size_t));
// p += sizeof(size_t);
t_R_CW.resize(N);
for (size_t i = 0; i < N; ++i)
{
t_R_CW[i] = (*p != 0);
++p;
}
return (p - buf);
}
};
// the key class for DPF
template <typename G, size_t N>
class dpf_key_class : public rdpf_key_class<N>
{
public:
G v_CW;
dpf_key_class() : rdpf_key_class<N>() {}
G get_v_CW() const { return v_CW; }
void set_v_CW(G w) { v_CW = w; }
// Compute the number of bytes needed for serialization
size_t get_serialized_size() const
{
size_t size = rdpf_key_class<N>::get_serialized_size();
size += get_serialized_size_helper(v_CW); // util.h helper
return size;
}
// Serialize to buffer
size_t serialize(char *buf) const
{
rdpf_key_class<N>::serialize(buf);
char *p = buf + rdpf_key_class<N>::get_serialized_size();
p += serialize_helper(v_CW, p); // util.h helper
return (p - buf);
}
// Deserialize from buffer
size_t deserialize(const char *buf)
{
rdpf_key_class<N>::deserialize(buf);
const char *p = buf + rdpf_key_class<N>::get_serialized_size();
p += deserialize_helper(v_CW, p); // util.h helper
return (p - buf);
}
};
// the key class for DCF
template <typename G, size_t N>
class dcf_key_class : public rdpf_key_class<N>
{
public:
std::vector<G> v_CW;
dcf_key_class() : rdpf_key_class<N>()
{
v_CW.resize(N); // Reserve space for N elements
}
G get_v_CW(uint32_t i) const { return v_CW.at(i - 1); }
void set_v_CW(uint32_t i, G w) { v_CW[i - 1] = w; }
// Compute the number of bytes needed for serialization
size_t get_serialized_size() const
{
size_t size = rdpf_key_class<N>::get_serialized_size();
// size += sizeof(size_t); // v_CW size
for (const auto &v : v_CW)
size += get_serialized_size_helper(v);
assert(v_CW.size() == N);
return size;
}
// Serialize to buffer
size_t serialize(char *buf) const
{
rdpf_key_class<N>::serialize(buf);
char *p = buf + rdpf_key_class<N>::get_serialized_size();
// size_t v_CW_size = v_CW.size();
// std::memcpy(p, &v_CW_size, sizeof(size_t));
// p += sizeof(size_t);
for (const auto &v : v_CW)
{
p += serialize_helper(v, p); // util.h helper
}
assert(v_CW.size() == N);
return (p - buf);
}
// Deserialize from buffer
size_t deserialize(const char *buf)
{
rdpf_key_class<N>::deserialize(buf);
const char *p = buf + rdpf_key_class<N>::get_serialized_size();
// size_t v_CW_size;
// std::memcpy(&v_CW_size, p, sizeof(size_t));
// p += sizeof(size_t);
// v_CW.resize(v_CW_size);
assert(v_CW.size() == N);
for (size_t i = 0; i < N; ++i)
{
p += deserialize_helper(v_CW[i], p); // util.h helper
}
return (p - buf);
}
};
// currently, we only DPF with beta = 1
// template <typename G,size_t N = 32>
// void DCF_gen(dcf_key_class<G> &key0, dcf_key_class<G> &key1, const index_type index)
// {
// // call the RDPFGen to setup key0 key1. Note that now we only set the rdpf_key_class memory
// RDPF_gen<N>(key0, key1, index);
// TwoKeyPRP prp(zero_block, makeBlock(0, 1));
// block s0 = key0.get_seed();
// block s1 = key1.get_seed();
// bool t0 = key0.get_party();
// bool t1 = key1.get_party();
// block s_CW;
// bool t_L_CW, t_R_CW;
// for (uint32_t i = 1; i <= N; i++)
// {
// s_CW = key.get_s_CW(i);
// t_L_CW = key.get_t_L_CW(i);
// t_R_CW = key.get_t_R_CW(i);
// block children[2];
// prp.node_expand_1to2(children, s);
// block s_L, s_R;
// bool t_L, t_R;
// s_L = (t == true ? children[0] ^ s_CW : children[0]);
// s_R = (t == true ? children[1] ^ s_CW : children[1]);
// t_L = emp::getLSB(children[0]) ^ (t & t_L_CW);
// t_R = emp::getLSB(children[1]) ^ (t & t_R_CW);
// bool bit = get_bit<N>(index, i);
// s = (bit == true ? s_R : s_L);
// t = (bit == true ? t_R : t_L);
// }
// }
template <size_t N = 64>
class sdpf_key_class : public rdpf_key_class<N>
{
public:
// the key for the last layer of the tree
// block k_prf_;
// the t value for the last layer of the tree
// bool t_last_;
// uint64_t ctr_ = 0;
// G CW_G; // CW_G is used to encode the beta value
// void set_k_prf(block v) { k_prf_ = v; }
// block get_k_prf() { return k_prf_; }
// void set_t_last(bool v) { t_last_ = v; }
// bool get_t_last() { return t_last_; }
// void set_ctr(uint64_t counter) { ctr_ = counter; }
// uint64_t get_ctr() { return ctr_; }
size_t get_serialized_size() const
{
size_t size = rdpf_key_class<N>::get_serialized_size();
// size += sizeof(block); // k_prf_
// size += sizeof(bool); // t_last_
// size += sizeof(uint64_t); // ctr_
return size;
}
size_t serialize(char *buf) const
{
char *p = buf + rdpf_key_class<N>::serialize(buf);
// char *p = buf + rdpf_key_class::get_serialized_size();
// std::memcpy(p, &k_prf_, sizeof(block));
// p += sizeof(block);
// std::memcpy(p, &t_last_, sizeof(bool));
// p += sizeof(bool);
// std::memcpy(p, &ctr_, sizeof(uint64_t));
// p += sizeof(uint64_t);
return (p - buf);
}
size_t deserialize(const char *buf)
{
const char *p = buf + rdpf_key_class<N>::deserialize(buf);
// const char *p = buf + rdpf_key_class::get_serialized_size();
// std::memcpy(&k_prf_, p, sizeof(block));
// p += sizeof(block);
// std::memcpy(&t_last_, p, sizeof(bool));
// p += sizeof(bool);
// std::memcpy(&ctr_, p, sizeof(uint64_t));
// p += sizeof(uint64_t);
return (p - buf);
}
};
template <typename G = uint64_t, size_t N = 64>
class sdcf_key_class : public dcf_key_class<G, N>
{
public:
uint64_t ctr_ = 0;
sdcf_key_class() : dcf_key_class<G, N>() {}
void set_ctr(uint64_t counter) { ctr_ = counter; }
uint64_t get_ctr() const { return ctr_; }
// Compute the number of bytes needed for serialization
size_t get_serialized_size() const
{
return dcf_key_class<G, N>::get_serialized_size() + sizeof(uint64_t);
}
// Serialize to buffer
size_t serialize(char *buf) const
{
// Serialize base class first
size_t offset = dcf_key_class<G, N>::serialize(buf);
std::memcpy(buf + offset, &ctr_, sizeof(uint64_t));
return offset + sizeof(uint64_t);
}
// Deserialize from buffer
size_t deserialize(const char *buf)
{
size_t offset = dcf_key_class<G, N>::deserialize(buf);
std::memcpy(&ctr_, buf + offset, sizeof(uint64_t));
return offset + sizeof(uint64_t);
}
};
template <typename K, typename G, size_t D = 512>
class khPRF // prf(K, index_type) -> G
{
public:
khPRF() : key(0) {}
khPRF(RingVec<K, D> k) : key(k) {}
khPRF(const khPRF<K, G, D> &other) : key(other.key) {}
khPRF &operator=(const khPRF<K, G, D> &other)
{
if (this != &other)
{
key = other.key;
}
return *this;
}
void set_random_key()
{
// use emp prg to generate D blocks and use them to set the key vector
emp::PRG prg;
emp::block random_block[D];
prg.random_block(random_block, D);
for (size_t i = 0; i < D; i++)
{
key.set(i, K(random_block[i]));
}
}
void set_key(RingVec<K, D> key) { this->key = key; }
RingVec<K, D> get_key() const { return key; }
// the key-homomorphic PRF function
G eval(const index_type input) const
{
//TIMEIT_START(khHash);
RingVec<K, D> hash = khHash(input);
//TIMEIT_END(khHash);
//TIMEIT_START(inner_product);
K tmp = key.inner_product(hash); // the multiplication is over K
//TIMEIT_END(inner_product);
//TIMEIT_START(K2G);
G tmpG = K2G(tmp); // convert K to G
//TIMEIT_END(K2G);
return tmpG;
}
G K2G(const K &k) const
{
// convert K -> G
// FIXME: currecntly only support K and G are MyInteger with MOD = 2^k
size_t K_BITS = K::get_BITS();
size_t G_BITS = G::get_BITS();
assert(K_BITS > G_BITS);
return G(k.get_value() >> (K_BITS - G_BITS));
}
private:
RingVec<K, D> key; // the secret key for LWR-based key-homomorphic PRF
RingVec<K, D> khHash(const index_type &input) const
{ // hash: D -> RingVec<K, D>
// hash the input to a RingVec<K, D> using emp::CRH
RingVec<K, D> result;
emp::CRH hash;
block in[D], out[D];
for (size_t i = 0; i < D; i++)
{
in[i] = emp::makeBlock(input, i); // use the index as the first part of the block
}
// get D hash out blocks
hash.Hn(out, in, D); // hash the input blocks
// convert the hash output to RingVec<K, D>
for (size_t i = 0; i < D; i++)
{
result.set(i, K(out[i])); // convert emp::block to K
}
// std::cout << "Hash result: " << result << std::endl;
return result;
}
};
// the streaming ciphertext class
template <typename G = uint64_t>
struct stream_ctx_type
{
uint64_t ctr;
G ctx;
stream_ctx_type() : ctr(0), ctx(0) {}
stream_ctx_type(const uint64_t c, const G g) : ctr(c), ctx(g) {}
void set_ctr(uint64_t c) { ctr = c; }
uint64_t get_ctr() const { return ctr; }
void set_ctx(G g) { ctx = g; }
G get_ctx() const { return ctx; }
// Compute the number of bytes needed for serialization
size_t get_serialized_size() const
{
return sizeof(uint64_t) + get_serialized_size_helper(ctx); // util.h helper
}
// Serialize to buffer
size_t serialize(char *buf) const
{
char *p = buf;
std::memcpy(p, &ctr, sizeof(uint64_t));
p += sizeof(uint64_t);
p += serialize_helper(ctx, p); // util.h helper
return p - buf;
}
// Deserialize from buffer
size_t deserialize(const char *buf)
{
const char *p = buf;
std::memcpy(&ctr, p, sizeof(uint64_t));
p += sizeof(uint64_t);
p += deserialize_helper(ctx, p); // util.h helper
return p - buf;
}
};
struct rdpf_out_type
{
bool t;
block s;
friend std::ostream &operator<<(std::ostream &os, const rdpf_out_type &out)
{
os << "(" << out.t << ", " << out.s << ")";
return os;
}
};
template <typename G = uint64_t>
struct dpf_out_type
{
// bool t;
// block s;
G v; // DCF output value at the current level
};
/// @brief The output type for the DCF
/// @tparam G The group type for the correction words.
template <typename G = uint64_t>
struct dcf_out_type : public rdpf_out_type
{
// bool t;
// block s;
G v; // DCF output value at the current level
};
struct sdpf_stream_key
{
block key_prf;
bool t_last = false;
uint64_t ctr = 0;
void set_key_prf(block k) { key_prf = k; }
block get_key_prf() { return key_prf; }
void set_t_last(bool t) { t_last = t; }
bool get_t_last() { return t_last; }
uint64_t get_ctr() { return ctr; }
void set_ctr(uint64_t c) { ctr = c; }
};
template <typename K>
struct sdcf_stream_key
{
K key; // the key for the PRF
uint64_t ctr = 0;
sdcf_stream_key() : key(0), ctr(0) {}
sdcf_stream_key(K k, uint64_t c) : key(k), ctr(c) {}
void set_key(K k) { key = k; }
K get_key() { return key; }
void set_ctr(uint64_t c) { ctr = c; }
uint64_t get_ctr() { return ctr; }
};
struct rdpf_level_state
{
block s0, s1;
bool t0, t1;
block s_CW;
bool t_L_CW, t_R_CW;
};
/// @brief we use the idea from incremental DPF to encode the correction words along the accepting path. However, we do not need to encode the root node.
/// @tparam G is the group type for the correction words.
template <typename G = uint64_t>
struct dcf_level_state : public rdpf_level_state
{
// block s0, s1;
// bool t0, t1;
// block s_CW;
// bool t_L_CW, t_R_CW;
G v_CW;
};
template <size_t N = 64>
struct subtree
{
/*
list contains {'0', '1', '*', '-'}, where '*' denotes the wildcard, and '-' denotes an empty tree if exists.
Suppose we have a depth-2 tree, the leaf nodes are represented as {'00', '01', '10', '11'}.
Then a subtree '0*' denotes {'00', '01'}, while '0-' denotes an invalid tree.
*/
std::vector<char> list;
bool empty;
subtree()
{
list.resize(N, '*'); // now list represents the full tree
empty = false;
}
subtree(const std::vector<char> &bits) : list(bits)
{
assert(bits.size() == N);
for (auto b : bits)
{
if (b == '-')
empty = true;
}
}
subtree(const char *buf)
{
list.resize(N);
for (size_t i = 0; i < N; ++i)
{
if (buf[i] == '0' || buf[i] == '1' || buf[i] == '*' || buf[i] == '-')
{
list[i] = buf[i];
if (buf[i] == '-')
empty = true; // if we have a '-' in the subtree, it means this subtree is empty
}
else
{
throw std::invalid_argument("Invalid character in subtree representation");
}
}
}
subtree(const subtree &other) : list(other.list), empty(other.empty) {}
// this performs set intersection over two subtrees
subtree operator&(const subtree &other) const
{
subtree result;
for (size_t i = 0; i < N; ++i)
{
if (list[i] == other.list[i])
result.list[i] = list[i];
else if (list[i] == '*' || other.list[i] == '*')
result.list[i] = list[i] + other.list[i] - '*';
else
{
// result.empty = true;
// result.list[i] = '-'; // if we have a '-' in the subtree, it means this subtree is empty
result.list.assign(N, '-');
result.empty = true;
// std::cout<< "return: " << result <<std::endl;
return result;
}
}
return result;
}
subtree &operator=(const subtree &other)
{
if (this != &other)
{
list = other.list;
empty = other.empty;
}
return *this;
}
size_t get_wildcard_size() const
{
size_t num = 0;
for (const auto &c : list)
{
if (c == '*')
num++;
}
return num;
}
friend std::ostream &operator<<(std::ostream &os, const subtree &x)
{
for (const auto &bit : x.list)
{
os << bit;
}
return os;
}
};
// forest contains a list of subtrees, organized in a vector. You can regard a forest as a union of subtrees.
template <size_t N = 64>
struct forest
{
std::vector<subtree<N>> subtree_vec;
forest() = default;
forest(const subtree<N> &single_tree)
{
subtree_vec.push_back(single_tree);
}
forest(const forest &other)
{
for (const auto &tree : other.subtree_vec)
{
subtree_vec.push_back(tree);
}
}
size_t size() const
{
return subtree_vec.size();
}
void append(const subtree<N> &tree)
{
subtree_vec.push_back(tree);
}
forest &operator=(const forest &other)
{
if (this != &other)
{
subtree_vec.clear();
for (const auto &tree : other.subtree_vec)
{
subtree_vec.push_back(tree);
}
}
return *this;
}
friend std::ostream &operator<<(std::ostream &os, const forest &x)
{
os << "{";
for (size_t i = 0; i < x.subtree_vec.size(); ++i)
{
os << x.subtree_vec[i];
if (i < x.subtree_vec.size() - 1)
os << ", ";
}
os << "}";
return os;
}
};
/// @brief Get the correction words and the next-level states for the DPF.
/// @param prp The TwoKeyPRP instance.
/// @param current_level The current level state.
/// @param alpha The current bit value.
/// @return The next level state + correction words for the current level.
rdpf_level_state next_level_state(TwoKeyPRP &prp, const rdpf_level_state ¤t_level, const bool alpha)
{
block children0[2], children1[2];
prp.node_expand_1to2(children0, current_level.s0);
prp.node_expand_1to2(children1, current_level.s1);
bool t_L_0 = emp::getLSB(children0[0]);
bool t_R_0 = emp::getLSB(children0[1]);
bool t_L_1 = emp::getLSB(children1[0]);
bool t_R_1 = emp::getLSB(children1[1]);
// get the current bit
block s_CW = children0[1 - alpha] ^ children1[1 - alpha];
bool t_L_CW = t_L_0 ^ t_L_1 ^ alpha ^ true;
bool t_R_CW = t_R_0 ^ t_R_1 ^ alpha;
// update the KEEP branch
block s_keep_0 = children0[alpha];
block s_keep_1 = children1[alpha];
bool t_keep_0 = (alpha == false ? t_L_0 : t_R_0);
bool t_keep_1 = (alpha == false ? t_L_1 : t_R_1);
bool t_keep_CW = (alpha == false ? t_L_CW : t_R_CW);
// prepare for next iteration
block s0 = s_keep_0 ^ (current_level.t0 == true ? s_CW : zero_block);
block s1 = s_keep_1 ^ (current_level.t1 == true ? s_CW : zero_block);
// t_b^i = t_b^{i-1} xor t_b^i * t_keep_CW
bool t0 = t_keep_0 ^ (current_level.t0 & t_keep_CW);
bool t1 = t_keep_1 ^ (current_level.t1 & t_keep_CW);
return {
s0,
s1,
t0,
t1,
s_CW,
t_L_CW,
t_R_CW};
}
/// @brief Get the correction words and the next-level states for the DCF. note that the DCF is a special case of DPF with beta = 1.
/// @tparam G is the group type for the correction words.
/// @param prp The TwoKeyPRP instance.
/// @param current_state The current level state.
/// @param alpha The current bit value.
/// @return The next level state + correction words for the current level.
template <typename G = uint64_t>
dcf_level_state<G> next_level_state(TwoKeyPRP &prp, const dcf_level_state<G> ¤t_state, const bool alpha, const G &beta = (G)1)
{
// first call the rdpf_level_state to get the rdpf correction word and next-level rdpf state
rdpf_level_state rdpf_next_state = next_level_state(prp, (rdpf_level_state)current_state, alpha);
// then compute the dcf correction word for the *next level*; this means we DO NOT compute correction word for the root node.
block s0_extension[2], s1_extension[2];
prp.node_expand_1to2(s0_extension, rdpf_next_state.s0);
prp.node_expand_1to2(s1_extension, rdpf_next_state.s1);
// std::cout << "s0_extension[0] = " << s0_extension[0] << ", s0_extension[1] = " << s0_extension[1] << std::endl;
// std::cout << "s1_extension[0] = " << s1_extension[0] << ", s1_extension[1] = " << s1_extension[1] << std::endl;
// s0_extension[1] and s1_extension[1] are used to compute correction word;
G v0 = block_to_G<G>(s0_extension[1]);
G v1 = block_to_G<G>(s1_extension[1]);
G CW_v = ((G)1 - (G)2 * (G)rdpf_next_state.t1) * (beta - v0 + v1);
// std::cout<< "G(-1) = " << G(-1) << ", G(1) = " << G(1) << ", G(-1) + G(1) = " << G(-1) + G(1) << std::endl;
// std::cout << "CW_v is " << CW_v << std::endl;
// finally set the next-level state
return {
s0_extension[0],
s1_extension[0],
rdpf_next_state.t0,
rdpf_next_state.t1,
rdpf_next_state.s_CW,
rdpf_next_state.t_L_CW,
rdpf_next_state.t_R_CW,
CW_v // the correction word for the next level
};
}
rdpf_out_type next_level(TwoKeyPRP &prp, const rdpf_out_type ¤t, const bool bit, const block &s_CW, const bool t_L_CW, const bool t_R_CW)
{
block children[2];
prp.node_expand_1to2(children, current.s);
// block s_L, s_R;
// bool t_L, t_R;
// s_L = (current.t == true ? children[0] ^ s_CW : children[0]);
// s_R = (current.t == true ? children[1] ^ s_CW : children[1]);
// t_L = emp::getLSB(children[0]) ^ (current.t & t_L_CW);
// t_R = emp::getLSB(children[1]) ^ (current.t & t_R_CW);
// block s = (bit == true ? s_R : s_L);
// bool t = (bit == true ? t_R : t_L);
block s = children[bit] ^ (current.t * s_CW); // children[bit] XOR (bit * s_CW)
bool t = emp::getLSB(children[bit]) ^ (current.t & (bit == true ? t_R_CW : t_L_CW)); // t[bit] XOR (current.t AND t[bit_CW])
return {t, s};
}
std::vector<rdpf_out_type> next_level_full(TwoKeyPRP &prp, const rdpf_out_type ¤t, const block &s_CW, const bool t_L_CW, const bool t_R_CW)
{
block children[2];
prp.node_expand_1to2(children, current.s);
std::vector<rdpf_out_type> children_out;
block s_L, s_R;
bool t_L, t_R;
s_L = (current.t == true ? children[0] ^ s_CW : children[0]);
s_R = (current.t == true ? children[1] ^ s_CW : children[1]);
t_L = emp::getLSB(children[0]) ^ (current.t & t_L_CW);
t_R = emp::getLSB(children[1]) ^ (current.t & t_R_CW);
children_out.emplace_back(rdpf_out_type{t_L, s_L});
children_out.emplace_back(rdpf_out_type{t_R, s_R});
return children_out;
}
template <typename G = uint64_t>
dcf_out_type<G> next_level(const bool party, TwoKeyPRP &prp, const dcf_out_type<G> ¤t, const bool bit, const block &s_CW, const bool t_L_CW, const bool t_R_CW, const G &v_CW)
{
// first call the rdpf_out_type to get the next-level stat
//TIMEIT_START(dcfNextLeval);
rdpf_out_type next_state = next_level(prp, rdpf_out_type{current.t, current.s}, bit, s_CW, t_L_CW, t_R_CW);
//TIMEIT_END(dcfNextLeval);
// then compute the correction word for the next level
block s_extension[2];
// TIMEIT_START(nodeExpand);
prp.node_expand_1to2(s_extension, next_state.s);
// TIMEIT_END(nodeExpand);
// std::cout <<"Party " << party << ", s_extension[0] = " << s_extension[0] << ", s_extension[1] = " << s_extension[1] << std::endl;
//TIMEIT_START(G_v);
G v = current.v;
//TIMEIT_END(G_v);
if (bit == false)
{
rdpf_out_type next_state_right = next_level(prp, rdpf_out_type{current.t, current.s}, 1 - bit, s_CW, t_L_CW, t_R_CW);
block s_extension_right[2];
prp.node_expand_1to2(s_extension_right, next_state_right.s);
// TIMEIT_START(vTotal);
// v = v + ((G)1 - (G)2 * (G)party) * (block_to_G<G>(s_extension_right[1]) + (next_state_right.t == false ? G(0) : v_CW));
// v = v + ((G)1 - (G)2 * (G)party) * (block_to_G<G>(s_extension_right[1]) + (next_state_right.t == false ? G(0) : v_CW));
// TIMEIT_START(blockToG);
G v0 = block_to_G<G>(s_extension_right[1]);
// TIMEIT_END(blockToG);
// TIMEIT_START(v1);
G v1 = (next_state_right.t == false ? G(0) : v_CW);
// TIMEIT_END(v1);
// TIMEIT_START(v);
if (party == 0)
{
v += v0 + v1;
}
else
{
v -= v0 + v1;
}
// TIMEIT_END(v);
// TIMEIT_END(vTotal);
}
// G v = current.v + ((G)1 - (G)2 * (G)next_state.t) * (current.v - v0 + v1);
return {next_state.t, s_extension[0], v};
}
template <size_t N = 32>
rdpf_level_state RDPF_gen(rdpf_key_class<N> &key0, rdpf_key_class<N> &key1, const index_type index)
{
PRG prg;
block s0, s1;
prg.random_block(&s0, 1);
prg.random_block(&s1, 1);
bool t0 = 0, t1 = 1;
// set the initial state
key0.set_seed(s0);
key0.set_party(t0);
key1.set_seed(s1);
key1.set_party(t1);
TwoKeyPRP prp(zero_block, makeBlock(0, 1));
rdpf_level_state state = {s0, s1, t0, t1, /*not used*/ zero_block, /*not used*/ false, /*not used*/ false};
for (uint32_t i = 1; i <= N; i++)
{
/*
block children0[2], children1[2];
prp.node_expand_1to2(children0, s0);
prp.node_expand_1to2(children1, s1);
bool t_L_0 = emp::getLSB(children0[0]);
bool t_R_0 = emp::getLSB(children0[1]);
bool t_L_1 = emp::getLSB(children1[0]);
bool t_R_1 = emp::getLSB(children1[1]);
// get the current bit
bool alpha = get_bit<N>(index, i);
block s_CW = children0[1 - alpha] ^ children1[1 - alpha];
bool t_L_CW = t_L_0 ^ t_L_1 ^ alpha ^ true;
bool t_R_CW = t_R_0 ^ t_R_1 ^ alpha;
// set CW = s_CW || t_L_CW || t_R_CW
key0.set_s_CW(i, s_CW);
key1.set_s_CW(i, s_CW);
key0.set_t_L_CW(i, t_L_CW);
key1.set_t_L_CW(i, t_L_CW);
key0.set_t_R_CW(i, t_R_CW);
key1.set_t_R_CW(i, t_R_CW);
// update the KEEP branch
block s_keep_0 = children0[alpha];
block s_keep_1 = children1[alpha];
bool t_keep_0 = (alpha == false ? t_L_0 : t_R_0);
bool t_keep_1 = (alpha == false ? t_L_1 : t_R_1);
bool t_keep_CW = (alpha == false ? t_L_CW : t_R_CW);
// prepare for next iteration
s0 = s_keep_0 ^ (t0 == true ? s_CW : zero_block);
s1 = s_keep_1 ^ (t1 == true ? s_CW : zero_block);
// t_b^i = t_b^{i-1} xor t_b^i * t_keep_CW
t0 = t_keep_0 ^ (t0 & t_keep_CW);
t1 = t_keep_1 ^ (t1 & t_keep_CW);
*/
state = next_level_state(prp, state, get_bit<N>(index, i));