Spaces:
Sleeping
Sleeping
| """ | |
| API functions for sampling from anticipatory infilling models. | |
| """ | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from anticipation import ops | |
| from anticipation.config import * | |
| from anticipation.vocab import * | |
| def safe_logits(logits, idx): | |
| logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls | |
| logits[SPECIAL_OFFSET:] = -float('inf') # don't generate special tokens | |
| # don't generate stuff in the wrong time slot | |
| if idx % 3 == 0: | |
| logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf') | |
| logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf') | |
| elif idx % 3 == 1: | |
| logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf') | |
| logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf') | |
| elif idx % 3 == 2: | |
| logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf') | |
| logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf') | |
| return logits | |
| def nucleus(logits, top_p): | |
| # from HF implementation | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep also the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| # scatter sorted tensors to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) | |
| logits[indices_to_remove] = -float("inf") | |
| return logits | |
| def future_logits(logits, curtime): | |
| """ don't sample events in the past """ | |
| if curtime > 0: | |
| logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf') | |
| return logits | |
| def instr_logits(logits, full_history): | |
| """ don't sample more than 16 instruments """ | |
| instrs = ops.get_instruments(full_history) | |
| if len(instrs) < 15: # 16 - 1 to account for the reserved drum track | |
| return logits | |
| for instr in range(MAX_INSTR): | |
| if instr not in instrs: | |
| logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf') | |
| return logits | |
| def add_token(model, z, tokens, top_p, current_time, debug=False): | |
| assert len(tokens) % 3 == 0 | |
| history = tokens.copy() | |
| lookback = max(len(tokens) - 1017, 0) | |
| history = history[lookback:] # Markov window | |
| offset = ops.min_time(history, seconds=False) | |
| history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer | |
| new_token = [] | |
| with torch.no_grad(): | |
| for i in range(3): | |
| input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device) | |
| logits = model(input_tokens).logits[0,-1] | |
| idx = input_tokens.shape[1]-1 | |
| logits = safe_logits(logits, idx) | |
| if i == 0: | |
| logits = future_logits(logits, current_time - offset) | |
| elif i == 2: | |
| logits = instr_logits(logits, tokens) | |
| logits = nucleus(logits, top_p) | |
| probs = F.softmax(logits, dim=-1) | |
| token = torch.multinomial(probs, 1) | |
| new_token.append(int(token)) | |
| new_token[0] += offset # revert to full sequence timing | |
| if debug: | |
| print(f' OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}') | |
| return new_token | |
| def generate(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION): | |
| if inputs is None: | |
| inputs = [] | |
| if controls is None: | |
| controls = [] | |
| start_time = int(TIME_RESOLUTION*start_time) | |
| end_time = int(TIME_RESOLUTION*end_time) | |
| # prompt is events up to start_time | |
| prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time) | |
| # treat events beyond start_time as controls | |
| future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False) | |
| if debug: | |
| print('Future') | |
| ops.print_tokens(future) | |
| # clip controls that preceed the sequence | |
| controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False) | |
| if debug: | |
| print('Controls') | |
| ops.print_tokens(controls) | |
| z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 else [AUTOREGRESS] | |
| if debug: | |
| print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode') | |
| # interleave the controls with the events | |
| tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future])) | |
| if debug: | |
| print('Prompt') | |
| ops.print_tokens(tokens) | |
| current_time = ops.max_time(prompt, seconds=False) | |
| if debug: | |
| print('Current time:', current_time) | |
| with tqdm(range(end_time-start_time)) as progress: | |
| if controls: | |
| atime, adur, anote = controls[0:3] | |
| anticipated_tokens = controls[3:] | |
| anticipated_time = atime - ATIME_OFFSET | |
| else: | |
| # nothing to anticipate | |
| anticipated_time = math.inf | |
| while True: | |
| while current_time >= anticipated_time - delta: | |
| tokens.extend([atime, adur, anote]) | |
| if debug: | |
| note = anote - ANOTE_OFFSET | |
| instr = note//2**7 | |
| print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr) | |
| if len(anticipated_tokens) > 0: | |
| atime, adur, anote = anticipated_tokens[0:3] | |
| anticipated_tokens = anticipated_tokens[3:] | |
| anticipated_time = atime - ATIME_OFFSET | |
| else: | |
| # nothing more to anticipate | |
| anticipated_time = math.inf | |
| new_token = add_token(model, z, tokens, top_p, max(start_time,current_time)) | |
| new_time = new_token[0] - TIME_OFFSET | |
| if new_time >= end_time: | |
| break | |
| if debug: | |
| new_note = new_token[2] - NOTE_OFFSET | |
| new_instr = new_note//2**7 | |
| new_pitch = new_note - (2**7)*new_instr | |
| print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch) | |
| tokens.extend(new_token) | |
| dt = new_time - current_time | |
| assert dt >= 0 | |
| current_time = new_time | |
| progress.update(dt) | |
| events, _ = ops.split(tokens) | |
| return ops.sort(ops.unpad(events) + future) | |
| def generate_ar(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION): | |
| if inputs is None: | |
| inputs = [] | |
| if controls is None: | |
| controls = [] | |
| else: | |
| # treat controls as ordinary tokens | |
| controls = [token-CONTROL_OFFSET for token in controls] | |
| start_time = int(TIME_RESOLUTION*start_time) | |
| end_time = int(TIME_RESOLUTION*end_time) | |
| inputs = ops.sort(inputs + controls) | |
| # prompt is events up to start_time | |
| prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time) | |
| if debug: | |
| print('Prompt') | |
| ops.print_tokens(prompt) | |
| # treat events beyond start_time as controls | |
| controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False) | |
| if debug: | |
| print('Future') | |
| ops.print_tokens(controls) | |
| z = [AUTOREGRESS] | |
| if debug: | |
| print('AR Mode') | |
| current_time = ops.max_time(prompt, seconds=False) | |
| if debug: | |
| print('Current time:', current_time) | |
| tokens = prompt | |
| with tqdm(range(end_time-start_time)) as progress: | |
| if controls: | |
| atime, adur, anote = controls[0:3] | |
| anticipated_tokens = controls[3:] | |
| anticipated_time = atime - TIME_OFFSET | |
| else: | |
| # nothing to anticipate | |
| anticipated_time = math.inf | |
| while True: | |
| new_token = add_token(model, z, tokens, top_p, max(start_time,current_time)) | |
| new_time = new_token[0] - TIME_OFFSET | |
| if new_time >= end_time: | |
| break | |
| dt = new_time - current_time | |
| assert dt >= 0 | |
| current_time = new_time | |
| # backfill anything that should have come before the new token | |
| while current_time >= anticipated_time: | |
| tokens.extend([atime, adur, anote]) | |
| if debug: | |
| note = anote - NOTE_OFFSET | |
| instr = note//2**7 | |
| print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr) | |
| if len(anticipated_tokens) > 0: | |
| atime, adur, anote = anticipated_tokens[0:3] | |
| anticipated_tokens = anticipated_tokens[3:] | |
| anticipated_time = atime - TIME_OFFSET | |
| else: | |
| # nothing more to anticipate | |
| anticipated_time = math.inf | |
| if debug: | |
| new_note = new_token[2] - NOTE_OFFSET | |
| new_instr = new_note//2**7 | |
| new_pitch = new_note - (2**7)*new_instr | |
| print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch) | |
| tokens.extend(new_token) | |
| progress.update(dt) | |
| if anticipated_time != math.inf: | |
| tokens.extend([atime, adur, anote]) | |
| return ops.sort(ops.unpad(tokens) + controls) | |