diff --git a/css/src/style.scss b/css/src/style.scss index 1b47f7c..619fd02 100644 --- a/css/src/style.scss +++ b/css/src/style.scss @@ -4,6 +4,7 @@ } body{ // background-color: ; + font-family: sans-serif; } #container{ height: 100vh; @@ -30,3 +31,11 @@ nav{ width: 20vw; height: 10vw; } + +.sliders{ + position: absolute; + top: 7vh; + left: 2vw; + width: 20vw; + height: 10vw; +} diff --git a/css/style.min.css b/css/style.min.css index c33b7f6..a983e74 100644 --- a/css/style.min.css +++ b/css/style.min.css @@ -1 +1 @@ -*{margin:0;padding:0}#container{height:100vh;position:relative}#canvas{height:100%}nav{position:absolute;top:10px;left:10px}.absolute{position:absolute;top:0;left:0}.plot{position:absolute;top:2vh;right:2vw;width:20vw;height:10vw} +*{margin:0;padding:0}body{font-family:sans-serif}#container{height:100vh;position:relative}#canvas{height:100%}nav{position:absolute;top:10px;left:10px}.absolute{position:absolute;top:0;left:0}.plot{position:absolute;top:2vh;right:2vw;width:20vw;height:10vw}.sliders{position:absolute;top:7vh;left:2vw;width:20vw;height:10vw} diff --git a/index.html b/index.html index 1348449..dbf2b14 100644 --- a/index.html +++ b/index.html @@ -8,6 +8,13 @@ + + + + + + + RL exhibit - prototype @@ -27,25 +34,42 @@ --> - + + + + + - + +
+

Learning Rate {{learning_rate}}

+ +

Discount Factor {{discount_factor}}

+ +

Epsilon {{epsilon}}

+ +

Current Score

+

{{score}}

+
+
diff --git a/js/rl.js b/js/rl.js index 6ee35a5..b32da77 100644 --- a/js/rl.js +++ b/js/rl.js @@ -8,25 +8,29 @@ class RL_machine { learning_rate, discount_factor, epsilon=0) { - this.q_table = actions_per_state.map((c) => c.reduce((o,n) => {o[n]=0; return o},{})); + this.actions_per_state = actions_per_state; this.transactions = transactions; this.rewards = rewards; this.lr = learning_rate; this.df = discount_factor; - this.state = start_state; this.start_state = start_state; this.end_score = end_score; this.end_states = end_states; - this.episode = 0; this.epsilon = epsilon; - this.score = 0; - this.running = false; - this.score_history = []; + this.q_table = this.actions_per_state.map((c) => c.reduce((o,n) => {o[n]=0; return o},{})); + this.reset_machine(); } reset_machine(){ - this.q_table = this.q_table.map((c) => c.map((a) => a.fill(0))); + for (var q in this.q_table){ + for (var key in this.q_table[q]){ + this.q_table[q][key] = 0; + } + } this.episode = 0; this.state = this.start_state; + this.score = 0; + this.running = false; + this.score_history = []; } new_episode(){ // add_new_episode_callback @@ -127,6 +131,10 @@ class Maze { for (let idy=0; idy index - N; - }) - .reduce( - function(last, current,index, arr){ - return (current/arr.length + last); - },0); - }); + function(x2, i2) { + return i2 <= index && i2 > index - N; + }) + .reduce( + function(last, current, index, arr) { + return (current / arr.length + last); + }, 0); + }); }; -Array.prototype.max=function() { -return this.map( - function(el,index, _arr) { +Array.prototype.max = function() { + return this.map( + function(el, index, _arr) { return _arr.filter( - function(x2,i2) { - return i2 <= index; - }) - .reduce( - function(last, current){ - return last > current ? last:current; - },-1000000000); - }); + function(x2, i2) { + return i2 <= index; + }) + .reduce( + function(last, current) { + return last > current ? last : current; + }, -1000000000); + }); }; app = new Vue({ el: '#app', + components: { + VueSlider: window['vue-slider-component'] + }, data: { width: 0, height: 0, q_table: machine.q_table, maze: maze, - state: {x:0,y:0}, + state: { + x: 0, + y: 0 + }, state_tween: new TimelineLite(), score: machine.score, score_history: machine.score_history, labels: [], + learning_rate: machine.lr, + discount_factor: machine.df, + epsilon: machine.epsilon, + slider_config: { + min: 0, + max: 1, + duration: 0, + interval: 0.01, + tooltip: 'none' + } }, created() { // Resize handler @@ -75,8 +93,13 @@ app = new Vue({ var $this = this; this.state = this.s2p(s); Object.defineProperty(machine, 'state', { - get: function() { return this._state }, - set: function(ne) { this._state=ne; $this.handleState(this._state); } + get: function() { + return this._state + }, + set: function(ne) { + this._state = ne; + $this.handleState(this._state); + } }); machine.state = s; // Score wrapper @@ -84,8 +107,13 @@ app = new Vue({ var $this = this; this.score = s; Object.defineProperty(machine, 'score', { - get: function() { return this._score }, - set: function(ne) { this._score=ne; $this.score=ne} + get: function() { + return this._score + }, + set: function(ne) { + this._score = ne; + $this.score = ne + } }); machine.score = s; // Score history wrapper @@ -93,8 +121,13 @@ app = new Vue({ var $this = this; this.score_history = s; Object.defineProperty(machine, 'score_history', { - get: function() { return this._score_history }, - set: function(ne) { this._score_history=ne; $this.score_history=ne} + get: function() { + return this._score_history + }, + set: function(ne) { + this._score_history = ne; + $this.score_history = ne + } }); machine.score_history = s; }, @@ -102,66 +135,84 @@ app = new Vue({ window.removeEventListener('resize', this.handleResize) }, computed: { - datacollection: function () { + datacollection: function() { return { labels: Array.from(Array(this.score_history.length).keys()), - datasets: [ - { + datasets: [{ label: 'Data One', backgroundColor: 'rgb(0,0,0,0)', - data: this.score_history.simpleSMA(Math.round(50)), + data: this.score_history,//.simpleSMA(Math.round(50)), fill: false, borderColor: 'rgb(255, 159, 64)', pointRadius: 1, }, - { - label: 'Data One', - backgroundColor: 'rgb(0,0,0,0)', - data: this.score_history.max(), - fill: false, - borderColor: 'rgb(64, 159, 255)', - pointRadius: 1, - }, + // { + // label: 'Data One', + // backgroundColor: 'rgb(0,0,0,0)', + // data: this.score_history.max(), + // fill: false, + // borderColor: 'rgb(64, 159, 255)', + // pointRadius: 1, + // }, ] } }, - stage_config: function () { + plot_options: function() { + var $this = this; + return { + responsive: true, + maintainAspectRatio: false, + scales: { + xAxes: [{ + // type: 'linear', + ticks: { + maxTicksLimit: 8, + maxRotation: 0, + } + }] + }, + legend: { + display: false + } + } + }, + stage_config: function() { return { width: this.width, height: this.height, } }, - mini_map_config: function () { + mini_map_config: function() { return { - x:this.width/2-(this.base_size*(this.maze.width)/2), - y:this.height/2-(this.base_size*(this.maze.height)/2), - scale:{ + x: this.width / 2 - (this.base_size * (this.maze.width) / 2), + y: this.height / 2 - (this.base_size * (this.maze.height) / 2), + scale: { x: 1, y: 1 } } }, - local_layer: function () { + local_layer: function() { return { - x: this.width/2, - y: this.height/2, - scale:{ + x: this.width / 2, + y: this.height / 2, + scale: { x: 2, y: 2 } } }, - map_config: function () { + map_config: function() { return { - x: this.base_size*(this.maze.width-this.state.x), - y: this.base_size*(this.maze.height-this.state.y), + x: this.base_size * (this.maze.width - this.state.x), + y: this.base_size * (this.maze.height - this.state.y), offset: { - x: this.base_size*this.maze.width+this.base_size/2, - y: this.base_size*this.maze.height+this.base_size/2, + x: this.base_size * this.maze.width + this.base_size / 2, + y: this.base_size * this.maze.height + this.base_size / 2, } } }, - agent_config: function () { + agent_config: function() { return { sides: 5, radius: this.base_size / 3, @@ -172,43 +223,60 @@ app = new Vue({ x: -this.base_size / 2, y: -this.base_size / 2 }, - x: this.base_size*this.state.x, - y: this.base_size*this.state.y, + x: this.base_size * this.state.x, + y: this.base_size * this.state.y, } }, - base_size: function () { - return Math.min(this.stage_config.height * 0.8 / this.maze.height, this.stage_config.width * 0.5 / this.maze.width); + base_size: function() { + return Math.min(this.stage_config.height * 0.8 / this.maze.height, this.stage_config.width * 0.5 / this.maze.width); }, - strokeW: function () { + strokeW: function() { return this.base_size / 50; }, + extreme_q_values: function(){ + var max = -10*30; + var min = 10*30; + for (field in this.q_table) { + for (key in this.q_table[field]){ + if (this.q_table[field][key]max){ + max = this.q_table[field][key]; + } + } + } + return {min:min,max:max}; + } }, methods: { - s2p: function(state){ + s2p: function(state) { return { - x: (state%this.maze.width), - y: Math.floor(state/this.maze.width), + x: (state % this.maze.width), + y: Math.floor(state / this.maze.width), } }, - p2s: function(x,y){ - return x+y*this.maze.width; + p2s: function(x, y) { + return x + y * this.maze.width; }, handleResize: function() { this.width = window.innerWidth; this.height = window.innerHeight; }, handleState: function(s) { - if (!machine.running){ - this.state_tween.to(this.state, 0.2, { x: this.s2p(s).x, y: this.s2p(s).y }); + if (!machine.running) { + this.state_tween.to(this.state, 0.2, { + x: this.s2p(s).x, + y: this.s2p(s).y + }); } else { this.state = this.s2p(s); } // this.hidden_state = s; }, - get_grid_line_config: function (idx, y=false) { - var offset = this.strokeW/2; - if (y){ - var points = [-offset, Math.round(idx * this.base_size), this.base_size * this.maze.width + offset,Math.round(idx * this.base_size)]; + get_grid_line_config: function(idx, y = false) { + var offset = this.strokeW / 2; + if (y) { + var points = [-offset, Math.round(idx * this.base_size), this.base_size * this.maze.width + offset, Math.round(idx * this.base_size)]; } else { var points = [Math.round(idx * this.base_size), -offset, Math.round(idx * this.base_size), this.base_size * this.maze.height + offset]; } @@ -218,34 +286,128 @@ app = new Vue({ strokeWidth: this.strokeW, } }, - get_tile_type: function (state){ + get_tile_type: function(state) { var pos = this.s2p(state); - if (pos.y > maze.height){ + if (pos.y > maze.height) { return null; - } else if (pos.x > maze.width){ + } else if (pos.x > maze.width) { return null; } else { return maze.map[pos.y][pos.x]; } }, - in_plus: function (pos1, pos2) { - if (Math.abs(pos1.x-pos2.x) + Math.abs(pos1.y-pos2.y) < 2) { + in_plus: function(pos1, pos2) { + if (Math.abs(pos1.x - pos2.x) + Math.abs(pos1.y - pos2.y) < 2) { return true; } return false; }, - get_tile_config: function (i, t_type, local=false) { + get_field_config: function(state) { + var pos = this.s2p(state); + return { + x: this.base_size * pos.x+this.base_size/2, + y: this.base_size * pos.y+this.base_size/2, + } + }, + get_q_text_config: function (val, i) { + var off, key; + switch (i) { + case 1: + off = { + align: "center", + verticalAlign: "top", + }; + key = dir.UP; + break; + case 2: + off = { + align: "right", + verticalAlign: "middle", + }; + key = dir.RIGHT; + break; + case 3: + off = { + align: "center", + verticalAlign: "bottom", + }; + key = dir.DOWN; + break; + case 4: + off = { + align: "left", + verticalAlign: "middle", + }; + key = dir.LEFT; + break; + } + if (val[key] === undefined) { + return {} + } + return { + fontSize: this.base_size/7, + fontFamily: 'Calibri', + fill: 'black', + text: +val[key].toFixed(2)+'', + width: this.base_size-5, + height: this.base_size-5, + ...off, + offset: { + x: (this.base_size-5)/2, + y: (this.base_size-5)/2, + } + } + }, + get_triangle_config: function(value, d) { + var rot = 0; + switch (d) { + case dir.UP: + rot = -90; + break; + case dir.RIGHT: + rot = 0; + break; + case dir.DOWN: + rot = 90; + break; + case dir.LEFT: + rot = 180; + break; + } + var $this = this; + var norma_value = (value-this.extreme_q_values.min)/((this.extreme_q_values.max-this.extreme_q_values.min)||1); + return { + sceneFunc: function(context, shape) { + context.beginPath(); + context.moveTo(0, 0); + context.lineTo($this.base_size / 2, $this.base_size / 2); + context.lineTo($this.base_size / 2, -$this.base_size / 2); + context.lineTo(0, 0); + context.closePath(); + // (!) Konva specific method, it is very important + context.fillStrokeShape(shape); + }, + fill: palette[Math.round(norma_value*99)], + stroke: 'black', + strokeWidth: 0, + rotation: rot, + } + }, + get_tile_config: function(i, t_type, local = false) { var pos = this.s2p(i); var over = {}; // not in plus if (local) { - if (!this.in_plus(this.s2p(i),{x:Math.round(this.state.x),y:Math.round(this.state.y)})) { + if (!this.in_plus(this.s2p(i), { + x: Math.round(this.state.x), + y: Math.round(this.state.y) + })) { over = { opacity: 0, fill: "#eee" }; - } else if (i != this.p2s(Math.round(this.state.x),Math.round(this.state.y))) { + } else if (i != this.p2s(Math.round(this.state.x), Math.round(this.state.y))) { over = { opacity: 1, fill: "#eee" @@ -299,4 +461,23 @@ app = new Vue({ } } }, + watch: { + learning_rate: function(new_val) { + machine.lr = new_val; + render_latex(); + }, + discount_factor: function(new_val) { + machine.df = new_val; + render_latex(); + }, + epsilon: function(new_val) { + machine.epsilon = new_val; + } + } }) + +function render_latex() { + // (1-lr) * Q[state, action] + lr * (reward + gamma * np.max(Q[new_state, :]) + katex.render(`Q(s,a)\\leftarrow${(1-machine.lr).toFixed(2)}Q(s,a)+${machine.lr.toFixed(2)}(reward + ${machine.df.toFixed(2)} * \\max_a(Q(s', a))`, document.getElementById('test'),{displayMode: true,}); +} +render_latex();