-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Riley McDowell
committed
Feb 8, 2017
1 parent
2c072de
commit 76cd55a
Showing
12 changed files
with
395 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
*.pkl | ||
*.csv | ||
*.swp | ||
*.swo | ||
*.ini | ||
*egg-info* | ||
*build | ||
*__pycache__* | ||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pygame>=1.9.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
from setuptools import setup, find_packages | ||
|
||
# Utility function to read the README file. | ||
def read(fname): | ||
return open(os.path.join(os.path.dirname(__file__), fname)).read() | ||
|
||
setup( | ||
name = "snakenet", | ||
version = "0.1", | ||
author = "Riley McDowell", | ||
author_email = "[email protected]", | ||
description = "A playable implementation of the 'Snake' game, complete" \ | ||
"complete with neuralnet player", | ||
license = "MIT", | ||
packages=find_packages(), | ||
long_description=read('README.md'), | ||
entry_points={'console_scripts': ['snake=snakenet.main:main']} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
BLACK = (0, 0, 0) | ||
WHITE = (255, 255, 255) | ||
GRAY = (128, 128, 128) | ||
RED = (255, 0, 0) | ||
GREEN = (0, 255, 0) | ||
BLUE = (0, 0, 255) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pygame | ||
|
||
from snakenet.colors import * | ||
from snakenet.game_constants import PIXEL_SIZE, PAD_ROWS, PAD_COLUMNS, RESOLUTION | ||
|
||
def _color_cell(pix_array, row, column, color): | ||
row_start = row*PIXEL_SIZE + PIXEL_SIZE*PAD_ROWS | ||
row_end = row_start + PIXEL_SIZE | ||
column_start = column*PIXEL_SIZE + PIXEL_SIZE*PAD_COLUMNS | ||
column_end = column_start + PIXEL_SIZE | ||
pix_array[column_start:column_end, row_start:row_end] = color | ||
|
||
def draw_plane(game): | ||
game.window_surface.fill(GRAY) | ||
|
||
pix_array = pygame.PixelArray(game.window_surface) | ||
|
||
# Black Board | ||
pix_array[PIXEL_SIZE:RESOLUTION[1]-PIXEL_SIZE, PIXEL_SIZE:RESOLUTION[0]-PIXEL_SIZE] = BLACK | ||
|
||
# Draw the snake. | ||
for (row, column) in game.state.snake_deque: | ||
_color_cell(pix_array, row, column, WHITE) | ||
|
||
# Draw the food | ||
_color_cell(pix_array, game.state.food_position[0], game.state.food_position[1], RED) | ||
|
||
# Do cleanup | ||
del pix_array | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pygame | ||
|
||
from snakenet.game_constants import RESOLUTION, RIGHT, LEFT, UP, DOWN | ||
from snakenet.game_state import GameState | ||
|
||
FLAGS = 0 # No flags | ||
COLOR_DEPTH = 32 # bits | ||
CAPTION = "Snake!" | ||
|
||
class Game(object): | ||
def __init__(self): | ||
self.window_surface = pygame.display.set_mode(RESOLUTION, FLAGS, COLOR_DEPTH) | ||
pygame.display.set_caption(CAPTION) | ||
self.state = GameState() | ||
|
||
def keypress(self, direction): | ||
if direction == UP: | ||
self.state.last_pressed = UP | ||
elif direction == DOWN: | ||
self.state.last_pressed = DOWN | ||
elif direction == RIGHT: | ||
self.state.last_pressed = RIGHT | ||
elif direction == LEFT: | ||
self.state.last_pressed = LEFT | ||
|
||
def move(self): | ||
self.state.process_move() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from collections import namedtuple | ||
|
||
# Screen Constants | ||
PIXEL_SIZE = 8 | ||
NUM_ROWS = 32 + 1 | ||
NUM_COLUMNS = 32 + 1 | ||
PAD_ROWS = 1 | ||
PAD_COLUMNS = 1 | ||
RESOLUTION = (PIXEL_SIZE * (NUM_ROWS+PAD_ROWS*2), PIXEL_SIZE * (NUM_COLUMNS+PAD_COLUMNS*2)) | ||
|
||
# Directions | ||
DOWN = 'd' | ||
UP = 'u' | ||
RIGHT = 'r' | ||
LEFT = 'l' | ||
|
||
# Snake Constants | ||
SNAKE_INITIALSIZE = 4 | ||
SNAKE_GROWBY = 3 | ||
|
||
# Plane Constants | ||
SNAKE_VALUE = 0 | ||
EMPTY_VALUE = 1 | ||
ITEM_VALUE = 2 | ||
|
||
Score = namedtuple('Score', ['food', 'moves']) | ||
|
||
# Exceptional States | ||
class LoseException(Exception): | ||
def __init__(self, food, moves): | ||
self.score = Score(food, moves) | ||
msg = "You Lose. Ate {} times. Moved {} times.".format(self.score.food, self.score.moves) | ||
super(LoseException, self).__init__(msg) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy as np | ||
|
||
from snakenet.game_constants import NUM_ROWS, NUM_COLUMNS, SNAKE_INITIALSIZE, SNAKE_VALUE, EMPTY_VALUE, LoseException, SNAKE_GROWBY | ||
from snakenet.game_constants import UP, DOWN, RIGHT, LEFT | ||
from collections import deque | ||
|
||
class GameState(object): | ||
def __init__(self): | ||
self.plane = np.full((NUM_ROWS, NUM_COLUMNS), fill_value=EMPTY_VALUE, dtype=np.uint8) | ||
self.snake_deque = deque(maxlen=SNAKE_INITIALSIZE) | ||
self.snake_position = None | ||
self.food_position = None | ||
self.last_pressed = None | ||
|
||
# Setup the initial game state. | ||
self.initialize() | ||
self.set_random_food_position() | ||
|
||
# Track score | ||
self.times_eaten = 0 | ||
self.moves = 0 | ||
|
||
def initialize(self): | ||
middle_row = NUM_ROWS // 2 | ||
middle_column = NUM_COLUMNS // 2 | ||
|
||
for i in range(SNAKE_INITIALSIZE): | ||
# End the last loop exactly in the middle. | ||
row = middle_row + SNAKE_INITIALSIZE - i - 1 | ||
self.set_snake(row, middle_column) | ||
|
||
def set_random_food_position(self): | ||
while True: | ||
row = np.random.choice(np.arange(0, NUM_ROWS, dtype=np.uint8)) | ||
column = np.random.choice(np.arange(0, NUM_COLUMNS, dtype=np.uint8)) | ||
if (row, column) in self.snake_deque: | ||
continue | ||
else: | ||
self.food_position = (row, column) | ||
break | ||
|
||
def eat_food(self): | ||
self.snake_deque = deque(self.snake_deque, maxlen=self.snake_deque.maxlen + SNAKE_GROWBY) | ||
self.set_random_food_position() | ||
self.times_eaten += 1 | ||
|
||
def lose(self): | ||
raise LoseException(self.times_eaten, self.moves) | ||
|
||
def set_snake(self, row, column): | ||
position = (row, column) | ||
if row < 0 or row >= NUM_ROWS: | ||
self.lose() | ||
if column < 0 or column >= NUM_COLUMNS: | ||
self.lose() | ||
if position in self.snake_deque: | ||
self.lose() | ||
|
||
self.plane[row, column] = SNAKE_VALUE | ||
self.snake_deque.appendleft((row, column)) | ||
self.snake_position = (row, column) | ||
|
||
if position == self.food_position: | ||
self.eat_food() | ||
|
||
@property | ||
def row(self): | ||
return self.snake_position[0] | ||
|
||
@property | ||
def column(self): | ||
return self.snake_position[1] | ||
|
||
def process_move(self): | ||
if self.last_pressed == UP: | ||
self.set_snake(self.row - 1, self.column) | ||
if self.last_pressed == DOWN: | ||
self.set_snake(self.row + 1, self.column) | ||
if self.last_pressed == RIGHT: | ||
self.set_snake(self.row, self.column + 1) | ||
if self.last_pressed == LEFT: | ||
self.set_snake(self.row, self.column - 1) | ||
self.moves += 1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
|
||
import sys | ||
import pygame | ||
import pygame.locals as pl | ||
|
||
from argparse import ArgumentParser | ||
|
||
from snakenet.draw import draw_plane | ||
from snakenet.game import Game | ||
from snakenet.game_constants import DOWN, UP, RIGHT, LEFT | ||
from snakenet.model_player import get_model_keypress | ||
|
||
QUIT = 'quit' | ||
TICK = pygame.USEREVENT + 1 | ||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument('--input', choices=['user', 'model'], default='user') | ||
return parser.parse_args() | ||
|
||
ARGS = parse_args() | ||
|
||
def process_tick(game): | ||
game.move() | ||
|
||
def process_keypress(game, key): | ||
if key == pygame.K_UP: | ||
game.keypress(UP) | ||
if key == pygame.K_DOWN: | ||
game.keypress(DOWN) | ||
if key == pygame.K_RIGHT: | ||
game.keypress(RIGHT) | ||
if key == pygame.K_LEFT: | ||
game.keypress(LEFT) | ||
|
||
def process_event(game, event): | ||
if event.type == pl.QUIT: | ||
sys.exit(0) | ||
if event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE: | ||
sys.exit(0) | ||
|
||
if event.type == pygame.KEYDOWN: | ||
if ARGS.input == 'user': | ||
process_keypress(game, event.key) | ||
|
||
if event.type == TICK: | ||
if ARGS.input == 'model': | ||
key = get_model_keypress(game) | ||
process_keypress(game, key) | ||
# Update the game. | ||
process_tick(game) | ||
# Refresh the image. | ||
draw_plane(game) | ||
pygame.display.update() | ||
|
||
def process_events(game, clock): | ||
for event in pygame.event.get(): | ||
process_event(game, event) | ||
|
||
def mainloop(game): | ||
clock = pygame.time.Clock() | ||
pygame.time.set_timer(TICK, 100) | ||
while True: | ||
process_events(game, clock) | ||
clock.tick(60) | ||
|
||
def main(): | ||
pygame.init() | ||
game = Game() | ||
|
||
# Do an initial draw. | ||
draw_plane(game) | ||
pygame.display.update() | ||
|
||
# Enter the mainloop. | ||
try: | ||
mainloop(game) | ||
except (KeyboardInterrupt, SystemExit): | ||
pygame.quit() | ||
sys.exit() | ||
except Exception as e: | ||
import traceback | ||
traceback.print_exc() | ||
exit(1) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import pygame | ||
|
||
def get_model_keypress(game): | ||
return pygame.K_RIGHT |
Oops, something went wrong.