diff --git a/css/src/style.scss b/css/src/style.scss
index 1b47f7c..619fd02 100644
--- a/css/src/style.scss
+++ b/css/src/style.scss
@@ -4,6 +4,7 @@
}
body{
// background-color: ;
+ font-family: sans-serif;
}
#container{
height: 100vh;
@@ -30,3 +31,11 @@ nav{
width: 20vw;
height: 10vw;
}
+
+.sliders{
+ position: absolute;
+ top: 7vh;
+ left: 2vw;
+ width: 20vw;
+ height: 10vw;
+}
diff --git a/css/style.min.css b/css/style.min.css
index c33b7f6..a983e74 100644
--- a/css/style.min.css
+++ b/css/style.min.css
@@ -1 +1 @@
-*{margin:0;padding:0}#container{height:100vh;position:relative}#canvas{height:100%}nav{position:absolute;top:10px;left:10px}.absolute{position:absolute;top:0;left:0}.plot{position:absolute;top:2vh;right:2vw;width:20vw;height:10vw}
+*{margin:0;padding:0}body{font-family:sans-serif}#container{height:100vh;position:relative}#canvas{height:100%}nav{position:absolute;top:10px;left:10px}.absolute{position:absolute;top:0;left:0}.plot{position:absolute;top:2vh;right:2vw;width:20vw;height:10vw}.sliders{position:absolute;top:7vh;left:2vw;width:20vw;height:10vw}
diff --git a/index.html b/index.html
index 1348449..dbf2b14 100644
--- a/index.html
+++ b/index.html
@@ -8,6 +8,13 @@
+
+
+
+
+
+
+
RL exhibit - prototype
@@ -27,25 +34,42 @@
-->
-
+
+
+
+
+
-
+
+
+
Learning Rate {{learning_rate}}
+
+
Discount Factor {{discount_factor}}
+
+
Epsilon {{epsilon}}
+
+
Current Score
+
{{score}}
+
+
diff --git a/js/rl.js b/js/rl.js
index 6ee35a5..b32da77 100644
--- a/js/rl.js
+++ b/js/rl.js
@@ -8,25 +8,29 @@ class RL_machine {
learning_rate,
discount_factor,
epsilon=0) {
- this.q_table = actions_per_state.map((c) => c.reduce((o,n) => {o[n]=0; return o},{}));
+ this.actions_per_state = actions_per_state;
this.transactions = transactions;
this.rewards = rewards;
this.lr = learning_rate;
this.df = discount_factor;
- this.state = start_state;
this.start_state = start_state;
this.end_score = end_score;
this.end_states = end_states;
- this.episode = 0;
this.epsilon = epsilon;
- this.score = 0;
- this.running = false;
- this.score_history = [];
+ this.q_table = this.actions_per_state.map((c) => c.reduce((o,n) => {o[n]=0; return o},{}));
+ this.reset_machine();
}
reset_machine(){
- this.q_table = this.q_table.map((c) => c.map((a) => a.fill(0)));
+ for (var q in this.q_table){
+ for (var key in this.q_table[q]){
+ this.q_table[q][key] = 0;
+ }
+ }
this.episode = 0;
this.state = this.start_state;
+ this.score = 0;
+ this.running = false;
+ this.score_history = [];
}
new_episode(){
// add_new_episode_callback
@@ -127,6 +131,10 @@ class Maze {
for (let idy=0; idy index - N;
- })
- .reduce(
- function(last, current,index, arr){
- return (current/arr.length + last);
- },0);
- });
+ function(x2, i2) {
+ return i2 <= index && i2 > index - N;
+ })
+ .reduce(
+ function(last, current, index, arr) {
+ return (current / arr.length + last);
+ }, 0);
+ });
};
-Array.prototype.max=function() {
-return this.map(
- function(el,index, _arr) {
+Array.prototype.max = function() {
+ return this.map(
+ function(el, index, _arr) {
return _arr.filter(
- function(x2,i2) {
- return i2 <= index;
- })
- .reduce(
- function(last, current){
- return last > current ? last:current;
- },-1000000000);
- });
+ function(x2, i2) {
+ return i2 <= index;
+ })
+ .reduce(
+ function(last, current) {
+ return last > current ? last : current;
+ }, -1000000000);
+ });
};
app = new Vue({
el: '#app',
+ components: {
+ VueSlider: window['vue-slider-component']
+ },
data: {
width: 0,
height: 0,
q_table: machine.q_table,
maze: maze,
- state: {x:0,y:0},
+ state: {
+ x: 0,
+ y: 0
+ },
state_tween: new TimelineLite(),
score: machine.score,
score_history: machine.score_history,
labels: [],
+ learning_rate: machine.lr,
+ discount_factor: machine.df,
+ epsilon: machine.epsilon,
+ slider_config: {
+ min: 0,
+ max: 1,
+ duration: 0,
+ interval: 0.01,
+ tooltip: 'none'
+ }
},
created() {
// Resize handler
@@ -75,8 +93,13 @@ app = new Vue({
var $this = this;
this.state = this.s2p(s);
Object.defineProperty(machine, 'state', {
- get: function() { return this._state },
- set: function(ne) { this._state=ne; $this.handleState(this._state); }
+ get: function() {
+ return this._state
+ },
+ set: function(ne) {
+ this._state = ne;
+ $this.handleState(this._state);
+ }
});
machine.state = s;
// Score wrapper
@@ -84,8 +107,13 @@ app = new Vue({
var $this = this;
this.score = s;
Object.defineProperty(machine, 'score', {
- get: function() { return this._score },
- set: function(ne) { this._score=ne; $this.score=ne}
+ get: function() {
+ return this._score
+ },
+ set: function(ne) {
+ this._score = ne;
+ $this.score = ne
+ }
});
machine.score = s;
// Score history wrapper
@@ -93,8 +121,13 @@ app = new Vue({
var $this = this;
this.score_history = s;
Object.defineProperty(machine, 'score_history', {
- get: function() { return this._score_history },
- set: function(ne) { this._score_history=ne; $this.score_history=ne}
+ get: function() {
+ return this._score_history
+ },
+ set: function(ne) {
+ this._score_history = ne;
+ $this.score_history = ne
+ }
});
machine.score_history = s;
},
@@ -102,66 +135,84 @@ app = new Vue({
window.removeEventListener('resize', this.handleResize)
},
computed: {
- datacollection: function () {
+ datacollection: function() {
return {
labels: Array.from(Array(this.score_history.length).keys()),
- datasets: [
- {
+ datasets: [{
label: 'Data One',
backgroundColor: 'rgb(0,0,0,0)',
- data: this.score_history.simpleSMA(Math.round(50)),
+ data: this.score_history,//.simpleSMA(Math.round(50)),
fill: false,
borderColor: 'rgb(255, 159, 64)',
pointRadius: 1,
},
- {
- label: 'Data One',
- backgroundColor: 'rgb(0,0,0,0)',
- data: this.score_history.max(),
- fill: false,
- borderColor: 'rgb(64, 159, 255)',
- pointRadius: 1,
- },
+ // {
+ // label: 'Data One',
+ // backgroundColor: 'rgb(0,0,0,0)',
+ // data: this.score_history.max(),
+ // fill: false,
+ // borderColor: 'rgb(64, 159, 255)',
+ // pointRadius: 1,
+ // },
]
}
},
- stage_config: function () {
+ plot_options: function() {
+ var $this = this;
+ return {
+ responsive: true,
+ maintainAspectRatio: false,
+ scales: {
+ xAxes: [{
+ // type: 'linear',
+ ticks: {
+ maxTicksLimit: 8,
+ maxRotation: 0,
+ }
+ }]
+ },
+ legend: {
+ display: false
+ }
+ }
+ },
+ stage_config: function() {
return {
width: this.width,
height: this.height,
}
},
- mini_map_config: function () {
+ mini_map_config: function() {
return {
- x:this.width/2-(this.base_size*(this.maze.width)/2),
- y:this.height/2-(this.base_size*(this.maze.height)/2),
- scale:{
+ x: this.width / 2 - (this.base_size * (this.maze.width) / 2),
+ y: this.height / 2 - (this.base_size * (this.maze.height) / 2),
+ scale: {
x: 1,
y: 1
}
}
},
- local_layer: function () {
+ local_layer: function() {
return {
- x: this.width/2,
- y: this.height/2,
- scale:{
+ x: this.width / 2,
+ y: this.height / 2,
+ scale: {
x: 2,
y: 2
}
}
},
- map_config: function () {
+ map_config: function() {
return {
- x: this.base_size*(this.maze.width-this.state.x),
- y: this.base_size*(this.maze.height-this.state.y),
+ x: this.base_size * (this.maze.width - this.state.x),
+ y: this.base_size * (this.maze.height - this.state.y),
offset: {
- x: this.base_size*this.maze.width+this.base_size/2,
- y: this.base_size*this.maze.height+this.base_size/2,
+ x: this.base_size * this.maze.width + this.base_size / 2,
+ y: this.base_size * this.maze.height + this.base_size / 2,
}
}
},
- agent_config: function () {
+ agent_config: function() {
return {
sides: 5,
radius: this.base_size / 3,
@@ -172,43 +223,60 @@ app = new Vue({
x: -this.base_size / 2,
y: -this.base_size / 2
},
- x: this.base_size*this.state.x,
- y: this.base_size*this.state.y,
+ x: this.base_size * this.state.x,
+ y: this.base_size * this.state.y,
}
},
- base_size: function () {
- return Math.min(this.stage_config.height * 0.8 / this.maze.height, this.stage_config.width * 0.5 / this.maze.width);
+ base_size: function() {
+ return Math.min(this.stage_config.height * 0.8 / this.maze.height, this.stage_config.width * 0.5 / this.maze.width);
},
- strokeW: function () {
+ strokeW: function() {
return this.base_size / 50;
},
+ extreme_q_values: function(){
+ var max = -10*30;
+ var min = 10*30;
+ for (field in this.q_table) {
+ for (key in this.q_table[field]){
+ if (this.q_table[field][key]max){
+ max = this.q_table[field][key];
+ }
+ }
+ }
+ return {min:min,max:max};
+ }
},
methods: {
- s2p: function(state){
+ s2p: function(state) {
return {
- x: (state%this.maze.width),
- y: Math.floor(state/this.maze.width),
+ x: (state % this.maze.width),
+ y: Math.floor(state / this.maze.width),
}
},
- p2s: function(x,y){
- return x+y*this.maze.width;
+ p2s: function(x, y) {
+ return x + y * this.maze.width;
},
handleResize: function() {
this.width = window.innerWidth;
this.height = window.innerHeight;
},
handleState: function(s) {
- if (!machine.running){
- this.state_tween.to(this.state, 0.2, { x: this.s2p(s).x, y: this.s2p(s).y });
+ if (!machine.running) {
+ this.state_tween.to(this.state, 0.2, {
+ x: this.s2p(s).x,
+ y: this.s2p(s).y
+ });
} else {
this.state = this.s2p(s);
}
// this.hidden_state = s;
},
- get_grid_line_config: function (idx, y=false) {
- var offset = this.strokeW/2;
- if (y){
- var points = [-offset, Math.round(idx * this.base_size), this.base_size * this.maze.width + offset,Math.round(idx * this.base_size)];
+ get_grid_line_config: function(idx, y = false) {
+ var offset = this.strokeW / 2;
+ if (y) {
+ var points = [-offset, Math.round(idx * this.base_size), this.base_size * this.maze.width + offset, Math.round(idx * this.base_size)];
} else {
var points = [Math.round(idx * this.base_size), -offset, Math.round(idx * this.base_size), this.base_size * this.maze.height + offset];
}
@@ -218,34 +286,128 @@ app = new Vue({
strokeWidth: this.strokeW,
}
},
- get_tile_type: function (state){
+ get_tile_type: function(state) {
var pos = this.s2p(state);
- if (pos.y > maze.height){
+ if (pos.y > maze.height) {
return null;
- } else if (pos.x > maze.width){
+ } else if (pos.x > maze.width) {
return null;
} else {
return maze.map[pos.y][pos.x];
}
},
- in_plus: function (pos1, pos2) {
- if (Math.abs(pos1.x-pos2.x) + Math.abs(pos1.y-pos2.y) < 2) {
+ in_plus: function(pos1, pos2) {
+ if (Math.abs(pos1.x - pos2.x) + Math.abs(pos1.y - pos2.y) < 2) {
return true;
}
return false;
},
- get_tile_config: function (i, t_type, local=false) {
+ get_field_config: function(state) {
+ var pos = this.s2p(state);
+ return {
+ x: this.base_size * pos.x+this.base_size/2,
+ y: this.base_size * pos.y+this.base_size/2,
+ }
+ },
+ get_q_text_config: function (val, i) {
+ var off, key;
+ switch (i) {
+ case 1:
+ off = {
+ align: "center",
+ verticalAlign: "top",
+ };
+ key = dir.UP;
+ break;
+ case 2:
+ off = {
+ align: "right",
+ verticalAlign: "middle",
+ };
+ key = dir.RIGHT;
+ break;
+ case 3:
+ off = {
+ align: "center",
+ verticalAlign: "bottom",
+ };
+ key = dir.DOWN;
+ break;
+ case 4:
+ off = {
+ align: "left",
+ verticalAlign: "middle",
+ };
+ key = dir.LEFT;
+ break;
+ }
+ if (val[key] === undefined) {
+ return {}
+ }
+ return {
+ fontSize: this.base_size/7,
+ fontFamily: 'Calibri',
+ fill: 'black',
+ text: +val[key].toFixed(2)+'',
+ width: this.base_size-5,
+ height: this.base_size-5,
+ ...off,
+ offset: {
+ x: (this.base_size-5)/2,
+ y: (this.base_size-5)/2,
+ }
+ }
+ },
+ get_triangle_config: function(value, d) {
+ var rot = 0;
+ switch (d) {
+ case dir.UP:
+ rot = -90;
+ break;
+ case dir.RIGHT:
+ rot = 0;
+ break;
+ case dir.DOWN:
+ rot = 90;
+ break;
+ case dir.LEFT:
+ rot = 180;
+ break;
+ }
+ var $this = this;
+ var norma_value = (value-this.extreme_q_values.min)/((this.extreme_q_values.max-this.extreme_q_values.min)||1);
+ return {
+ sceneFunc: function(context, shape) {
+ context.beginPath();
+ context.moveTo(0, 0);
+ context.lineTo($this.base_size / 2, $this.base_size / 2);
+ context.lineTo($this.base_size / 2, -$this.base_size / 2);
+ context.lineTo(0, 0);
+ context.closePath();
+ // (!) Konva specific method, it is very important
+ context.fillStrokeShape(shape);
+ },
+ fill: palette[Math.round(norma_value*99)],
+ stroke: 'black',
+ strokeWidth: 0,
+ rotation: rot,
+ }
+ },
+ get_tile_config: function(i, t_type, local = false) {
var pos = this.s2p(i);
var over = {};
// not in plus
if (local) {
- if (!this.in_plus(this.s2p(i),{x:Math.round(this.state.x),y:Math.round(this.state.y)})) {
+ if (!this.in_plus(this.s2p(i), {
+ x: Math.round(this.state.x),
+ y: Math.round(this.state.y)
+ })) {
over = {
opacity: 0,
fill: "#eee"
};
- } else if (i != this.p2s(Math.round(this.state.x),Math.round(this.state.y))) {
+ } else if (i != this.p2s(Math.round(this.state.x), Math.round(this.state.y))) {
over = {
opacity: 1,
fill: "#eee"
@@ -299,4 +461,23 @@ app = new Vue({
}
}
},
+ watch: {
+ learning_rate: function(new_val) {
+ machine.lr = new_val;
+ render_latex();
+ },
+ discount_factor: function(new_val) {
+ machine.df = new_val;
+ render_latex();
+ },
+ epsilon: function(new_val) {
+ machine.epsilon = new_val;
+ }
+ }
})
+
+function render_latex() {
+ // (1-lr) * Q[state, action] + lr * (reward + gamma * np.max(Q[new_state, :])
+ katex.render(`Q(s,a)\\leftarrow${(1-machine.lr).toFixed(2)}Q(s,a)+${machine.lr.toFixed(2)}(reward + ${machine.df.toFixed(2)} * \\max_a(Q(s', a))`, document.getElementById('test'),{displayMode: true,});
+}
+render_latex();