Predict.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00021 void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
00022 
00023     SEXP weights, leftnode, rightnode, split;
00024     SEXP responses, inputs, whichNA;
00025     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00026     double sleft = 0.0, sright = 0.0;
00027     int *ix, *levelset, *iwhichNA;
00028     int nobs, i, nna;
00029                     
00030     weights = S3get_nodeweights(node);
00031     dweights = REAL(weights);
00032     responses = GET_SLOT(learnsample, PL2_responsesSym);
00033     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00034     nobs = get_nobs(learnsample);
00035             
00036     /* set up memory for the left daughter */
00037     SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
00038     C_init_node(leftnode, nobs, 
00039         get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00040         ncol(get_jointtransf(GET_SLOT(learnsample, PL2_responsesSym))));
00041     leftweights = REAL(S3get_nodeweights(leftnode));
00042 
00043     /* set up memory for the right daughter */
00044     SET_VECTOR_ELT(node, S3_RIGHT, 
00045                    rightnode = allocVector(VECSXP, NODE_LENGTH));
00046     C_init_node(rightnode, nobs, 
00047         get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00048         ncol(get_jointtransf(GET_SLOT(learnsample, PL2_responsesSym))));
00049     rightweights = REAL(S3get_nodeweights(rightnode));
00050 
00051     /* split according to the primary split */
00052     split = S3get_primarysplit(node);
00053     if (has_missings(inputs, S3get_variableID(split))) {
00054         whichNA = get_missings(inputs, S3get_variableID(split));
00055         iwhichNA = INTEGER(whichNA);
00056         nna = LENGTH(whichNA);
00057     } else {
00058         nna = 0;
00059         whichNA = R_NilValue;
00060         iwhichNA = NULL;
00061     }
00062     
00063     if (S3is_ordered(split)) {
00064         cutpoint = REAL(S3get_splitpoint(split))[0];
00065         dx = REAL(get_variable(inputs, S3get_variableID(split)));
00066         for (i = 0; i < nobs; i++) {
00067             if (nna > 0) {
00068                 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00069             }
00070             if (dx[i] <= cutpoint) 
00071                 leftweights[i] = dweights[i]; 
00072             else 
00073                 leftweights[i] = 0.0;
00074             rightweights[i] = dweights[i] - leftweights[i];
00075             sleft += leftweights[i];
00076             sright += rightweights[i];
00077         }
00078     } else {
00079         levelset = INTEGER(S3get_splitpoint(split));
00080         ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
00081 
00082         for (i = 0; i < nobs; i++) {
00083             if (nna > 0) {
00084                 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00085             }
00086             if (levelset[ix[i] - 1])
00087                 leftweights[i] = dweights[i];
00088             else 
00089                 leftweights[i] = 0.0;
00090             rightweights[i] = dweights[i] - leftweights[i];
00091             sleft += leftweights[i];
00092             sright += rightweights[i];
00093         }
00094     }
00095     
00096     /* for the moment: NA's go with majority */
00097     if (nna > 0) {
00098         for (i = 0; i < nna; i++) {
00099             if (sleft > sright) {
00100                 leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00101                 rightweights[iwhichNA[i] - 1] = 0.0;
00102             } else {
00103                 rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00104                 leftweights[iwhichNA[i] - 1] = 0.0;
00105             }
00106         }
00107     }
00108 }
00109 
00110 
00120 SEXP C_get_node(SEXP subtree, SEXP newinputs, 
00121                 double mincriterion, int numobs) {
00122 
00123     SEXP split, whichNA, weights, ssplit, surrsplit;
00124     double cutpoint, x, *dweights, swleft, swright;
00125     int level, *levelset, i, ns;
00126 
00127     if (S3get_nodeterminal(subtree) || 
00128         REAL(S3get_maxcriterion(subtree))[0] < mincriterion) 
00129         return(subtree);
00130     
00131     split = S3get_primarysplit(subtree);
00132 
00133     /* missing values. Maybe store the proportions left / 
00134        right in each node? */
00135     if (has_missings(newinputs, S3get_variableID(split))) {
00136         whichNA = get_missings(newinputs, S3get_variableID(split));
00137     
00138         if (C_i_in_set(numobs, whichNA)) {
00139         
00140             surrsplit = S3get_surrogatesplits(subtree);
00141             ns = 0;
00142             i = numobs;      
00143 
00144             /* try to find a surrogate split */
00145             while(TRUE) {
00146     
00147                 if (ns >= LENGTH(surrsplit)) break;
00148             
00149                 ssplit = VECTOR_ELT(surrsplit, ns);
00150                 if (has_missings(newinputs, S3get_variableID(ssplit))) {
00151                     if (INTEGER(get_missings(newinputs, 
00152                                              S3get_variableID(ssplit)))[i]) {
00153                         ns++;
00154                         continue;
00155                     }
00156                 }
00157 
00158                 cutpoint = REAL(S3get_splitpoint(ssplit))[0];
00159                 x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
00160                      
00161                 if (S3get_toleft(ssplit)) {
00162                     if (x <= cutpoint) {
00163                         return(C_get_node(S3get_leftnode(subtree),
00164                                           newinputs, mincriterion, numobs));
00165                     } else {
00166                         return(C_get_node(S3get_rightnode(subtree),
00167                                newinputs, mincriterion, numobs));
00168                     }
00169                 } else {
00170                     if (x <= cutpoint) {
00171                         return(C_get_node(S3get_rightnode(subtree),
00172                                           newinputs, mincriterion, numobs));
00173                     } else {
00174                         return(C_get_node(S3get_leftnode(subtree),
00175                                newinputs, mincriterion, numobs));
00176                     }
00177                 }
00178                 break;
00179             }
00180 
00181             /* if this was not successful, we go with the majority */
00182             weights = S3get_nodeweights(S3get_leftnode(subtree));
00183             dweights = REAL(weights);
00184             swleft = 0.0;
00185             for (i = 0; i < LENGTH(weights); i++)
00186                 swleft += dweights[i];
00187             weights = S3get_nodeweights(S3get_rightnode(subtree));
00188             dweights = REAL(weights);
00189             swright = 0.0;
00190             for (i = 0; i < LENGTH(weights); i++)
00191                 swright += dweights[i];
00192             if (swleft > swright) {
00193                 return(C_get_node(S3get_leftnode(subtree), 
00194                                   newinputs, mincriterion, numobs));
00195             } else {
00196                 return(C_get_node(S3get_rightnode(subtree), 
00197                                   newinputs, mincriterion, numobs));
00198             }
00199         }
00200     }
00201     
00202     if (S3is_ordered(split)) {
00203         cutpoint = REAL(S3get_splitpoint(split))[0];
00204         x = REAL(get_variable(newinputs, 
00205                      S3get_variableID(split)))[numobs];
00206         if (x <= cutpoint) {
00207             return(C_get_node(S3get_leftnode(subtree), 
00208                               newinputs, mincriterion, numobs));
00209         } else {
00210             return(C_get_node(S3get_rightnode(subtree), 
00211                               newinputs, mincriterion, numobs));
00212         }
00213     } else {
00214         levelset = INTEGER(S3get_splitpoint(split));
00215         level = INTEGER(get_variable(newinputs, 
00216                             S3get_variableID(split)))[numobs];
00217         /* level is in 1, ..., K */
00218         if (levelset[level - 1]) {
00219             return(C_get_node(S3get_leftnode(subtree), newinputs, 
00220                               mincriterion, numobs));
00221         } else {
00222             return(C_get_node(S3get_rightnode(subtree), newinputs, 
00223                               mincriterion, numobs));
00224         }
00225     }
00226 }
00227 
00228 
00237 SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion, 
00238                 SEXP numobs) {
00239     return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
00240                       INTEGER(numobs)[0] - 1));
00241 }
00242 
00243 
00250 SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
00251     
00252     if (nodenum == S3get_nodeID(subtree)) return(subtree);
00253 
00254     if (S3get_nodeterminal(subtree)) 
00255         error("no node with number %d\n", nodenum);
00256 
00257     if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
00258         return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
00259     } else {
00260         return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
00261     }
00262 }
00263 
00264 
00271 SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
00272     return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
00273 }
00274 
00275 
00284 SEXP C_get_prediction(SEXP subtree, SEXP newinputs, 
00285                       double mincriterion, int numobs) {
00286     return(S3get_prediction(C_get_node(subtree, newinputs, 
00287                             mincriterion, numobs)));
00288 }
00289 
00290 
00299 SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs, 
00300                        double mincriterion, int numobs) {
00301     return(S3get_nodeweights(C_get_node(subtree, newinputs, 
00302                              mincriterion, numobs)));
00303 }
00304 
00305 
00314 int C_get_nodeID(SEXP subtree, SEXP newinputs,
00315                   double mincriterion, int numobs) {
00316      return(S3get_nodeID(C_get_node(subtree, newinputs, 
00317             mincriterion, numobs)));
00318 }
00319 
00320 
00328 SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00329 
00330     SEXP ans;
00331     int nobs, i, *dans;
00332             
00333     nobs = get_nobs(newinputs);
00334     PROTECT(ans = allocVector(INTSXP, nobs));
00335     dans = INTEGER(ans);
00336     for (i = 0; i < nobs; i++)
00337          dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00338     UNPROTECT(1);
00339     return(ans);
00340 }
00341 
00342 
00351 void C_predict(SEXP tree, SEXP newinputs, double mincriterion, SEXP ans) {
00352     
00353     int nobs, i;
00354     
00355     nobs = get_nobs(newinputs);    
00356     if (LENGTH(ans) != nobs) 
00357         error("ans is not of length %d\n", nobs);
00358         
00359     for (i = 0; i < nobs; i++)
00360         SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs, 
00361                        mincriterion, i));
00362 }
00363 
00364 
00372 SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00373 
00374     SEXP ans;
00375     int nobs;
00376     
00377     nobs = get_nobs(newinputs);
00378     PROTECT(ans = allocVector(VECSXP, nobs));
00379     C_predict(tree, newinputs, REAL(mincriterion)[0], ans);
00380     UNPROTECT(1);
00381     return(ans);
00382 }
00383 
00384 
00392 void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
00393 
00394     int nobs, i, *iwhere;
00395     
00396     nobs = LENGTH(where);
00397     iwhere = INTEGER(where);
00398     if (LENGTH(ans) != nobs)
00399         error("ans is not of length %d\n", nobs);
00400         
00401     for (i = 0; i < nobs; i++)
00402         SET_VECTOR_ELT(ans, i, S3get_prediction(
00403             C_get_nodebynum(tree, iwhere[i])));
00404 }
00405 
00406 
00413 SEXP R_getpredictions(SEXP tree, SEXP where) {
00414 
00415     SEXP ans;
00416     int nobs;
00417             
00418     nobs = LENGTH(where);
00419     PROTECT(ans = allocVector(VECSXP, nobs));
00420     C_getpredictions(tree, where, ans);
00421     UNPROTECT(1);
00422     return(ans);
00423 }                        
00424 
00425 
00433 void C_getweights(SEXP tree, SEXP where, SEXP ans) {
00434 
00435     int nobs, i, *iwhere;
00436     
00437     nobs = LENGTH(where);
00438     iwhere = INTEGER(where);
00439     if (LENGTH(ans) != nobs)
00440         error("ans is not of length %d\n", nobs);
00441         
00442     for (i = 0; i < nobs; i++)
00443         SET_VECTOR_ELT(ans, i, S3get_nodeweights(
00444             C_get_nodebynum(tree, iwhere[i])));
00445 }
00446 
00447 
00454 SEXP R_getweights(SEXP tree, SEXP where) {
00455 
00456     SEXP ans;
00457     int nobs;
00458             
00459     nobs = LENGTH(where);
00460     PROTECT(ans = allocVector(VECSXP, nobs));
00461     C_getweights(tree, where, ans);
00462     UNPROTECT(1);
00463     return(ans);
00464 }                        
00465 
00466 
00475 void C_weights(SEXP tree, SEXP newinputs, 
00476                double mincriterion, SEXP ans) {
00477     
00478     int nobs, i;
00479     
00480     nobs = get_nobs(newinputs);    
00481     if (LENGTH(ans) != nobs) 
00482         error("ans is not of length %d\n", nobs);
00483         
00484     for (i = 0; i < nobs; i++)
00485         SET_VECTOR_ELT(ans, i, C_get_nodeweights(tree, newinputs, 
00486                        mincriterion, i));
00487 }
00488 
00489 
00497 SEXP R_weights(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00498 
00499     SEXP ans;
00500     int nobs;
00501     
00502     nobs = get_nobs(newinputs);
00503     PROTECT(ans = allocVector(VECSXP, nobs));
00504     C_weights(tree, newinputs, REAL(mincriterion)[0], ans);
00505     UNPROTECT(1);
00506     return(ans);
00507 }
00508 
00509 
00518 SEXP R_predictRF(SEXP forest, SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00519 
00520     SEXP ans, tmp, tree;
00521     int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0;
00522     
00523     if (LOGICAL(oobpred)[0]) oob = 1;
00524     
00525     nobs = get_nobs(newinputs);
00526     ntrees = LENGTH(forest);
00527     q = LENGTH(S3get_prediction(
00528                    C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00529 
00530     if (oob) {
00531         if (LENGTH(S3get_nodeweights(
00532                        C_get_nodebynum(VECTOR_ELT(forest, 0), 1))) != nobs)
00533             error("number of observations don't match");
00534     }    
00535     
00536     PROTECT(ans = allocVector(VECSXP, nobs));
00537     
00538     for (i = 0; i < nobs; i++) {
00539         count = 0;
00540         SET_VECTOR_ELT(ans, i, allocVector(REALSXP, q));
00541         for (j = 0; j < q; j++)
00542                     REAL(VECTOR_ELT(ans, i))[j] = 0.0;
00543         for (b = 0; b < ntrees; b++) {
00544             tree = VECTOR_ELT(forest, b);
00545 
00546             if (oob && 
00547                 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0) 
00548                 continue;
00549 
00550             iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00551             tmp = S3get_prediction(C_get_nodebynum(tree, iwhere));
00552             for (j = 0; j < q; j++)
00553                 REAL(VECTOR_ELT(ans, i))[j] += REAL(tmp)[j];
00554             count++;
00555         }
00556         if (count == 0) 
00557             error("cannot compute out-of-bag predictions for obs ", i + 1);
00558         for (j = 0; j < q; j++)
00559             REAL(VECTOR_ELT(ans, i))[j] = REAL(VECTOR_ELT(ans, i))[j] / count;
00560     }
00561     UNPROTECT(1);
00562     return(ans);
00563 }
00564 
00574 SEXP R_predictRF2(SEXP forest, SEXP response, SEXP newinputs, 
00575                   SEXP mincriterion, SEXP oobpred) {
00576 
00577     SEXP ans, tmp, tree, w;
00578     int ntrees, nobs, i, b, j, q, n, iwhere, oob = 0;
00579     double *dtmp, *dw, sumw = 0.0;
00580 
00581     if (LOGICAL(oobpred)[0]) oob = 1;
00582     
00583     nobs = get_nobs(newinputs);
00584     ntrees = LENGTH(forest);
00585     n = nrow(response);
00586     q = ncol(response);
00587 
00588     if (oob) {
00589         if (n != nobs)
00590             error("number of observations don't match");
00591     }    
00592     
00593     PROTECT(ans = allocVector(VECSXP, nobs));
00594     PROTECT(w = allocMatrix(REALSXP, 1, n));
00595     dw = REAL(w);
00596     
00597     for (i = 0; i < nobs; i++) {
00598 
00599         SET_VECTOR_ELT(ans, i, allocVector(REALSXP, q));
00600         for (j = 0; j < n; j++)
00601             dw[j] = 0.0;
00602 
00603         for (b = 0; b < ntrees; b++) {
00604             tree = VECTOR_ELT(forest, b);
00605 
00606             if (oob && 
00607                 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0) 
00608                 continue;
00609 
00610             iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00611             tmp = S3get_nodeweights(C_get_nodebynum(tree, iwhere));
00612             dtmp = REAL(tmp);
00613             
00614             for (j = 0; j < n; j++)
00615                 dw[j] += dtmp[j];
00616         }
00617         
00618         C_matprod(dw, 1, n, REAL(response), n, q, REAL(VECTOR_ELT(ans, i)));
00619 
00620         sumw = 0.0;
00621         for (j = 0; j < n; j++)
00622             sumw += dw[j];
00623 
00624         for (j = 0; j < q; j++)
00625             REAL(VECTOR_ELT(ans, i))[j] = REAL(VECTOR_ELT(ans, i))[j] / sumw;
00626     }
00627     UNPROTECT(2);
00628     return(ans);
00629 }
00630 
00639 SEXP R_predictRF_weights(SEXP forest, SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00640 
00641     SEXP ans, tree, bw;
00642     int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0, ntrain;
00643     double *dtmp;
00644     
00645     if (LOGICAL(oobpred)[0]) oob = 1;
00646     
00647     nobs = get_nobs(newinputs);
00648     ntrees = LENGTH(forest);
00649     q = LENGTH(S3get_prediction(
00650                    C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00651 
00652     if (oob) {
00653         if (LENGTH(S3get_nodeweights(
00654                        C_get_nodebynum(VECTOR_ELT(forest, 0), 1))) != nobs)
00655             error("number of observations don't match");
00656     }    
00657     
00658     tree = VECTOR_ELT(forest, 0);
00659     ntrain = LENGTH(S3get_nodeweights(C_get_nodebynum(tree, 1)));
00660     
00661     PROTECT(ans = allocVector(VECSXP, nobs));
00662     
00663     for (i = 0; i < nobs; i++) {
00664         count = 0;
00665         SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
00666         for (j = 0; j < ntrain; j++)
00667             REAL(bw)[j] = 0.0;
00668         for (b = 0; b < ntrees; b++) {
00669             tree = VECTOR_ELT(forest, b);
00670 
00671             if (oob && 
00672                 REAL(S3get_nodeweights(C_get_nodebynum(tree, 1)))[i] > 0.0) 
00673                 continue;
00674 
00675             iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00676             dtmp = REAL(S3get_nodeweights(C_get_nodebynum(tree, iwhere)));
00677             for (j = 0; j < ntrain; j++)
00678                 REAL(bw)[j] += dtmp[j];
00679             count++;
00680         }
00681         if (count == 0) 
00682             error("cannot compute out-of-bag predictions for obs ", i + 1);
00683     }
00684     UNPROTECT(1);
00685     return(ans);
00686 }

Generated on Mon Jan 22 17:37:53 2007 for party by  doxygen 1.4.6