/******************************************************************************
* blackjack.c
*
* e-soft, on-policy Monte Carlo control for Blackjack 
* (as described in example 5.1)
*
* 
* 
* NOTES:
* 
* - We don't consider cases in which player has a natural, or
*   has less than twelve, because the action to be taken by the 
*   player in those cases is obvious.
*
* - The game described in this example doesn't recognize the
*   the situation in which player reaches 5 cards without 
*   going busted as a special situation.
*
******************************************************************************/


#include <stdio.h>
#include <stdlib.h>
#include <string.h>



void initialize();
int request_card();
int select_action(int p, int d, int u);
void show_POL();
void update_card_sum(int *current, int *new, int *usable);



/* Show episodes on screen? */
#define SHOW_EPISODES 0 

/* How many sample episode to simulate */
#define MAX_EPISODES 5000000

/* How "soft" the policy should be */
#define EPSILON 0.1

/* What is what */
#define HIT     0
#define STICK   1

/* 
 * Let  p = player_sum, 
 *      d = dealer_showing, 
 *      u = usable_ace,
 *      a = action.
 * 
 * Q[p][d][u][a]represents the 400 action values Q(s,a), where every 
 * state s is defined by a triple (p,d,u)
 */
double Q[22][11][2][2]; // <-- p={12..21}, d={1..10}, u={0..1}, a={0..1}

/*
 * POL[p][d][u][a] represents the probability of taking action a
 * at state s(p,d,u) under the current policy
 */
double POL[22][11][2][2]; // <-- p={12..21}, d={1..10}, u={0..1}, a={0..1}

/*
 * W[p][d][u][a] holds the number of times a state-action pair (s,a)
 * has been encoutered in the simulation.
 * Needed to incrementally calculate the average of Q(s,a). (see section
 * 2.5 and 5.7 for details)
 */
int W[22][11][2][2]; // <-- p={12..21}, d={1..10}, u={0..1}, a={0..1}

/*
 * To be used as nodes in the linked list of state-action pairs.
 */
struct sa_pair{
        int p;
        int d;
        int u;
        int a;
        struct sa_pair *next;
};



/*
 * main()
 * 
 * 
 */
