Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MIDI-160: refactoring configuration #6

Open
wants to merge 47 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
6da26c9
remove model_args, add block_size to hydra configs, make vocab_size a…
WojciechMat Jan 28, 2025
f039d8b
rename sequence_length and block_size to , do not concatenate dicts i…
WojciechMat Jan 29, 2025
f713204
Minor cleanup
roszcz Jan 30, 2025
690502f
Move LR scheduler definition out of the training flow
roszcz Jan 30, 2025
a33c9bb
rename tokenizer in checkpoint to tokenizer_desc, remove tokenizer fi…
WojciechMat Jan 30, 2025
eb590f3
make pretraining and finetuning debugging configs, do not concatenate…
WojciechMat Jan 31, 2025
1545cd8
eval should work for pretraining too
WojciechMat Jan 31, 2025
6bc7b39
eval should work for pretraining too
WojciechMat Jan 31, 2025
7b4b980
bump requirements versions and make them explicit, remove .pt suffix …
WojciechMat Feb 1, 2025
b55fb80
make less validation dataloader processes - the same number as traini…
WojciechMat Feb 1, 2025
ec87e5a
use run stage for evaluation, add tokenized_every dataset config for …
WojciechMat Feb 4, 2025
fe5283f
Change `ParametricTaskManager` to `PianoTaskManager`
SamuelJanas Feb 5, 2025
df3bf32
change roszcz/ -> epr-labs/ for all dataset occurences
SamuelJanas Feb 5, 2025
67ac227
change naming: sustain-v2 -> augmented
SamuelJanas Feb 5, 2025
b6839f7
Decrease onboarding friction
roszcz Feb 5, 2025
67ae0b6
Training loop and configuration docs, TODO comment regarding gradient…
WojciechMat Feb 6, 2025
86ebb63
external evaluation documentation
WojciechMat Feb 6, 2025
2017129
fix eval samplers, control splits used for evaluation
roszcz Feb 7, 2025
969a163
fix eval script samplers container
roszcz Feb 7, 2025
8861e2a
fix wandb for piano metrics
roszcz Feb 7, 2025
d6c7ae6
cleanup eval config
roszcz Feb 7, 2025
ff5e733
start organizing bigger datasets
roszcz Feb 8, 2025
8a94b24
Switch to private hf repos
roszcz Feb 8, 2025
517c0f1
Finetuning logging config is named subsequence - not finetuning
WojciechMat Feb 8, 2025
ec68c94
Simplify the pretraining config
roszcz Feb 8, 2025
5d3518b
unlock chorales dataset
roszcz Feb 8, 2025
bb5859e
Simplify the record lenghts management for next token dataset
roszcz Feb 8, 2025
53d2b5b
add well named configs
roszcz Feb 8, 2025
7931f9d
Add PIAST dataset
roszcz Feb 8, 2025
220c52e
improve logs
roszcz Feb 9, 2025
a6d5247
Optimize PIANO dataset record lengths management
roszcz Feb 9, 2025
240bcfc
clean run name, clean lr config
roszcz Feb 9, 2025
c78fd99
Update the augmentation script
roszcz Feb 9, 2025
a8a9709
Augmentation helpers
roszcz Feb 9, 2025
9150583
Increase number of shards to speedup dataset build
roszcz Feb 9, 2025
276cfe7
load dataset speedup :rocket:
roszcz Feb 9, 2025
a26d95e
Use all augmented datasets
roszcz Feb 9, 2025
4e733d9
It was in fact not that ugly
WojciechMat Feb 10, 2025
d28780d
load piano tasks config with hydra
WojciechMat Feb 10, 2025
d3a8e00
load piano tasks in eval script, from training
WojciechMat Feb 10, 2025
613287b
Update upload script
roszcz Feb 10, 2025
2769ff6
make piano tasks config be a list
WojciechMat Feb 11, 2025
3d4d5bb
Remove indirectness from confgis
roszcz Feb 11, 2025
988057b
Support multiple LR scheduler types
roszcz Feb 12, 2025
c7bf95d
refactor the device setup phase
roszcz Feb 14, 2025
c9e9e4b
Refactor management of tokens related to music knowledge
roszcz Feb 14, 2025
195ab5e
Artifacts cleanup in PianoDataset
roszcz Feb 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
rev: 23.7.0
hooks:
- id: black
args: [--line-length=130]
args: [--line-length=120]
additional_dependencies: ['click==8.0.4']
- repo: https://github.com/pycqa/isort
rev: 5.12.0
Expand Down
33 changes: 25 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# Piano-GPT: MIDI Piano Music Generation

## Quickstart

Train a 10M model:

```sh
# This will create checkpoints in ./tmp/checkpoints and logs in wandb
python -m gpt2.train dataset=small model=gpt2_10M

# No wandb, small memory footprint
python -m gpt2.train dataset=small model=gpt2_10M data.batch_size=2 logging.wandb_log=false
```

