-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #121 from NanFangDieDao/gs-midi-branch
Add Music Accompaniment Generator Project to ML-Nexus
- Loading branch information
Showing
18 changed files
with
751 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
2 changes: 2 additions & 0 deletions
2
Generative Models/Music_Accompaniment_Generator/.gitattributes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Auto detect text files and perform LF normalization | ||
* text=auto |
72 changes: 72 additions & 0 deletions
72
Generative Models/Music_Accompaniment_Generator/data_load.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import random | ||
|
||
import torch | ||
from miditok import REMI, TokenizerConfig | ||
from utils import midi_to_array | ||
from tqdm import tqdm | ||
from train_parameters import max_len | ||
|
||
# Our parameters | ||
TOKENIZER_PARAMS = { | ||
"pitch_range": (21, 109), | ||
"beat_res": {(0, 4): 8, (4, 12): 4}, | ||
"num_velocities": 32, | ||
"special_tokens": ["PAD", "BOS", "EOS"], | ||
"use_chords": True, | ||
"use_rests": False, | ||
"use_tempos": True, | ||
"use_programs": True, | ||
"num_tempos": 191, | ||
"tempo_range": (60, 250), | ||
"program_changes": True, | ||
"programs": [-1, 0, 24, 27, 30, 33, 36], | ||
} | ||
config = TokenizerConfig(**TOKENIZER_PARAMS) | ||
|
||
# Creates the tokenizer | ||
tokenizer = REMI(config) | ||
|
||
word2idx = tokenizer.vocab | ||
idx2word = {idx: word for idx, word in enumerate(word2idx)} | ||
vocab_len = len(word2idx) | ||
|
||
|
||
def data_load(data_type, split, data_len, x_folder, y_folder): | ||
print("---Data Load Start!---") | ||
x = [] | ||
y = [] | ||
data_range = (0, 1) | ||
if data_type == "train": | ||
data_range = range(0, split) | ||
if data_type == "test": | ||
data_range = range(split, data_len) | ||
for i in tqdm(data_range, desc="Data Loading...", unit="data"): | ||
x.append( | ||
midi_to_array( | ||
tokenizer=tokenizer, | ||
midifile=f"{x_folder}{i}.mid", | ||
max_len=max_len)) | ||
y.append( | ||
midi_to_array( | ||
tokenizer=tokenizer, | ||
midifile=f"{y_folder}{i}.mid", | ||
max_len=max_len)) | ||
x = torch.tensor(x) | ||
y = torch.tensor(y) | ||
print("---Data Load Completed!---") | ||
return x, y | ||
|
||
|
||
def get_batch_indices(total_length, batch_size): | ||
assert (batch_size <= | ||
total_length), ('Batch size is large than total data length.' | ||
'Check your data or change batch size.') | ||
current_index = 0 | ||
indexes = [i for i in range(total_length)] | ||
random.shuffle(indexes) | ||
while True: | ||
if current_index + batch_size >= total_length: | ||
yield indexes[current_index:total_length], current_index | ||
break | ||
yield indexes[current_index:current_index + batch_size], current_index | ||
current_index += batch_size |
29 changes: 29 additions & 0 deletions
29
Generative Models/Music_Accompaniment_Generator/midi_generate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from utils import merge_midi_tracks, generate_midi_v2 | ||
from data_load import tokenizer | ||
from transformer import Transformer | ||
from data_load import vocab_len | ||
import torch | ||
|
||
from train_parameters import (max_len, d_model, d_ff, n_layers, | ||
heads, dropout_rate, PAD_ID) | ||
|
||
if __name__ == '__main__': | ||
instruments = ['Drum', 'Bass', 'Guitar', 'Piano'] | ||
src_midi = "./HMuseData/Melody2Drum/Melody/69.mid" | ||
for instrument in instruments: | ||
print(f"-----Loading {instrument} model-----") | ||
model = Transformer(src_vocab_size=vocab_len, dst_vocab_size=vocab_len, pad_idx=PAD_ID, d_model=d_model, | ||
d_ff=d_ff, n_layers=n_layers, heads=heads, dropout=dropout_rate, max_seq_len=max_len) | ||
model_path = f"./models/model_{instrument}/model_{instrument}2.pth" | ||
model.load_state_dict( | ||
torch.load( | ||
model_path, | ||
map_location=torch.device('mps'))) | ||
print(f"-----{instrument} model loaded!-----") | ||
print(f"-----Generating {instrument} track-----") | ||
generate_midi_v2(model=model, tokenizer=tokenizer, src_midi=src_midi, max_len=max_len, PAD_ID=PAD_ID, | ||
tgt_midi=f"./MIDIs/output_MIDI/{instrument}_track.mid") | ||
print(f"-----{instrument} track generated!-----") | ||
merge_midi_tracks(src_midi, "./MIDIs/output_MIDI/Drum_track.mid", "./MIDIs/output_MIDI/Bass_track.mid", | ||
"./MIDIs/output_MIDI/Guitar_track.mid", "./MIDIs/output_MIDI/Piano_track.mid", | ||
tgt_dir="./MIDIs/output_MIDI/generated_midi.mid") |
79 changes: 79 additions & 0 deletions
79
...dels/Music_Accompaniment_Generator/readme/MIDI Music Accompaniment Generator.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# MIDI Music Accompaniment Generator | ||
|
||
In this project, I implemented an automatic music accompaniment model. This model can take a single-track melody MIDI file as input and output a corresponding accompaniment MIDI file featuring other instruments such as guitar, bass, and drums. To accomplish the above tasks, I also completed numerous additional efforts, which I will elaborate on below. | ||
|
||
## Rule-based music score alignment algorithm | ||
|
||
Among the numerous MIDI files collected in various ways for training purposes, a significant portion consists of "real" MIDI files generated from instrumental performances. The notes and events in these real performances contain many factors, such as the performer's emotions, and are often dynamic and uncertain. Therefore, there are timing errors in the performance beats, making it difficult to align precisely with the time axis corresponding to the beats. This results in the events recorded in these MIDI files being advanced or delayed relative to the standard musical score on the time axis, further generating inaccurate notes during the music conversion process. Specifically, this manifests in the score as the appearance of numerous sixteenth notes, thirty-second notes, dotted notes, and so forth. This will make it difficult for the model to learn the correct characteristics of the music and significantly reduce learning efficiency. To address this issue, I have developed a rule-based music score alignment algorithm, with the following effects: | ||
|
||
Before the algorithm: | ||
|
||
 | ||