int main(int argc, char *argv[]){

        struct sa_pair *appeared_pairs, *new_pair, *temp_ptr;
        int i, p, d, u, dealer_u, a, hit_result;
        int player_busted, dealer_busted;
	int dealer_has_natural;
        int episode_ending, card_count;

	printf("\n");
	printf("Doing %d episodes.\n", MAX_EPISODES);
	printf("Change SHOW_EPISODES to 1 if you want to see them displayed.\n");
	printf("\n");

        initialize();

        /*
         * Repeat "forever":
         * 
         *      1. Simulate an episode and update action-values for
         *         every state-action pair that appeared in it.
         * 
         *      2. Greedily improve the policy according to the updated
         *         action value function
         */
        for(i = 0; i < MAX_EPISODES; i++){ 

                appeared_pairs = NULL;
                player_busted = 0; dealer_busted = 0;
		dealer_has_natural = 0;
		p = 0; u = 0; 

                /*
                 * Give 2 cards to the player and 1 to the dealer
                 */
		hit_result = request_card(); // <-- player's first card
		update_card_sum(&p, &hit_result, &u);
		hit_result = request_card(); // <-- player's second card
		update_card_sum(&p, &hit_result, &u);
		if ((p < 12)||(p == 21)){ // <-- We don't consider these cases
			continue;
		}
		d = request_card(); // <-- dealer's first card
                if(SHOW_EPISODES) printf("New episode: p=%d, d=%d, u=%d\n", p, d, u);

                /*
                 * Player plays
                 */
                while(select_action(p, d, u) == HIT){

                        /*
                         * Add this state-action pair to the list of appeared (s,a)
                         */
                        new_pair = (struct sa_pair *) malloc(sizeof(struct sa_pair));
                        new_pair->p = p;
                        new_pair->d = d;
                        new_pair->u = u;
                        new_pair->a = HIT;
                        new_pair->next = appeared_pairs;
                        appeared_pairs = new_pair;

                        /*
                         * Now request for a card
                         */
                        hit_result = request_card();
			update_card_sum(&p, &hit_result, &u);
			if (p > 21){
                                player_busted = 1;
                                if(SHOW_EPISODES) printf("Player hit: %d and went busted.\n", hit_result);
                                break;
			}

                        if(SHOW_EPISODES) printf("Player hit: %d and becomes p=%d, d=%d, u=%d\n", hit_result, p, d, u);

                }
		
                /*
                 * If player was not busted, record the action value of the last (s,a)
                 */
                if (!player_busted){
                        new_pair = (struct sa_pair *) malloc(sizeof(struct sa_pair));
                        new_pair->p = p;
                        new_pair->d = d;
                        new_pair->u = u;
                        new_pair->a = STICK;
                        new_pair->next = appeared_pairs;
                        appeared_pairs = new_pair;
                }

                /*
                 * Dealer plays
                 */
		if (!player_busted){

			/*
			 * An ace as the first and only card is a usable ace 
			 * and is counted as 11
			 */
                	if (d == 1){
                        	dealer_u = 1; 
                        	d = 11;
                	}
                	else{ 
                        	dealer_u = 0;
                	}
			
			/*
			 * Request for more card(s)
			 */
 			card_count = 1;
                	while (d < 17){ // <-- has to hit if less than 17

				hit_result = request_card();
				update_card_sum(&d, &hit_result, &dealer_u);
				if (d > 21){
                                	dealer_busted = 1;
                                	if(SHOW_EPISODES) printf("Dealer hit: %d and went busted.\n", hit_result);
                                	break;
				}
                        	card_count++;
				
				/*
				 * The dealer is said to have a natural if he has 2 cards AND:
				 *     	- Both of them are aces
				 * 	  OR
				 *     	- 1 is an ace, and 1 is a ten
				 */
				if((card_count == 2) && (((hit_result == 1)&&(d == 13)) || (d == 21))){
					dealer_has_natural == 1;
					if (SHOW_EPISODES) printf("Dealer has a natural (%d, %d)\n", d, hit_result);
					break;
				}

                        	if(SHOW_EPISODES) printf("Dealer hit: %d and becomes d=%d, dealer_u=%d\n", hit_result, d, dealer_u);
		       	}
			
                }

                /*
                 * Who won
                 */
		if ((dealer_has_natural) || (player_busted)){
		       episode_ending = -1;
		}
		else if (dealer_busted){
		       episode_ending = 1;
		}
		else{ 
		       if (d < p)
       			       episode_ending = 1;
		       else if (d == p)
       			       episode_ending = 0;
		       else
       			       episode_ending = -1;
		} 		  
                if(SHOW_EPISODES) printf("---> Episode ended: %d\n", episode_ending);

                /*
                 * Now update the action-values for
                 * every state-action pair that appeared in this episode.
                 */
                if(SHOW_EPISODES) printf("---> Involved state-action pairs (in reverse order): \n");            
                while(appeared_pairs != NULL){

                        p = appeared_pairs->p;
                        d = appeared_pairs->d;
                        u = appeared_pairs->u;
                        a = appeared_pairs->a;
                        if(SHOW_EPISODES) printf("        %d, %d, %d, %d\n", p, d, u, a);

                        /*
                         * Incrementally update the average of Q(s,a):
                         * (See section 2.5)
                         */
                        Q[p][d][u][a] = Q[p][d][u][a] + (episode_ending - Q[p][d][u][a])/(W[p][d][u][a] + 1); 
                        W[p][d][u][a]++;

			temp_ptr = appeared_pairs->next;
			free(appeared_pairs);
                        appeared_pairs = temp_ptr;

                }
		if(SHOW_EPISODES) printf("\n");

                /*
                 * Finally, improve the policy according to the updated value function
                 */
                for(p = 12; p < 22; p++){
                        for(d = 1; d < 11; d++){
                                for(u = 0; u < 2; u++){
                                        if (Q[p][d][u][HIT] > Q[p][d][u][STICK]){
                                                POL[p][d][u][HIT] = 1 - EPSILON + EPSILON/2;
                                                POL[p][d][u][STICK] = EPSILON/2;
                                        }
                                        else{
                                                POL[p][d][u][STICK] = 1 - EPSILON + EPSILON/2;
                                                POL[p][d][u][HIT] = EPSILON/2;
                                        }
                                }
                        }
                }
        }

        show_POL();
	
	exit(0);
	
}


/*
 * intitialize()
 * 
 * 
 */
