-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample.py
51 lines (40 loc) · 1.41 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gym
from trainers import PPOMultiAgentTrainer
from utils import *
import pickle
numAgents = 2
agentViewRadius = 5
smallMap = [
list(' '),
list(' @ '),
list(' @@@ '),
list(' @@@ '),
list(' @ '),
list(' '),
list(' '),
list(' ')]
mediumMap = [list(' @ @ @ @ @ '),
list('@@@ @@@ @@@ @@@ @@@ '),
list(' @ @ @ @ @ '),
list(' '),
list(' @ @ @ @ '),
list(' @@@ @@@ @@@ @@@ '),
list(' @ @ @ @ ')]
env = gym.make('CommonsGame:CommonsGame-v0', numAgents=numAgents, visualRadius=agentViewRadius, mapSketch=mediumMap)
archSpecs = [ProtoMLP([256], ['relu'], useBias=True), ProtoLSTMNet([128])]
maxEpisodes = 10000
maxEpisodeLength = 1000
updatePeriod = 2000
logPeriod = 20
savePeriod = 100
logPath = 'system_{}_agents.data'.format(numAgents)
loadModel = False
def main():
if loadModel:
trainer = PPOMultiAgentTrainer(env, modelPath=logPath)
trainer.test(maxEpisodeLength)
else:
trainer = PPOMultiAgentTrainer(env, neuralNetSpecs=archSpecs, learningRate=0.002)
trainer.train(maxEpisodes, maxEpisodeLength, logPeriod, savePeriod, logPath)
if __name__ == '__main__':
main()