master
Ugo Finnendahl 5 years ago
commit cfe6fe5734
  1. 19
      css/src/style.scss
  2. 1
      css/style.min.css
  3. 22
      index.html
  4. 50
      js/controls.js
  5. 220
      js/rl.js
  6. 170
      js/view.js

@ -0,0 +1,19 @@
*{
margin: 0;
padding: 0;
}
body{
// background-color: ;
}
#container{
height: 100vh;
position: relative;
}
#canvas{
height: 100%;
}
nav{
position: absolute;
top: 10px;
left: 10px;
}

1
css/style.min.css vendored

@ -0,0 +1 @@
*{margin:0;padding:0}#container{height:100vh;position:relative}#canvas{height:100%}nav{position:absolute;top:10px;left:10px}

@ -0,0 +1,22 @@
<!DOCTYPE html>
<html lang="en" dir="ltr">
<head>
<meta charset="utf-8">
<script src="https://unpkg.com/konva@4.0.0/konva.min.js"></script>
<title>RL exhibit - prototype</title>
<link rel="stylesheet" href="css/style.min.css">
</head>
<body>
<div id="container">
<div id="canvas"></div>
<nav>
<button class="button" onclick="machine.run(100)">run 100 episodes!</button>
<button class="button" onclick="machine.auto_step();draw_map(map, machine.state);">auto step!</button>
<button class="button" onclick="machine.greedy_step();draw_map(map, machine.state);">greedy step!</button>
</nav>
</div>
<script src="js/rl.js"></script>
<script src="js/view.js"></script>
<script src="js/controls.js"></script>
</body>
</html>

50
js/controls.js vendored

@ -0,0 +1,50 @@
function dir_to_action(dir){
let actions = [...Array(machine.q_table[machine.state].length).keys()].map((a) => maze.get_direction(machine.state,a));
var action = actions.indexOf(dir);
if (action>-1){
return action;
}
return undefined;
}
var animate = false;
function key_callback(e) {
var tmp;
if (animate){
return
}
switch (e.keyCode) {
case 37:
tmp = dir_to_action(dir.LEFT);
break;
case 38:
tmp = dir_to_action(dir.UP);
break;
case 39:
tmp = dir_to_action(dir.RIGHT);
break;
case 40:
tmp = dir_to_action(dir.DOWN);
break;
}
if (tmp != undefined){
machine.step(tmp)
}
draw_map(map, machine.state);
}
document.addEventListener('keydown', key_callback);
function show_solution() {
var sol = machine.current_solution();
animate = true;
show_path(sol.states, 0);
}
function show_path(path, i){
if (path.length == i) {
animate = false;
return
}
agent.set_state(path[i]);
window.setTimeout(function(){ show_path(path, ++i) }, 1000);
}

