Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MIDI-166: Make the number of notes to be generated explicit in the prompt #13

Merged
merged 2 commits into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions dashboards/prompt_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,24 @@ def main():
if not use_top_k:
top_k = None

dataset_tokens = MusicManager().dataset_tokens
music_manager = MusicManager(
max_n_notes=run_config.training.max_notes_per_record,
)
dataset_tokens = music_manager.dataset_tokens
dataset_token = st.selectbox(
label="Select a dataset token:",
options=dataset_tokens,
help="Choose from available special tokens to add to your prompt",
)

n_target_notes = st.number_input(
label="N target notes",
min_value=0,
max_value=music_manager.max_n_notes,
value=music_manager.max_n_notes // 3,
)
n_notes_token = music_manager.get_n_notes_token(n_target_notes)

piano_task_manager: PianoTaskManager = model_setup["piano_task_manager"]

# Using prompt path as a key to force st to restart this form whenever
Expand Down Expand Up @@ -227,7 +238,7 @@ def main():

composer_tokens = ["<BACH>", "<MOZART>", "<CHOPIN>", "<UNKNOWN_COMPOSER>"]
for composer_token in composer_tokens:
pre_input_tokens = [dataset_token, composer_token] + piano_task.prefix_tokens
pre_input_tokens = [dataset_token, composer_token, n_notes_token] + piano_task.prefix_tokens

st.write("Pre-input tokens:", pre_input_tokens)

Expand All @@ -239,8 +250,7 @@ def main():
generation_setup = {
"seed": local_seed,
"temperature": temperature,
"dataset_token": dataset_token,
"composer_token": composer_token,
"pre_input_tokens": pre_input_tokens,
"piano_task": piano_task.name,
"top_k": top_k,
"model_id": os.path.basename(checkpoint_path),
Expand Down Expand Up @@ -271,6 +281,7 @@ def main():
generated_piece = ff.MidiPiece(generated_notes_df)

streamlit_pianoroll.from_fortepyan(prompt_piece, generated_piece)
st.write("Generated notes:", generated_piece.size)

unique_id = secrets.token_hex(10)
if pianoroll_apikey:
Expand Down
14 changes: 12 additions & 2 deletions gpt2/data/musicality.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,19 @@ class MusicManager:
"Johann Sebastian Bach": "<BACH>",
}

def __init__(self):
def __init__(self, max_n_notes: int):
self.max_n_notes = max_n_notes
self.composer_regex_map = self.create_composer_regex_map()

@property
def n_note_tokens(self) -> list[str]:
# NOTE: 0 is a valid number of notes
tokens = [self.get_n_notes_token(n_notes) for n_notes in range(self.max_n_notes + 1)]
return tokens

@property
def tokens(self) -> list[str]:
return self.dataset_tokens + self.composer_tokens
return self.dataset_tokens + self.composer_tokens + self.n_note_tokens

def create_composer_regex_map(self) -> dict[re.Pattern, str]:
regex_map: dict[re.Pattern, str] = {}
Expand Down Expand Up @@ -91,3 +98,6 @@ def get_composer_token(self, composer: str) -> str:
return matches[0][1]

return "<UNKNOWN_COMPOSER>"

def get_n_notes_token(self, n_notes: int) -> str:
return f"<N_NOTES_{n_notes}>"
16 changes: 11 additions & 5 deletions gpt2/data/piano_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def _build_records(self):
record_lengths = np.array(self.dataset["n_notes"]) - self.max_notes_per_record + 1

# For every record we can have that many different subsequences with different lengths
self.n_duration_options = self.max_notes_per_record - self.min_notes_per_record
self.n_record_durations = self.max_notes_per_record - self.min_notes_per_record

# Records shorter than context are effectively discarded
self.record_lengths = record_lengths.clip(min=0)

# Calculate total dataset length
self.length = self.record_lengths.sum() * self.num_tasks * self.n_duration_options
self.length = self.record_lengths.sum() * self.num_tasks * self.n_record_durations

