commit
cfe6fe5734
@ -0,0 +1,19 @@ |
|||||||
|
*{ |
||||||
|
margin: 0; |
||||||
|
padding: 0; |
||||||
|
} |
||||||
|
body{ |
||||||
|
// background-color: ; |
||||||
|
} |
||||||
|
#container{ |
||||||
|
height: 100vh; |
||||||
|
position: relative; |
||||||
|
} |
||||||
|
#canvas{ |
||||||
|
height: 100%; |
||||||
|
} |
||||||
|
nav{ |
||||||
|
position: absolute; |
||||||
|
top: 10px; |
||||||
|
left: 10px; |
||||||
|
} |
@ -0,0 +1 @@ |
|||||||
|
*{margin:0;padding:0}#container{height:100vh;position:relative}#canvas{height:100%}nav{position:absolute;top:10px;left:10px} |
@ -0,0 +1,22 @@ |
|||||||
|
<!DOCTYPE html> |
||||||
|
<html lang="en" dir="ltr"> |
||||||
|
<head> |
||||||
|
<meta charset="utf-8"> |
||||||
|
<script src="https://unpkg.com/konva@4.0.0/konva.min.js"></script> |
||||||
|
<title>RL exhibit - prototype</title> |
||||||
|
<link rel="stylesheet" href="css/style.min.css"> |
||||||
|
</head> |
||||||
|
<body> |
||||||
|
<div id="container"> |
||||||
|
<div id="canvas"></div> |
||||||
|
<nav> |
||||||
|
<button class="button" onclick="machine.run(100)">run 100 episodes!</button> |
||||||
|
<button class="button" onclick="machine.auto_step();draw_map(map, machine.state);">auto step!</button> |
||||||
|
<button class="button" onclick="machine.greedy_step();draw_map(map, machine.state);">greedy step!</button> |
||||||
|
</nav> |
||||||
|
</div> |
||||||
|
<script src="js/rl.js"></script> |
||||||
|
<script src="js/view.js"></script> |
||||||
|
<script src="js/controls.js"></script> |
||||||
|
</body> |
||||||
|
</html> |
@ -0,0 +1,50 @@ |
|||||||
|
function dir_to_action(dir){ |
||||||
|
let actions = [...Array(machine.q_table[machine.state].length).keys()].map((a) => maze.get_direction(machine.state,a)); |
||||||
|
var action = actions.indexOf(dir); |
||||||
|
if (action>-1){ |
||||||
|
return action; |
||||||
|
} |
||||||
|
return undefined; |
||||||
|
} |
||||||
|
var animate = false; |
||||||
|
|
||||||
|
function key_callback(e) { |
||||||
|
var tmp; |
||||||
|
if (animate){ |
||||||
|
return |
||||||
|
} |
||||||
|
switch (e.keyCode) { |
||||||
|
case 37: |
||||||
|
tmp = dir_to_action(dir.LEFT); |
||||||
|
break; |
||||||
|
case 38: |
||||||
|
tmp = dir_to_action(dir.UP); |
||||||
|
break; |
||||||
|
case 39: |
||||||
|
tmp = dir_to_action(dir.RIGHT); |
||||||
|
break; |
||||||
|
case 40: |
||||||
|
tmp = dir_to_action(dir.DOWN); |
||||||
|
break; |
||||||
|
} |
||||||
|
if (tmp != undefined){ |
||||||
|
machine.step(tmp) |
||||||
|
} |
||||||
|
draw_map(map, machine.state); |
||||||
|
} |
||||||
|
document.addEventListener('keydown', key_callback); |
||||||
|
|
||||||
|
function show_solution() { |
||||||
|
var sol = machine.current_solution(); |
||||||
|
animate = true; |
||||||
|
show_path(sol.states, 0); |
||||||
|
} |
||||||
|
|
||||||
|
function show_path(path, i){ |
||||||
|
if (path.length == i) { |
||||||
|
animate = false; |
||||||
|
return |
||||||
|
} |
||||||
|
agent.set_state(path[i]); |
||||||
|
window.setTimeout(function(){ show_path(path, ++i) }, 1000); |
||||||
|
} |
@ -0,0 +1,220 @@ |
|||||||
|
class RL_machine { |
||||||
|
constructor(actions_per_state, |
||||||
|
transactions, |
||||||
|
rewards, |
||||||
|
start_state, |
||||||
|
end_states, |
||||||
|
end_score, |
||||||
|
learning_rate, |
||||||
|
discount_factor, |
||||||
|
epsilon=0) { |
||||||
|
this.q_table = actions_per_state.map((c) => Array(c).fill(0)); |
||||||
|
this.transactions = transactions; |
||||||
|
this.rewards = rewards; |
||||||
|
this.lr = learning_rate; |
||||||
|
this.df = discount_factor; |
||||||
|
this.state = start_state; |
||||||
|
this.start_state = start_state; |
||||||
|
this.end_score = end_score; |
||||||
|
this.end_states = end_states; |
||||||
|
this.episode = 0; |
||||||
|
this.epsilon = epsilon; |
||||||
|
this.score = 0; |
||||||
|
} |
||||||
|
reset_machine(){ |
||||||
|
this.q_table = this.q_table.map((c) => c.map((a) => a.fill(0))); |
||||||
|
this.episode = 0; |
||||||
|
this.state = this.start_state; |
||||||
|
} |
||||||
|
new_episode(){ |
||||||
|
// add_new_episode_callback
|
||||||
|
this.episode++; |
||||||
|
this.state = this.start_state; |
||||||
|
this.score = 0; |
||||||
|
} |
||||||
|
auto_step(){ |
||||||
|
if (Math.random() < this.epsilon){ |
||||||
|
return this.step(Math.floor(Math.random() * this.q_table[this.state].length)); |
||||||
|
} else{ |
||||||
|
return this.greedy_step(); |
||||||
|
} |
||||||
|
} |
||||||
|
greedy_step(){ |
||||||
|
return this.step(argMax(this.q_table[this.state])); |
||||||
|
} |
||||||
|
step(action){ |
||||||
|
this.state = this.update_q_table(this.state, action); |
||||||
|
// add_new_step_callback
|
||||||
|
if (this.end_states.indexOf(this.state) >= 0 || this.score < this.end_score){ |
||||||
|
this.new_episode(); |
||||||
|
return 2 |
||||||
|
} |
||||||
|
return 1 |
||||||
|
} |
||||||
|
update_q_table(state, action){ |
||||||
|
let new_state = this.transactions(state, action); |
||||||
|
this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...this.q_table[new_state])); |
||||||
|
this.score += this.rewards[new_state]; |
||||||
|
return new_state; |
||||||
|
} |
||||||
|
run(episodes, max_steps_per_episode=10000){ |
||||||
|
for (var i = 0; i < episodes; i++) { |
||||||
|
for (var j = 0; j < max_steps_per_episode; j++) { |
||||||
|
if (this.auto_step() == 2) { |
||||||
|
break; |
||||||
|
} |
||||||
|
} |
||||||
|
this.new_episode(); |
||||||
|
} |
||||||
|
} |
||||||
|
current_solution(max_steps_per_episode=10000){ |
||||||
|
let temp_state = this.start_state; |
||||||
|
let score = 0; |
||||||
|
let states = [temp_state]; |
||||||
|
let actions = []; |
||||||
|
let scores = []; |
||||||
|
for (var j = 0; j < max_steps_per_episode; j++) { |
||||||
|
let ac = argMax(this.q_table[temp_state]); |
||||||
|
temp_state = this.transactions(temp_state, ac); |
||||||
|
let sc = this.rewards[temp_state]; |
||||||
|
score += sc; |
||||||
|
actions.push(ac); |
||||||
|
scores.push(sc); |
||||||
|
states.push(temp_state); |
||||||
|
if (this.end_states.indexOf(temp_state) >= 0 || score < this.end_score){ |
||||||
|
return {actions: actions, scores: scores, states: states} |
||||||
|
} |
||||||
|
} |
||||||
|
return {actions: actions, scores: scores, states: states} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
function argMax(array) { |
||||||
|
return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1]; |
||||||
|
} |
||||||
|
|
||||||
|
// ------------------ maze stuff --------------------------------------------
|
||||||
|
const tile = { |
||||||
|
regular: 0, |
||||||
|
start: 1, |
||||||
|
end: 2, |
||||||
|
dangerous: 4, |
||||||
|
}; |
||||||
|
|
||||||
|
const dir = { |
||||||
|
UP: 0, |
||||||
|
RIGHT: 1, |
||||||
|
DOWN: 2, |
||||||
|
LEFT: 3, |
||||||
|
}; |
||||||
|
|
||||||
|
class Maze { |
||||||
|
constructor(map, reward_map) { |
||||||
|
this.map = map |
||||||
|
this.height = map.length; |
||||||
|
this.width = map[0].length; |
||||||
|
this.start_state = this.get_states(tile.start)[0]; |
||||||
|
this.end_states = this.get_states(tile.end); |
||||||
|
this.actions = this.get_actions(); |
||||||
|
this.transactions = this.get_transactions(); |
||||||
|
this.rewards = this.get_rewards(reward_map); |
||||||
|
} |
||||||
|
get_states(tile) { |
||||||
|
var res = []; |
||||||
|
for (var idy = 0; idy < this.map.length; idy++) { |
||||||
|
for (var idx = 0; idx < this.map[idy].length; idx++) { |
||||||
|
if (this.map[idy][idx] == tile) { |
||||||
|
res.push(idy*this.map[0].length+idx); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
return res; |
||||||
|
} |
||||||
|
get_actions() { |
||||||
|
var actions = []; |
||||||
|
for (let idy=0; idy<this.map.length; idy++){ |
||||||
|
let y_actions = 0; |
||||||
|
if (idy != 0){ |
||||||
|
y_actions++; |
||||||
|
} |
||||||
|
if (idy != this.map.length-1){ |
||||||
|
y_actions++; |
||||||
|
} |
||||||
|
for (let idx=0; idx<this.map[0].length; idx++){ |
||||||
|
let actions_sum = y_actions; |
||||||
|
if (idx != 0){ |
||||||
|
actions_sum++; |
||||||
|
} |
||||||
|
if (idx != this.map[0].length-1){ |
||||||
|
actions_sum++; |
||||||
|
} |
||||||
|
actions.push(actions_sum); |
||||||
|
} |
||||||
|
} |
||||||
|
return actions; |
||||||
|
} |
||||||
|
get_direction(state, action){ |
||||||
|
const y = Math.floor(state/this.width); |
||||||
|
const x = state%this.width; |
||||||
|
var h_flip = dir.RIGHT; |
||||||
|
var flex = dir.DOWN |
||||||
|
if (x == this.width-1){ |
||||||
|
h_flip = dir.LEFT; |
||||||
|
} |
||||||
|
if (y == this.height-1){ |
||||||
|
flex = dir.LEFT; |
||||||
|
} else if (y == 0){ |
||||||
|
action++; |
||||||
|
} |
||||||
|
switch (action) { |
||||||
|
case 0: |
||||||
|
return dir.UP; |
||||||
|
case 1: |
||||||
|
return h_flip; |
||||||
|
case 2: |
||||||
|
return flex; |
||||||
|
case 3: |
||||||
|
return dir.LEFT; |
||||||
|
} |
||||||
|
} |
||||||
|
get_transactions(){ |
||||||
|
return function(state, action){ |
||||||
|
var kk = this.get_direction(state, action); |
||||||
|
switch (kk) { |
||||||
|
case dir.UP: |
||||||
|
return state-this.width; |
||||||
|
case dir.RIGHT: |
||||||
|
return state+1; |
||||||
|
case dir.DOWN: |
||||||
|
return state+this.width; |
||||||
|
case dir.LEFT: |
||||||
|
return state-1; |
||||||
|
} |
||||||
|
}.bind(this); |
||||||
|
} |
||||||
|
get_rewards(rewards){ |
||||||
|
rewards = []; |
||||||
|
for (let idy=0; idy<this.map.length; idy++){ |
||||||
|
for (let idx=0; idx<this.map[0].length; idx++){ |
||||||
|
rewards.push(reward[this.map[idy][idx]]); |
||||||
|
} |
||||||
|
} |
||||||
|
return rewards; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
var map = [ |
||||||
|
[0, 0, 4, 2, 0, 0, 0, 0], |
||||||
|
[0, 0, 4, 4, 4, 4, 0, 0], |
||||||
|
[4, 0, 0, 0, 0, 4, 0, 4], |
||||||
|
[0, 0, 4, 0, 0, 0, 0, 0], |
||||||
|
[1, 0, 4, 0, 4, 0, 0, 4] |
||||||
|
]; |
||||||
|
|
||||||
|
const reward = {[tile.regular]:-1,[tile.dangerous]:-1000,[tile.end]:1000,[tile.start]:-1}; |
||||||
|
var maze = new Maze(map, reward); |
||||||
|
|
||||||
|
var learning_rate = 1; |
||||||
|
var discount_factor = 1; |
||||||
|
|
||||||
|
var machine = new RL_machine(maze.actions, maze.transactions, maze.rewards, maze.start_state, maze.end_states, -999, learning_rate, discount_factor, 0.5); |
@ -0,0 +1,170 @@ |
|||||||
|
const canvas = document.getElementById("canvas"); |
||||||
|
|
||||||
|
var canvas_width = canvas.offsetWidth; |
||||||
|
var canvas_height = canvas.offsetHeight; |
||||||
|
|
||||||
|
function sort(array) { |
||||||
|
return array.sort(function(a, b) { |
||||||
|
return a - b; |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
const grid_line = { |
||||||
|
stroke: '#ddd', |
||||||
|
} |
||||||
|
|
||||||
|
// first we need to create a stage
|
||||||
|
var stage = new Konva.Stage({ |
||||||
|
container: canvas, |
||||||
|
}); |
||||||
|
|
||||||
|
var agent; |
||||||
|
var map_layer = new Konva.Layer(); |
||||||
|
var map_group = new Konva.Group(); |
||||||
|
var grid_group = new Konva.Group(); |
||||||
|
var tile_group = new Konva.Group(); |
||||||
|
var agent_group = new Konva.Group(); |
||||||
|
|
||||||
|
map_layer.add(map_group); |
||||||
|
stage.add(map_layer); |
||||||
|
|
||||||
|
function init_stage() { |
||||||
|
stage.width(canvas_width); |
||||||
|
stage.height(canvas_height); |
||||||
|
map_group.width(stage.width() * 0.6); |
||||||
|
map_group.height(stage.height() * 0.9); |
||||||
|
map_group.setX(stage.width() * 0.2); |
||||||
|
map_group.setY(stage.height() / 2 - (stage.height() * 0.45)); |
||||||
|
} |
||||||
|
|
||||||
|
function draw_map(map, state) { |
||||||
|
// cleanup
|
||||||
|
agent_group.remove(); |
||||||
|
agent_group = new Konva.Group(); |
||||||
|
grid_group.remove(); |
||||||
|
grid_group = new Konva.Group(); |
||||||
|
tile_group.remove(); |
||||||
|
tile_group = new Konva.Group(); |
||||||
|
|
||||||
|
init_stage(); |
||||||
|
|
||||||
|
map = map; |
||||||
|
var padding = Math.min(map_group.height() / map.length, map_group.width() / map[0].length); |
||||||
|
var strokeW = 16 / Math.max(map.length, map[0].length); |
||||||
|
const offset = strokeW / 2; |
||||||
|
// x
|
||||||
|
for (let i = 0; i < map[0].length + 1; i++) { |
||||||
|
grid_group.add(new Konva.Line({ |
||||||
|
points: [Math.round(i * padding), -offset, Math.round(i * padding), padding * map.length + offset], |
||||||
|
strokeWidth: strokeW, |
||||||
|
...grid_line |
||||||
|
})); |
||||||
|
grid_group.width(Math.round(i * padding)); |
||||||
|
} |
||||||
|
// y
|
||||||
|
for (let j = 0; j < map.length + 1; j++) { |
||||||
|
grid_group.add(new Konva.Line({ |
||||||
|
points: [-offset, Math.round(j * padding), padding * map[0].length + offset, Math.round(j * padding)], |
||||||
|
strokeWidth: strokeW, |
||||||
|
...grid_line |
||||||
|
})); |
||||||
|
grid_group.height(Math.round(j * padding)); |
||||||
|
} |
||||||
|
|
||||||
|
for (var idy in map) { |
||||||
|
for (var idx in map[idy]) { |
||||||
|
const layout = { |
||||||
|
x: padding * idx + offset, |
||||||
|
y: padding * idy + offset, |
||||||
|
width: padding - 2 * offset, |
||||||
|
height: padding - 2 * offset, |
||||||
|
// stroke: '#CF6412',
|
||||||
|
// strokeWidth: 4
|
||||||
|
}; |
||||||
|
if (map[idy][idx] == tile.dangerous) { |
||||||
|
tile_group.add(new Konva.Rect({ |
||||||
|
...layout, |
||||||
|
fill: '#FF7B17', |
||||||
|
opacity: 1, |
||||||
|
})) |
||||||
|
} else if (map[idy][idx] == tile.end) { |
||||||
|
tile_group.add(new Konva.Rect({ |
||||||
|
...layout, |
||||||
|
fill: '#0eb500', |
||||||
|
opacity: 0.5, |
||||||
|
})) |
||||||
|
} else if (map[idy][idx] == tile.start) { |
||||||
|
tile_group.add(new Konva.Rect({ |
||||||
|
...layout, |
||||||
|
fill: '#ffc908', |
||||||
|
opacity: 0.5, |
||||||
|
})) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
map_group.offset({ |
||||||
|
x: -(map_group.width() - grid_group.width()) / 2, |
||||||
|
y: -(map_group.height() - grid_group.height()) / 2, |
||||||
|
}); |
||||||
|
|
||||||
|
agent = { |
||||||
|
"konva": new Konva.RegularPolygon({ |
||||||
|
offset: { |
||||||
|
x: -padding / 2, |
||||||
|
y: -padding / 2 |
||||||
|
}, |
||||||
|
sides: 5, |
||||||
|
radius: padding / 3, |
||||||
|
fill: '#00D2FF', |
||||||
|
stroke: 'black', |
||||||
|
strokeWidth: 2 |
||||||
|
}), |
||||||
|
get x() { |
||||||
|
return Math.floor(this.konva.getX() / padding); |
||||||
|
}, |
||||||
|
set x(pos) { |
||||||
|
this.konva.setX(sort([0, pos, map[0].length - 1])[1] * padding); |
||||||
|
map_layer.draw(); |
||||||
|
}, |
||||||
|
get y() { |
||||||
|
return Math.floor(this.konva.getY() / padding); |
||||||
|
}, |
||||||
|
set y(pos) { |
||||||
|
this.konva.setY(sort([0, pos, map.length - 1])[1] * padding); |
||||||
|
map_layer.draw(); |
||||||
|
}, |
||||||
|
up() { |
||||||
|
this.y--; |
||||||
|
}, |
||||||
|
down() { |
||||||
|
this.y++; |
||||||
|
}, |
||||||
|
left() { |
||||||
|
this.x++; |
||||||
|
}, |
||||||
|
right() { |
||||||
|
this.x--; |
||||||
|
}, |
||||||
|
set_state(state){ |
||||||
|
this.y = Math.floor(state/map[0].length); |
||||||
|
this.x = state%map[0].length; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
agent.set_state(state); |
||||||
|
agent_group.add(agent.konva); |
||||||
|
|
||||||
|
map_group.add(grid_group); |
||||||
|
map_group.add(tile_group); |
||||||
|
map_group.add(agent_group); |
||||||
|
map_layer.draw(); |
||||||
|
} |
||||||
|
|
||||||
|
draw_map(map, 32); |
||||||
|
|
||||||
|
window.addEventListener('resize', function() { |
||||||
|
canvas_width = canvas.offsetWidth; |
||||||
|
canvas_height = canvas.offsetHeight; |
||||||
|
draw_map(map, machine.state); |
||||||
|
}); |
Loading…
Reference in new issue