Calculate PIANO metrics:

```sh
python -m gpt2.high_level_piano_eval init_from=<checkpoint path>
```

## Overview

Piano-GPT is a project leveraging the GPT-2 architecture for generating and processing MIDI piano music. It introduces the PIANO (Performance Inference And Note Orchestration) dataset, a multi-task benchmark for voice and dynamic reconstruction in MIDI piano rolls.
Expand Down Expand Up @@ -28,8 +46,7 @@ The PIANO dataset is designed to standardize approaches and provide a benchmark

## Project Structure

- `artifacts.py`: Utility functions and constants
- `checkpoints/`: Saved model checkpoints
- `tmp/checkpoints/`: Saved model checkpoints
- `dashboards/`: Streamlit dashboards for data visualization
- `data/`: Dataset handling and preprocessing modules
- `database/`: Database connection and management utilities
Expand Down Expand Up @@ -66,8 +83,8 @@ gpt2/train.py --config-name=gpt2_pretraining \
data.batch_size=32 \
optimizer.gradient_accumulation_steps=8 \
optimizer.max_iters=30000 \
data.sequence_length=4096 \
dataset.extra_datasets="['roszcz/maestro-sustain-v2', 'roszcz/giant-midi-sustain-v2', 'roszcz/pianofor-ai-sustain-v2']" \
data.context_size=4096 \
dataset.extra_datasets="['epr-labs/maestro-sustain-v2', 'epr-labs/giant-midi-sustain-v2', 'epr-labs/pianofor-ai-sustain-v2']" \
dataset.augmentation.max_pitch_shift=5 \
"dataset.augmentation.speed_change_factors=[0.975, 0.95, 1.025, 1.05]" \
lr.warmup_iters=1000 \
Expand All @@ -88,9 +105,9 @@ tasks = subsequence \
data.batch_size=64 \
optimizer.gradient_accumulation_steps=4 \
optimizer.max_iters=30000 \
data.sequence_length=1024 \
data.context_size=1024 \
data.notes_per_record=128 \
dataset.extra_datasets="['roszcz/maestro-sustain-v2', 'roszcz/giant-midi-sustain-v2', 'roszcz/pianofor-ai-sustain-v2']" \
dataset.extra_datasets="['epr-labs/maestro-sustain-v2', 'epr-labs/giant-midi-sustain-v2', 'epr-labs/pianofor-ai-sustain-v2']" \
dataset.augmentation.max_pitch_shift=5 \
dataset.augmentation.speed_change_factors="[0.95, 1.05]" \
lr.learning_rate=8e-5 \
Expand All @@ -115,9 +132,9 @@ system.data_workers=124 \
optimizer.gradient_accumulation_steps=4 \
task=next_token_prediction_with_composer \
eval_iters=200 eval_interval=1000 \
"dataset.extra_datasets=['roszcz/maestro-sustain-v2', 'roszcz/giant-midi-sustain-v2', 'roszcz/pianofor-ai-sustain-v2']" \
"dataset.extra_datasets=['epr-labs/maestro-sustain-v2', 'epr-labs/giant-midi-sustain-v2', 'epr-labs/pianofor-ai-sustain-v2']" \
data.batch_size=20 \
data.sequence_length=4096 \
data.context_size=4096 \
logging.wandb_run_name_suffix=huge-pretraining-4096-ctx \
tokenizer=awesome \
logging.wandb_project=piano-awesome-gpt
Expand Down
14 changes: 7 additions & 7 deletions dashboards/piano_dataset_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
from datasets import Dataset, load_dataset
from midi_tokenizers import ExponentialTimeTokenizer
from piano_dataset.piano_tasks import ParametricTaskManager
from piano_dataset.piano_tasks import PianoTaskManager

