/******************************************************************************
*
* gambler.c
*
* Policy Iteration for the Gambler problem (example 4.3)
*
* Huy N Pham <huy_n_pham@yahoo.com>
*
******************************************************************************/


#include <stdio.h>
#include <math.h>

/*
 * Prototypes
 */
void show_V();
void show_POL();
void initialize();
void evaluate_policy();
double update_state_val();
void improve_policy();
int argmax(int s);
int min(int a, int b);


/* How close the approximation should be */
#define THETA 0.000001

/* How far sighted should we be */
#define GAMMA 1		// <-- Very

/* Probability for the coin to come up head (ie, win) */
#define ODDVAL 0.4

/* V[s] represents the values of 99 normal states and 2 terminal states */
double V[101];		// <-- s = {0..100}, where V[0] and V[100] are ALWAYS 0

/* P[s][a][s'] represents the prob of ending up in s' from s by doing a */
double P[100][51][101];	// <-- s = {1..99}, a = {1..50}, s' = {0..100}

/* POL[s] represent the amount to bet at every state s */
int POL[100];		// <-- s = {1..99}

/* R[s][a][s'] represent the reward recieved for getting to s' from s by doing a */
double R[100][51][101]; // <-- s = {1..99}, a = {1..50}, s' = {0..100}

/* Is the current policy stable? */
int policy_stable; 	// <-- No if 0, Yes if otherwise.


/*
 * main()
 * 
 * Policy Iteration
 */			
int main(int argc, char *argv[]){

	initialize(); 		
	
	do{
		
		evaluate_policy(); 	// <== Policy Evaluation (for current policy)
		improve_policy(); 	// <== Policy Improvement
		
	}while (!policy_stable);

	show_V();
	show_POL();

	exit(0);
}


/*
 * initialize()
 * 
 * Data structures initialization
 */
void initialize(){
	
	int i, j, k;
	
	/*
	 * Info and User inputs
	 */
	printf("\n\n");
	printf("Policy Iteration for the Gambler problem in example 4.3\n");
	printf("by Huy N Pham <huy_n_pham@yahoo.com>\n");
	printf("\n");

	/*
	 * Start with a zero state function and 
	 * a random policy (alway bet 1 dollar)
	 * 
	 * Note that Terminal states values are also initialized to 0
	 */
	for(i = 0; i < 101; i++){ // <-- Terminal states are included
		V[i] = 0;
		POL[i] = 1;
	}
	
	/*
	 * Initialize P[s][a][s'] 
	 */
	for(i = 1; i < 100; i++){ // <== 99 possible s
		for(j = 1; j <= min(i, 100 - i); j++){ // <== 50 possible a
			for(k = 0; k < 101; k++){ // <== 101 possible s'
				if (k == (i + j)) 	// <== the coin came up Head
					P[i][j][k] = ODDVAL;
				else if (k == (i - j)) 	// <== the coin came up Tail
					P[i][j][k] = 1 - ODDVAL; 
				else 			
					P[i][j][k] = 0; // <== no other case is possible
			}
				
		}
	}
	
	/*
	 * Initialize R[s][a][s']
	 */
	for(i = 1; i < 100; i++){ // <== 99 possible s
		for(j = 1; j <= min(i, 100 - i); j++){ // <== 50 possible a
			for(k = 0; k < 101; k++){ // <== 101 possible s'
				if (k == 100) 		// <== We won the game
					R[i][j][k] = 1;
// 				else if (k == 0)	// <== damm!
// 					R[i][j][k] = -1;
				else 
					R[i][j][k] = 0;
			}
		}
	}

	return;
}


/*
 * show_V()
 * 
 * Printout the state values for all states
 */
void show_V(){
	
	int i;
	
	printf("   STATE VALUES\n");
	printf("        |     1     2     3     4     5     6     7     8     9    10\n");
	printf("   -----+------------------------------------------------------------\n");
	printf("      0 |");
	for(i = 1; i < 100; i++){
		printf(" %5.2f", V[i]);
		if ((i % 10) == 0){
			printf("\n    %3d |", i);
		}
	}
	printf("\n");
	printf("\n");
}



/*
 * show_POL()
 * 
 * Printout the policy 
 */
void show_POL(){

	int i;
	
	printf("   POLICY\n");
	printf("        |   1   2   3   4   5   6   7   8   9  10\n");
	printf("   -----+----------------------------------------\n");
	printf("      0 |");
	for(i = 1; i < 100; i++){
		printf(" %3d", POL[i]);
		if ((i % 10) == 0){
			printf("\n    %3d |", i);
		}
	}
	printf("\n");
	printf("\n");
	
}

 
/*
 * evaluate_policy()
 * 
 * Policy Evaluation: Iteratively update V(s) for all states s.
 */	
void evaluate_policy(){

	int i;

	double delta = 0;
	double v;
	
	/*
	 * Iterate until V(s) converse
	 */
	do{
		delta = 0;
	
		/*
		 * Iteratively update each state in the state space
		 * 
		 * Note: Since terminal states should always have values
		 * of zero (once we are in these state, no more reward 
		 * can be possible), they are excluded here and won't 
		 * get updated.
		 */
		for(i = 1; i < 100; i++){ // <-- terminal states excluded
			v = V[i];
			V[i] = update_state_val(i);
			delta = (delta > fabs(v - V[i]) ? delta : fabs(v - V[i]));
		}
	}while (delta >= THETA);
	
	return;
}


/*
 * update_state_val()
 * 
 * Update a state value using Bellman equation
 */
double update_state_val(int s){
	
	double v;
	int a;

	/*
	 * In this particular problem, 
	 * from any given state s, we have only two possible s':
	 * 
	 * 	s' = s + POL[s], with probability ODDVAL
	 *  OR:	s' = s - POL[s], with probability (1 - ODDVAL)
	 * 
	 * Therefore, our equation has 2 terms
	 */
	a = POL[s];	

	v = P[s][a][s+a] * (R[s][a][s+a] + (GAMMA * V[s+a])) + \
	    P[s][a][s-a] * (R[s][a][s-a] + (GAMMA * V[s-a]));


	return v;	
}


/*
 * improve_policy()
 * 
 * Use the current state value function to select the best action for every steps
 */
void improve_policy(){

	int i;
	int old_action;

	policy_stable = 1;
	
	for(i = 1; i < 100; i++){ // <== 99 possible states
		old_action = POL[i];
		POL[i] = argmax(i);
		if (POL[i] != old_action)
			policy_stable = 0;		
	}
		
	return;
}


/*
 * argmax()
 * 
 * Return the best (greedy) action for s
 */
int argmax(int s){

	int i;
	int best_action = 1;
	float best_reward = 0;
	float scoreboard[51];
	
	/*
	 * Try every possible action and record their corresponding rewards
	 */
	for(i = 1; i <= min(s, 100 - s); i++){ 
		scoreboard[i] =     ODDVAL * (R[s][i][s+i] + (GAMMA * V[s+i])) + \
	    			(1-ODDVAL) * (R[s][i][s-i] + (GAMMA * V[s-i]));	
	}
	
	/*
	 * Which action was the best?
	 */
	for(i = 1; i <= min(s, 100-s); i++){
		if (scoreboard[i] > best_reward){
			best_reward = scoreboard[i];
			best_action = i;
		}
	}
	
	return best_action;
	
}
	

/*
 * min()
 * 
 * Return the smaller of 2 intergers
 */
int min(int a, int b){
	
	return ((a < b) ? a : b);

}