| import random | |
| import numpy as np | |
| from scipy.misc import imresize | |
| class CatchEnv3: | |
| def __init__(self): | |
| self.size = 21 | |
| self.image = np.zeros((self.size, self.size)) | |
| self.state = [] | |
| self.fps = 4 | |
| self.output_shape = (84, 84) | |
| def reset_random(self): | |
| self.image.fill(0) | |
| self.pos = np.random.randint(2, self.size-2) | |
| self.vx = np.random.randint(5) - 2 | |
| self.vy = 1 | |
| self.ballx, self.bally = np.random.randint(self.size), 4 | |
| self.image[self.bally, self.ballx] = 1 | |
| self.image[-5, self.pos - 2:self.pos + 3] = np.ones(5) | |
| return self.step(2)[0] | |
| def step(self, action): | |
| def left(): | |
| if self.pos > 3: | |
| self.pos -= 2 | |
| def right(): | |
| if self.pos < 17: | |
| self.pos += 2 | |
| def noop(): | |
| pass | |
| {0: left, 1: right, 2: noop}[action]() | |
| self.image[self.bally, self.ballx] = 0 | |
| self.ballx += self.vx | |
| self.bally += self.vy | |
| if self.ballx > self.size - 1: | |
| self.ballx -= 2 * (self.ballx - (self.size-1)) | |
| self.vx *= -1 | |
| elif self.ballx < 0: | |
| self.ballx += 2 * (0 - self.ballx) | |
| self.vx *= -1 | |
| self.image[self.bally, self.ballx] = 1 | |
| self.image[-5].fill(0) | |
| self.image[-5, self.pos-2:self.pos+3] = np.ones(5) | |
| terminal = self.bally == self.size - 1 - 4 | |
| reward = int(self.pos - 2 <= self.ballx <= self.pos + 2) if terminal else 0 | |
| [self.state.append(imresize(self.image, (84, 84))) for _ in range(self.fps - len(self.state) + 1)] | |
| self.state = self.state[-self.fps:] | |
| self.state[0] = self.state[0][::-1,:] | |
| self.state[1] = self.state[1][::-1,:] | |
| self.state[2] = self.state[2][::-1,:] | |
| self.state[3] = self.state[3][::-1,:] | |
| return np.transpose(self.state, [1, 2, 0]), reward, terminal | |
| def get_num_actions(self): | |
| return 3 | |
| def reset(self): | |
| return self.reset_random() | |
| def state_shape(self): | |
| return (self.fps,) + self.output_shape | |
| def test(): | |
| env = CatchEnv2() | |
| i = 0 | |
| for ep in range(1): | |
| env.reset() | |
| state, reward, terminal = env.step(1) | |
| while not terminal: | |
| env.show_state(i) | |
| state, reward, terminal = env.step(1) | |
| state = np.squeeze(state) | |
| plt.imsave('image_'+str(i)+'.jpg', state) | |
| i += 1 | |
| if __name__ == "main": | |
| test() | |