/************************************************************************
* blackjack.c
*
* e-soft, on-policy Monte Carlo for Blackjack (as described in example 5.1)
*
* 
* 
* NOTES:
* 
* - We don't consider cases in which player has a natural, or
*   has less than twelve, because the action to be taken by the 
*   player in those cases is obvious.
*
**************************************************************************/


#include <stdio.h>
#include <stdlib.h>
#include <string.h>



/* Show episodes on screen? */
#define SHOW_EPISODES 0 

/* How many sample episode to simulate */
#define MAX_EPISODES 5000000

/* How "soft" the policy should be */
#define EPSILON 0.1

/* What is what */
#define HIT     0
#define STICK   1

/* 
 * Let  p = player_sum, 
 *      d = dealer_showing, 
 *      u = usable_ace,
 *      a = action.
 * 
 * Q[p][d][u][a]represents the 400 action values Q(s,a), where every 
 * state s is defined by a triple (p,d,u)
 */
double Q[22][11][2][2]; /// <-- p={12..21}, d={1..10}, u={0..1}, a={0..1}/

/*
 * POL[p][d][u][a] represents the probability of taking action a
 * at state s(p,d,u) under the current policy
 */
double POL[22][11][2][2]; /// <-- p={12..21}, d={1..10}, u={0..1}, a={0..1}/

/*
 * W[p][d][u][a] holds the number of times a state-action pair (s,a)
 * has been encoutered in the simulation.
 * Needed to incrementally calculate the average of Q(s,a). 
 * See Sections 2.4 (p.31) and 5.6 (p.109) in the 2nd edition for details.
 */
int W[22][11][2][2]; /// <-- p={12..21}, d={1..10}, u={0..1}, a={0..1}/

/*
 * To be used as nodes in the linked list of state-action pairs.
 */
struct sa_pair{
        int p;
        int d;
        int u;
        int a;
        struct sa_pair *next;
};



void initialize();
int request_card();
int select_action(int p, int d, int u);
void show_POL();
void update_card_sum(int *current, int *new, int *usable);


/*
 * main()
 * 
 * 
 */