void initialize(){

        int p, d, u, a;
	
	/*
	 * Provide a seed to the random generator
	 */
// 	srand(time(NULL));

        /*
         * Start with an arbitrary (zero) action value function
         */
        for(p = 12; p < 22; p++){
                for(d = 1; d < 11; d++){
                        for(u = 0; u < 2; u++){
                                for(a = 0; a < 2; a++){
                                        Q[p][d][u][a] = 0;
                                        W[p][d][u][a] = 0;
                                }
                        }
                }
        }

        /*
         * Start with an arbitrary (always hit) policy
         */
        for(p = 12; p < 22; p++){
                for(d = 1; d < 11; d++){
                        for(u = 0; u < 2; u++){
                                POL[p][d][u][HIT] = 1 - EPSILON + EPSILON/2;
                                POL[p][d][u][STICK] = EPSILON/2;
                        }
                }
        }

        return;
}


/*
 * request_card()
 *
 *  
 */
int request_card(){

	int card;

	/*
	 * There are 13 different cards in each color
	 * Each should has an equal chance to appear
	 */
// 	card = (random() % 13) + 1;  // <-- Won't do. Type "man rand" to see why
 	card = 1 + (int) (13.0*rand()/(RAND_MAX+1.0)); 	// <-- {1..13}
        
	/*
	 * All the cards after 9 are counted as 10
	 */
	if (card > 9)
		return 10;
	else
		return card;
	

}


/*
 * select_action()
 * 
 * decide whether we should "hit" or "stick"
 */
int select_action(int p, int d, int u){

        /*
         * Take the greedy action (1- EPSILON + EPSILON/2) percents 
         * of the time, and non-greedy (EPSILON/2) percents of the times
         */
        if ((rand()/(RAND_MAX+1.0)) < (EPSILON/2)){ // <-- non-greedy action should be taken
                if (POL[p][d][u][HIT] < POL[p][d][u][STICK]){ // <-- is HIT the "minority" action?
                        return HIT;
                }
                else{
                        return STICK;
                }
        }
        else{   
                if (POL[p][d][u][HIT] < POL[p][d][u][STICK]){ 
                        return STICK;
                }
                else{
                        return HIT;
                }
        }

}


/*
 * update_card_sum()
 * 
 * 
 */
void update_card_sum(int *current, int *new, int *usable){
 	
	int p = *current;
 	int hit_result = *new;
	int u = *usable;
	
        if ((p + hit_result) > 21){ // <-- potentially busted
                if (u == 1){ // <-- No problem, since he has a usable ace
                        p = p - 10 + hit_result;
                        u = 0;
                }
                else{ 	
			p = p + hit_result; // <-- busted
                }
        }
        else{
                if ((hit_result == 1) && ((p + 11) < 22)){ // <-- Just got a usable ace
                        p = p + 11;
                        u = 1;
                }
                else{
                        p = p + hit_result;
                }
        }
	
	*current = p;
 	*new = hit_result;
	*usable = u;
	
	return;
}


/*
 * show_POL()
 * 
 * 
 */
void show_POL(){

        int action0[22][11]; // <-- No usable ace
        int action1[22][11]; // <-- usable ace
        int i, j;

        /*
         * Compile the "soft" policy into 2 deterministic policies 
         */
        for(i = 12; i < 22; i++){
                for(j = 1; j < 11; j++){

                        /*
                         * No usable ace
                         */
                        if (POL[i][j][0][HIT] < POL[i][j][0][STICK]){ // <-- HIT is not the favorite action
                                action0[i][j] = STICK;
                        }
                        else // <-- Hit is the favorite action
                                action0[i][j] = HIT;

                        /*
                         * Usable ace
                         */     
                        if (POL[i][j][1][HIT] < POL[i][j][1][STICK]) // <-- HIT is not the favorite action
                                action1[i][j] = STICK;
                        else // <-- Hit is the favorite action
                                action1[i][j] = HIT;    

                }
        }

        /*
         * Now show them
         */
        printf(" Policy for episodes that start with a usable ace\n");
        printf("     |  1  2  3  4  5  6  7  8  9 10\n");
        printf("-----+-------------------------------\n");
        for(i = 21; i >11 ; i--){
                printf(" %3d |", i);
                for(j = 1; j < 11; j++){
                        printf("  %d", action1[i][j]);
                }
                printf("\n");
        }
	printf("\n");

        printf(" Policy for episodes that start without a usable ace\n");
        printf("     |  1  2  3  4  5  6  7  8  9 10\n");
        printf("-----+-------------------------------\n");
        for(i = 21; i >11 ; i--){
                printf(" %3d |", i);
                for(j = 1; j < 11; j++){
                        printf("  %d", action0[i][j]);
                }
                printf("\n");
        }


        printf("\n");

}