commit
cfe6fe5734
@ -0,0 +1,19 @@ |
||||
*{ |
||||
margin: 0; |
||||
padding: 0; |
||||
} |
||||
body{ |
||||
// background-color: ; |
||||
} |
||||
#container{ |
||||
height: 100vh; |
||||
position: relative; |
||||
} |
||||
#canvas{ |
||||
height: 100%; |
||||
} |
||||
nav{ |
||||
position: absolute; |
||||
top: 10px; |
||||
left: 10px; |
||||
} |
@ -0,0 +1 @@ |
||||
*{margin:0;padding:0}#container{height:100vh;position:relative}#canvas{height:100%}nav{position:absolute;top:10px;left:10px} |
@ -0,0 +1,22 @@ |
||||
<!DOCTYPE html> |
||||
<html lang="en" dir="ltr"> |
||||
<head> |
||||
<meta charset="utf-8"> |
||||
<script src="https://unpkg.com/konva@4.0.0/konva.min.js"></script> |
||||
<title>RL exhibit - prototype</title> |
||||
<link rel="stylesheet" href="css/style.min.css"> |
||||
</head> |
||||
<body> |
||||
<div id="container"> |
||||
<div id="canvas"></div> |
||||
<nav> |
||||
<button class="button" onclick="machine.run(100)">run 100 episodes!</button> |
||||
<button class="button" onclick="machine.auto_step();draw_map(map, machine.state);">auto step!</button> |
||||
<button class="button" onclick="machine.greedy_step();draw_map(map, machine.state);">greedy step!</button> |
||||
</nav> |
||||
</div> |
||||
<script src="js/rl.js"></script> |
||||
<script src="js/view.js"></script> |
||||
<script src="js/controls.js"></script> |
||||
</body> |
||||
</html> |
@ -0,0 +1,50 @@ |
||||
function dir_to_action(dir){ |
||||
let actions = [...Array(machine.q_table[machine.state].length).keys()].map((a) => maze.get_direction(machine.state,a)); |
||||
var action = actions.indexOf(dir); |
||||
if (action>-1){ |
||||
return action; |
||||
} |
||||
return undefined; |
||||
} |
||||
var animate = false; |
||||
|
||||
function key_callback(e) { |
||||
var tmp; |
||||
if (animate){ |
||||
return |
||||
} |
||||
switch (e.keyCode) { |
||||
case 37: |
||||
tmp = dir_to_action(dir.LEFT); |
||||
break; |
||||
case 38: |
||||
tmp = dir_to_action(dir.UP); |
||||
break; |
||||
case 39: |
||||
tmp = dir_to_action(dir.RIGHT); |
||||
break; |
||||
case 40: |
||||
tmp = dir_to_action(dir.DOWN); |
||||
break; |
||||
} |
||||
if (tmp != undefined){ |
||||
machine.step(tmp) |
||||
} |
||||
draw_map(map, machine.state); |
||||
} |
||||
document.addEventListener('keydown', key_callback); |
||||
|
||||
function show_solution() { |
||||
var sol = machine.current_solution(); |
||||
animate = true; |
||||
show_path(sol.states, 0); |
||||
} |
||||
|
||||
function show_path(path, i){ |
||||
if (path.length == i) { |
||||
animate = false; |
||||
return |
||||
} |
||||
agent.set_state(path[i]); |
||||
window.setTimeout(function(){ show_path(path, ++i) }, 1000); |
||||
} |
@ -0,0 +1,220 @@ |
||||
class RL_machine { |
||||
constructor(actions_per_state, |
||||
transactions, |
||||
rewards, |
||||
start_state, |
||||
end_states, |
||||
end_score, |
||||
learning_rate, |
||||
discount_factor, |
||||
epsilon=0) { |
||||
this.q_table = actions_per_state.map((c) => Array(c).fill(0)); |
||||
this.transactions = transactions; |
||||
this.rewards = rewards; |
||||
this.lr = learning_rate; |
||||
this.df = discount_factor; |
||||
this.state = start_state; |
||||
this.start_state = start_state; |
||||
this.end_score = end_score; |
||||
this.end_states = end_states; |
||||
this.episode = 0; |
||||
this.epsilon = epsilon; |
||||
this.score = 0; |
||||
} |
||||
reset_machine(){ |
||||
this.q_table = this.q_table.map((c) => c.map((a) => a.fill(0))); |
||||
this.episode = 0; |
||||
this.state = this.start_state; |
||||
} |
||||
new_episode(){ |
||||
// add_new_episode_callback
|
||||
this.episode++; |
||||
this.state = this.start_state; |
||||
this.score = 0; |
||||
} |
||||
auto_step(){ |
||||
if (Math.random() < this.epsilon){ |
||||
return this.step(Math.floor(Math.random() * this.q_table[this.state].length)); |
||||
} else{ |
||||
return this.greedy_step(); |
||||
} |
||||
} |
||||
greedy_step(){ |
||||
return this.step(argMax(this.q_table[this.state])); |
||||
} |
||||
step(action){ |
||||
this.state = this.update_q_table(this.state, action); |
||||
// add_new_step_callback
|
||||
if (this.end_states.indexOf(this.state) >= 0 || this.score < this.end_score){ |
||||
this.new_episode(); |
||||
return 2 |
||||
} |
||||
return 1 |
||||
} |
||||
update_q_table(state, action){ |
||||
let new_state = this.transactions(state, action); |
||||
this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...this.q_table[new_state])); |
||||
this.score += this.rewards[new_state]; |
||||
return new_state; |
||||
} |
||||
run(episodes, max_steps_per_episode=10000){ |
||||
for (var i = 0; i < episodes; i++) { |
||||
for (var j = 0; j < max_steps_per_episode; j++) { |
||||
if (this.auto_step() == 2) { |
||||
break; |
||||
} |
||||
} |
||||
this.new_episode(); |
||||
} |
||||
} |
||||
current_solution(max_steps_per_episode=10000){ |
||||
let temp_state = this.start_state; |
||||
let score = 0; |
||||
let states = [temp_state]; |
||||
let actions = []; |
||||
let scores = []; |
||||
for (var j = 0; j < max_steps_per_episode; j++) { |
||||
let ac = argMax(this.q_table[temp_state]); |
||||
temp_state = this.transactions(temp_state, ac); |
||||
let sc = this.rewards[temp_state]; |
||||
score += sc; |
||||
actions.push(ac); |
||||
scores.push(sc); |
||||
states.push(temp_state); |
||||
if (this.end_states.indexOf(temp_state) >= 0 || score < this.end_score){ |
||||
return {actions: actions, scores: scores, states: states} |
||||
} |
||||
} |
||||
return {actions: actions, scores: scores, states: states} |
||||
} |
||||
} |
||||
|
||||
function argMax(array) { |
||||
return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1]; |
||||
} |
||||
|
||||
// ------------------ maze stuff --------------------------------------------
|
||||
const tile = { |
||||
regular: 0, |
||||
start: 1, |
||||
end: 2, |
||||
dangerous: 4, |
||||
}; |
||||
|
||||
const dir = { |
||||
UP: 0, |
||||
RIGHT: 1, |
||||
DOWN: 2, |
||||
LEFT: 3, |
||||
}; |
||||
|
||||
class Maze { |
||||
constructor(map, reward_map) { |
||||
this.map = map |
||||
this.height = map.length; |
||||
this.width = map[0].length; |
||||
this.start_state = this.get_states(tile.start)[0]; |
||||
this.end_states = this.get_states(tile.end); |
||||
this.actions = this.get_actions(); |
||||
this.transactions = this.get_transactions(); |
||||
this.rewards = this.get_rewards(reward_map); |
||||
} |
||||
get_states(tile) { |
||||
var res = []; |
||||
for (var idy = 0; idy < this.map.length; idy++) { |
||||
for (var idx = 0; idx < this.map[idy].length; idx++) { |
||||
if (this.map[idy][idx] == tile) { |
||||
res.push(idy*this.map[0].length+idx); |
||||
} |
||||
} |
||||
} |
||||
return res; |
||||
} |
||||
get_actions() { |
||||
var actions = []; |
||||
for (let idy=0; idy<this.map.length; idy++){ |
||||
let y_actions = 0; |
||||
if (idy != 0){ |
||||
y_actions++; |
||||
} |
||||
if (idy != this.map.length-1){ |
||||
y_actions++; |
||||
} |
||||
for (let idx=0; idx<this.map[0].length; idx++){ |
||||
let actions_sum = y_actions; |
||||
if (idx != 0){ |
||||
actions_sum++; |
||||
} |
||||
if (idx != this.map[0].length-1){ |
||||
actions_sum++; |
||||
} |
||||
actions.push(actions_sum); |
||||
} |
||||
} |
||||
return actions; |
||||
} |
||||
get_direction(state, action){ |
||||
const y = Math.floor(state/this.width); |
||||
const x = state%this.width; |
||||
var h_flip = dir.RIGHT; |
||||
var flex = dir.DOWN |
||||
if (x == this.width-1){ |
||||
h_flip = dir.LEFT; |
||||
} |
||||
if (y == this.height-1){ |
||||
flex = dir.LEFT; |
||||
} else if (y == 0){ |
||||
action++; |
||||
} |
||||
switch (action) { |
||||
case 0: |
||||
return dir.UP; |
||||
case 1: |
||||
return h_flip; |
||||
case 2: |
||||
return flex; |
||||
case 3: |
||||
return dir.LEFT; |
||||
} |
||||
} |
||||
get_transactions(){ |
||||
return function(state, action){ |
||||
var kk = this.get_direction(state, action); |
||||
switch (kk) { |
||||
case dir.UP: |
||||
return state-this.width; |
||||
case dir.RIGHT: |
||||
return state+1; |
||||
case dir.DOWN: |
||||
return state+this.width; |
||||
case dir.LEFT: |
||||
return state-1; |
||||
} |
||||
}.bind(this); |
||||
} |
||||
get_rewards(rewards){ |
||||
rewards = []; |
||||
for (let idy=0; idy<this.map.length; idy++){ |
||||
for (let idx=0; idx<this.map[0].length; idx++){ |
||||
rewards.push(reward[this.map[idy][idx]]); |
||||
} |
||||
} |
||||
return rewards; |
||||
} |
||||
} |
||||
|
||||
var map = [ |
||||
[0, 0, 4, 2, 0, 0, 0, 0], |
||||
[0, 0, 4, 4, 4, 4, 0, 0], |
||||
[4, 0, 0, 0, 0, 4, 0, 4], |
||||
[0, 0, 4, 0, 0, 0, 0, 0], |
||||
[1, 0, 4, 0, 4, 0, 0, 4] |
||||
]; |
||||
|
||||
const reward = {[tile.regular]:-1,[tile.dangerous]:-1000,[tile.end]:1000,[tile.start]:-1}; |
||||
var maze = new Maze(map, reward); |
||||
|
||||
var learning_rate = 1; |
||||
var discount_factor = 1; |
||||
|
||||
var machine = new RL_machine(maze.actions, maze.transactions, maze.rewards, maze.start_state, maze.end_states, -999, learning_rate, discount_factor, 0.5); |
@ -0,0 +1,170 @@ |
||||
const canvas = document.getElementById("canvas"); |
||||
|
||||
var canvas_width = canvas.offsetWidth; |
||||
var canvas_height = canvas.offsetHeight; |
||||
|
||||
function sort(array) { |
||||
return array.sort(function(a, b) { |
||||
return a - b; |
||||
}) |
||||
} |
||||
|
||||
const grid_line = { |
||||
stroke: '#ddd', |
||||
} |
||||
|
||||
// first we need to create a stage
|
||||
var stage = new Konva.Stage({ |
||||
container: canvas, |
||||
}); |
||||
|
||||
var agent; |
||||
var map_layer = new Konva.Layer(); |
||||
var map_group = new Konva.Group(); |
||||
var grid_group = new Konva.Group(); |
||||
var tile_group = new Konva.Group(); |
||||
var agent_group = new Konva.Group(); |
||||
|
||||
map_layer.add(map_group); |
||||
stage.add(map_layer); |
||||
|
||||
function init_stage() { |
||||
stage.width(canvas_width); |
||||
stage.height(canvas_height); |
||||
map_group.width(stage.width() * 0.6); |
||||
map_group.height(stage.height() * 0.9); |
||||
map_group.setX(stage.width() * 0.2); |
||||
map_group.setY(stage.height() / 2 - (stage.height() * 0.45)); |
||||
} |
||||
|
||||
function draw_map(map, state) { |
||||
// cleanup
|
||||
agent_group.remove(); |
||||
agent_group = new Konva.Group(); |
||||
grid_group.remove(); |
||||
grid_group = new Konva.Group(); |
||||
tile_group.remove(); |
||||
tile_group = new Konva.Group(); |
||||
|
||||
init_stage(); |
||||
|
||||
map = map; |
||||
var padding = Math.min(map_group.height() / map.length, map_group.width() / map[0].length); |
||||
var strokeW = 16 / Math.max(map.length, map[0].length); |
||||
const offset = strokeW / 2; |
||||
// x
|
||||
for (let i = 0; i < map[0].length + 1; i++) { |
||||
grid_group.add(new Konva.Line({ |
||||
points: [Math.round(i * padding), -offset, Math.round(i * padding), padding * map.length + offset], |
||||
strokeWidth: strokeW, |
||||
...grid_line |
||||
})); |
||||
grid_group.width(Math.round(i * padding)); |
||||
} |
||||
// y
|
||||
for (let j = 0; j < map.length + 1; j++) { |
||||
grid_group.add(new Konva.Line({ |
||||
points: [-offset, Math.round(j * padding), padding * map[0].length + offset, Math.round(j * padding)], |
||||
strokeWidth: strokeW, |
||||
...grid_line |
||||
})); |
||||
grid_group.height(Math.round(j * padding)); |
||||
} |
||||
|
||||
for (var idy in map) { |
||||
for (var idx in map[idy]) { |
||||
const layout = { |
||||
x: padding * idx + offset, |
||||
y: padding * idy + offset, |
||||
width: padding - 2 * offset, |
||||
height: padding - 2 * offset, |
||||
// stroke: '#CF6412',
|
||||
// strokeWidth: 4
|
||||
}; |
||||
if (map[idy][idx] == tile.dangerous) { |
||||
tile_group.add(new Konva.Rect({ |
||||
...layout, |
||||
fill: '#FF7B17', |
||||
opacity: 1, |
||||
})) |
||||
} else if (map[idy][idx] == tile.end) { |
||||
tile_group.add(new Konva.Rect({ |
||||
...layout, |
||||
fill: '#0eb500', |
||||
opacity: 0.5, |
||||
})) |
||||
} else if (map[idy][idx] == tile.start) { |
||||
tile_group.add(new Konva.Rect({ |
||||
...layout, |
||||
fill: '#ffc908', |
||||
opacity: 0.5, |
||||
})) |
||||
} |
||||
} |
||||
} |
||||
|
||||
map_group.offset({ |
||||
x: -(map_group.width() - grid_group.width()) / 2, |
||||
y: -(map_group.height() - grid_group.height()) / 2, |
||||
}); |
||||
|
||||
agent = { |
||||
"konva": new Konva.RegularPolygon({ |
||||
offset: { |
||||
x: -padding / 2, |
||||
y: -padding / 2 |
||||
}, |
||||
sides: 5, |
||||
radius: padding / 3, |
||||
fill: '#00D2FF', |
||||
stroke: 'black', |
||||
strokeWidth: 2 |
||||
}), |
||||
get x() { |
||||
return Math.floor(this.konva.getX() / padding); |
||||
}, |
||||
set x(pos) { |
||||
this.konva.setX(sort([0, pos, map[0].length - 1])[1] * padding); |
||||
map_layer.draw(); |
||||
}, |
||||
get y() { |
||||
return Math.floor(this.konva.getY() / padding); |
||||
}, |
||||
set y(pos) { |
||||
this.konva.setY(sort([0, pos, map.length - 1])[1] * padding); |
||||
map_layer.draw(); |
||||
}, |
||||
up() { |
||||
this.y--; |
||||
}, |
||||
down() { |
||||
this.y++; |
||||
}, |
||||
left() { |
||||
this.x++; |
||||
}, |
||||
right() { |
||||
this.x--; |
||||
}, |
||||
set_state(state){ |
||||
this.y = Math.floor(state/map[0].length); |
||||
this.x = state%map[0].length; |
||||
} |
||||
} |
||||
|
||||
agent.set_state(state); |
||||
agent_group.add(agent.konva); |
||||
|
||||
map_group.add(grid_group); |
||||
map_group.add(tile_group); |
||||
map_group.add(agent_group); |
||||
map_layer.draw(); |
||||
} |
||||
|
||||
draw_map(map, 32); |
||||
|
||||
window.addEventListener('resize', function() { |
||||
canvas_width = canvas.offsetWidth; |
||||
canvas_height = canvas.offsetHeight; |
||||
draw_map(map, machine.state); |
||||
}); |
Loading…
Reference in new issue