int main(int argc, char *argv[]){

   struct sa_pair *appeared_pairs, *new_pair, *temp_ptr;
   int i, p, d, u, dealer_u, a, hit_result;
   int player_busted, dealer_busted;
   int episode_ending, card_count;

   printf("\n");
   printf("Doing %d episodes.\n", MAX_EPISODES);
   printf("Change SHOW_EPISODES to 1 if you want to see them displayed.\n");
   printf("\n");

   initialize();

   /*
    * Repeat "forever":
    * 
    *      1. Simulate an episode and update action-values for
    *         every state-action pair that appeared in it.
    * 
    *      2. Greedily improve the policy according to the updated
    *         action value function
    */
   for(i = 0; i < MAX_EPISODES; i++){ 

      appeared_pairs = NULL;
      player_busted = 0; dealer_busted = 0;
      p = 0; u = 0; 

      /*
       * Give 2 cards to the player and 1 to the dealer
       */
      hit_result = request_card(); /// <-- player's first card/
      update_card_sum(&p, &hit_result, &u);
      hit_result = request_card(); /// <-- player's second card/
      update_card_sum(&p, &hit_result, &u);
      if ((p < 12)||(p == 21)){ /// <-- We don't consider these cases/
            continue;
      }
      d = request_card(); /// <-- dealer's first card/
      if(SHOW_EPISODES) printf("New episode: p=%d, d=%d, u=%d\n", p, d, u);

      /*
       * Player plays
       */
      while(select_action(p, d, u) == HIT){
         
         /*
          * Add this state-action pair to the list of appeared (s,a)
          */
         new_pair = (struct sa_pair *) malloc(sizeof(struct sa_pair));
         new_pair->p = p;
         new_pair->d = d;
         new_pair->u = u;
         new_pair->a = HIT;
         new_pair->next = appeared_pairs;
         appeared_pairs = new_pair;
         
         /*
          * Now request for a card
          */
         hit_result = request_card();
         update_card_sum(&p, &hit_result, &u);
         if (p > 21){
            player_busted = 1;
            if(SHOW_EPISODES) 
               printf("Player hit: %d and went busted.\n", hit_result);
               break;
            }
         
           if(SHOW_EPISODES) 
  printf("Player hit: %d and becomes p=%d, d=%d, u=%d\n", hit_result, p, d, u);
         }
           
         /*
          * If player was not busted, 
          * record the action value of the last (s,a)
          */
         if (!player_busted){
            new_pair = (struct sa_pair *) malloc(sizeof(struct sa_pair));
            new_pair->p = p;
            new_pair->d = d;
            new_pair->u = u;
            new_pair->a = STICK;
            new_pair->next = appeared_pairs;
            appeared_pairs = new_pair;
         }
          
         /*
          * Dealer plays
          */
         if (!player_busted){

           /*
            * An ace as the first and only card is a usable ace 
            * and is counted as 11
            */
            if (d == 1){
                        dealer_u = 1; 
                        d = 11;
            }
            else{ 
                        dealer_u = 0;
            }
      
           /*
            * Request for more card(s)
            */
            card_count = 1;
            while (d < 17){ /// <-- has to hit if less than 17/

               hit_result = request_card();
               update_card_sum(&d, &hit_result, &dealer_u);
               if (d > 21){
                           dealer_busted = 1;
                           if(SHOW_EPISODES) 
                        printf("Dealer hit: %d and went busted.\n", hit_result);
                           break;
               }
               card_count++;
        
               if(SHOW_EPISODES) 
printf("Dealer hit: %d and becomes d=%d, dealer_u=%d\n",hit_result,d,dealer_u);
             }
      
         }

         /*
          * Who won
          */
         if (player_busted){
            episode_ending = -1;
         }
         else if (dealer_busted){
            episode_ending = 1;
         }
         else{ 
            if (d < p)
               episode_ending = 1;
            else if (d == p)
               episode_ending = 0;
            else
               episode_ending = -1;
         }
       
         if(SHOW_EPISODES) 
             printf("---> Episode ended: %d\n", episode_ending);

         /*
          * Now update the action-values for
          * every state-action pair that appeared in this episode.
          */
         if(SHOW_EPISODES) 
            printf("---> Involved state-action pairs (in reverse order): \n");            
         while(appeared_pairs != NULL){

            p = appeared_pairs->p;
            d = appeared_pairs->d;
            u = appeared_pairs->u;
            a = appeared_pairs->a;
            if(SHOW_EPISODES) printf("        %d, %d, %d, %d\n", p, d, u, a);

            /*
             * Incrementally update the average of Q(s,a):
             * (See section 2.5)
             */
            Q[p][d][u][a] = 
                Q[p][d][u][a] + (episode_ending - Q[p][d][u][a])/(W[p][d][u][a] + 1); 
            W[p][d][u][a]++;

            temp_ptr = appeared_pairs->next;
            free(appeared_pairs);
            appeared_pairs = temp_ptr;
         }
         if(SHOW_EPISODES) printf("\n");

         /*
          * Finally, improve the policy according to the updated value function
          */
         for(p = 12; p < 22; p++){
            for(d = 1; d < 11; d++){
               for(u = 0; u < 2; u++){
                  if (Q[p][d][u][HIT] > Q[p][d][u][STICK]) {
                     POL[p][d][u][HIT] = 1 - EPSILON + EPSILON/2;
                     POL[p][d][u][STICK] = EPSILON/2;
                  }
                  else{
                     POL[p][d][u][STICK] = 1 - EPSILON + EPSILON/2;
                     POL[p][d][u][HIT] = EPSILON/2;
                  }
               }
            }
         }
    }   // end of the last episode

    show_POL();
  
    exit(0);
  
}


/*
 * intitialize()
 * 
 * 
 */
void initialize(){

   int p, d, u, a;

   /*
    * Start with an arbitrary (zero) action value function
    */
   for(p = 12; p < 22; p++) {
      for(d = 1; d < 11; d++){
         for(u = 0; u < 2; u++){
            for(a = 0; a < 2; a++){
                                   Q[p][d][u][a] = 0;
                                   W[p][d][u][a] = 0;
            }
         }
      }
   }

   /* 
    * Start with an arbitrary (always hit) policy
    */
   for(p = 12; p < 22; p++){
      for(d = 1; d < 11; d++){
         for(u = 0; u < 2; u++){
                                POL[p][d][u][HIT] = 1 - EPSILON + EPSILON/2;
                                POL[p][d][u][STICK] = EPSILON/2;
            }
         }
      }
   return;
}


