00001
00009 #include "party.h"
00010
00011
00021 void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
00022
00023 SEXP weights, leftnode, rightnode, split;
00024 SEXP responses, inputs, whichNA;
00025 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00026 double sleft = 0.0, sright = 0.0;
00027 int *ix, *levelset, *iwhichNA;
00028 int nobs, i, nna;
00029
00030 weights = S3get_nodeweights(node);
00031 dweights = REAL(weights);
00032 responses = GET_SLOT(learnsample, PL2_responsesSym);
00033 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00034 nobs = get_nobs(learnsample);
00035
00036
00037 SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
00038 C_init_node(leftnode, nobs,
00039 get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00040 ncol(get_jointtransf(GET_SLOT(learnsample, PL2_responsesSym))));
00041 leftweights = REAL(S3get_nodeweights(leftnode));
00042
00043
00044 SET_VECTOR_ELT(node, S3_RIGHT,
00045 rightnode = allocVector(VECSXP, NODE_LENGTH));
00046 C_init_node(rightnode, nobs,
00047 get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00048 ncol(get_jointtransf(GET_SLOT(learnsample, PL2_responsesSym))));
00049 rightweights = REAL(S3get_nodeweights(rightnode));
00050
00051
00052 split = S3get_primarysplit(node);
00053 if (has_missings(inputs, S3get_variableID(split))) {
00054 whichNA = get_missings(inputs, S3get_variableID(split));
00055 iwhichNA = INTEGER(whichNA);
00056 nna = LENGTH(whichNA);
00057 } else {
00058 nna = 0;
00059 whichNA = R_NilValue;
00060 iwhichNA = NULL;
00061 }
00062
00063 if (S3is_ordered(split)) {
00064 cutpoint = REAL(S3get_splitpoint(split))[0];
00065 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00066 for (i = 0; i < nobs; i++) {
00067 if (nna > 0) {
00068 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00069 }
00070 if (dx[i] <= cutpoint)
00071 leftweights[i] = dweights[i];
00072 else
00073 leftweights[i] = 0.0;
00074 rightweights[i] = dweights[i] - leftweights[i];
00075 sleft += leftweights[i];
00076 sright += rightweights[i];
00077 }
00078 } else {
00079 levelset = INTEGER(S3get_splitpoint(split));
00080 ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
00081
00082 for (i = 0; i < nobs; i++) {
00083 if (nna > 0) {
00084 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00085 }
00086 if (levelset[ix[i] - 1])
00087 leftweights[i] = dweights[i];
00088 else
00089 leftweights[i] = 0.0;
00090 rightweights[i] = dweights[i] - leftweights[i];
00091 sleft += leftweights[i];
00092 sright += rightweights[i];
00093 }
00094 }
00095
00096
00097 if (nna > 0) {
00098 for (i = 0; i < nna; i++) {
00099 if (sleft > sright) {
00100 leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00101 rightweights[iwhichNA[i] - 1] = 0.0;
00102 } else {
00103 rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00104 leftweights[iwhichNA[i] - 1] = 0.0;
00105 }
00106 }
00107 }
00108 }
00109
00110
00120 SEXP C_get_node(SEXP subtree, SEXP newinputs,
00121 double mincriterion, int numobs) {
00122
00123 SEXP split, whichNA, weights, ssplit, surrsplit;
00124 double cutpoint, x, *dweights, swleft, swright;
00125 int level, *levelset, i, ns;
00126
00127 if (S3get_nodeterminal(subtree) ||
00128 REAL(S3get_maxcriterion(subtree))[0] < mincriterion)
00129 return(subtree);
00130
00131 split = S3get_primarysplit(subtree);
00132
00133
00134
00135 if (has_missings(newinputs, S3get_variableID(split))) {
00136 whichNA = get_missings(newinputs, S3get_variableID(split));
00137
00138 if (C_i_in_set(numobs, whichNA)) {
00139
00140 surrsplit = S3get_surrogatesplits(subtree);
00141 ns = 0;
00142 i = numobs;
00143
00144
00145 while(TRUE) {
00146
00147 if (ns >= LENGTH(surrsplit)) break;
00148
00149 ssplit = VECTOR_ELT(surrsplit, ns);
00150 if (has_missings(newinputs, S3get_variableID(ssplit))) {
00151 if (INTEGER(get_missings(newinputs,
00152 S3get_variableID(ssplit)))[i]) {
00153 ns++;
00154 continue;
00155 }
00156 }
00157
00158 cutpoint = REAL(S3get_splitpoint(ssplit))[0];
00159 x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
00160
00161 if (S3get_toleft(ssplit)) {
00162 if (x <= cutpoint) {
00163 return(C_get_node(S3get_leftnode(subtree),
00164 newinputs, mincriterion, numobs));
00165 } else {
00166 return(C_get_node(S3get_rightnode(subtree),
00167 newinputs, mincriterion, numobs));
00168 }
00169 } else {
00170 if (x <= cutpoint) {
00171 return(C_get_node(S3get_rightnode(subtree),
00172 newinputs, mincriterion, numobs));
00173 } else {
00174 return(C_get_node(S3get_leftnode(subtree),
00175 newinputs, mincriterion, numobs));
00176 }
00177 }
00178 break;
00179 }
00180
00181
00182 weights = S3get_nodeweights(S3get_leftnode(subtree));
00183 dweights = REAL(weights);
00184 swleft = 0.0;
00185 for (i = 0; i < LENGTH(weights); i++)
00186 swleft += dweights[i];
00187 weights = S3get_nodeweights(S3get_rightnode(subtree));
00188 dweights = REAL(weights);
00189 swright = 0.0;
00190 for (i = 0; i < LENGTH(weights); i++)
00191 swright += dweights[i];
00192 if (swleft > swright) {
00193 return(C_get_node(S3get_leftnode(subtree),
00194 newinputs, mincriterion, numobs));
00195 } else {
00196 return(C_get_node(S3get_rightnode(subtree),
00197 newinputs, mincriterion, numobs));
00198 }
00199 }
00200 }
00201
00202 if (S3is_ordered(split)) {
00203 cutpoint = REAL(S3get_splitpoint(split))[0];
00204 x = REAL(get_variable(newinputs,
00205 S3get_variableID(split)))[numobs];
00206 if (x <= cutpoint) {
00207 return(C_get_node(S3get_leftnode(subtree),
00208 newinputs, mincriterion, numobs));
00209 } else {
00210 return(C_get_node(S3get_rightnode(subtree),
00211 newinputs, mincriterion, numobs));
00212 }
00213 } else {
00214 levelset = INTEGER(S3get_splitpoint(split));
00215 level = INTEGER(get_variable(newinputs,
00216 S3get_variableID(split)))[numobs];
00217
00218 if (levelset[level - 1]) {
00219 return(C_get_node(S3get_leftnode(subtree), newinputs,
00220 mincriterion, numobs));
00221 } else {
00222 return(C_get_node(S3get_rightnode(subtree), newinputs,
00223 mincriterion, numobs));
00224 }
00225 }
00226 }
00227
00228
00237 SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion,
00238 SEXP numobs) {
00239 return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
00240 INTEGER(numobs)[0] - 1));
00241 }
00242
00243
00250 SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
00251
00252 if (nodenum == S3get_nodeID(subtree)) return(subtree);
00253
00254 if (S3get_nodeterminal(subtree))
00255 error("no node with number %d\n", nodenum);
00256
00257 if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
00258 return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
00259 } else {
00260 return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
00261 }
00262 }
00263
00264
00271 SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
00272 return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
00273 }
00274
00275
00284 SEXP C_get_prediction(SEXP subtree, SEXP newinputs,
00285 double mincriterion, int numobs) {
00286 return(S3get_prediction(C_get_node(subtree, newinputs,
00287 mincriterion, numobs)));
00288 }
00289
00290
00299 SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs,
00300 double mincriterion, int numobs) {
00301 return(S3get_nodeweights(C_get_node(subtree, newinputs,
00302 mincriterion, numobs)));
00303 }
00304
00305
00314 int C_get_nodeID(SEXP subtree, SEXP newinputs,
00315 double mincriterion, int numobs) {
00316 return(S3get_nodeID(C_get_node(subtree, newinputs,
00317 mincriterion, numobs)));
00318 }
00319
00320
00328 SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00329
00330 SEXP ans;
00331 int nobs, i, *dans;
00332
00333 nobs = get_nobs(newinputs);
00334 PROTECT(ans = allocVector(INTSXP, nobs));
00335 dans = INTEGER(ans);
00336 for (i = 0; i < nobs; i++)
00337 dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00338 UNPROTECT(1);
00339 return(ans);
00340 }
00341
00342
00351 void C_predict(SEXP tree, SEXP newinputs, double mincriterion, SEXP ans) {
00352
00353 int nobs, i;
00354
00355 nobs = get_nobs(newinputs);
00356 if (LENGTH(ans) != nobs)
00357 error("ans is not of length %d\n", nobs);
00358
00359 for (i = 0; i < nobs; i++)
00360 SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs,
00361 mincriterion, i));
00362 }
00363
00364
00372 SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00373
00374 SEXP ans;
00375 int nobs;
00376
00377 nobs = get_nobs(newinputs);
00378 PROTECT(ans = allocVector(VECSXP, nobs));
00379 C_predict(tree, newinputs, REAL(mincriterion)[0], ans);
00380 UNPROTECT(1);
00381 return(ans);
00382 }
00383
00384
00392 void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
00393
00394 int nobs, i, *iwhere;
00395
00396 nobs = LENGTH(where);
00397 iwhere = INTEGER(where);
00398 if (LENGTH(ans) != nobs)
00399 error("ans is not of length %d\n", nobs);
00400
00401 for (i = 0; i < nobs; i++)
00402 SET_VECTOR_ELT(ans, i, S3get_prediction(
00403 C_get_nodebynum(tree, iwhere[i])));
00404 }
00405
00406
00413 SEXP R_getpredictions(SEXP tree, SEXP where) {
00414
00415 SEXP ans;
00416 int nobs;
00417
00418 nobs = LENGTH(where);
00419 PROTECT(ans = allocVector(VECSXP, nobs));
00420 C_getpredictions(tree, where, ans);
00421 UNPROTECT(1);
00422 return(ans);
00423 }
00424
00425
00433 void C_getweights(SEXP tree, SEXP where, SEXP ans) {
00434
00435 int nobs, i, *iwhere;
00436
00437 nobs = LENGTH(where);
00438 iwhere = INTEGER(where);
00439 if (LENGTH(ans) != nobs)
00440 error("ans is not of length %d\n", nobs);
00441
00442 for (i = 0; i < nobs; i++)
00443 SET_VECTOR_ELT(ans, i, S3get_nodeweights(
00444 C_get_nodebynum(tree, iwhere[i])));
00445 }
00446
00447
00454 SEXP R_getweights(SEXP tree, SEXP where) {
00455
00456 SEXP ans;
00457 int nobs;
00458
00459 nobs = LENGTH(where);
00460 PROTECT(ans = allocVector(VECSXP, nobs));
00461 C_getweights(tree, where, ans);
00462 UNPROTECT(1);
00463 return(ans);
00464 }
00465
00466
00475 void C_weights(SEXP tree, SEXP newinputs,
00476 double mincriterion, SEXP ans) {
00477
00478 int nobs, i;
00479
00480 nobs = get_nobs(newinputs);
00481 if (LENGTH(ans) != nobs)
00482 error("ans is not of length %d\n", nobs);
00483
00484 for (i = 0; i < nobs; i++)
00485 SET_VECTOR_ELT(ans, i, C_get_nodeweights(tree, newinputs,
00486 mincriterion, i));
00487 }
00488
00489
00497 SEXP R_weights(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00498
00499 SEXP ans;
00500 int nobs;
00501
00502 nobs = get_nobs(newinputs);
00503 PROTECT(ans = allocVector(VECSXP, nobs));
00504 C_weights(tree, newinputs, REAL(mincriterion)[0], ans);
00505 UNPROTECT(1);
00506 return(ans);
00507 }
00508
00509
00518 SEXP R_predictRF(SEXP forest, SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00519
00520 SEXP ans, tmp, tree;
00521 int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0;
00522
00523 if (LOGICAL(oobpred)[0]) oob = 1;
00524
00525 nobs = get_nobs(newinputs);
00526 ntrees = LENGTH(forest);
00527 q = LENGTH(S3get_prediction(
00528 C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00529
00530 if (oob) {
00531 if (LENGTH(S3get_nodeweights(
00532 C_get_nodebynum(VECTOR_ELT(forest, 0), 1))) != nobs)
00533 error("number of observations don't match");
00534 }
00535
00536 PROTECT(ans = allocVector(VECSXP, nobs));
00537
00538 for (i = 0; i < nobs; i++) {
00539 count = 0;
00540 SET_VECTOR_ELT(ans, i, allocVector(REALSXP, q));
00541 for (j = 0; j < q; j++)
00542 REAL(VECTOR_ELT(ans, i))[j] = 0.0;
00543 for (b = 0; b < ntrees; b++) {
00544 tree = VECTOR_ELT(forest, b);
00545
00546 if (oob &&
00547 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0)
00548 continue;
00549
00550 iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00551 tmp = S3get_prediction(C_get_nodebynum(tree, iwhere));
00552 for (j = 0; j < q; j++)
00553 REAL(VECTOR_ELT(ans, i))[j] += REAL(tmp)[j];
00554 count++;
00555 }
00556 if (count == 0)
00557 error("cannot compute out-of-bag predictions for obs ", i + 1);
00558 for (j = 0; j < q; j++)
00559 REAL(VECTOR_ELT(ans, i))[j] = REAL(VECTOR_ELT(ans, i))[j] / count;
00560 }
00561 UNPROTECT(1);
00562 return(ans);
00563 }
00564
00574 SEXP R_predictRF2(SEXP forest, SEXP response, SEXP newinputs,
00575 SEXP mincriterion, SEXP oobpred) {
00576
00577 SEXP ans, tmp, tree, w;
00578 int ntrees, nobs, i, b, j, q, n, iwhere, oob = 0;
00579 double *dtmp, *dw, sumw = 0.0;
00580
00581 if (LOGICAL(oobpred)[0]) oob = 1;
00582
00583 nobs = get_nobs(newinputs);
00584 ntrees = LENGTH(forest);
00585 n = nrow(response);
00586 q = ncol(response);
00587
00588 if (oob) {
00589 if (n != nobs)
00590 error("number of observations don't match");
00591 }
00592
00593 PROTECT(ans = allocVector(VECSXP, nobs));
00594 PROTECT(w = allocMatrix(REALSXP, 1, n));
00595 dw = REAL(w);
00596
00597 for (i = 0; i < nobs; i++) {
00598
00599 SET_VECTOR_ELT(ans, i, allocVector(REALSXP, q));
00600 for (j = 0; j < n; j++)
00601 dw[j] = 0.0;
00602
00603 for (b = 0; b < ntrees; b++) {
00604 tree = VECTOR_ELT(forest, b);
00605
00606 if (oob &&
00607 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0)
00608 continue;
00609
00610 iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00611 tmp = S3get_nodeweights(C_get_nodebynum(tree, iwhere));
00612 dtmp = REAL(tmp);
00613
00614 for (j = 0; j < n; j++)
00615 dw[j] += dtmp[j];
00616 }
00617
00618 C_matprod(dw, 1, n, REAL(response), n, q, REAL(VECTOR_ELT(ans, i)));
00619
00620 sumw = 0.0;
00621 for (j = 0; j < n; j++)
00622 sumw += dw[j];
00623
00624 for (j = 0; j < q; j++)
00625 REAL(VECTOR_ELT(ans, i))[j] = REAL(VECTOR_ELT(ans, i))[j] / sumw;
00626 }
00627 UNPROTECT(2);
00628 return(ans);
00629 }
00630
00639 SEXP R_predictRF_weights(SEXP forest, SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00640
00641 SEXP ans, tree, bw;
00642 int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0, ntrain;
00643 double *dtmp;
00644
00645 if (LOGICAL(oobpred)[0]) oob = 1;
00646
00647 nobs = get_nobs(newinputs);
00648 ntrees = LENGTH(forest);
00649 q = LENGTH(S3get_prediction(
00650 C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00651
00652 if (oob) {
00653 if (LENGTH(S3get_nodeweights(
00654 C_get_nodebynum(VECTOR_ELT(forest, 0), 1))) != nobs)
00655 error("number of observations don't match");
00656 }
00657
00658 tree = VECTOR_ELT(forest, 0);
00659 ntrain = LENGTH(S3get_nodeweights(C_get_nodebynum(tree, 1)));
00660
00661 PROTECT(ans = allocVector(VECSXP, nobs));
00662
00663 for (i = 0; i < nobs; i++) {
00664 count = 0;
00665 SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
00666 for (j = 0; j < ntrain; j++)
00667 REAL(bw)[j] = 0.0;
00668 for (b = 0; b < ntrees; b++) {
00669 tree = VECTOR_ELT(forest, b);
00670
00671 if (oob &&
00672 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0)
00673 continue;
00674
00675 iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00676 dtmp = REAL(S3get_nodeweights(C_get_nodebynum(tree, iwhere)));
00677 for (j = 0; j < ntrain; j++)
00678 REAL(bw)[j] += dtmp[j];
00679 count++;
00680 }
00681 if (count == 0)
00682 error("cannot compute out-of-bag predictions for obs ", i + 1);
00683 }
00684 UNPROTECT(1);
00685 return(ans);
00686 }