SurrogateSplits.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00022                   SEXP fitmem) {
00023 
00024     SEXP x, y, expcovinf; 
00025     SEXP splitctrl, inputs; 
00026     SEXP split, thiswhichNA;
00027     int nobs, ninputs, i, j, k, jselect, maxsurr, *order;
00028     double ms, cp, *thisweights, *cutpoint, *maxstat, 
00029            *splitstat, *dweights, *tweights, *dx, *dy;
00030     double cut, *twotab;
00031     
00032     nobs = get_nobs(learnsample);
00033     ninputs = get_ninputs(learnsample);
00034     splitctrl = get_splitctrl(controls);
00035     maxsurr = get_maxsurrogate(splitctrl);
00036 
00037     if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00038         error("nodes does not have %d surrogate splits", maxsurr);
00039     if ((ninputs - 1 - maxsurr) < 1)
00040         error("cannot set up %d surrogate splits with only %d input variable(s)", 
00041               maxsurr, ninputs);
00042 
00043     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00044     jselect = S3get_variableID(S3get_primarysplit(node));
00045     y = S3get_nodeweights(VECTOR_ELT(node, 7));
00046 
00047     tweights = Calloc(nobs, double);
00048     dweights = REAL(weights);
00049     for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00050     if (has_missings(inputs, jselect)) {
00051         thiswhichNA = get_missings(inputs, jselect);
00052         for (k = 0; k < LENGTH(thiswhichNA); k++)
00053             tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00054     }
00055 
00056     expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00057     C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00058     
00059     splitstat = REAL(get_splitstatistics(fitmem));
00060     /* <FIXME> extend `TreeFitMemory' to those as well ... */
00061     maxstat = Calloc(ninputs, double);
00062     cutpoint = Calloc(ninputs, double);
00063     order = Calloc(ninputs, int);
00064     /* <FIXME> */
00065     
00066     /* this is essentially an exhaustive search */
00067     /* <FIXME>: we don't want to do this for random forest like trees 
00068        </FIXME>
00069      */
00070     for (j = 0; j < ninputs; j++) {
00071     
00072          order[j] = j + 1;
00073          maxstat[j] = 0.0;
00074          cutpoint[j] = 0.0;
00075 
00076          /* ordered input variables only (for the moment) */
00077          if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00078              continue;
00079 
00080          x = get_variable(inputs, j + 1);
00081 
00082          if (has_missings(inputs, j + 1)) {
00083 
00084              thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00085                  
00086              C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00087              
00088              C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00089                      INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00090                      GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00091                      expcovinf, &cp, &ms, splitstat);
00092          } else {
00093          
00094              C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00095              INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00096              GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00097              expcovinf, &cp, &ms, splitstat);
00098          }
00099 
00100          maxstat[j] = -ms;
00101          cutpoint[j] = cp;
00102     }
00103 
00104     /* order with respect to maximal statistic */
00105     rsort_with_index(maxstat, order, ninputs);
00106     
00107     twotab = Calloc(4, double);
00108     
00109     /* the best `maxsurr' ones are implemented */
00110     for (j = 0; j < maxsurr; j++) {
00111 
00112         for (i = 0; i < 4; i++) twotab[i] = 0.0;
00113         cut = cutpoint[order[j] - 1];
00114         SET_VECTOR_ELT(S3get_surrogatesplits(node), j, 
00115                        split = allocVector(VECSXP, SPLIT_LENGTH));
00116         C_init_orderedsplit(split, 0);
00117         S3set_variableID(split, order[j]);
00118         REAL(S3get_splitpoint(split))[0] = cut;
00119         dx = REAL(get_variable(inputs, order[j]));
00120         dy = REAL(y);
00121 
00122         /* OK, this is a dirty hack: determine if the split 
00123            goes left or right by the Pearson residual of a 2x2 table.
00124            I don't want to use the big caliber here 
00125         */
00126         for (i = 0; i < nobs; i++) {
00127             twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00128             twotab[1] += (dy[i] == 1) * tweights[i];
00129             twotab[2] += (dx[i] <= cut) * tweights[i];
00130             twotab[3] += tweights[i];
00131         }
00132         S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / 
00133                      twotab[3]) > 0);
00134     }
00135     
00136     Free(maxstat);
00137     Free(cutpoint);
00138     Free(order);
00139     Free(tweights);
00140     Free(twotab);
00141 }
00142 
00153 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00154                   SEXP fitmem) {
00155 
00156     C_surrogates(node, learnsample, weights, controls, fitmem);
00157     return(S3get_surrogatesplits(node));
00158     
00159 }
00160 
00168 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00169 
00170     SEXP weights, split, surrsplit;
00171     SEXP inputs, whichNA;
00172     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00173     int *iwhichNA, k;
00174     int nobs, i, nna, ns;
00175                     
00176     weights = S3get_nodeweights(node);
00177     dweights = REAL(weights);
00178     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00179     nobs = get_nobs(learnsample);
00180             
00181     leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00182     rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00183     surrsplit = S3get_surrogatesplits(node);
00184 
00185     /* if the primary split has any missings */
00186     split = S3get_primarysplit(node);
00187     if (has_missings(inputs, S3get_variableID(split))) {
00188 
00189         /* where are the missings? */
00190         whichNA = get_missings(inputs, S3get_variableID(split));
00191         iwhichNA = INTEGER(whichNA);
00192         nna = LENGTH(whichNA);
00193 
00194         /* for all missing values ... */
00195         for (k = 0; k < nna; k++) {
00196             ns = 0;
00197             i = iwhichNA[k] - 1;
00198             if (dweights[i] == 0) continue;
00199             
00200             /* loop over surrogate splits until an appropriate one is found */
00201             while(TRUE) {
00202             
00203                 if (ns >= LENGTH(surrsplit)) break;
00204             
00205                 split = VECTOR_ELT(surrsplit, ns);
00206                 if (has_missings(inputs, S3get_variableID(split))) {
00207                     if (INTEGER(get_missings(inputs, 
00208                             S3get_variableID(split)))[i]) {
00209                         ns++;
00210                         continue;
00211                     }
00212                 }
00213 
00214                 cutpoint = REAL(S3get_splitpoint(split))[0];
00215                 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00216 
00217                 if (S3get_toleft(split)) {
00218                     if (dx[i] <= cutpoint) {
00219                         leftweights[i] = dweights[i];
00220                         rightweights[i] = 0.0;
00221                     } else {
00222                         rightweights[i] = dweights[i];
00223                         leftweights[i] = 0.0;
00224                     }
00225                 } else {
00226                     if (dx[i] <= cutpoint) {
00227                         rightweights[i] = dweights[i];
00228                         leftweights[i] = 0.0;
00229                     } else {
00230                         leftweights[i] = dweights[i];
00231                         rightweights[i] = 0.0;
00232                     }
00233                 }
00234                 break;
00235             }
00236         }
00237     }
00238 }

Generated on Thu Sep 27 15:50:56 2007 for party by  doxygen 1.4.6