-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompute_bart.h
More file actions
52 lines (35 loc) · 1.23 KB
/
compute_bart.h
File metadata and controls
52 lines (35 loc) · 1.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
//
// Created by busygin on 12/24/15.
//
#ifndef BART_COMPUTE_BART_H
#define BART_COMPUTE_BART_H
#include <cstddef>
#include "info.h"
#include "tree.h"
class compute_bart {
public:
compute_bart() {}
void set_insample_matrix(size_t n, size_t p, double *x) {
insample_data.n = n; insample_data.p = p; insample_data.x = x;
}
void set_insample_target(size_t n, double *y) { insample_data.y = y; }
void set_outsample_matrix(size_t n, size_t p, double *x) {
outsample_data.n = n; outsample_data.p = p; outsample_data.x = x;
}
void set_outsample_target(size_t n, double *y) { outsample_data.y = y; }
void set_mcmc_params(double pb_=0.5, double alpha_=0.95, double beta_=2.0, double tau_=1.0, double sigma_=1.0) {
mcmc_params.init(pb_, alpha_, beta_, tau_, sigma_);
}
void set_run_params(bool regression_=true, size_t nd_=1000, double lambda_=1.0, size_t burn_=100, size_t m_=200, size_t nc_=100, int nu_=3, double kfac_=2.0) {
run_params.init(nd_, lambda_, burn_, m_, nc_, nu_, kfac_, regression_);
}
void fit();
void predict();
dinfo insample_data;
dinfo outsample_data;
pinfo mcmc_params;
rinfo run_params;
xinfo xi;
std::vector<tree> t;
};
#endif //BART_COMPUTE_BART_H