From cfe6fe5734ee118ab5bbedcdcc71388efad64f4f Mon Sep 17 00:00:00 2001 From: Ugo Date: Thu, 26 Sep 2019 17:57:50 +0200 Subject: [PATCH] Framework --- css/src/style.scss | 19 ++++ css/style.min.css | 1 + index.html | 22 +++++ js/controls.js | 50 +++++++++++ js/rl.js | 220 +++++++++++++++++++++++++++++++++++++++++++++ js/view.js | 170 +++++++++++++++++++++++++++++++++++ 6 files changed, 482 insertions(+) create mode 100644 css/src/style.scss create mode 100644 css/style.min.css create mode 100644 index.html create mode 100644 js/controls.js create mode 100644 js/rl.js create mode 100644 js/view.js diff --git a/css/src/style.scss b/css/src/style.scss new file mode 100644 index 0000000..de9b36e --- /dev/null +++ b/css/src/style.scss @@ -0,0 +1,19 @@ +*{ + margin: 0; + padding: 0; +} +body{ + // background-color: ; +} +#container{ + height: 100vh; + position: relative; +} +#canvas{ + height: 100%; +} +nav{ + position: absolute; + top: 10px; + left: 10px; +} diff --git a/css/style.min.css b/css/style.min.css new file mode 100644 index 0000000..37d36d7 --- /dev/null +++ b/css/style.min.css @@ -0,0 +1 @@ +*{margin:0;padding:0}#container{height:100vh;position:relative}#canvas{height:100%}nav{position:absolute;top:10px;left:10px} diff --git a/index.html b/index.html new file mode 100644 index 0000000..379de73 --- /dev/null +++ b/index.html @@ -0,0 +1,22 @@ + + + + + + RL exhibit - prototype + + + +
+
+ +
+ + + + + diff --git a/js/controls.js b/js/controls.js new file mode 100644 index 0000000..26b0f56 --- /dev/null +++ b/js/controls.js @@ -0,0 +1,50 @@ +function dir_to_action(dir){ + let actions = [...Array(machine.q_table[machine.state].length).keys()].map((a) => maze.get_direction(machine.state,a)); + var action = actions.indexOf(dir); + if (action>-1){ + return action; + } + return undefined; +} +var animate = false; + +function key_callback(e) { + var tmp; + if (animate){ + return + } + switch (e.keyCode) { + case 37: + tmp = dir_to_action(dir.LEFT); + break; + case 38: + tmp = dir_to_action(dir.UP); + break; + case 39: + tmp = dir_to_action(dir.RIGHT); + break; + case 40: + tmp = dir_to_action(dir.DOWN); + break; + } + if (tmp != undefined){ + machine.step(tmp) + } + draw_map(map, machine.state); +} +document.addEventListener('keydown', key_callback); + +function show_solution() { + var sol = machine.current_solution(); + animate = true; + show_path(sol.states, 0); +} + +function show_path(path, i){ + if (path.length == i) { + animate = false; + return + } + agent.set_state(path[i]); + window.setTimeout(function(){ show_path(path, ++i) }, 1000); +} diff --git a/js/rl.js b/js/rl.js new file mode 100644 index 0000000..c6e2698 --- /dev/null +++ b/js/rl.js @@ -0,0 +1,220 @@ +class RL_machine { + constructor(actions_per_state, + transactions, + rewards, + start_state, + end_states, + end_score, + learning_rate, + discount_factor, + epsilon=0) { + this.q_table = actions_per_state.map((c) => Array(c).fill(0)); + this.transactions = transactions; + this.rewards = rewards; + this.lr = learning_rate; + this.df = discount_factor; + this.state = start_state; + this.start_state = start_state; + this.end_score = end_score; + this.end_states = end_states; + this.episode = 0; + this.epsilon = epsilon; + this.score = 0; + } + reset_machine(){ + this.q_table = this.q_table.map((c) => c.map((a) => a.fill(0))); + this.episode = 0; + this.state = this.start_state; + } + new_episode(){ + // add_new_episode_callback + this.episode++; + this.state = this.start_state; + this.score = 0; + } + auto_step(){ + if (Math.random() < this.epsilon){ + return this.step(Math.floor(Math.random() * this.q_table[this.state].length)); + } else{ + return this.greedy_step(); + } + } + greedy_step(){ + return this.step(argMax(this.q_table[this.state])); + } + step(action){ + this.state = this.update_q_table(this.state, action); + // add_new_step_callback + if (this.end_states.indexOf(this.state) >= 0 || this.score < this.end_score){ + this.new_episode(); + return 2 + } + return 1 + } + update_q_table(state, action){ + let new_state = this.transactions(state, action); + this.q_table[state][action] = (1-this.lr)*this.q_table[state][action] + this.lr*(this.rewards[new_state] + this.df*Math.max(...this.q_table[new_state])); + this.score += this.rewards[new_state]; + return new_state; + } + run(episodes, max_steps_per_episode=10000){ + for (var i = 0; i < episodes; i++) { + for (var j = 0; j < max_steps_per_episode; j++) { + if (this.auto_step() == 2) { + break; + } + } + this.new_episode(); + } + } + current_solution(max_steps_per_episode=10000){ + let temp_state = this.start_state; + let score = 0; + let states = [temp_state]; + let actions = []; + let scores = []; + for (var j = 0; j < max_steps_per_episode; j++) { + let ac = argMax(this.q_table[temp_state]); + temp_state = this.transactions(temp_state, ac); + let sc = this.rewards[temp_state]; + score += sc; + actions.push(ac); + scores.push(sc); + states.push(temp_state); + if (this.end_states.indexOf(temp_state) >= 0 || score < this.end_score){ + return {actions: actions, scores: scores, states: states} + } + } + return {actions: actions, scores: scores, states: states} + } +} + +function argMax(array) { + return array.map((x, i) => [x, i]).reduce((r, a) => (a[0] > r[0] ? a : r))[1]; +} + +// ------------------ maze stuff -------------------------------------------- +const tile = { + regular: 0, + start: 1, + end: 2, + dangerous: 4, +}; + +const dir = { + UP: 0, + RIGHT: 1, + DOWN: 2, + LEFT: 3, +}; + +class Maze { + constructor(map, reward_map) { + this.map = map + this.height = map.length; + this.width = map[0].length; + this.start_state = this.get_states(tile.start)[0]; + this.end_states = this.get_states(tile.end); + this.actions = this.get_actions(); + this.transactions = this.get_transactions(); + this.rewards = this.get_rewards(reward_map); + } + get_states(tile) { + var res = []; + for (var idy = 0; idy < this.map.length; idy++) { + for (var idx = 0; idx < this.map[idy].length; idx++) { + if (this.map[idy][idx] == tile) { + res.push(idy*this.map[0].length+idx); + } + } + } + return res; + } + get_actions() { + var actions = []; + for (let idy=0; idy