@ -0,0 +1,220 @@
class RL_machine {
constructor(actions_per_state,
transactions,
rewards,
start_state,
end_states,
end_score,
learning_rate,
discount_factor,
epsilon=0) {
this.q_table = actions_per_state.map((c) => Array(c).fill(0));
this.transactions = transactions;
this.rewards = rewards;
this.lr = learning_rate;
this.df = discount_factor;
this.state = start_state;
this.start_state = start_state;
this.end_score = end_score;
this.end_states = end_states;
this.episode = 0;
this.epsilon = epsilon;
this.score = 0;
}
reset_machine(){
this.q_table = this.q_table.map((c) => c.map((a) => a.fill(0)));
this.episode = 0;
this.state = this.start_state;
}
new_episode(){
// add_new_episode_callback
this.episode++;
this.state = this.start_state;
this.score = 0;
}
auto_step(){
if (Math.random() < this.epsilon){
return this.step(Math.floor(Math.random() * this.q_table[this.state].length));
} else{
return this.greedy_step();
}
}
greedy_step(){
return this.step(argMax(this.q_table[this.state]));
}
step(action){
this.state = this.update_q_table(this.state, action);
// add_new_step_callback
if (this.end_states.indexOf(this.state) >= 0 || this.score < this.end_score){
this.new_episode();
return 2
}
return 1
}
update_q_table(state, action){
let new_state = this.transactions(state, action);
this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...this.q_table[new_state]));
this.score += this.rewards[new_state];
return new_state;
}
run(episodes, max_steps_per_episode=10000){
for (var i = 0; i < episodes; i++) {
for (var j = 0; j < max_steps_per_episode; j++) {
if (this.auto_step() == 2) {
break;
}
}
this.new_episode();
}
}
current_solution(max_steps_per_episode=10000){
let temp_state = this.start_state;
let score = 0;
let states = [temp_state];
let actions = [];
let scores = [];
for (var j = 0; j < max_steps_per_episode; j++) {
let ac = argMax(this.q_table[temp_state]);
temp_state = this.transactions(temp_state, ac);
let sc = this.rewards[temp_state];
score += sc;
actions.push(ac);
scores.push(sc);
states.push(temp_state);
if (this.end_states.indexOf(temp_state) >= 0 || score < this.end_score){
return {actions: actions, scores: scores, states: states}
}
}
return {actions: actions, scores: scores, states: states}
}
}
function argMax(array) {
return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1];
}
// ------------------ maze stuff --------------------------------------------
const tile = {
regular: 0,
start: 1,
end: 2,
dangerous: 4,
};
const dir = {
UP: 0,
RIGHT: 1,
DOWN: 2,
LEFT: 3,
};
class Maze {
constructor(map, reward_map) {
this.map = map
this.height = map.length;
this.width = map[0].length;
this.start_state = this.get_states(tile.start)[0];
this.end_states = this.get_states(tile.end);
this.actions = this.get_actions();
this.transactions = this.get_transactions();
this.rewards = this.get_rewards(reward_map);
}
get_states(tile) {
var res = [];
for (var idy = 0; idy < this.map.length; idy++) {
for (var idx = 0; idx < this.map[idy].length; idx++) {
if (this.map[idy][idx] == tile) {
res.push(idy*this.map[0].length+idx);
}
}
}
return res;
}
get_actions() {
var actions = [];
for (let idy=0; idy<this.map.length; idy++){
let y_actions = 0;
if (idy != 0){
y_actions++;
}
if (idy != this.map.length-1){
y_actions++;
}
for (let idx=0; idx<this.map[0].length; idx++){
let actions_sum = y_actions;
if (idx != 0){
actions_sum++;
}
if (idx != this.map[0].length-1){
actions_sum++;
}
actions.push(actions_sum);
}
}
return actions;
}
get_direction(state, action){
const y = Math.floor(state/this.width);
const x = state%this.width;
var h_flip = dir.RIGHT;
var flex = dir.DOWN
if (x == this.width-1){
h_flip = dir.LEFT;
}
if (y == this.height-1){
flex = dir.LEFT;
} else if (y == 0){
action++;
}
switch (action) {
case 0:
return dir.UP;
case 1:
return h_flip;
case 2:
return flex;
case 3:
return dir.LEFT;
}
}
get_transactions(){
return function(state, action){
var kk = this.get_direction(state, action);
switch (kk) {
case dir.UP:
return state-this.width;
case dir.RIGHT:
return state+1;
case dir.DOWN:
return state+this.width;
case dir.LEFT:
return state-1;
}
}.bind(this);
}
get_rewards(rewards){
rewards = [];
for (let idy=0; idy<this.map.length; idy++){
for (let idx=0; idx<this.map[0].length; idx++){
rewards.push(reward[this.map[idy][idx]]);
}
}
return rewards;
}
}
var map = [
[0, 0, 4, 2, 0, 0, 0, 0],
[0, 0, 4, 4, 4, 4, 0, 0],
[4, 0, 0, 0, 0, 4, 0, 4],
[0, 0, 4, 0, 0, 0, 0, 0],
[1, 0, 4, 0, 4, 0, 0, 4]
];
const reward = {[tile.regular]:-1,[tile.dangerous]:-1000,[tile.end]:1000,[tile.start]:-1};
var maze = new Maze(map, reward);
var learning_rate = 1;
var discount_factor = 1;
var machine = new RL_machine(maze.actions, maze.transactions, maze.rewards, maze.start_state, maze.end_states, -999, learning_rate, discount_factor, 0.5);

