00001
00009 #include "party.h"
00010
00011
00022 void C_prediction(const double *y, int n, int q, const double *weights,
00023 const double sweights, double *ans) {
00024
00025 int i, j, jn;
00026
00027 for (j = 0; j < q; j++) {
00028 ans[j] = 0.0;
00029 jn = j * n;
00030 for (i = 0; i < n; i++)
00031 ans[j] += weights[i] * y[jn + i];
00032 ans[j] = ans[j] / sweights;
00033 }
00034 }
00035
00036
00048 void C_Node(SEXP node, SEXP learnsample, SEXP weights,
00049 SEXP fitmem, SEXP controls, int TERMINAL) {
00050
00051 int nobs, ninputs, jselect, q, j, k, i;
00052 double mincriterion, sweights, *dprediction;
00053 double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00054 double *standstat, *splitstat;
00055 SEXP responses, inputs, x, expcovinf, thisweights, linexpcov;
00056 SEXP varctrl, splitctrl, gtctrl, tgctrl, split, jointy;
00057 double *dxtransf, *dweights;
00058 int *itable;
00059
00060 nobs = get_nobs(learnsample);
00061 ninputs = get_ninputs(learnsample);
00062 varctrl = get_varctrl(controls);
00063 splitctrl = get_splitctrl(controls);
00064 gtctrl = get_gtctrl(controls);
00065 tgctrl = get_tgctrl(controls);
00066 mincriterion = get_mincriterion(gtctrl);
00067 responses = GET_SLOT(learnsample, PL2_responsesSym);
00068 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00069 jointy = get_jointtransf(responses);
00070 q = ncol(jointy);
00071
00072
00073
00074
00075 C_GlobalTest(learnsample, weights, fitmem, varctrl,
00076 gtctrl, get_minsplit(splitctrl),
00077 REAL(S3get_teststat(node)), REAL(S3get_criterion(node)));
00078
00079
00080 sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym),
00081 PL2_sumweightsSym))[0];
00082
00083
00084 dprediction = REAL(S3get_prediction(node));
00085
00086
00087
00088 C_prediction(REAL(jointy), nobs, q, REAL(weights),
00089 sweights, dprediction);
00090
00091
00092 teststat = REAL(S3get_teststat(node));
00093 pvalue = REAL(S3get_criterion(node));
00094
00095
00096
00097
00098 for (j = 0; j < 2; j++) {
00099
00100 smax = C_max(pvalue, ninputs);
00101 REAL(S3get_maxcriterion(node))[0] = smax;
00102
00103
00104 if (smax > mincriterion && !TERMINAL) {
00105
00106
00107 jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00108
00109
00110 x = get_variable(inputs, jselect);
00111 if (has_missings(inputs, jselect)) {
00112 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect),
00113 PL2_expcovinfSym);
00114 thisweights = get_weights(fitmem, jselect);
00115 } else {
00116 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00117 thisweights = weights;
00118 }
00119
00120
00121 if (!is_nominal(inputs, jselect)) {
00122
00123
00124 split = S3get_primarysplit(node);
00125
00126
00127
00128 if (get_savesplitstats(tgctrl)) {
00129 C_init_orderedsplit(split, nobs);
00130 splitstat = REAL(S3get_splitstatistics(split));
00131 } else {
00132 C_init_orderedsplit(split, 0);
00133 splitstat = REAL(get_splitstatistics(fitmem));
00134 }
00135
00136 C_split(REAL(x), 1, REAL(jointy), q, REAL(weights), nobs,
00137 INTEGER(get_ordering(inputs, jselect)), splitctrl,
00138 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00139 expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00140 splitstat);
00141 S3set_variableID(split, jselect);
00142 } else {
00143
00144
00145 split = S3get_primarysplit(node);
00146
00147
00148
00149 if (get_savesplitstats(tgctrl)) {
00150 C_init_nominalsplit(split,
00151 LENGTH(get_levels(inputs, jselect)),
00152 nobs);
00153 splitstat = REAL(S3get_splitstatistics(split));
00154 } else {
00155 C_init_nominalsplit(split,
00156 LENGTH(get_levels(inputs, jselect)),
00157 0);
00158 splitstat = REAL(get_splitstatistics(fitmem));
00159 }
00160
00161 linexpcov = get_varmemory(fitmem, jselect);
00162 standstat = Calloc(get_dimension(linexpcov), double);
00163 C_standardize(REAL(GET_SLOT(linexpcov,
00164 PL2_linearstatisticSym)),
00165 REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00166 REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00167 get_dimension(linexpcov), get_tol(splitctrl),
00168 standstat);
00169
00170 C_splitcategorical(INTEGER(x),
00171 LENGTH(get_levels(inputs, jselect)),
00172 REAL(jointy), q, REAL(weights),
00173 nobs, standstat, splitctrl,
00174 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00175 expcovinf, &cutpoint,
00176 INTEGER(S3get_splitpoint(split)),
00177 &maxstat, splitstat);
00178
00179
00180
00181
00182
00183 itable = INTEGER(S3get_table(split));
00184 dxtransf = REAL(get_transformation(inputs, jselect));
00185 dweights = REAL(thisweights);
00186 for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
00187 itable[k] = 0;
00188 for (i = 0; i < nobs; i++) {
00189 if (dxtransf[k * nobs + i] * dweights[i] > 0) {
00190 itable[k] = 1;
00191 continue;
00192 }
00193 }
00194 }
00195
00196 Free(standstat);
00197 }
00198 if (maxstat == 0) {
00199 warning("no admissible split found\n");
00200
00201 if (j == 1) {
00202 S3set_nodeterminal(node);
00203 } else {
00204
00205 pvalue[jselect - 1] = 0.0;
00206 }
00207 } else {
00208 S3set_variableID(split, jselect);
00209 break;
00210 }
00211 } else {
00212 S3set_nodeterminal(node);
00213 break;
00214 }
00215 }
00216 }
00217
00218
00227 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00228
00229 SEXP ans;
00230
00231 PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00232 C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample),
00233 get_maxsurrogate(get_splitctrl(controls)),
00234 ncol(get_jointtransf(GET_SLOT(learnsample, PL2_responsesSym))));
00235
00236 C_Node(ans, learnsample, weights, fitmem, controls, 0);
00237 UNPROTECT(1);
00238 return(ans);
00239 }