def __len__(self):
# Return the total length of the dataset
Expand All @@ -93,10 +93,10 @@ def _decode_piano_index(self, idx: int) -> PianoIndex:
idx_bis = idx // self.num_tasks

# ... and decode the number of notes for this record
n_notes = self.min_notes_per_record + (idx_bis % self.n_duration_options)
n_notes = self.min_notes_per_record + (idx_bis % self.n_record_durations)

# ... and then decode the starting note idx
start_point = idx_bis // self.n_duration_options
start_point = idx_bis // self.n_record_durations

for record_idx, record_length in enumerate(self.record_lengths):
if start_point < record_length:
Expand Down Expand Up @@ -144,7 +144,11 @@ def __getitem__(self, idx: int) -> dict:
dataset_token = self.music_manager.get_dataset_token(
piece_source=piece_source,
)
source_prefix_tokens = [dataset_token, composer_token] + piano_task.prefix_tokens
n_notes_token = self.music_manager.get_n_notes_token(
n_notes=piece_split.n_target_notes,
)
source_prefix_tokens = [dataset_token, composer_token, n_notes_token]
source_prefix_tokens += piano_task.prefix_tokens
prefix_token_ids = self.tokenizer.encode_tokens(source_prefix_tokens)

# ... and join into a single promp sequence of token ids
Expand Down Expand Up @@ -207,6 +211,8 @@ def __getitem__(self, idx: int) -> dict:
"task": piano_index.task_name,
"target_mask": target_mask,
"n_notes": piano_index.n_notes,
"n_source_notes": piece_split.n_source_notes,
"n_target_notes": piece_split.n_target_notes,
"start_point": piano_index.start_point,
"piece_source": json.dumps(piece_source),
"source_token_ids": source_token_ids,
Expand Down
12 changes: 8 additions & 4 deletions gpt2/setup/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def next_token_prediction_setup(
)
)

music_manager = MusicManager()
music_manager = MusicManager(
max_n_notes=cfg.training.max_notes_per_record,
)
if not tokenizer:
tokenizer = load_tokenizer(
cfg=cfg,
Expand All @@ -102,7 +104,7 @@ def next_token_prediction_setup(
)

train_dataset = datasets["train_split"]
print("Train dataset samples [M]:", len(train_dataset) / 1e6)
print("Train dataset samples [G]:", len(train_dataset) / 1e9)
val_datasets = datasets["validation_splits"]

train_loader, val_loaders = loaders_setup(
Expand Down Expand Up @@ -132,7 +134,9 @@ def piano_task_setup(
) -> DatasetsSetup:
hf_dataset = create_augmented_dataset(cfg)

music_manager = MusicManager()
music_manager = MusicManager(
max_n_notes=cfg.training.max_notes_per_record,
)
if not tokenizer:
tokenizer = load_tokenizer(
cfg=cfg,
Expand All @@ -155,7 +159,7 @@ def piano_task_setup(
)

train_dataset = datasets["train_split"]
print("Train dataset samples [M]:", len(train_dataset) / 1e6)
print("Train dataset samples [G]:", len(train_dataset) / 1e9)
val_datasets = datasets["validation_splits"]

train_loader, val_loaders = loaders_setup(
Expand Down
3 changes: 0 additions & 3 deletions gpt2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def model_tuning(tune_cfg: DictConfig):
model.to(device_setup.device)

if run_cfg.system.compile:
print("compiling the model... (takes a ~minute)")
model = torch.compile(model)

backprop_setup = setup_backprop(
Expand Down Expand Up @@ -135,7 +134,6 @@ def resume_training(resume_cfg: DictConfig):

# TODO Not sure if this is a "system" setting
if run_cfg.system.compile:
print("compiling the model... (takes a ~minute)")
model = torch.compile(model)

backprop_setup = setup_backprop(
Expand Down Expand Up @@ -235,7 +233,6 @@ def training_from_scratch(cfg: DictConfig):
)

if cfg.system.compile:
print("compiling the model... (takes a ~minute)")
model = torch.compile(model)

milion_params = model.get_num_params() / 1e6
Expand Down