-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
39 lines (29 loc) · 1.26 KB
/
main.py
File metadata and controls
39 lines (29 loc) · 1.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from agent import DQN_Stable_Baselines_Agent
from agent import PPO_Stable_Baselines_Agent
from agent import PPO_LSTM_Stable_Baselines_Agent
from env.snake_env import SnakeEnv
from config import DQN_TRAIN_TIMESTEPS, DQN_TRAIN_EPOCHS, PPO_TRAIN_TIMESTEPS, PPO_TRAIN_EPOCHS, PPO_LSTM_TRAIN_EPOCHS, \
PPO_LSTM_TRAIN_TIMESTEPS
from ui.snake_renderer import SnakeRenderer
from config.generic import GRID_SIZE, CELL_SIZE, SPEED
from agent.human import HumanAgent
if __name__ == "__main__":
renderer = SnakeRenderer(grid_size=GRID_SIZE, cell_size=CELL_SIZE, render_rate=SPEED)
# Human play
# agent = HumanAgent()
# env = SnakeEnv(grid_size=GRID_SIZE, renderer=renderer, agent=agent)
# env.run()
# DQN Stable Baselines 3
# agent = DQN_Stable_Baselines_Agent()
# agent = PPO_Stable_Baselines_Agent()
agent = PPO_LSTM_Stable_Baselines_Agent()
env = SnakeEnv(grid_size=GRID_SIZE, renderer=renderer, agent=agent)
# for epoch in range(PPO_LSTM_TRAIN_EPOCHS):
# agent.score = 0
# print(f"Training epoch {epoch}...")
# env.train_agent(timesteps=PPO_LSTM_TRAIN_TIMESTEPS)
#
# print(f"Testing epoch {epoch}...")
# env.run(close_renderer_after_run=False)
# Test
env.run(close_renderer_after_run=True)