-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_dyna_q.py
118 lines (88 loc) · 3.57 KB
/
test_dyna_q.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from mdp.algorithms.dyna_q import DynaQ
from mdp.environment.env import Environment
import numpy as np
import pygame
pygame.init()
from env_config import grid, actions, rewards, gw, gh
CONVERT_POLICY = {1: '↓', 0: '↑' , 2: '→', 3: '←'}
DISPLAY_GRID = True
UTILITY_FONT_SIZE = 15
UTILITY_OFFSET = (4, 14)
POLICY_FONT_SIZE = 30
POLICY_OFFSET = (17, 5)
ratio = 1
mdp = Environment(grid, actions, rewards, gw, gh)
dyna_q = DynaQ(n_w=6, n_h=6, n_actions=4)
q_table = dyna_q.solve(mdp=mdp)
print('\n(Column, Row)')
for i in range(q_table.shape[0]):
for j in range(q_table.shape[1]):
print(f"{j, i}: {max(q_table[i][j])}")
# Display utility and policy plot
if DISPLAY_GRID:
GREEN = (100, 200, 100)
RED = (200, 100, 100)
WHITE = (200, 200, 200)
GREY = (50, 50, 50)
directions = [[CONVERT_POLICY[np.argmax(cell)] for cell in row] for row in q_table]
utilities = [["{:.3f}".format(np.max(cell)) for cell in row] for row in q_table]
colors = []
for row in grid:
color = []
for cell in row:
if cell == 'W':
color.append(GREY)
elif cell == 'G':
color.append(GREEN)
elif cell == 'R':
color.append(RED)
else:
color.append(WHITE)
colors.append(color)
block_size = 50
width = 300
height = 300
screen_dimensions = (width, height)
screen_color = (0, 0, 0)
policy_font = pygame.font.Font("assets/seguisym.ttf", int(POLICY_FONT_SIZE*ratio))
utility_font = pygame.font.Font("assets/seguisym.ttf", int(UTILITY_FONT_SIZE*ratio))
screen = pygame.display.set_mode(screen_dimensions)
pygame.display.set_caption('Dyna Q')
# Display Policy
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
rect = pygame.Rect(0, 0, width, height)
pygame.draw.rect(screen, screen_color, rect)
for row in range(len(grid)):
for col in range(len(grid)):
rect = pygame.Rect(col * block_size, row * block_size, block_size, block_size)
pygame.draw.rect(screen, colors[row][col], rect)
pygame.draw.rect(screen, (0, 0, 0), rect, 1)
if grid[row][col] == 'W':
continue
message = policy_font.render(directions[row][col], True, (0, 0, 0))
screen.blit(message, (col * block_size + POLICY_OFFSET[0] * ratio, row * block_size + POLICY_OFFSET[1]*ratio))
pygame.display.update()
screen = pygame.display.set_mode(screen_dimensions)
pygame.display.set_caption('Dyna Q')
# Display Utilities
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
rect = pygame.Rect(0, 0, width, height)
pygame.draw.rect(screen, screen_color, rect)
for row in range(len(grid)):
for col in range(len(grid)):
rect = pygame.Rect(col * block_size, row * block_size, block_size, block_size)
pygame.draw.rect(screen, colors[row][col], rect)
pygame.draw.rect(screen, (0, 0, 0), rect, 1)
if grid[row][col] == 'W':
continue
message = utility_font.render(utilities[row][col], True, (0, 0, 0))
screen.blit(message, (col * block_size + UTILITY_OFFSET[0]*ratio, row * block_size + UTILITY_OFFSET[1]*ratio))
pygame.display.update()