00001
00009 #include "party.h"
00010
00011
00022 void C_prediction(const double *y, int n, int q, const double *weights,
00023 const double sweights, double *ans) {
00024
00025 int i, j, jn;
00026
00027 for (j = 0; j < q; j++) {
00028 ans[j] = 0.0;
00029 jn = j * n;
00030 for (i = 0; i < n; i++)
00031 ans[j] += weights[i] * y[jn + i];
00032 ans[j] = ans[j] / sweights;
00033 }
00034 }
00035
00036
00048 void C_Node(SEXP node, SEXP learnsample, SEXP weights,
00049 SEXP fitmem, SEXP controls, int TERMINAL) {
00050
00051 int nobs, ninputs, jselect, yORDERED, q, j, k, i;
00052 double mincriterion, sweights, *dprediction;
00053 double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00054 double *standstat, *splitstat;
00055 SEXP responses, inputs, y, x, expcovinf, thisweights, linexpcov;
00056 SEXP varctrl, splitctrl, gtctrl, tgctrl, split, joint;
00057 double *dxtransf, *dweights;
00058 int *itable;
00059
00060 nobs = get_nobs(learnsample);
00061 ninputs = get_ninputs(learnsample);
00062 varctrl = get_varctrl(controls);
00063 splitctrl = get_splitctrl(controls);
00064 gtctrl = get_gtctrl(controls);
00065 tgctrl = get_tgctrl(controls);
00066 mincriterion = get_mincriterion(gtctrl);
00067 responses = GET_SLOT(learnsample, PL2_responsesSym);
00068 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00069 yORDERED = is_ordinal(responses, 1);
00070 y = get_transformation(responses, 1);
00071 q = ncol(y);
00072 joint = GET_SLOT(responses, PL2_jointtransfSym);
00073
00074
00075
00076
00077 C_GlobalTest(learnsample, weights, fitmem, varctrl,
00078 gtctrl, get_minsplit(splitctrl),
00079 REAL(S3get_teststat(node)), REAL(S3get_criterion(node)));
00080
00081
00082 sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym),
00083 PL2_sumweightsSym))[0];
00084
00085
00086 dprediction = REAL(S3get_prediction(node));
00087
00088
00089
00090 C_prediction(REAL(joint), nobs, ncol(joint), REAL(weights),
00091 sweights, dprediction);
00092
00093
00094 teststat = REAL(S3get_teststat(node));
00095 pvalue = REAL(S3get_criterion(node));
00096
00097
00098
00099
00100 for (j = 0; j < 2; j++) {
00101
00102 smax = C_max(pvalue, ninputs);
00103 REAL(S3get_maxcriterion(node))[0] = smax;
00104
00105
00106 if (smax > mincriterion && !TERMINAL) {
00107
00108
00109 jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00110
00111
00112 x = get_variable(inputs, jselect);
00113 if (has_missings(inputs, jselect)) {
00114 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect),
00115 PL2_expcovinfSym);
00116 thisweights = get_weights(fitmem, jselect);
00117 } else {
00118 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00119 thisweights = weights;
00120 }
00121
00122
00123 if (!is_nominal(inputs, jselect)) {
00124
00125
00126 split = S3get_primarysplit(node);
00127
00128
00129
00130 if (get_savesplitstats(tgctrl)) {
00131 C_init_orderedsplit(split, nobs);
00132 splitstat = REAL(S3get_splitstatistics(split));
00133 } else {
00134 C_init_orderedsplit(split, 0);
00135 splitstat = REAL(get_splitstatistics(fitmem));
00136 }
00137
00138 C_split(REAL(x), 1, REAL(y), q, REAL(weights), nobs,
00139 INTEGER(get_ordering(inputs, jselect)),
00140 REAL(VECTOR_ELT(GET_SLOT(responses, PL2_scoresSym), 0)),
00141 yORDERED, splitctrl,
00142 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00143 expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00144 splitstat);
00145 S3set_variableID(split, jselect);
00146 } else {
00147
00148
00149 split = S3get_primarysplit(node);
00150
00151
00152
00153 if (get_savesplitstats(tgctrl)) {
00154 C_init_nominalsplit(split,
00155 LENGTH(get_levels(inputs, jselect)),
00156 nobs);
00157 splitstat = REAL(S3get_splitstatistics(split));
00158 } else {
00159 C_init_nominalsplit(split,
00160 LENGTH(get_levels(inputs, jselect)),
00161 0);
00162 splitstat = REAL(get_splitstatistics(fitmem));
00163 }
00164
00165 linexpcov = get_varmemory(fitmem, jselect);
00166 standstat = Calloc(get_dimension(linexpcov), double);
00167 C_standardize(REAL(GET_SLOT(linexpcov,
00168 PL2_linearstatisticSym)),
00169 REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00170 REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00171 get_dimension(linexpcov), get_tol(splitctrl),
00172 standstat);
00173
00174 C_splitcategorical(INTEGER(x),
00175 LENGTH(get_levels(inputs, jselect)),
00176 REAL(y), q, REAL(weights),
00177 nobs, REAL(VECTOR_ELT(GET_SLOT(responses,
00178 PL2_scoresSym), 0)),
00179 yORDERED, standstat, splitctrl,
00180 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00181 expcovinf, &cutpoint,
00182 INTEGER(S3get_splitpoint(split)),
00183 &maxstat, splitstat);
00184
00185
00186
00187
00188
00189 itable = INTEGER(S3get_table(split));
00190 dxtransf = REAL(get_transformation(inputs, jselect));
00191 dweights = REAL(thisweights);
00192 for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
00193 itable[k] = 0;
00194 for (i = 0; i < nobs; i++) {
00195 if (dxtransf[k * nobs + i] * dweights[i] > 0) {
00196 itable[k] = 1;
00197 continue;
00198 }
00199 }
00200 }
00201
00202 Free(standstat);
00203 }
00204 if (maxstat == 0) {
00205 warning("no admissible split found\n");
00206
00207 if (j == 1) {
00208 S3set_nodeterminal(node);
00209 } else {
00210
00211 pvalue[jselect - 1] = 0.0;
00212 }
00213 } else {
00214 S3set_variableID(split, jselect);
00215 break;
00216 }
00217 } else {
00218 S3set_nodeterminal(node);
00219 break;
00220 }
00221 }
00222 }
00223
00224
00233 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00234
00235 SEXP ans;
00236
00237 PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00238 C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample),
00239 get_maxsurrogate(get_splitctrl(controls)),
00240 ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00241 PL2_jointtransfSym)));
00242
00243 C_Node(ans, learnsample, weights, fitmem, controls, 0);
00244 UNPROTECT(1);
00245 return(ans);
00246 }