/*
 * request_card()
 *
 *  
 */
int request_card(){

   int card;

  /*
   * There are 13 different cards in each color
   * Each should has an equal chance to appear
   */
   // card = (random() % 13) + 1;  // <-- Won't do. Type "man rand" to see why/
   card = 1 + (int) (13.0*rand()/(RAND_MAX+1.0));   /// <-- {1..13}/
        
  /*
   * All the card after 9 are counted as 10
   */
   if (card > 9)
      return 10;
   else
      return card;
}


/*
 * select_action()
 * 
 * decide whether we should "hit" or "stick"
 */
int select_action(int p, int d, int u){

        /*
         * Take the greedy action (1- EPSILON + EPSILON/2) percents 
         * of the time, and non-greedy (EPSILON/2) percents of the times
         */

   if ( (rand()/(RAND_MAX+1.0)) < (EPSILON/2) ) { 
                  /// <-- non-greedy action should be taken/
      if (POL[p][d][u][HIT] < POL[p][d][u][STICK]) { 
                  /// <-- is HIT the "minority" action?/
              return HIT;
      }
      else {
              return STICK;
      }
   }
   else {   
      if (POL[p][d][u][HIT] < POL[p][d][u][STICK]) { 
              return STICK;
      }
      else {
              return HIT;
      }
   }
}


/*
 * update_card_sum()
 * 
 * 
 */
void update_card_sum(int *current, int *new, int *usable){
   
   int p = *current;
   int hit_result = *new;
   int u = *usable;
  
   if ((p + hit_result) > 21) { /// <-- potentially busted/
       if (u == 1) { /// <-- No problem, since he has a usable ace/
             p = p - 10 + hit_result;
             u = 0;
       }
       else {   
           p = p + hit_result; /// <-- busted/
       }
   }
   else {
      if ((hit_result == 1) && ((p + 11) < 22)) { 
          /// <-- Just got a usable ace/
          p = p + 11;
          u = 1;
      }
      else {
           p = p + hit_result;
      }
   }
  
   *current = p;
   *new = hit_result;
   *usable = u;
  
   return;
}


/*
 * show_POL()
 * 
 * 
 */
void show_POL(){

        int action0[22][11]; /// <-- No usable ace/
        int action1[22][11]; /// <-- usable ace/
        int i, j;

        /*
         * Compile the "soft" policy into 2 deterministic policies 
         */
        for(i = 12; i < 22; i++){
                for(j = 1; j < 11; j++){

                        /*
                         * No usable ace
                         */
                        if (POL[i][j][0][HIT] < POL[i][j][0][STICK]){ 
                            /// <-- HIT is not the favorite action/
                                action0[i][j] = STICK;
                        }
                        else /// <-- Hit is the favorite action/
                                action0[i][j] = HIT;

                        /*
                         * Usable ace
                         */     
                        if (POL[i][j][1][HIT] < POL[i][j][1][STICK]) 
                            /// <-- HIT is not the favorite action/
                                action1[i][j] = STICK;
                        else /// <-- Hit is the favorite action/
                                action1[i][j] = HIT;    

                }
        }

        /*
         * Now show them
         */
        printf(" Policy for episodes that start with a usable ace\n");
        printf("     |  1  2  3  4  5  6  7  8  9 10\n");
        printf("-----+-------------------------------\n");
        for(i = 21; i >11 ; i--){
                printf(" %3d |", i);
                for(j = 1; j < 11; j++){
                        printf("  %d", action1[i][j]);
                }
                printf("\n");
        }
  printf("\n");

        printf(" Policy for episodes that start without a usable ace\n");
        printf("     |  1  2  3  4  5  6  7  8  9 10\n");
        printf("-----+-------------------------------\n");
        for(i = 21; i >11 ; i--){
                printf(" %3d |", i);
                for(j = 1; j < 11; j++){
                        printf("  %d", action0[i][j]);
                }
                printf("\n");
        }


        printf("\n");

}