diff --git a/js/controls.js b/js/controls.js index 0d57b76..6824b3e 100644 --- a/js/controls.js +++ b/js/controls.js @@ -1,8 +1,7 @@ function dir_to_action(dir){ - let actions = [...Array(machine.q_table[machine.state].length).keys()].map((a) => maze.get_direction(machine.state,a)); - var action = actions.indexOf(dir); - if (action>-1){ - return action; + let actions = [...Object.keys(machine.q_table[machine.state])]; + if (actions.indexOf(dir) > -1){ + return dir; } return undefined; } diff --git a/js/rl.js b/js/rl.js index 5b276f9..521926f 100644 --- a/js/rl.js +++ b/js/rl.js @@ -8,7 +8,7 @@ class RL_machine { learning_rate, discount_factor, epsilon=0) { - this.q_table = actions_per_state.map((c) => Array(c).fill(0)); + this.q_table = actions_per_state.map((c) => c.reduce((o,n) => {o[n]=0; return o},{})); this.transactions = transactions; this.rewards = rewards; this.lr = learning_rate; @@ -34,13 +34,13 @@ class RL_machine { } auto_step(){ if (Math.random() < this.epsilon){ - return this.step(Math.floor(Math.random() * this.q_table[this.state].length)); + return this.step(choose(Object.keys(this.q_table[this.state]))); } else{ return this.greedy_step(); } } greedy_step(){ - return this.step(argMax(this.q_table[this.state])); + return this.step(keyMax(this.q_table[this.state])); } step(action){ this.state = this.update_q_table(this.state, action); @@ -53,7 +53,7 @@ class RL_machine { } update_q_table(state, action){ let new_state = this.transactions(state, action); - this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...this.q_table[new_state])); + this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...Object.values(this.q_table[new_state]))); this.score += this.rewards[new_state]; return new_state; } @@ -69,9 +69,15 @@ class RL_machine { } } +function keyMax(obj) { + return Object.entries(obj).reduce((r, a) => (a[1] > r[1] ? a : r),[0,Number.MIN_SAFE_INTEGER])[0]; +} function argMax(array) { return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1]; } +function choose(array) { + return array[array.length * Math.random() << 0]; +}; // ------------------ maze stuff -------------------------------------------- const tile = { @@ -82,10 +88,10 @@ const tile = { }; const dir = { - UP: 0, - RIGHT: 1, - DOWN: 2, - LEFT: 3, + UP: "UP", + RIGHT: "RIGHT", + DOWN: "DOWN", + LEFT: "LEFT", }; class Maze { @@ -113,54 +119,29 @@ class Maze { get_actions() { var actions = []; for (let idy=0; idy