|
|
@ -8,7 +8,7 @@ class RL_machine { |
|
|
|
learning_rate, |
|
|
|
learning_rate, |
|
|
|
discount_factor, |
|
|
|
discount_factor, |
|
|
|
epsilon=0) { |
|
|
|
epsilon=0) { |
|
|
|
this.q_table = actions_per_state.map((c) => Array(c).fill(0)); |
|
|
|
this.q_table = actions_per_state.map((c) => c.reduce((o,n) => {o[n]=0; return o},{})); |
|
|
|
this.transactions = transactions; |
|
|
|
this.transactions = transactions; |
|
|
|
this.rewards = rewards; |
|
|
|
this.rewards = rewards; |
|
|
|
this.lr = learning_rate; |
|
|
|
this.lr = learning_rate; |
|
|
@ -34,13 +34,13 @@ class RL_machine { |
|
|
|
} |
|
|
|
} |
|
|
|
auto_step(){ |
|
|
|
auto_step(){ |
|
|
|
if (Math.random() < this.epsilon){ |
|
|
|
if (Math.random() < this.epsilon){ |
|
|
|
return this.step(Math.floor(Math.random() * this.q_table[this.state].length)); |
|
|
|
return this.step(choose(Object.keys(this.q_table[this.state]))); |
|
|
|
} else{ |
|
|
|
} else{ |
|
|
|
return this.greedy_step(); |
|
|
|
return this.greedy_step(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
greedy_step(){ |
|
|
|
greedy_step(){ |
|
|
|
return this.step(argMax(this.q_table[this.state])); |
|
|
|
return this.step(keyMax(this.q_table[this.state])); |
|
|
|
} |
|
|
|
} |
|
|
|
step(action){ |
|
|
|
step(action){ |
|
|
|
this.state = this.update_q_table(this.state, action); |
|
|
|
this.state = this.update_q_table(this.state, action); |
|
|
@ -53,7 +53,7 @@ class RL_machine { |
|
|
|
} |
|
|
|
} |
|
|
|
update_q_table(state, action){ |
|
|
|
update_q_table(state, action){ |
|
|
|
let new_state = this.transactions(state, action); |
|
|
|
let new_state = this.transactions(state, action); |
|
|
|
this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...this.q_table[new_state])); |
|
|
|
this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...Object.values(this.q_table[new_state]))); |
|
|
|
this.score += this.rewards[new_state]; |
|
|
|
this.score += this.rewards[new_state]; |
|
|
|
return new_state; |
|
|
|
return new_state; |
|
|
|
} |
|
|
|
} |
|
|
@ -69,9 +69,15 @@ class RL_machine { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function keyMax(obj) { |
|
|
|
|
|
|
|
return Object.entries(obj).reduce((r, a) => (a[1] > r[1] ? a : r),[0,Number.MIN_SAFE_INTEGER])[0]; |
|
|
|
|
|
|
|
} |
|
|
|
function argMax(array) { |
|
|
|
function argMax(array) { |
|
|
|
return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1]; |
|
|
|
return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
function choose(array) { |
|
|
|
|
|
|
|
return array[array.length * Math.random() << 0]; |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
// ------------------ maze stuff --------------------------------------------
|
|
|
|
// ------------------ maze stuff --------------------------------------------
|
|
|
|
const tile = { |
|
|
|
const tile = { |
|
|
@ -82,10 +88,10 @@ const tile = { |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
const dir = { |
|
|
|
const dir = { |
|
|
|
UP: 0, |
|
|
|
UP: "UP", |
|
|
|
RIGHT: 1, |
|
|
|
RIGHT: "RIGHT", |
|
|
|
DOWN: 2, |
|
|
|
DOWN: "DOWN", |
|
|
|
LEFT: 3, |
|
|
|
LEFT: "LEFT", |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
class Maze { |
|
|
|
class Maze { |
|
|
@ -113,54 +119,29 @@ class Maze { |
|
|
|
get_actions() { |
|
|
|
get_actions() { |
|
|
|
var actions = []; |
|
|
|
var actions = []; |
|
|
|
for (let idy=0; idy<this.map.length; idy++){ |
|
|
|
for (let idy=0; idy<this.map.length; idy++){ |
|
|
|
let y_actions = 0; |
|
|
|
var y_actions = []; |
|
|
|
if (idy != 0){ |
|
|
|
if (idy != 0){ |
|
|
|
y_actions++; |
|
|
|
y_actions.push(dir.UP); |
|
|
|
} |
|
|
|
} |
|
|
|
if (idy != this.map.length-1){ |
|
|
|
if (idy != this.map.length-1){ |
|
|
|
y_actions++; |
|
|
|
y_actions.push(dir.DOWN); |
|
|
|
} |
|
|
|
} |
|
|
|
for (let idx=0; idx<this.map[0].length; idx++){ |
|
|
|
for (let idx=0; idx<this.map[0].length; idx++){ |
|
|
|
let actions_sum = y_actions; |
|
|
|
var x_actions = []; |
|
|
|
if (idx != 0){ |
|
|
|
if (idx != 0){ |
|
|
|
actions_sum++; |
|
|
|
x_actions.push(dir.LEFT); |
|
|
|
} |
|
|
|
} |
|
|
|
if (idx != this.map[0].length-1){ |
|
|
|
if (idx != this.map[0].length-1){ |
|
|
|
actions_sum++; |
|
|
|
x_actions.push(dir.RIGHT); |
|
|
|
} |
|
|
|
} |
|
|
|
actions.push(actions_sum); |
|
|
|
actions.push([...y_actions,...x_actions]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return actions; |
|
|
|
return actions; |
|
|
|
} |
|
|
|
} |
|
|
|
get_direction(state, action){ |
|
|
|
|
|
|
|
const y = Math.floor(state/this.width); |
|
|
|
|
|
|
|
const x = state%this.width; |
|
|
|
|
|
|
|
var h_flip = dir.RIGHT; |
|
|
|
|
|
|
|
var flex = dir.DOWN |
|
|
|
|
|
|
|
if (x == this.width-1){ |
|
|
|
|
|
|
|
h_flip = dir.LEFT; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if (y == this.height-1){ |
|
|
|
|
|
|
|
flex = dir.LEFT; |
|
|
|
|
|
|
|
} else if (y == 0){ |
|
|
|
|
|
|
|
action++; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
switch (action) { |
|
|
|
|
|
|
|
case 0: |
|
|
|
|
|
|
|
return dir.UP; |
|
|
|
|
|
|
|
case 1: |
|
|
|
|
|
|
|
return h_flip; |
|
|
|
|
|
|
|
case 2: |
|
|
|
|
|
|
|
return flex; |
|
|
|
|
|
|
|
case 3: |
|
|
|
|
|
|
|
return dir.LEFT; |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
get_transactions(){ |
|
|
|
get_transactions(){ |
|
|
|
return function(state, action){ |
|
|
|
return function(state, action){ |
|
|
|
var kk = this.get_direction(state, action); |
|
|
|
switch (action) { |
|
|
|
switch (kk) { |
|
|
|
|
|
|
|
case dir.UP: |
|
|
|
case dir.UP: |
|
|
|
return state-this.width; |
|
|
|
return state-this.width; |
|
|
|
case dir.RIGHT: |
|
|
|
case dir.RIGHT: |
|
|
|