|
||
After the algorithm: | ||
|
||
 | ||
|
||
## Tokenization | ||
|
||
Convert MIDI files into tensor arrays for input into the model.For an example: | ||
|
||
Given a MIDI music file: | ||
|
||
 | ||
|
||
Tokenize and turn it into tensor array: | ||
|
||
 | ||
|
||
## Build the model | ||
|
||
Build a model based on transformer.The structure of the model is shown in the diagram below. | ||
|
||
 | ||
|
||
## Train and Evaluate | ||
|
||
Train the model and then evaluate its performance. The specific results are shown in the table below.The values of the hyperparameters adjusted during the training process are also presented below. | ||
|
||
Hyperparameters: | ||
|
||
| Hyperparameters | Value | | ||
| :-------------: | :----: | | ||
| lr | 0.0001 | | ||
| d_model | 512 | | ||
| d_ff | 2048 | | ||
| n_layers | 6 | | ||
| heads | 8 | | ||
| dropout_rate | 0.2 | | ||
| n_epochs | 60 | | ||
|
||
Train accuracy: | ||
|
||
| Model | Accuracy | | ||
| :---------------: | :------: | | ||
| DrumTransformer | 88.7% | | ||
| PianoTransformer | 91.6% | | ||
| GuitarTransformer | 85.3% | | ||
| BassTransformer | 89.7% | | ||
|
||
Test accuracy: | ||
|
||
| Model | Accuracy | | ||
| :---------------: | :------: | | ||
| DrumTransformer | 75.3% | | ||
| PianoTransformer | 71.5% | | ||
| GuitarTransformer | 64.3% | | ||
| BassTransformer | 67.1% | | ||
|
||
## Result | ||
|
||
Finally, let's demonstrate the effectiveness of the model. | ||
|
||
For a given main melody MIDI music file: | ||
|
||
 | ||
|
||
The model can generate the auto_accompaniment MIDI file: | ||
|
||
 |
