Spaces:
Sleeping
Sleeping
| """ | |
| Top-level functions for preprocessing data to be used for training. | |
| """ | |
| from tqdm import tqdm | |
| import numpy as np | |
| from anticipation import ops | |
| from anticipation.config import * | |
| from anticipation.vocab import * | |
| from anticipation.convert import compound_to_events, midi_to_interarrival | |
| def extract_spans(all_events, rate): | |
| events = [] | |
| controls = [] | |
| span = True | |
| next_span = end_span = TIME_OFFSET+0 | |
| for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]): | |
| assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet | |
| # end of an anticipated span; decide when to do it again (next_span) | |
| if span and time >= end_span: | |
| span = False | |
| next_span = time+int(TIME_RESOLUTION*np.random.exponential(1./rate)) | |
| # anticipate a 3-second span | |
| if (not span) and time >= next_span: | |
| span = True | |
| end_span = time + DELTA*TIME_RESOLUTION | |
| if span: | |
| # mark this event as a control | |
| controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note]) | |
| else: | |
| events.extend([time, dur, note]) | |
| return events, controls | |
| ANTICIPATION_RATES = 10 | |
| def extract_random(all_events, rate): | |
| events = [] | |
| controls = [] | |
| for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]): | |
| assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet | |
| if np.random.random() < rate/float(ANTICIPATION_RATES): | |
| # mark this event as a control | |
| controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note]) | |
| else: | |
| events.extend([time, dur, note]) | |
| return events, controls | |
| def extract_instruments(all_events, instruments): | |
| events = [] | |
| controls = [] | |
| for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]): | |
| assert note < CONTROL_OFFSET # shouldn't be in the sequence yet | |
| assert note not in [SEPARATOR, REST] # these shouldn't either | |
| instr = (note-NOTE_OFFSET)//2**7 | |
| if instr in instruments: | |
| # mark this event as a control | |
| controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note]) | |
| else: | |
| events.extend([time, dur, note]) | |
| return events, controls | |
| def maybe_tokenize(compound_tokens): | |
| # skip sequences with very few events | |
| if len(compound_tokens) < COMPOUND_SIZE*MIN_TRACK_EVENTS: | |
| return None, None, 1 # short track | |
| events, truncations = compound_to_events(compound_tokens, stats=True) | |
| end_time = ops.max_time(events, seconds=False) | |
| # don't want to deal with extremely short tracks | |
| if end_time < TIME_RESOLUTION*MIN_TRACK_TIME_IN_SECONDS: | |
| return None, None, 1 # short track | |
| # don't want to deal with extremely long tracks | |
| if end_time > TIME_RESOLUTION*MAX_TRACK_TIME_IN_SECONDS: | |
| return None, None, 2 # long track | |
| # skip sequences more instruments than MIDI channels (16) | |
| if len(ops.get_instruments(events)) > MAX_TRACK_INSTR: | |
| return None, None, 3 # too many instruments | |
| return events, truncations, 0 | |
| def tokenize_ia(datafiles, output, augment_factor, idx=0, debug=False): | |
| assert augment_factor == 1 # can't augment interarrival-tokenized data | |
| all_truncations = 0 | |
| seqcount = rest_count = 0 | |
| stats = 4*[0] # (short, long, too many instruments, inexpressible) | |
| np.random.seed(0) | |
| with open(output, 'w') as outfile: | |
| concatenated_tokens = [] | |
| for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True): | |
| with open(filename, 'r') as f: | |
| _, _, status = maybe_tokenize([int(token) for token in f.read().split()]) | |
| if status > 0: | |
| stats[status-1] += 1 | |
| continue | |
| filename = filename[:-len('.compound.txt')] # get the original MIDI | |
| # already parsed; shouldn't raise an exception | |
| tokens, truncations = midi_to_interarrival(filename, stats=True) | |
| tokens[0:0] = [MIDI_SEPARATOR] | |
| concatenated_tokens.extend(tokens) | |
| all_truncations += truncations | |
| # write out full sequences to file | |
| while len(concatenated_tokens) >= CONTEXT_SIZE: | |
| seq = concatenated_tokens[0:CONTEXT_SIZE] | |
| concatenated_tokens = concatenated_tokens[CONTEXT_SIZE:] | |
| outfile.write(' '.join([str(tok) for tok in seq]) + '\n') | |
| seqcount += 1 | |
| if debug: | |
| fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)' | |
| print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count)) | |
| return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations) | |
| def tokenize(datafiles, output, augment_factor, idx=0, debug=False): | |
| tokens = [] | |
| all_truncations = 0 | |
| seqcount = rest_count = 0 | |
| stats = 4*[0] # (short, long, too many instruments, inexpressible) | |
| np.random.seed(0) | |
| with open(output, 'w') as outfile: | |
| concatenated_tokens = [] | |
| for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True): | |
| with open(filename, 'r') as f: | |
| all_events, truncations, status = maybe_tokenize([int(token) for token in f.read().split()]) | |
| if status > 0: | |
| stats[status-1] += 1 | |
| continue | |
| instruments = list(ops.get_instruments(all_events).keys()) | |
| end_time = ops.max_time(all_events, seconds=False) | |
| # different random augmentations | |
| for k in range(augment_factor): | |
| if k % 10 == 0: | |
| # no augmentation | |
| events = all_events.copy() | |
| controls = [] | |
| elif k % 10 == 1: | |
| # span augmentation | |
| lmbda = .05 | |
| events, controls = extract_spans(all_events, lmbda) | |
| elif k % 10 < 6: | |
| # random augmentation | |
| r = np.random.randint(1,ANTICIPATION_RATES) | |
| events, controls = extract_random(all_events, r) | |
| else: | |
| if len(instruments) > 1: | |
| # instrument augmentation: at least one, but not all instruments | |
| u = 1+np.random.randint(len(instruments)-1) | |
| subset = np.random.choice(instruments, u, replace=False) | |
| events, controls = extract_instruments(all_events, subset) | |
| else: | |
| # no augmentation | |
| events = all_events.copy() | |
| controls = [] | |
| if len(concatenated_tokens) == 0: | |
| z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS | |
| all_truncations += truncations | |
| events = ops.pad(events, end_time) | |
| rest_count += sum(1 if tok == REST else 0 for tok in events[2::3]) | |
| tokens, controls = ops.anticipate(events, controls) | |
| assert len(controls) == 0 # should have consumed all controls (because of padding) | |
| tokens[0:0] = [SEPARATOR, SEPARATOR, SEPARATOR] | |
| concatenated_tokens.extend(tokens) | |
| # write out full sequences to file | |
| while len(concatenated_tokens) >= EVENT_SIZE*M: | |
| seq = concatenated_tokens[0:EVENT_SIZE*M] | |
| concatenated_tokens = concatenated_tokens[EVENT_SIZE*M:] | |
| # relativize time to the context | |
| seq = ops.translate(seq, -ops.min_time(seq, seconds=False), seconds=False) | |
| assert ops.min_time(seq, seconds=False) == 0 | |
| if ops.max_time(seq, seconds=False) >= MAX_TIME: | |
| stats[3] += 1 | |
| continue | |
| # if seq contains SEPARATOR, global controls describe the first sequence | |
| seq.insert(0, z) | |
| outfile.write(' '.join([str(tok) for tok in seq]) + '\n') | |
| seqcount += 1 | |
| # grab the current augmentation controls if we didn't already | |
| z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS | |
| if debug: | |
| fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)' | |
| print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count)) | |
| return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations) | |