13.9 RL with Generalization

Usually, there are too many states to reason about explicitly. The alternative to reasoning explicitly in terms of states is to reason in terms of features, which can either be provided explicitly or learned.

1: controller SARSA_with_Generalization(Learner,γ)
2:   Inputs
3:    Learner with operations Learner.add(x,y) and Learner.predict(x).
4:    γ[0,1]: discount factor   
5:   observe current state s
6:   select action a
7:   repeat
8:    do(a)
9:    observe reward r and state s
10:    select action a based on Learner.predict((s,a))
11:    Learner.add((s,a),r+γLearner.predict((s,a)))
12:    s:=s
13:    a:=a
14:   until termination
Figure 13.7: SARSA with generalization

Figure 13.7 shows a generic reinforcement on-policy learner that incorporates a supervised learner. This assumes the learner can carry out the operations

  • add(x,y) which adds a new example to the dataset, with input x and target value y

  • predict(x) which gives a point prediction for the target for an example with input x.

In SARSA_with_Generalization, the input x for the learner is a state–action pair, and the target for pair (s,a) is an estimate of Q(s,a).

The only difference from the learners considered in Chapters 7 and 8 is that the learner must be able to incrementally add examples, and make predictions based on the examples it currently has. Newer examples are often better approximations than old examples, and the algorithms might need to take this into account.

Selecting the next action a on line 10 with pure exploitation means selecting an a that maximizes Learner.predict((s,a)); exploration can be carried out using one of the exploration techniques of Section 13.5.

Generalization in this algorithm occurs by the learner generalizing. The learner could be, for example, a linear function (see next section), a decision tree learner, or a neural network. SARSA is an instance of this algorithm where the learner memorizes, but does not generalize.

In deep reinforcement learning, a deep neural network is used as the learner. In particular, a neural network can be used to represent the Q-function, the value function, and/or the policy. Deep learning requires a large amount of data, and many iterations to learn, and can be sensitive to the architecture provided. While it has been very successful in games such as Go or Chess (see Section 14.7.3), it is notoriously difficult to make it work, and it is very computationally intensive. A linear function is usually better for smaller problems.

13.9.1 SARSA with Linear Function Approximation

Consider an instance of SARSA with generalization (Figure 13.7) that is a linear function of features of the state and the action. While there are more complicated alternatives, such as using a decision tree or neural network, the linear function often works well, but requires feature engineering.

The feature-based learners require more information about the domain than the reinforcement-learning methods considered so far. Whereas the previous reinforcement learners were provided only with the states and the possible actions, the feature-based learners require extra domain knowledge in terms of features. This approach requires careful selection of the features; the designer should find features adequate to represent the Q-function.

The algorithm SARSA with linear function approximation, SARSA_LFA, uses a linear function of features to approximate the Q-function. It is based on incremental gradient descent, a variant of stochastic gradient descent that updates the parameters after every example. Suppose F1,,Fn are numerical features of the state and the action. Fi(s,a) provides the value for the ith feature for state s and action a. These features will be used to represent the linear Q-function

Qw¯(s,a)=w0+w1F1(s,a)++wnFn(s,a)

for some tuple of weights w¯=w0,w1,,wn that have to be learned. Assume that there is an extra feature F0(s,a) whose value is always 1, so that w0 is not a special case.

An experience in SARSA of the form s,a,r,s,a (the agent was in state s, did action a, received reward r, and ended up in state s, in which it decided to do action a) provides the new estimate r+γQw¯(s,a) to update Qw¯(s,a). This experience can be used as a data point for linear regression. Let δ=Qw¯(s,a)(r+γQw¯(s,a)). Using Equation 7.4, weight wi is updated by

wi:=wiηδFi(s,a).

This update can then be incorporated into SARSA, giving the algorithm shown in Figure 13.8.