@ -0,0 +1,170 @@
const canvas = document.getElementById("canvas");
var canvas_width = canvas.offsetWidth;
var canvas_height = canvas.offsetHeight;
function sort(array) {
return array.sort(function(a, b) {
return a - b;
})
}
const grid_line = {
stroke: '#ddd',
}
// first we need to create a stage
var stage = new Konva.Stage({
container: canvas,
});
var agent;
var map_layer = new Konva.Layer();
var map_group = new Konva.Group();
var grid_group = new Konva.Group();
var tile_group = new Konva.Group();
var agent_group = new Konva.Group();
map_layer.add(map_group);
stage.add(map_layer);
function init_stage() {
stage.width(canvas_width);
stage.height(canvas_height);
map_group.width(stage.width() * 0.6);
map_group.height(stage.height() * 0.9);
map_group.setX(stage.width() * 0.2);
map_group.setY(stage.height() / 2 - (stage.height() * 0.45));
}
function draw_map(map, state) {
// cleanup
agent_group.remove();
agent_group = new Konva.Group();
grid_group.remove();
grid_group = new Konva.Group();
tile_group.remove();
tile_group = new Konva.Group();
init_stage();
map = map;
var padding = Math.min(map_group.height() / map.length, map_group.width() / map[0].length);
var strokeW = 16 / Math.max(map.length, map[0].length);
const offset = strokeW / 2;
// x
for (let i = 0; i < map[0].length + 1; i++) {
grid_group.add(new Konva.Line({
points: [Math.round(i * padding), -offset, Math.round(i * padding), padding * map.length + offset],
strokeWidth: strokeW,
...grid_line
}));
grid_group.width(Math.round(i * padding));
}
// y
for (let j = 0; j < map.length + 1; j++) {
grid_group.add(new Konva.Line({
points: [-offset, Math.round(j * padding), padding * map[0].length + offset, Math.round(j * padding)],
strokeWidth: strokeW,
...grid_line
}));
grid_group.height(Math.round(j * padding));
}
for (var idy in map) {
for (var idx in map[idy]) {
const layout = {
x: padding * idx + offset,
y: padding * idy + offset,
width: padding - 2 * offset,
height: padding - 2 * offset,
// stroke: '#CF6412',
// strokeWidth: 4
};
if (map[idy][idx] == tile.dangerous) {
tile_group.add(new Konva.Rect({
...layout,
fill: '#FF7B17',
opacity: 1,
}))
} else if (map[idy][idx] == tile.end) {
tile_group.add(new Konva.Rect({
...layout,
fill: '#0eb500',
opacity: 0.5,
}))
} else if (map[idy][idx] == tile.start) {
tile_group.add(new Konva.Rect({
...layout,
fill: '#ffc908',
opacity: 0.5,
}))
}
}
}
map_group.offset({
x: -(map_group.width() - grid_group.width()) / 2,
y: -(map_group.height() - grid_group.height()) / 2,
});
agent = {
"konva": new Konva.RegularPolygon({
offset: {
x: -padding / 2,
y: -padding / 2
},
sides: 5,
radius: padding / 3,
fill: '#00D2FF',
stroke: 'black',
strokeWidth: 2
}),
get x() {
return Math.floor(this.konva.getX() / padding);
},
set x(pos) {
this.konva.setX(sort([0, pos, map[0].length - 1])[1] * padding);
map_layer.draw();
},
get y() {
return Math.floor(this.konva.getY() / padding);
},
set y(pos) {
this.konva.setY(sort([0, pos, map.length - 1])[1] * padding);
map_layer.draw();
},
up() {
this.y--;
},
down() {
this.y++;
},
left() {
this.x++;
},
right() {
this.x--;
},
set_state(state){
this.y = Math.floor(state/map[0].length);
this.x = state%map[0].length;
}
}
agent.set_state(state);
agent_group.add(agent.konva);
map_group.add(grid_group);
map_group.add(tile_group);
map_group.add(agent_group);
map_layer.draw();
}
draw_map(map, 32);
window.addEventListener('resize', function() {
canvas_width = canvas.offsetWidth;
canvas_height = canvas.offsetHeight;
draw_map(map, machine.state);
});
Loading…
Cancel
Save