Binary file added
BIN
+263 KB
Generative Models/Music_Accompaniment_Generator/readme/accompaniment.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+87.7 KB
Generative Models/Music_Accompaniment_Generator/readme/main_melody.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+95.9 KB
Generative Models/Music_Accompaniment_Generator/readme/midi_music.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+87 KB
Generative Models/Music_Accompaniment_Generator/readme/tokenization.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions
5
Generative Models/Music_Accompaniment_Generator/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mido==1.2.10 | ||
torch==2.0.0 | ||
miditok==3.0.1 | ||
tqdm==4.62.3 | ||
pretty_midi==0.2.10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from transformer import Transformer | ||
from data_load import vocab_len | ||
from train_model import train_model | ||
from train_parameters import (max_len, batch_size, lr, d_model, d_ff, n_layers, | ||
heads, dropout_rate, n_epochs, PAD_ID, device, | ||
print_interval, data_split_rate, len_Dataset) | ||
|
||
|
||
if __name__ == '__main__': | ||
instruments = ['Drum', 'Bass', 'Guitar', 'Piano'] | ||
for instrument in instruments: | ||
model = Transformer(src_vocab_size=vocab_len, dst_vocab_size=vocab_len, pad_idx=PAD_ID, d_model=d_model, | ||
d_ff=d_ff, n_layers=n_layers, heads=heads, dropout=dropout_rate, max_seq_len=max_len) | ||
train_model( | ||
model=model, | ||
data_split_rate=data_split_rate, | ||
data_len=len_Dataset[instrument], | ||
batch_size=batch_size, | ||
lr=lr, | ||
n_epochs=n_epochs, | ||
PAD_ID=PAD_ID, | ||
device=device, | ||
print_interval=print_interval, | ||
instrument=instrument) |
62 changes: 62 additions & 0 deletions
62
Generative Models/Music_Accompaniment_Generator/train_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from data_load import get_batch_indices, data_load | ||
from train_parameters import max_len | ||
from tqdm import tqdm | ||
import torch | ||
from torch import nn | ||
import time | ||
import os | ||
|
||
|
||
def train_model(model, data_split_rate, data_len, batch_size, lr, | ||
n_epochs, PAD_ID, device, print_interval, instrument): | ||
print(f"--------Train Model For {instrument} Start!--------") | ||
x_folder = f"./HMuseData/Melody2{instrument}/Melody/" | ||
y_folder = f"./HMuseData/Melody2{instrument}/{instrument}/" | ||
split = round(data_len * data_split_rate) | ||
x, y = data_load(data_type="train", split=split, data_len=data_len, | ||
x_folder=x_folder, y_folder=y_folder) | ||
model.to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr) | ||
criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID) | ||
tic = time.time() | ||
counter = 0 | ||
for epoch in range(n_epochs): | ||
for index, _ in tqdm(get_batch_indices( | ||
len(x), batch_size), desc="Processing", unit="batches"): | ||
x_batch = torch.LongTensor(x[index]).to(device) | ||
y_batch = torch.LongTensor(y[index]).to(device) | ||
y_input = y_batch[:, :-1] | ||
y_label = y_batch[:, 1:] | ||
y_hat = model(x_batch, y_input) | ||
|
||
y_label_mask = y_label != PAD_ID | ||
preds = torch.argmax(y_hat, -1) | ||
correct = preds == y_label | ||
acc = torch.sum(y_label_mask * correct) / torch.sum(y_label_mask) | ||
|
||
n, seq_len = y_label.shape | ||
y_hat = torch.reshape(y_hat, (n * seq_len, -1)) | ||
y_label = torch.reshape(y_label, (n * seq_len, )) | ||
loss = criterion(y_hat, y_label) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1) | ||
optimizer.step() | ||
|
||
if counter % print_interval == 0: | ||
toc = time.time() | ||
interval = toc - tic | ||
minutes = int(interval // 60) | ||
seconds = int(interval % 60) | ||
print(f'{counter:08d} {minutes:02d}:{seconds:02d}' | ||
f' loss: {loss.item()} acc: {acc.item()}') | ||
counter += 1 | ||
|
||
model_path = f"models/model_{instrument}/" | ||
os.makedirs(model_path, exist_ok=True) | ||
model_name = f"{model_path}model_{instrument}_{max_len}.pth" | ||
torch.save(model.state_dict(), model_name) | ||
|
||
print(f'Model saved to {model_name}') | ||
print(f"--------Train Model For {instrument} Completed!--------") |
18 changes: 18 additions & 0 deletions
18
Generative Models/Music_Accompaniment_Generator/train_parameters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
batch_size = 16 | ||
lr = 0.0001 | ||
d_model = 512 | ||
d_ff = 2048 | ||
n_layers = 6 | ||
heads = 8 | ||
dropout_rate = 0.2 | ||
n_epochs = 60 | ||
PAD_ID = 0 | ||
device = "mps" | ||
# device = "cuda:0" | ||
print_interval = 100 | ||
max_len = 750 | ||
data_split_rate = 0.99 | ||
len_Dataset = {'Drum': 18621, | ||
'Bass': 14316, | ||
'Guitar': 20037, | ||
'Piano': 11684} |
Oops, something went wrong.