1: controller SARSA_LFA(F¯,γ,η)
2:   Inputs
3:    F¯=F1,,Fn: a set of features. Define F0(s,a)=1.
4:    γ[0,1]: discount factor
5:    η>0: step size for gradient descent   
6:   Local
7:    weights w¯=w0,,wn, initialized arbitrarily   
8:   observe current state s
9:   select action a
10:   repeat
11:    do(a)
12:    observe reward r and state s
13:    select action a (using a policy based on Qw¯)
14:    δ:=Qw¯(s,a)(r+γQw¯(s,a))
15:    for i=0 to n do
16:      wi:=wiηδFi(s,a)    
17:    s:=s
18:    a:=a
19:   until termination
Figure 13.8: SARSA_LFA: SARSA with linear function approximation

Although this program is simple to implement, feature engineering – choosing what features to include – is non-trivial. The linear function must not only convey the best action to carry out, it must also convey the information about what future states are useful.

Example 13.6.

Consider the monster game of Example 13.2. From understanding the domain, and not just treating it as a black box, some possible features that can be computed and might be useful are

  • F1(s,a) has value 1 if action a would most likely take the agent from state s into a location where a monster could appear and has value 0 otherwise.

  • F2(s,a) has value 1 if action a would most likely take the agent into a wall and has value 0 otherwise.

  • F3(s,a) has value 1 if step a would most likely take the agent toward a prize.

  • F4(s,a) has value 1 if the agent is damaged in state s and action a takes it toward the repair station.

  • F5(s,a) has value 1 if the agent is damaged and action a would most likely take the agent into a location where a monster could appear and has value 0 otherwise. That is, it is the same as F1(s,a) but is only applicable when the agent is damaged.

  • F6(s,a) has value 1 if the agent is damaged in state s and has value 0 otherwise.

  • F7(s,a) has value 1 if the agent is not damaged in state s and has value 0 otherwise.

  • F8(s,a) has value 1 if the agent is damaged and there is a prize ahead in direction a.

  • F9(s,a) has value 1 if the agent is not damaged and there is a prize ahead in direction a.

  • F10(s,a) has the value of the x-value in state s if there is a prize at location P0 in state s. That is, it is the distance from the left wall if there is a prize at location P0.

  • F11(s,a) has the value 4x, where x is the horizontal position in state s if there is a prize at location P0 in state s. That is, it is the distance from the right wall if there is a prize at location P0.

  • F12(s,a) to F29(s,a) are like F10 and F11 for different combinations of the prize location and the distance from each of the four walls. For the case where the prize is at location P0, the y-distance could take into account the wall.

An example linear function is

Q(s,a )=2.01.0F1(s,a)0.4F2(s,a)1.3F3(s,a)
0.5F4(s,a)1.2F5(s,a)1.6F6(s,a)+3.5F7(s,a)+
0.6F8(s,a)+0.6F9(s,a)0.0F10(s,a)+1.0F11(s,a)+.

These are the learned values (to one decimal place) for one run of the SARSA_LFA algorithm in Figure 13.8.

AIPython (aipython.org) has an open-source Python implementation of this algorithm for this monster game. Experiment with stepping through the algorithm for individual steps, trying to understand how each step updates each parameter. Now run it for a number of steps. Consider the performance using the evaluation measures of Section 13.6. Try to make sense of the values of the parameters learned.

This algorithm tends to overfit to current experiences, and to forget about old experiences, so that when it returns to a part of the state space it has not visited recently, it will have to relearn all over again. This is known as catastrophic forgetting. One modification is to remember old experiences (s,a,r,s tuples) and to carry out some steps of action replay, by doing some weight updates based on random previous experiences. Updating the weights requires the use of the next action a, which should be chosen according to the current policy, not the policy that was under effect when the experience occurred. When memory size becomes an issue, some of the old experiences can be discarded.

13.9.2 Escaping Local Optima

State-based MDPs and state-based reinforcement learning algorithms such as Q-learning, SARSA, and the model-based reinforcement learner have no local maxima that are not global maxima. This is because each state can be optimized separately; improving a policy for one state cannot negatively impact another state.

However, when there is generalization, improving on one state can make other states worse. This means that the algorithms can converge to local optima with a value that is not the best possible. They can work better when there is some way to escape local optima. A standard way to escape local maxima is to use randomized algorithms, for example using population-based methods, similar to particle filtering, where multiple initial initializations are run in parallel, and the best policy chosen. There has been some notable – arguably creative – solutions that have been found using evolutionary algorithms, where the individual runs are combined using a genetic algorithm.