actions in qtable changed to objects

master
Ugo Finnendahl 5 years ago
parent 8bd9c482ef
commit eb13a10c3a
  1. 7
      js/controls.js
  2. 63
      js/rl.js
  3. 3
      js/view.js

7
js/controls.js vendored

@ -1,8 +1,7 @@
function dir_to_action(dir){ function dir_to_action(dir){
let actions = [...Array(machine.q_table[machine.state].length).keys()].map((a) => maze.get_direction(machine.state,a)); let actions = [...Object.keys(machine.q_table[machine.state])];
var action = actions.indexOf(dir); if (actions.indexOf(dir) > -1){
if (action>-1){ return dir;
return action;
} }
return undefined; return undefined;
} }

@ -8,7 +8,7 @@ class RL_machine {
learning_rate, learning_rate,
discount_factor, discount_factor,
epsilon=0) { epsilon=0) {
this.q_table = actions_per_state.map((c) => Array(c).fill(0)); this.q_table = actions_per_state.map((c) => c.reduce((o,n) => {o[n]=0; return o},{}));
this.transactions = transactions; this.transactions = transactions;
this.rewards = rewards; this.rewards = rewards;
this.lr = learning_rate; this.lr = learning_rate;
@ -34,13 +34,13 @@ class RL_machine {
} }
auto_step(){ auto_step(){
if (Math.random() < this.epsilon){ if (Math.random() < this.epsilon){
return this.step(Math.floor(Math.random() * this.q_table[this.state].length)); return this.step(choose(Object.keys(this.q_table[this.state])));
} else{ } else{
return this.greedy_step(); return this.greedy_step();
} }
} }
greedy_step(){ greedy_step(){
return this.step(argMax(this.q_table[this.state])); return this.step(keyMax(this.q_table[this.state]));
} }
step(action){ step(action){
this.state = this.update_q_table(this.state, action); this.state = this.update_q_table(this.state, action);
@ -53,7 +53,7 @@ class RL_machine {
} }
update_q_table(state, action){ update_q_table(state, action){
let new_state = this.transactions(state, action); let new_state = this.transactions(state, action);
this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...this.q_table[new_state])); this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...Object.values(this.q_table[new_state])));
this.score += this.rewards[new_state]; this.score += this.rewards[new_state];
return new_state; return new_state;
} }
@ -69,9 +69,15 @@ class RL_machine {
} }
} }
function keyMax(obj) {
return Object.entries(obj).reduce((r, a) => (a[1] > r[1] ? a : r),[0,Number.MIN_SAFE_INTEGER])[0];
}
function argMax(array) { function argMax(array) {
return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1]; return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1];
} }
function choose(array) {
return array[array.length * Math.random() << 0];
};
// ------------------ maze stuff -------------------------------------------- // ------------------ maze stuff --------------------------------------------
const tile = { const tile = {
@ -82,10 +88,10 @@ const tile = {
}; };
const dir = { const dir = {
UP: 0, UP: "UP",
RIGHT: 1, RIGHT: "RIGHT",
DOWN: 2, DOWN: "DOWN",
LEFT: 3, LEFT: "LEFT",
}; };
class Maze { class Maze {
@ -113,54 +119,29 @@ class Maze {
get_actions() { get_actions() {
var actions = []; var actions = [];
for (let idy=0; idy<this.map.length; idy++){ for (let idy=0; idy<this.map.length; idy++){
let y_actions = 0; var y_actions = [];
if (idy != 0){ if (idy != 0){
y_actions++; y_actions.push(dir.UP);
} }
if (idy != this.map.length-1){ if (idy != this.map.length-1){
y_actions++; y_actions.push(dir.DOWN);
} }
for (let idx=0; idx<this.map[0].length; idx++){ for (let idx=0; idx<this.map[0].length; idx++){
let actions_sum = y_actions; var x_actions = [];
if (idx != 0){ if (idx != 0){
actions_sum++; x_actions.push(dir.LEFT);
} }
if (idx != this.map[0].length-1){ if (idx != this.map[0].length-1){
actions_sum++; x_actions.push(dir.RIGHT);
} }
actions.push(actions_sum); actions.push([...y_actions,...x_actions]);
} }
} }
return actions; return actions;
} }
get_direction(state, action){
const y = Math.floor(state/this.width);
const x = state%this.width;
var h_flip = dir.RIGHT;
var flex = dir.DOWN
if (x == this.width-1){
h_flip = dir.LEFT;
}
if (y == this.height-1){
flex = dir.LEFT;
} else if (y == 0){
action++;
}
switch (action) {
case 0:
return dir.UP;
case 1:
return h_flip;
case 2:
return flex;
case 3:
return dir.LEFT;
}
}
get_transactions(){ get_transactions(){
return function(state, action){ return function(state, action){
var kk = this.get_direction(state, action); switch (action) {
switch (kk) {
case dir.UP: case dir.UP:
return state-this.width; return state-this.width;
case dir.RIGHT: case dir.RIGHT:

@ -83,11 +83,10 @@ function draw_agent(state) {
map_layer.draw(); map_layer.draw();
}, },
do_action(action, animate=false){ do_action(action, animate=false){
let d = maze.get_direction(machine.state, action);
var y = this.y; var y = this.y;
var x = this.x; var x = this.x;
var fun; var fun;
switch (d) { switch (action) {
case dir.UP: case dir.UP:
y--; y--;
fun = function () {return [this.x,this.y-1]}.bind(this); fun = function () {return [this.x,this.y-1]}.bind(this);

Loading…
Cancel
Save