-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_client.py
More file actions
89 lines (69 loc) · 2.47 KB
/
test_client.py
File metadata and controls
89 lines (69 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import typing
import requests
from gym import Env, spaces
class MouseAndCheese(Env):
mouse = (0, 0)
cheese = (4, 4)
number_of_actions = 4
number_of_observations = 4
steps = 0
def __init__(self):
self.action_space = spaces.Discrete(self.number_of_actions)
self.observation_space = spaces.Box(0, 1, (self.number_of_observations,))
def reset(self):
self.mouse = (0, 0)
self.cheese = (4, 4)
self.steps = 0
return self.get_observation()
def step(self, action):
self.steps += 1
if action == 0: # move up
self.mouse = (min(self.mouse[0] - 1, 5), self.mouse[1])
elif action == 1: # move right
self.mouse = (self.mouse[0], min(self.mouse[1] + 1, 5))
elif action == 2: # move down
self.mouse = (max(self.mouse[0] + 1, 0), self.mouse[1])
elif action == 3: # move left
self.mouse = (self.mouse[0], max(self.mouse[1] - 1, 0))
else:
raise ValueError("Invalid action")
return self.get_observation(), self.get_reward(), self.is_done(), {}
def get_observation(self) -> typing.List[float]:
return [
float(self.mouse[0]) / 5.0,
float(self.mouse[1]) / 5.0,
abs(self.cheese[0] - self.mouse[0]) / 5.0,
abs(self.cheese[1] - self.mouse[1]) / 5.0,
]
def get_reward(self) -> float:
return 1 if self.mouse == self.cheese else 0
def is_done(self) -> bool:
return self.mouse == self.cheese
env = MouseAndCheese()
auth = ("admin", "admin")
def get_payload(obs, reward, done):
obs_dict = {
"mouse_row": obs[0],
"mouse_col": obs[1],
"mouse_row_dist": obs[2],
"mouse_col_dist": obs[3],
}
payload = {"observation": obs_dict, "reward": reward, "done": done}
return payload
for episode in range(100):
obs = env.reset()
reward = 0
done = False
while not done:
payload = get_payload(obs, reward, done)
response = requests.post(
"http://localhost:8000/collect_experience/", json=payload, auth=auth
).json()
action = response.get("actions")[0]
obs, reward, done, info = env.step(action)
if done:
payload = get_payload(obs, reward, done)
response = requests.post(
"http://localhost:8000/collect_experience/", json=payload, auth=auth
).json()
print(">>> Episode complete.")