|
1 |
| -import itertools |
2 |
| -from midi_handler import upperBound, lowerBound |
3 | 1 | import numpy as np
|
4 |
| - |
5 |
| - |
6 |
| -def startSentinel(): |
7 |
| - def noteSentinel(note): |
8 |
| - position = note |
9 |
| - part_position = [position] |
10 |
| - |
11 |
| - pitchclass = (note + lowerBound) % 12 |
12 |
| - part_pitchclass = [int(i == pitchclass) for i in range(12)] |
13 |
| - |
14 |
| - return part_position + part_pitchclass + [0]*66 + [1] |
15 |
| - return [noteSentinel(note) for note in range(upperBound-lowerBound)] |
16 |
| - |
17 |
| - |
18 |
| -def getOrDefault(l, i, d): |
19 |
| - try: |
20 |
| - return l[i] |
21 |
| - except IndexError: |
22 |
| - return d |
23 |
| - |
24 |
| - |
25 |
| -def buildContext(state): |
26 |
| - context = [0]*12 |
27 |
| - for note, notestate in enumerate(state): |
28 |
| - if notestate[0] == 1: |
29 |
| - pitchclass = (note + lowerBound) % 12 |
30 |
| - context[pitchclass] += 1 |
31 |
| - return context |
32 |
| - |
33 |
| - |
34 |
| -def buildBeat(time): |
35 |
| - return [2*x-1 for x in [time % 2, |
36 |
| - (time//2) % 2, |
37 |
| - (time//4) % 2, |
38 |
| - (time//8) % 2]] |
39 |
| - |
40 |
| - |
41 |
| -def noteInputForm(note, state, context, beat): |
42 |
| - position = note |
43 |
| - part_position = [position] |
44 |
| - |
45 |
| - pitchclass = (note + lowerBound) % 12 |
46 |
| - part_pitchclass = [int(i == pitchclass) for i in range(12)] |
47 |
| - # Concatenate the note states for the previous vicinity |
48 |
| - part_prev_vicinity = list(itertools.chain.from_iterable( |
49 |
| - (getOrDefault(state, note+i, [0, 0]) for i in range(-12, 13)))) |
50 |
| - |
51 |
| - part_context = context[pitchclass:] + context[:pitchclass] |
52 |
| - |
53 |
| - return part_position \ |
54 |
| - + part_pitchclass \ |
55 |
| - + part_prev_vicinity \ |
56 |
| - + part_context \ |
57 |
| - + beat \ |
58 |
| - + [0] |
59 |
| - |
60 |
| - |
61 |
| -def noteStateSingleToInputForm(state, time): |
62 |
| - beat = buildBeat(time) |
63 |
| - context = buildContext(state) |
64 |
| - return [noteInputForm(note, |
65 |
| - state, |
66 |
| - context, |
67 |
| - beat) for note in range(len(state))] |
68 |
| - |
69 |
| - |
70 |
| -def noteStateMatrixToInputForm(statematrix): |
71 |
| - # NOTE: May have to transpose this or transform it |
72 |
| - # in some way to make Theano like it |
73 |
| - # [startSentinel()] |
74 |
| - |
75 |
| - inputform = [noteStateSingleToInputForm(state, time) |
76 |
| - for time, state in enumerate(statematrix)] |
77 |
| - # print(np.array(inputform).shape) |
78 |
| - return inputform |
| 2 | +import pickle |
| 3 | +import hyper_params as hp |
| 4 | + |
| 5 | + |
| 6 | +def load_pieces(dirpath): |
| 7 | + # loads piano roll |
| 8 | + file = open(dirpath, "rb") |
| 9 | + pieces = pickle.load(file, encoding="latin1") |
| 10 | + pieces = clean_pieces(pieces) |
| 11 | + pieces, seqlens = pad_pieces_to_max(pieces) |
| 12 | + return pieces, seqlens |
| 13 | + |
| 14 | + |
| 15 | +def clean_pieces(pieces): |
| 16 | + def pad(chord): |
| 17 | + # pad to 4 voices |
| 18 | + padded = np.array(list(chord)) |
| 19 | + while len(padded) < 4: |
| 20 | + padded = np.append([hp.REST], padded) |
| 21 | + return padded |
| 22 | + |
| 23 | + def clean_piece(piece): |
| 24 | + # pad and serialize |
| 25 | + return np.array([pad(chord) for chord in piece]).flatten() |
| 26 | + pieces["train"] = np.array([clean_piece(p) for p in pieces["train"]]) |
| 27 | + pieces["test"] = np.array([clean_piece(p) for p in pieces["test"]]) |
| 28 | + pieces["valid"] = np.array([clean_piece(p) for p in pieces["valid"]]) |
| 29 | + return pieces |
| 30 | + |
| 31 | + |
| 32 | +def pad_pieces_to_max(pieces): |
| 33 | + def pad_piece_to_max(piece): |
| 34 | + while len(piece) < hp.MAX_LEN + 1: |
| 35 | + piece = np.append([hp.PAD], piece) |
| 36 | + return piece |
| 37 | + |
| 38 | + def seperate_long_piece(pieces_list): |
| 39 | + new_pieces = [] |
| 40 | + for i in range(len(pieces_list)): |
| 41 | + piece = np.append(pieces_list[i], [hp.STOP]) |
| 42 | + if len(pieces_list[i]) > hp.MAX_LEN + 1: |
| 43 | + new_pieces = new_pieces \ |
| 44 | + + [piece[j:j+hp.MAX_LEN+1] for j in |
| 45 | + range(0, len(piece), hp.SEPERATION)] |
| 46 | + else: |
| 47 | + new_pieces.append(pad_piece_to_max(piece)) |
| 48 | + new_pieces = list( |
| 49 | + filter(lambda x: len(x) == hp.MAX_LEN + 1, new_pieces) |
| 50 | + ) |
| 51 | + return np.array(new_pieces) |
| 52 | + |
| 53 | + pieces["train"] = seperate_long_piece(pieces["train"]) |
| 54 | + pieces["test"] = seperate_long_piece(pieces["test"]) |
| 55 | + pieces["valid"] = seperate_long_piece(pieces["valid"]) |
| 56 | + |
| 57 | + seqlens = { |
| 58 | + "train": np.zeros(len(pieces["train"])), |
| 59 | + "test": np.zeros(len(pieces["test"])), |
| 60 | + "valid": np.zeros(len(pieces["valid"])) |
| 61 | + } |
| 62 | + |
| 63 | + for i in range(len(pieces["train"])): |
| 64 | + seqlens["train"][i] = len(pieces["train"][i]) |
| 65 | + for i in range(len(pieces["test"])): |
| 66 | + seqlens["test"][i] = len(pieces["test"][i]) |
| 67 | + for i in range(len(pieces["valid"])): |
| 68 | + seqlens["valid"][i] = len(pieces["valid"][i]) |
| 69 | + return pieces, seqlens |
| 70 | + |
| 71 | + |
| 72 | +def build_vocab(pieces): |
| 73 | + total_notes = np.hstack((pieces["train"].astype(int).flatten(), |
| 74 | + pieces["test"].astype(int).flatten(), |
| 75 | + pieces["valid"].astype(int).flatten())) |
| 76 | + vocabs = set(total_notes) |
| 77 | + vocabs.remove(hp.PAD) |
| 78 | + idx2token = {i+1: w for i, w in enumerate(vocabs)} |
| 79 | + token2idx = {w: i+1 for i, w in enumerate(vocabs)} |
| 80 | + idx2token[0] = hp.PAD |
| 81 | + token2idx[hp.PAD] = 0 |
| 82 | + |
| 83 | + # print(idx2token[0]) |
| 84 | + # print(token2idx[hp.PAD]) |
| 85 | + |
| 86 | + return token2idx, idx2token |
| 87 | + |
| 88 | + |
| 89 | +def tokenize(pieces, token2idx, idx2token): |
| 90 | + pieces["train"] = np.array( |
| 91 | + [[token2idx[w] for w in p] for p in pieces["train"]] |
| 92 | + ) |
| 93 | + pieces["test"] = np.array( |
| 94 | + [[token2idx[w] for w in p] for p in pieces["test"]] |
| 95 | + ) |
| 96 | + pieces["valid"] = np.array( |
| 97 | + [[token2idx[w] for w in p] for p in pieces["valid"]] |
| 98 | + ) |
| 99 | + return pieces |
| 100 | + |
| 101 | + |
| 102 | +def get_batch(pieces): |
| 103 | + batch_indices = np.random.choice(len(pieces["train"]), |
| 104 | + size=hp.BATCH_SIZE, |
| 105 | + replace=True) |
| 106 | + x = pieces["train"][batch_indices][:, :-1] |
| 107 | + y = pieces["train"][batch_indices][:, -1] |
| 108 | + # seqlens = seqlens["train"][batch_indices] |
| 109 | + # print(x[:, -5:]) |
| 110 | + # print(y) |
| 111 | + return x.astype(int), y.astype(int) |
| 112 | + |
| 113 | + |
| 114 | +# pieces, seqlens = load_pieces("data/roll/jsb16.pkl") |
| 115 | +# get_batch(pieces, seqlens) |
| 116 | +# token2idx, idx2token = build_vocab(load_pieces("data/roll/jsb16.pkl")[0]) |
| 117 | +# print(tokenize(load_pieces("data/roll/jsb16.pkl")[0], token2idx, idx2token)["train"]) |
0 commit comments