TreeGrow.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00023 void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem, 
00024                 SEXP controls, int *where, int *nodenum, int depth) {
00025 
00026     SEXP weights;
00027     int nobs, i, stop;
00028     double *dweights;
00029     
00030     weights = S3get_nodeweights(node);
00031     
00032     /* stop if either stumps have been requested or 
00033        the maximum depth is exceeded */
00034     stop = (nodenum[0] == 2 || nodenum[0] == 3) && 
00035            get_stump(get_tgctrl(controls));
00036     stop = stop || !check_depth(get_tgctrl(controls), depth);
00037     
00038     if (stop)
00039         C_Node(node, learnsample, weights, fitmem, controls, 1);
00040     else
00041         C_Node(node, learnsample, weights, fitmem, controls, 0);
00042     
00043     S3set_nodeID(node, nodenum[0]);    
00044     
00045     if (!S3get_nodeterminal(node)) {
00046 
00047         C_splitnode(node, learnsample, controls);
00048 
00049         /* determine surrogate splits and split missing values */
00050         if (get_maxsurrogate(get_splitctrl(controls)) > 0) {
00051             C_surrogates(node, learnsample, weights, controls, fitmem);
00052             C_splitsurrogate(node, learnsample);
00053         }
00054             
00055         nodenum[0] += 1;
00056         C_TreeGrow(S3get_leftnode(node), learnsample, fitmem, 
00057                    controls, where, nodenum, depth + 1);
00058 
00059         nodenum[0] += 1;                                      
00060         C_TreeGrow(S3get_rightnode(node), learnsample, fitmem, 
00061                    controls, where, nodenum, depth + 1);
00062     } else {
00063         dweights = REAL(weights);
00064         nobs = get_nobs(learnsample);
00065         for (i = 0; i < nobs; i++)
00066             if (dweights[i] > 0) where[i] = nodenum[0];
00067     } 
00068 }
00069 
00070 
00080 SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) {
00081             
00082      SEXP ans, nweights;
00083      double *dnweights, *dweights;
00084      int nobs, i, nodenum = 1;
00085 
00086      GetRNGstate();
00087      
00088      nobs = get_nobs(learnsample);
00089      PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00090      C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)),
00091                  ncol(get_jointtransf(GET_SLOT(learnsample, PL2_responsesSym))));
00092 
00093      nweights = S3get_nodeweights(ans);
00094      dnweights = REAL(nweights);
00095      dweights = REAL(weights);
00096      for (i = 0; i < nobs; i++) dnweights[i] = dweights[i];
00097      
00098      C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum, 1);
00099      
00100      PutRNGstate();
00101      
00102      UNPROTECT(1);
00103      return(ans);
00104 }

Generated on Mon Jan 22 17:37:53 2007 for party by  doxygen 1.4.6