diff --git a/dashboards/prompt_practice.py b/dashboards/prompt_practice.py index eb8a810..42f26fa 100644 --- a/dashboards/prompt_practice.py +++ b/dashboards/prompt_practice.py @@ -185,13 +185,24 @@ def main(): if not use_top_k: top_k = None - dataset_tokens = MusicManager().dataset_tokens + music_manager = MusicManager( + max_n_notes=run_config.training.max_notes_per_record, + ) + dataset_tokens = music_manager.dataset_tokens dataset_token = st.selectbox( label="Select a dataset token:", options=dataset_tokens, help="Choose from available special tokens to add to your prompt", ) + n_target_notes = st.number_input( + label="N target notes", + min_value=0, + max_value=music_manager.max_n_notes, + value=music_manager.max_n_notes // 3, + ) + n_notes_token = music_manager.get_n_notes_token(n_target_notes) + piano_task_manager: PianoTaskManager = model_setup["piano_task_manager"] # Using prompt path as a key to force st to restart this form whenever @@ -227,7 +238,7 @@ def main(): composer_tokens = ["", "", "", ""] for composer_token in composer_tokens: - pre_input_tokens = [dataset_token, composer_token] + piano_task.prefix_tokens + pre_input_tokens = [dataset_token, composer_token, n_notes_token] + piano_task.prefix_tokens st.write("Pre-input tokens:", pre_input_tokens) @@ -239,8 +250,7 @@ def main(): generation_setup = { "seed": local_seed, "temperature": temperature, - "dataset_token": dataset_token, - "composer_token": composer_token, + "pre_input_tokens": pre_input_tokens, "piano_task": piano_task.name, "top_k": top_k, "model_id": os.path.basename(checkpoint_path), @@ -271,6 +281,7 @@ def main(): generated_piece = ff.MidiPiece(generated_notes_df) streamlit_pianoroll.from_fortepyan(prompt_piece, generated_piece) + st.write("Generated notes:", generated_piece.size) unique_id = secrets.token_hex(10) if pianoroll_apikey: diff --git a/gpt2/data/musicality.py b/gpt2/data/musicality.py index 138861c..b5c94f4 100644 --- a/gpt2/data/musicality.py +++ b/gpt2/data/musicality.py @@ -49,12 +49,19 @@ class MusicManager: "Johann Sebastian Bach": "", } - def __init__(self): + def __init__(self, max_n_notes: int): + self.max_n_notes = max_n_notes self.composer_regex_map = self.create_composer_regex_map() + @property + def n_note_tokens(self) -> list[str]: + # NOTE: 0 is a valid number of notes + tokens = [self.get_n_notes_token(n_notes) for n_notes in range(self.max_n_notes + 1)] + return tokens + @property def tokens(self) -> list[str]: - return self.dataset_tokens + self.composer_tokens + return self.dataset_tokens + self.composer_tokens + self.n_note_tokens def create_composer_regex_map(self) -> dict[re.Pattern, str]: regex_map: dict[re.Pattern, str] = {} @@ -91,3 +98,6 @@ def get_composer_token(self, composer: str) -> str: return matches[0][1] return "" + + def get_n_notes_token(self, n_notes: int) -> str: + return f"" diff --git a/gpt2/data/piano_dataset.py b/gpt2/data/piano_dataset.py index c91b8c8..eb098f4 100644 --- a/gpt2/data/piano_dataset.py +++ b/gpt2/data/piano_dataset.py @@ -71,13 +71,13 @@ def _build_records(self): record_lengths = np.array(self.dataset["n_notes"]) - self.max_notes_per_record + 1 # For every record we can have that many different subsequences with different lengths - self.n_duration_options = self.max_notes_per_record - self.min_notes_per_record + self.n_record_durations = self.max_notes_per_record - self.min_notes_per_record # Records shorter than context are effectively discarded self.record_lengths = record_lengths.clip(min=0) # Calculate total dataset length - self.length = self.record_lengths.sum() * self.num_tasks * self.n_duration_options + self.length = self.record_lengths.sum() * self.num_tasks * self.n_record_durations def __len__(self): # Return the total length of the dataset @@ -93,10 +93,10 @@ def _decode_piano_index(self, idx: int) -> PianoIndex: idx_bis = idx // self.num_tasks # ... and decode the number of notes for this record - n_notes = self.min_notes_per_record + (idx_bis % self.n_duration_options) + n_notes = self.min_notes_per_record + (idx_bis % self.n_record_durations) # ... and then decode the starting note idx - start_point = idx_bis // self.n_duration_options + start_point = idx_bis // self.n_record_durations for record_idx, record_length in enumerate(self.record_lengths): if start_point < record_length: @@ -144,7 +144,11 @@ def __getitem__(self, idx: int) -> dict: dataset_token = self.music_manager.get_dataset_token( piece_source=piece_source, ) - source_prefix_tokens = [dataset_token, composer_token] + piano_task.prefix_tokens + n_notes_token = self.music_manager.get_n_notes_token( + n_notes=piece_split.n_target_notes, + ) + source_prefix_tokens = [dataset_token, composer_token, n_notes_token] + source_prefix_tokens += piano_task.prefix_tokens prefix_token_ids = self.tokenizer.encode_tokens(source_prefix_tokens) # ... and join into a single promp sequence of token ids @@ -207,6 +211,8 @@ def __getitem__(self, idx: int) -> dict: "task": piano_index.task_name, "target_mask": target_mask, "n_notes": piano_index.n_notes, + "n_source_notes": piece_split.n_source_notes, + "n_target_notes": piece_split.n_target_notes, "start_point": piano_index.start_point, "piece_source": json.dumps(piece_source), "source_token_ids": source_token_ids, diff --git a/gpt2/setup/datasets.py b/gpt2/setup/datasets.py index 4a636ad..eebc405 100644 --- a/gpt2/setup/datasets.py +++ b/gpt2/setup/datasets.py @@ -84,7 +84,9 @@ def next_token_prediction_setup( ) ) - music_manager = MusicManager() + music_manager = MusicManager( + max_n_notes=cfg.training.max_notes_per_record, + ) if not tokenizer: tokenizer = load_tokenizer( cfg=cfg, @@ -102,7 +104,7 @@ def next_token_prediction_setup( ) train_dataset = datasets["train_split"] - print("Train dataset samples [M]:", len(train_dataset) / 1e6) + print("Train dataset samples [G]:", len(train_dataset) / 1e9) val_datasets = datasets["validation_splits"] train_loader, val_loaders = loaders_setup( @@ -132,7 +134,9 @@ def piano_task_setup( ) -> DatasetsSetup: hf_dataset = create_augmented_dataset(cfg) - music_manager = MusicManager() + music_manager = MusicManager( + max_n_notes=cfg.training.max_notes_per_record, + ) if not tokenizer: tokenizer = load_tokenizer( cfg=cfg, @@ -155,7 +159,7 @@ def piano_task_setup( ) train_dataset = datasets["train_split"] - print("Train dataset samples [M]:", len(train_dataset) / 1e6) + print("Train dataset samples [G]:", len(train_dataset) / 1e9) val_datasets = datasets["validation_splits"] train_loader, val_loaders = loaders_setup( diff --git a/gpt2/train.py b/gpt2/train.py index 4fbb640..b9a7c3b 100644 --- a/gpt2/train.py +++ b/gpt2/train.py @@ -63,7 +63,6 @@ def model_tuning(tune_cfg: DictConfig): model.to(device_setup.device) if run_cfg.system.compile: - print("compiling the model... (takes a ~minute)") model = torch.compile(model) backprop_setup = setup_backprop( @@ -135,7 +134,6 @@ def resume_training(resume_cfg: DictConfig): # TODO Not sure if this is a "system" setting if run_cfg.system.compile: - print("compiling the model... (takes a ~minute)") model = torch.compile(model) backprop_setup = setup_backprop( @@ -235,7 +233,6 @@ def training_from_scratch(cfg: DictConfig): ) if cfg.system.compile: - print("compiling the model... (takes a ~minute)") model = torch.compile(model) milion_params = model.get_num_params() / 1e6