Skip to content

Commit 6fe5595

Browse files
committed
transformer implementation
1 parent 5ca7e3f commit 6fe5595

File tree

141 files changed

+1028
-332
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+1028
-332
lines changed

data.py

Lines changed: 116 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,117 @@
1-
import itertools
2-
from midi_handler import upperBound, lowerBound
31
import numpy as np
4-
5-
6-
def startSentinel():
7-
def noteSentinel(note):
8-
position = note
9-
part_position = [position]
10-
11-
pitchclass = (note + lowerBound) % 12
12-
part_pitchclass = [int(i == pitchclass) for i in range(12)]
13-
14-
return part_position + part_pitchclass + [0]*66 + [1]
15-
return [noteSentinel(note) for note in range(upperBound-lowerBound)]
16-
17-
18-
def getOrDefault(l, i, d):
19-
try:
20-
return l[i]
21-
except IndexError:
22-
return d
23-
24-
25-
def buildContext(state):
26-
context = [0]*12
27-
for note, notestate in enumerate(state):
28-
if notestate[0] == 1:
29-
pitchclass = (note + lowerBound) % 12
30-
context[pitchclass] += 1
31-
return context
32-
33-
34-
def buildBeat(time):
35-
return [2*x-1 for x in [time % 2,
36-
(time//2) % 2,
37-
(time//4) % 2,
38-
(time//8) % 2]]
39-
40-
41-
def noteInputForm(note, state, context, beat):
42-
position = note
43-
part_position = [position]
44-
45-
pitchclass = (note + lowerBound) % 12
46-
part_pitchclass = [int(i == pitchclass) for i in range(12)]
47-
# Concatenate the note states for the previous vicinity
48-
part_prev_vicinity = list(itertools.chain.from_iterable(
49-
(getOrDefault(state, note+i, [0, 0]) for i in range(-12, 13))))
50-
51-
part_context = context[pitchclass:] + context[:pitchclass]
52-
53-
return part_position \
54-
+ part_pitchclass \
55-
+ part_prev_vicinity \
56-
+ part_context \
57-
+ beat \
58-
+ [0]
59-
60-
61-
def noteStateSingleToInputForm(state, time):
62-
beat = buildBeat(time)
63-
context = buildContext(state)
64-
return [noteInputForm(note,
65-
state,
66-
context,
67-
beat) for note in range(len(state))]
68-
69-
70-
def noteStateMatrixToInputForm(statematrix):
71-
# NOTE: May have to transpose this or transform it
72-
# in some way to make Theano like it
73-
# [startSentinel()]
74-
75-
inputform = [noteStateSingleToInputForm(state, time)
76-
for time, state in enumerate(statematrix)]
77-
# print(np.array(inputform).shape)
78-
return inputform
2+
import pickle
3+
import hyper_params as hp
4+
5+
6+
def load_pieces(dirpath):
7+
# loads piano roll
8+
file = open(dirpath, "rb")
9+
pieces = pickle.load(file, encoding="latin1")
10+
pieces = clean_pieces(pieces)
11+
pieces, seqlens = pad_pieces_to_max(pieces)
12+
return pieces, seqlens
13+
14+
15+
def clean_pieces(pieces):
16+
def pad(chord):
17+
# pad to 4 voices
18+
padded = np.array(list(chord))
19+
while len(padded) < 4:
20+
padded = np.append([hp.REST], padded)
21+
return padded
22+
23+
def clean_piece(piece):
24+
# pad and serialize
25+
return np.array([pad(chord) for chord in piece]).flatten()
26+
pieces["train"] = np.array([clean_piece(p) for p in pieces["train"]])
27+
pieces["test"] = np.array([clean_piece(p) for p in pieces["test"]])
28+
pieces["valid"] = np.array([clean_piece(p) for p in pieces["valid"]])
29+
return pieces
30+
31+
32+
def pad_pieces_to_max(pieces):
33+
def pad_piece_to_max(piece):
34+
while len(piece) < hp.MAX_LEN + 1:
35+
piece = np.append([hp.PAD], piece)
36+
return piece
37+
38+
def seperate_long_piece(pieces_list):
39+
new_pieces = []
40+
for i in range(len(pieces_list)):
41+
piece = np.append(pieces_list[i], [hp.STOP])
42+
if len(pieces_list[i]) > hp.MAX_LEN + 1:
43+
new_pieces = new_pieces \
44+
+ [piece[j:j+hp.MAX_LEN+1] for j in
45+
range(0, len(piece), hp.SEPERATION)]
46+
else:
47+
new_pieces.append(pad_piece_to_max(piece))
48+
new_pieces = list(
49+
filter(lambda x: len(x) == hp.MAX_LEN + 1, new_pieces)
50+
)
51+
return np.array(new_pieces)
52+
53+
pieces["train"] = seperate_long_piece(pieces["train"])
54+
pieces["test"] = seperate_long_piece(pieces["test"])
55+
pieces["valid"] = seperate_long_piece(pieces["valid"])
56+
57+
seqlens = {
58+
"train": np.zeros(len(pieces["train"])),
59+
"test": np.zeros(len(pieces["test"])),
60+
"valid": np.zeros(len(pieces["valid"]))
61+
}
62+
63+
for i in range(len(pieces["train"])):
64+
seqlens["train"][i] = len(pieces["train"][i])
65+
for i in range(len(pieces["test"])):
66+
seqlens["test"][i] = len(pieces["test"][i])
67+
for i in range(len(pieces["valid"])):
68+
seqlens["valid"][i] = len(pieces["valid"][i])
69+
return pieces, seqlens
70+
71+
72+
def build_vocab(pieces):
73+
total_notes = np.hstack((pieces["train"].astype(int).flatten(),
74+
pieces["test"].astype(int).flatten(),
75+
pieces["valid"].astype(int).flatten()))
76+
vocabs = set(total_notes)
77+
vocabs.remove(hp.PAD)
78+
idx2token = {i+1: w for i, w in enumerate(vocabs)}
79+
token2idx = {w: i+1 for i, w in enumerate(vocabs)}
80+
idx2token[0] = hp.PAD
81+
token2idx[hp.PAD] = 0
82+
83+
# print(idx2token[0])
84+
# print(token2idx[hp.PAD])
85+
86+
return token2idx, idx2token
87+
88+
89+
def tokenize(pieces, token2idx, idx2token):
90+
pieces["train"] = np.array(
91+
[[token2idx[w] for w in p] for p in pieces["train"]]
92+
)
93+
pieces["test"] = np.array(
94+
[[token2idx[w] for w in p] for p in pieces["test"]]
95+
)
96+
pieces["valid"] = np.array(
97+
[[token2idx[w] for w in p] for p in pieces["valid"]]
98+
)
99+
return pieces
100+
101+
102+
def get_batch(pieces):
103+
batch_indices = np.random.choice(len(pieces["train"]),
104+
size=hp.BATCH_SIZE,
105+
replace=True)
106+
x = pieces["train"][batch_indices][:, :-1]
107+
y = pieces["train"][batch_indices][:, -1]
108+
# seqlens = seqlens["train"][batch_indices]
109+
# print(x[:, -5:])
110+
# print(y)
111+
return x.astype(int), y.astype(int)
112+
113+
114+
# pieces, seqlens = load_pieces("data/roll/jsb16.pkl")
115+
# get_batch(pieces, seqlens)
116+
# token2idx, idx2token = build_vocab(load_pieces("data/roll/jsb16.pkl")[0])
117+
# print(tokenize(load_pieces("data/roll/jsb16.pkl")[0], token2idx, idx2token)["train"])

data/midi/327/chpn_op66.mid

-36.8 KB
Binary file not shown.

data/roll/Piano-midi.de.pickle

7.07 MB
Binary file not shown.

data/roll/jsb16.pkl

8.12 MB
Binary file not shown.

data/roll/jsb4.pkl

1.96 MB
Binary file not shown.

data/roll/jsb8.pkl

4.06 MB
Binary file not shown.

hyper_params.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
BATCH_SIZE = 5 # size of batches
2+
NOTE_LEN = 78 # length of note sequence
3+
DIVISION_LEN = 16 # 1 bar of music
4+
5+
EPS = 1e-10 # episilon for log math
6+
DROPOUT = 0.3 # dropout rate
7+
8+
HIDDEN_SIZE = 512 # size of encoder decoder hidden dimension
9+
FF_SIZE = 2048 # size of feed forward dimension
10+
NUM_HEADS = 8 # number of attention head
11+
NUM_BLOCKS = 6 # number of encoder decoder blocks
12+
VOCAB_SIZE = 49
13+
14+
REST = 420
15+
STOP = 690
16+
PAD = 999
17+
MAX_LEN = 512
18+
SEPERATION = 16

0 commit comments

Comments
 (0)