00001 #include "mex.h"
00002 #include <stdio.h>
00003 #include <math.h>
00004 #include <memory.h>
00005 #include "mexhead.h"
00006 #include "Doc/mixNprob2d_docstring.h"
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036 #define NARGIN_MIN 3
00037 #define NARGIN_MAX 4
00038 #define NARGOUT_MIN 1
00039 #define NARGOUT_MAX 1
00040
00041 #define ARG_i1 0
00042 #define ARG_i2 1
00043 #define ARG_model 2
00044 #define ARG_logmode 3
00045
00046 #define ARG_p 0
00047
00048 static const char *progname = "mixNprob2d";
00049 #define PROGNAME mixNprob2d
00050 static const char *in_specs[NARGIN_MAX] = {
00051 "RM",
00052 "RM",
00053 "RM",
00054 "IS(1)"};
00055 static const char *in_names[NARGIN_MAX] = {
00056 "i1",
00057 "i2",
00058 "model",
00059 "logmode"};
00060 static const char *out_names[NARGOUT_MAX] = {
00061 "p"};
00062
00063
00064 #define SHORTNAME m2d
00065
00066
00067
00068
00069 #define inx_lambda 0
00070 #define inx_mu1 1
00071 #define inx_mu2 2
00072 #define inx_sig11 3
00073 #define inx_sig22 4
00074 #define inx_sig12 5
00075
00076 #define ActiveData(x) (!isnan(x))
00077
00078 #define sqr(x) ((x)*(x))
00079
00080 static
00081 void
00082 control(double **p,
00083 double **i1,
00084 double **i2,
00085 double **model,
00086 int log_mode,
00087 int K,
00088 int M,
00089 int N)
00090
00091 {
00092 void * mxCalloc();
00093 int k, m, n;
00094 double det, rho;
00095 double *normF, *sig1F, *sig2F, *crosF, *halfF;
00096 double *lprob;
00097 double lprob_max;
00098 const double lprob_max_init = -1E64;
00099 double NaN;
00100 double j1, j2;
00101 double sum;
00102 const double log2pi = log(2 * M_PI);
00103
00104 lprob = mxCalloc(K, sizeof(double));
00105
00106
00107
00108
00109 normF = mxCalloc(K, sizeof(double));
00110 sig1F = mxCalloc(K, sizeof(double));
00111 sig2F = mxCalloc(K, sizeof(double));
00112 crosF = mxCalloc(K, sizeof(double));
00113 halfF = mxCalloc(K, sizeof(double));
00114
00115 for (k = 0; k < K; k++) {
00116
00117
00118 det = model[k][inx_sig11] * model[k][inx_sig22] -
00119 sqr(model[k][inx_sig12]);
00120
00121 normF[k] = log(model[k][inx_lambda]) - log2pi - log(det)/2;
00122
00123 sig1F[k] = 1/sqrt(model[k][inx_sig11]);
00124 sig2F[k] = 1/sqrt(model[k][inx_sig22]);
00125
00126 rho = model[k][inx_sig12] * sig1F[k] * sig2F[k];
00127
00128 crosF[k] = -2*rho;
00129
00130 halfF[k] = -1/(2 * (1 - sqr(rho)));
00131 }
00132 NaN = (rho - rho)/(rho - rho);
00133 for (n = 0; n < N; n++)
00134 for (m = 0; m < M; m++) {
00135
00136 if (!ActiveData(i1[n][m]) || !ActiveData(i2[n][m])) {
00137 p[n][m] = NaN;
00138 continue;
00139 }
00140
00141 for (lprob_max = lprob_max_init, k = 0; k < K; k++) {
00142
00143 j1 = sig1F[k] * (i1[n][m] - model[k][inx_mu1]);
00144 j2 = sig2F[k] * (i2[n][m] - model[k][inx_mu2]);
00145
00146 lprob[k] =
00147 normF[k] + halfF[k] * (sqr(j1) + sqr(j2) + crosF[k] * j1 * j2);
00148 if (lprob[k] > lprob_max) lprob_max = lprob[k];
00149 }
00150
00151 for (sum = k = 0; k < K; k++)
00152 sum += exp(lprob[k] - lprob_max);
00153
00154 if (log_mode)
00155 p[n][m] = log(sum) + lprob_max;
00156 else
00157 p[n][m] = sum * exp(lprob_max);
00158 }
00159
00160 mxFree((char *) lprob);
00161 mxFree((char *) normF);
00162 mxFree((char *) sig1F);
00163 mxFree((char *) sig2F);
00164 mxFree((char *) crosF);
00165 mxFree((char *) halfF);
00166 }
00167
00168
00169
00170
00171
00172
00173
00174 #ifdef StaticP
00175 StaticP
00176 #endif
00177 void
00178 mexFunction(
00179 int nlhs,
00180 mxArray *plhs[],
00181 int nrhs,
00182 const mxArray *prhs[])
00183 {
00184 int log_mode;
00185 int m, n;
00186 int k;
00187 double **p_2d, **i1_2d, **i2_2d, **m_2d;
00188 char errstr[120];
00189
00190
00191 if (nrhs < 0) {
00192 plhs[0] = mxt_PackSignature((mxt_Signature) (-nrhs),
00193 NARGIN_MIN, NARGIN_MAX,
00194 NARGOUT_MIN, NARGOUT_MAX,
00195 in_names, in_specs, out_names, docstring);
00196 return;
00197 }
00198
00199
00200
00201 if ((nrhs < NARGIN_MIN) || (nrhs > NARGIN_MAX))
00202 mexErrMsgTxt((snprintf(errstr, sizeof(errstr),
00203 "Expect %d <= input args <= %d",
00204 NARGIN_MIN, NARGIN_MAX), errstr));
00205 if ((nlhs < NARGOUT_MIN) || (nlhs > NARGOUT_MAX))
00206 mexErrMsgTxt((snprintf(errstr, sizeof(errstr),
00207 "%s: Expect %d <= output args <= %d",
00208 progname, NARGOUT_MIN, NARGOUT_MAX), errstr));
00209 mexargparse(nrhs, prhs, in_names, in_specs, NULL, progname);
00210 start_sizechecking();
00211 sizeinit(prhs[ARG_i1]);
00212 sizeagree(prhs[ARG_i2]);
00213 sizecheck_msg(progname, in_names, ARG_i1);
00214 sizeinit(prhs[ARG_model]);
00215 sizeisM(6);
00216 sizecheck_msg(progname, in_names, ARG_model);
00217
00218
00219
00220
00221
00222 m = (int) mxGetM(prhs[ARG_i1]);
00223 n = (int) mxGetN(prhs[ARG_i1]);
00224 k = (int) mxGetN(prhs[ARG_model]);
00225
00226
00227 if (k <= 0)
00228 mexErrMsgTxt((snprintf(errstr, sizeof(errstr),
00229 "%s: Must have a non-empty model",
00230 progname), errstr));
00231
00232
00233
00234
00235
00236
00237 plhs[ARG_p] = mxCreateDoubleMatrix(m, n, mxREAL);
00238
00239
00240 if (nrhs > ARG_logmode)
00241 log_mode = (int) mxGetScalar(prhs[ARG_logmode]) > 0;
00242 else
00243 log_mode = 1;
00244
00245
00246
00247
00248
00249 p_2d = mxt_make_matrix2(plhs[ARG_p], -1, -1, 0.0);
00250 i1_2d = mxt_make_matrix2(prhs[ARG_i1], -1, -1, 0.0);
00251 i2_2d = mxt_make_matrix2(prhs[ARG_i2], -1, -1, 0.0);
00252 m_2d = mxt_make_matrix2(prhs[ARG_model], -1, -1, 0.0);
00253
00254 control(p_2d, i1_2d, i2_2d, m_2d, log_mode, k, m, n);
00255
00256 mxFree(p_2d);
00257 mxFree(i1_2d);
00258 mxFree(i2_2d);
00259 mxFree(m_2d);
00260 }
00261
00262
00263
00264 #ifdef MEX2C_TAIL_HOOK
00265 #include "mex2c_tail.h"
00266 #endif
00267