Spaces:
Runtime error
Runtime error
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import collections | |
| import random | |
| class ReplayBuffer: | |
| def __init__(self, capacity): | |
| self.buffer = collections.deque(maxlen=capacity) | |
| def add(self, state, action, reward, next_state, done): | |
| self.buffer.append((state, action, reward, next_state, done)) | |
| def sample(self, batch_size): | |
| transitions = random.sample(self.buffer, batch_size) | |
| state, action, reward, next_state, done = zip(*transitions) | |
| return np.array(state), np.array(action), reward, np.array(next_state), done | |
| def size(self): | |
| return len(self.buffer) |