from data.piano_dataset import PianoDataset
from artifacts import dataset_tokens, composer_tokens
Expand Down Expand Up @@ -58,7 +58,7 @@ def load_piano_dataset(
config: dict,
dataset_name: str,
dataset_split: str,
sequence_length: int,
context_size: int,
notes_per_record: int,
loss_masking: str,
selected_composers: list[str],
Expand All @@ -78,13 +78,13 @@ def filter_dataset(record):
return composer_match and title_match

filtered_dataset = dataset.filter(filter_dataset)
parametric_task_manager = ParametricTaskManager.load_default()
parametric_task_manager = PianoTaskManager.load_default()

tokenizer = ExponentialTimeTokenizer(**tokenizer_parameters)
piano_dataset = PianoDataset(
dataset=filtered_dataset,
tokenizer=tokenizer,
sequence_length=sequence_length,
context_size=context_size,
notes_per_record=notes_per_record,
piano_task_manager=parametric_task_manager,
loss_masking=loss_masking,
Expand All @@ -110,7 +110,7 @@ def main():
value=256,
)
with col2:
sequence_length = st.number_input(
context_size = st.number_input(
label="Sequence Length",
min_value=1,
value=2048,
Expand Down Expand Up @@ -148,7 +148,7 @@ def main():

st.form_submit_button(label="Update Tokenizer")

parametric_task_manager = ParametricTaskManager.load_default()
parametric_task_manager = PianoTaskManager.load_default()

config = {
"base_dataset_name": base_dataset_name,
Expand Down Expand Up @@ -186,7 +186,7 @@ def main():
config=config,
dataset_name=dataset_name,
dataset_split=dataset_split,
sequence_length=sequence_length,
context_size=context_size,
notes_per_record=notes_per_record,
loss_masking=loss_masking,
selected_composers=selected_composers,
Expand Down
3 changes: 0 additions & 3 deletions data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Literal
from abc import abstractmethod

from datasets import Dataset as HuggingFaceDataset
Expand All @@ -19,7 +18,6 @@ def __init__(
self,
dataset: HuggingFaceDataset,
tokenizer: ExponentialTimeTokenizer | AwesomeMidiTokenizer,
loss_masking: Literal["finetuning", "pretraining"] = "pretraining",
):
"""
Initialize the MidiDataset.
Expand All @@ -32,7 +30,6 @@ def __init__(

# MidiTokenizer which was used during creation of the dataset
self.tokenizer = tokenizer
self.loss_masking = loss_masking

# Dataset with tokenized MIDI data
self.dataset = dataset
Expand Down
93 changes: 93 additions & 0 deletions data/musicality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re


class MusicManager:
dataset_tokens = [
"<MAESTRO>",
"<PIJAMA>",
"<VGMIDI>",
"<MUSIC-NET>",
"<PIANO-MIDI-DE>",
"<LAKH-LMD-FULL>",
"<GIANT-MIDI>",
"<IMSLP>",
"<ATEPP-1.1>",
"<PIANO_FOR_AI>",
]
composer_tokens = [
"<SCRIABIN>",
"<FRANCK>",
"<MOZART>",
"<CHOPIN>",
"<MENDELSSON>",
"<LISZT>",
"<SCHUBERT>",
"<BRAHMS>",
"<HAYDN>",
"<BEETHOVEN>",
"<BALAKIREV>",
"<SCHUMANN>",
"<RACHMANIOFF>",
"<UNKNOWN_COMPOSER>",
"<BACH>",
]

composer_token_map: dict[str, str] = {
"Alexander Scriabin": "<SCRIABIN>",
"César Franck": "<FRANCK>",
"Wolfgang Amadeus Mozart": "<MOZART>",
"Frédéric Chopin": "<CHOPIN>",
"Felix Mendelssohn": "<MENDELSSON>",
"Franz Liszt": "<LISZT>",
"Franz Schubert": "<SCHUBERT>",
"Johannes Brahms": "<BRAHMS>",
"Joseph Haydn": "<HAYDN>",
"Ludwig van Beethoven": "<BEETHOVEN>",
"Mily Balakirev": "<BALAKIREV>",
"Robert Schumann": "<SCHUMANN>",
"Sergei Rachmaninoff": "<RACHMANIOFF>",
"Johann Sebastian Bach": "<BACH>",
}

def __init__(self):
self.composer_regex_map = self.create_composer_regex_map()

@property
def tokens(self) -> list[str]:
return self.dataset_tokens + self.composer_tokens

def create_composer_regex_map(self) -> dict[re.Pattern, str]:
regex_map: dict[re.Pattern, str] = {}
for full_name, token in self.composer_token_map.items():
names = full_name.split()
surname = names[-1]
pattern = re.compile(rf"\b{re.escape(surname)}\b", re.IGNORECASE)
regex_map[pattern] = token
return regex_map

def get_dataset_token(self, piece_source: dict) -> str:
dataset_name = piece_source.get("dataset")

for dataset_token in self.dataset_tokens:
dataset_token_name = dataset_token[1:-1]
if dataset_token_name.lower() == dataset_name.lower():
return dataset_token

# FIXME Our internal dataset is the only one without the name
# stored as part of the source. This should change with the next
# dataset version, then we can add <UNKNOWN_DATASET> here
return "<PIANO_FOR_AI>"

def get_composer_token(self, composer: str) -> str:
# TODO This should be more refined - we know that composer
# informaion is stored in many ways across different datasets
# and we should use that knowledge:
# def get_composer_token(dataset_name: str, piece_source: dict): ...
matches: list[tuple[re.Match, str]] = [
(match, token) for pattern, token in self.composer_regex_map.items() if (match := pattern.search(composer))
]

if len(matches) == 1:
return matches[0][1]

return "<UNKNOWN_COMPOSER>"
Loading