00001
00009 #include "party.h"
00010
00011
00021 SEXP R_Ensemble(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00022
00023 SEXP nweights, tree, where, ans;
00024 double *dnweights, *dweights, sw = 0.0, *prob, fraction;
00025 int nobs, i, b, B , nodenum = 1, *iweights, *iweightstmp,
00026 *iwhere, replace;
00027
00028 B = get_ntree(controls);
00029 nobs = get_nobs(learnsample);
00030
00031 PROTECT(ans = allocVector(VECSXP, B));
00032
00033 iweights = Calloc(nobs, int);
00034 iweightstmp = Calloc(nobs, int);
00035 prob = Calloc(nobs, double);
00036 dweights = REAL(weights);
00037
00038 for (i = 0; i < nobs; i++)
00039 sw += dweights[i];
00040 for (i = 0; i < nobs; i++)
00041 prob[i] = dweights[i]/sw;
00042
00043 replace = get_replace(controls);
00044 fraction = get_fraction(controls) * nobs;
00045
00046 if (!replace) {
00047 if (fraction < 10)
00048 error("fraction of %f is too small", fraction);
00049 }
00050
00051
00052
00053 GetRNGstate();
00054
00055 for (b = 0; b < B; b++) {
00056 SET_VECTOR_ELT(ans, b, tree = allocVector(VECSXP, NODE_LENGTH + 1));
00057 SET_VECTOR_ELT(tree, NODE_LENGTH, where = allocVector(INTSXP, nobs));
00058 iwhere = INTEGER(where);
00059 for (i = 0; i < nobs; i++) iwhere[i] = 0;
00060
00061 C_init_node(tree, nobs, get_ninputs(learnsample),
00062 get_maxsurrogate(get_splitctrl(controls)),
00063 ncol(get_jointtransf(GET_SLOT(learnsample,
00064 PL2_responsesSym))));
00065
00066
00067 if (replace) {
00068
00069 rmultinom((int) sw, prob, nobs, iweights);
00070 } else {
00071
00072 C_SampleNoReplace(iweightstmp, nobs, nobs, iweights);
00073 for (i = 0; i < nobs; i++) {
00074 if (iweights[i] < fraction) {
00075 iweights[i] = 1;
00076 } else {
00077 iweights[i] = 0;
00078 }
00079 }
00080 }
00081
00082 nweights = S3get_nodeweights(tree);
00083 dnweights = REAL(nweights);
00084 for (i = 0; i < nobs; i++) dnweights[i] = (double) iweights[i];
00085
00086 C_TreeGrow(tree, learnsample, fitmem, controls, iwhere, &nodenum, 1);
00087 nodenum = 1;
00088 }
00089
00090 PutRNGstate();
00091
00092 Free(prob); Free(iweights); Free(iweightstmp);
00093 UNPROTECT(1);
00094 return(ans);
00095 }