00001
00009 #include "party.h"
00010
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00022 SEXP fitmem) {
00023
00024 SEXP x, y, expcovinf;
00025 SEXP splitctrl, inputs;
00026 SEXP split, thiswhichNA;
00027 int nobs, ninputs, i, j, k, jselect, maxsurr, *order;
00028 double ms, cp, *thisweights, *cutpoint, *maxstat,
00029 *splitstat, *dweights, *tweights, *dx, *dy;
00030 double cut, *twotab;
00031
00032 nobs = get_nobs(learnsample);
00033 ninputs = get_ninputs(learnsample);
00034 splitctrl = get_splitctrl(controls);
00035 maxsurr = get_maxsurrogate(splitctrl);
00036
00037 if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00038 error("nodes does not have %d surrogate splits", maxsurr);
00039 if ((ninputs - 1 - maxsurr) < 1)
00040 error("cannot set up %d surrogate splits with only %d input variable(s)",
00041 maxsurr, ninputs);
00042
00043 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00044 jselect = S3get_variableID(S3get_primarysplit(node));
00045 y = S3get_nodeweights(VECTOR_ELT(node, 7));
00046
00047 tweights = Calloc(nobs, double);
00048 dweights = REAL(weights);
00049 for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00050 if (has_missings(inputs, jselect)) {
00051 thiswhichNA = get_missings(inputs, jselect);
00052 for (k = 0; k < LENGTH(thiswhichNA); k++)
00053 tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00054 }
00055
00056 expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00057 C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00058
00059 splitstat = REAL(get_splitstatistics(fitmem));
00060
00061 maxstat = Calloc(ninputs, double);
00062 cutpoint = Calloc(ninputs, double);
00063 order = Calloc(ninputs, int);
00064
00065
00066
00067
00068
00069
00070 for (j = 0; j < ninputs; j++) {
00071
00072 order[j] = j + 1;
00073 maxstat[j] = 0.0;
00074 cutpoint[j] = 0.0;
00075
00076
00077 if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00078 continue;
00079
00080 x = get_variable(inputs, j + 1);
00081
00082 if (has_missings(inputs, j + 1)) {
00083
00084 thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00085
00086 C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00087
00088 C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00089 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00090 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00091 expcovinf, &cp, &ms, splitstat);
00092 } else {
00093
00094 C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00095 INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00096 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00097 expcovinf, &cp, &ms, splitstat);
00098 }
00099
00100 maxstat[j] = -ms;
00101 cutpoint[j] = cp;
00102 }
00103
00104
00105 rsort_with_index(maxstat, order, ninputs);
00106
00107 twotab = Calloc(4, double);
00108
00109
00110 for (j = 0; j < maxsurr; j++) {
00111
00112 for (i = 0; i < 4; i++) twotab[i] = 0.0;
00113 cut = cutpoint[order[j] - 1];
00114 SET_VECTOR_ELT(S3get_surrogatesplits(node), j,
00115 split = allocVector(VECSXP, SPLIT_LENGTH));
00116 C_init_orderedsplit(split, 0);
00117 S3set_variableID(split, order[j]);
00118 REAL(S3get_splitpoint(split))[0] = cut;
00119 dx = REAL(get_variable(inputs, order[j]));
00120 dy = REAL(y);
00121
00122
00123
00124
00125
00126 for (i = 0; i < nobs; i++) {
00127 twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00128 twotab[1] += (dy[i] == 1) * tweights[i];
00129 twotab[2] += (dx[i] <= cut) * tweights[i];
00130 twotab[3] += tweights[i];
00131 }
00132 S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] /
00133 twotab[3]) > 0);
00134 }
00135
00136 Free(maxstat);
00137 Free(cutpoint);
00138 Free(order);
00139 Free(tweights);
00140 Free(twotab);
00141 }
00142
00153 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls,
00154 SEXP fitmem) {
00155
00156 C_surrogates(node, learnsample, weights, controls, fitmem);
00157 return(S3get_surrogatesplits(node));
00158
00159 }
00160
00168 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00169
00170 SEXP weights, split, surrsplit;
00171 SEXP inputs, whichNA;
00172 double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00173 int *iwhichNA, k;
00174 int nobs, i, nna, ns;
00175
00176 weights = S3get_nodeweights(node);
00177 dweights = REAL(weights);
00178 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00179 nobs = get_nobs(learnsample);
00180
00181 leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00182 rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00183 surrsplit = S3get_surrogatesplits(node);
00184
00185
00186 split = S3get_primarysplit(node);
00187 if (has_missings(inputs, S3get_variableID(split))) {
00188
00189
00190 whichNA = get_missings(inputs, S3get_variableID(split));
00191 iwhichNA = INTEGER(whichNA);
00192 nna = LENGTH(whichNA);
00193
00194
00195 for (k = 0; k < nna; k++) {
00196 ns = 0;
00197 i = iwhichNA[k] - 1;
00198 if (dweights[i] == 0) continue;
00199
00200
00201 while(TRUE) {
00202
00203 if (ns >= LENGTH(surrsplit)) break;
00204
00205 split = VECTOR_ELT(surrsplit, ns);
00206 if (has_missings(inputs, S3get_variableID(split))) {
00207 if (INTEGER(get_missings(inputs,
00208 S3get_variableID(split)))[i]) {
00209 ns++;
00210 continue;
00211 }
00212 }
00213
00214 cutpoint = REAL(S3get_splitpoint(split))[0];
00215 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00216
00217 if (S3get_toleft(split)) {
00218 if (dx[i] <= cutpoint) {
00219 leftweights[i] = dweights[i];
00220 rightweights[i] = 0.0;
00221 } else {
00222 rightweights[i] = dweights[i];
00223 leftweights[i] = 0.0;
00224 }
00225 } else {
00226 if (dx[i] <= cutpoint) {
00227 rightweights[i] = dweights[i];
00228 leftweights[i] = 0.0;
00229 } else {
00230 leftweights[i] = dweights[i];
00231 rightweights[i] = 0.0;
00232 }
00233 }
00234 break;
00235 }
00236 }
00237 }
00238 }