Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
version 1
  • Loading branch information
yuz12012 committed Apr 29, 2017
1 parent 7655ee0 commit 6cd1212
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions parallel/para_gibbs.cu
Expand Up @@ -118,6 +118,11 @@ int main(int argc, char **argv){

load_data(argc, argv, &K, &y, &n);

if(K<nBlocks*nThreads){
printf("Too many threads for too little data!!!");
return 1;
}

/* starting values of hyperparameters */
a = 20;
b = 1;
Expand Down Expand Up @@ -215,7 +220,7 @@ __global__ void mergePosterior(int trials, float *dev_a_out,float *dev_b_out,int
h[i] = powf(i,(-1/(4+2)));
for (int m=0; m < M; m++) {
int *cDot = (int*) malloc(M*sizeof(int));
printf("%d\n",m);
//printf("%d\n",m);

memcpy(cDot, tDot, sizeof(int) * M);
cDot[m] = (curand(&state[id]) % (trials-1)) + 1;
Expand All @@ -231,24 +236,31 @@ __global__ void mergePosterior(int trials, float *dev_a_out,float *dev_b_out,int
float posterior_mean_a = posteriorMean(tDot, dev_a_out, M, trials);
float posterior_mean_b = posteriorMean(tDot, dev_b_out, M, trials);
float variance = (h[i]*h[i])/M;
printf("%d\n",trials);
// printf("%d\n",trials);
printf("%f, %f\n" , general_normoral(state, posterior_mean_a, variance), general_normoral(state, posterior_mean_b, variance));
}
}


__device__ float general_normoral(curandState *state, float mean, float variance) {
int id = threadIdx.x + blockIdx.x * blockDim.x;
return curand_normal(&state[id]) * variance + mean;
float n = curand_normal(&state[id]) * variance + mean;
// printf("normal is:%f\n",n);
return n;
}

__device__ float posteriorMean(int *tDot, float *dev_x_out, int M, int nThreads) {
float sum = 0;
for (int i=0; i < M; i++) {
int index = tDot[i] + i * nThreads; // trial m of posterior m (note: i = blockId)
sum += dev_x_out[index]; // posterior_m_tm
// printf("data is %f\n",dev_x_out[index]);

// printf("sum is %f\n",sum);

}
float mean = sum/M;
// printf("mean is %f\n",mean);
return mean;
}

Expand All @@ -269,7 +281,9 @@ __device__ float computeW(int *tDot, float *dev_x_out, float *dev_y_out, int M,
}

__device__ float normPDF(float x, float mean, float variance) {

float denominator = sqrtf(2*PI*(variance*variance));
//printf("the denominator is %f\n",PI);
float numerator = expf( -1 * (x-mean)*(x-mean) / (2*variance*variance) );
return numerator/denominator;
}
Expand Down Expand Up @@ -349,6 +363,8 @@ __global__ void seqMetroProcess(int K, int nBlocks, int *y, float *n, curandStat
float sum_logs=0;
for(int ii=0;ii<lengthPerBlock;ii++){
//*ptr refers to the value at address
// printf("sLogTheta:%f \n",*sLogTheta);

sum_logs = sum_logs + *sLogTheta;
sLogTheta++;
}
Expand All @@ -360,7 +376,7 @@ __global__ void seqMetroProcess(int K, int nBlocks, int *y, float *n, curandStat
b = sample_b(state,id, a, lengthPerBlock, flat_sum);

/* print hyperparameters. */
//printf("%f, %f\n", a, b);
// printf("%f, %f\n", a, b);
////////////save new output to the global array
*subA = a;
subA++;
Expand Down Expand Up @@ -415,6 +431,7 @@ __host__ void load_data(int argc, char **argv, int *K, int **y, float **n){
*/

__device__ float sample_a(curandState *state,int id, float a, float b, int K, float sum_logs){
//id = threadIdx.x + blockIdx.x * blockDim.x;

float sigma = 2.0;
float norm = curand_normal(&state[id]);
Expand All @@ -423,11 +440,13 @@ __device__ float sample_a(curandState *state,int id, float a, float b, int K, fl
if(proposal <= 0)
return a;

// printf("sum_logs:%f \n",sum_logs);

log_acceptance_ratio = (proposal - a) * sum_logs +
K * (proposal - a) * log(b) -
K * (lgamma(proposal) - lgamma(a));

U = curand_normal(&state[id]);
U = curand_uniform(&state[id]);
//printf("log_acceptance result:%f \n",log_acceptance_ratio);
//printf("log U result:%f \n",log(U));

Expand All @@ -446,6 +465,7 @@ __device__ float sample_a(curandState *state,int id, float a, float b, int K, fl
*/

__device__ float sample_b(curandState *state,int id, float a, int K, float flat_sum){
//id = threadIdx.x + blockIdx.x * blockDim.x;

float hyperA = K * a + 1;
float hyperB = flat_sum;
Expand Down

0 comments on commit 6cd1212

Please sign in to comment.