John Q-Train

December 11th, 2023 4:45 PM

Abstract

We created John Q-Train: An AI Jazz Musician that can improvise a melody in real-time while a backup is playing chords. John Q-Train is an MDP that is formulated as an RNN Q-Network and was trained by iterating through 30 jazz MIDI files and updating its parameters to minimize loss. We defined our loss function as the mean squared error between the target Q-value and the predicted Q-value. We were able to successfully decrease the loss through training and found that the quality of our model’s improvisation improved significantly throughout the process.

Problem

Jazz improvisation embodies decision-making under uncertainty because musicians must listen to each other and make real-time decisions about what note, tempo, and volume to play without any reference to a composition.

Prior Work

There have been multiple prior attempts at an AI jazz musician, but they have yet to achieve a subjective quality that consistently matches human performance. One example is “The Jazz Transformer on the Front Line”, which used a Transformer to model jazz lead sheets. Another example is “On the Adaptability of Recurrent Neural Networks for Real-Time Jazz Improvisation Accompaniment”, which implemented a jazz accompanist using a Recurrent Neural Network.

Approach

Data processing

First, We selected 30 jazz MIDI songs that were structurally simple and extracted the lead and backup tracks from them. Our training dataset was from Kaggle and is called “Jazz ML ready MIDI”.

Next, we created our own data structures (Song and MusicInterval) that would allow it to be operated on by our model. Song contained a song’s title, key, and notes. The notes were an array of MusicInterval, which represents an 8th note and an individual timestep. Each MusicInterval contains the note and velocity played by the lead and the notes and velocities played by the backup.

Finally, we populated these data structures. We created a representation of the song’s key by counting the pitches within each song and using it to calculate the probability distribution of each possible key. Then we populated Song for each song by finding the MusicInterval that corresponds to each note and then filling in its note and velocity.

In the end, we had 30 populated Song objects that were ready to be processed by our model.

Model design

For this problem, we designed a neural network whose inputs are the state space and the outputs are the corresponding Q values for each action.

To simplify the problem, we set our model to the role of a lead, meaning it plays a maximum of one note at each time step. We also simplified the environment to be one backup accompaniment playing chord voicings.

MDP

  • The state space is all of the notes previously played by the lead and backup.
  • The action space is a vector in range [0,129] which represents playing a note at the current timestep
  • 0-127 are MIDI notes, 128 means no note, and 129 means continue previous note
  • The reward space evaluates the quality of an action based on a combination of velocity, playing in-key, playing in-chord, and being close to the previous note.

Q-Function Model

Input preprocessing layer that would dot product all of the state notes with a vector of shape (12) repeated along the length of the input and shifted by i places for i in [0,11]. This vector was a learnable parameter for the model which we hoped would capture the key for the current time step.

Because of the linear, temporal nature of music, we decided to use an LSTM to process the state space. A final linear layer to go from the hidden state to the size of the action space.

Model training

To train the model, we took inspiration from the Q-learning paradigm but with an ML twist. The loss function we were minimizing was the mean squared error of the Q learning objective:

At each timestep, we would feed one MusicInterval into the model to compute the Q vector for all possible actions. Then we would consider only the Q value for the action that was actually taken in the training set and compute the loss between that value and the reward plus the discounted maximum value of the output of the model’s next timestep. We did not use Monte Carlo sampling and instead computed the mean squared error for each entire song before taking a step with the gradient.

The reward function that was used during training is a mixture of four components:

  • A Velocity Reward which assumes louder notes in the training data were better notes
  • A Closeness Reward which incentivizes consecutive notes that are closer together
  • A chord Reward which incentivizes lead notes that are present in the harmony
  • And A Scale Reward which incentivizes lead notes that are part of the song’s key.

Qualitative Evaluation:

We don’t listen to music by looking at the score, you have to hear it be sung.

To generate a lead melody from the model, we fed in each MusicInterval from the test set and queried the model for the Q values of taking each action from the current state. To sample from the model, we took a temperature-scaled softmax approach (prob of taking action a is proportional to \lambda Q(a,s)) to get random melodies each time, while also prioritizing the notes that likely would sound good.

We were able to significantly improve the model simply by tweaking hyperparameters and retraining the model. Here are the hyperparameters we tweaked to get the final result:

For training, we decreased discount, increased reward for closeness, and increased reward for playing in key.

For Evaluating, we decreased temperature for softmax (more spiky) nad increased scale-lock.

Discussion

Ultimately, John Q-Train was able to successfully learn melodic jazz improvisation through Q-learning. Here are a few key decisions that were the most influential towards its resulting performance.

  • Discretizing to 8th notes. Firstly, we decided to discretize all notes to 8th notes so the state space was a tractable size. However, restricting the model to play in 8th notes took away some of the human feel that naturally exists in Jazz.
  • Model design. Secondly, Our decision to use an RNN for our Q-function was key to allowing it to understand the temporal relationships in Jazz music. However, modeling jazz as an MDP inherently assumes an objective truth in what makes good Jazz music, which is far from the truth. We also found that it was quite difficult to balance exploration and exploitation. The outputs were usually either terrible sounding or utterly boring
  • Reward design
  • Hyperparameter tuning. Finally, our iterative approach to designing the reward function as well as tuning that function and our sampling procedures using hyperparmeters ultimately gave us much more success than relying fully on the model to learn the underlying structure of jazz music. This human-in-the-loop approach to creating John Q-Train ultimately gave us an improviser that was not very good but was surprisingly competent despite the small model and dataset.

Here is some selected code that highlights the primary functions of John Q-Train:

1 2class Q_RNN_model(nn.Module): 3 def __init__(self, input_size, num_possible_actions, hidden_size=128, num_layers=2, batch_size=1): 4 super(Q_RNN_model, self).__init__() 5 6 self.input_size = input_size 7 self.state_size = input_size + 12 + 12 8 self.num_possible_actions = num_possible_actions 9 self.hidden_size = hidden_size 10 self.num_layers = num_layers 11 self.batch_size = batch_size 12 13 self.hidden_state = torch.zeros(self.num_layers, self.hidden_size) 14 self.cell_state = torch.zeros(self.num_layers, self.hidden_size) 15 16 self.ionian = nn.Parameter(torch.tensor(K.init_ionian).float()) 17 self.len_scale = len(K.init_ionian) 18 19 self.lstm = nn.LSTM(self.state_size, hidden_size, batch_first=True, num_layers=num_layers) 20 21 self.linear = nn.Linear(hidden_size, num_possible_actions) 22 23 def forward(self, curr_state): 24 curr_state = self.pre_process_state(curr_state).unsqueeze(0) 25 lstm_out, (self.hidden_state, self.cell_state) = self.lstm(curr_state, (self.hidden_state, self.cell_state)) 26 last_hidden_state = lstm_out[-1, :] 27 output = self.linear(last_hidden_state) 28 return output 29 30 def pre_process_state(self, curr_state): 31 scale = self.ionian.repeat(2 + int(K.num_possible_notes / self.len_scale)) 32 scale = scale[:K.num_possible_notes + 11] 33 scale_matrix = scale.unfold(0, K.num_possible_notes, 1) 34 leadscale = scale_matrix @ curr_state[K.num_possible_notes:2*K.num_possible_notes] # [c, c#, d, d#, e, etc.] # major scale 35 backupscale = scale_matrix @ curr_state[:K.num_possible_notes] # [c, c#, d, d#, e, etc.] 36 return torch.cat(( 37 curr_state, 38 F.softmax(leadscale, dim=-1), 39 F.softmax(backupscale, dim=-1) 40 )) 41 42 def clear_hidden_state(self): 43 self.hidden_state = torch.zeros(self.num_layers, self.hidden_size) 44 self.cell_state = torch.zeros(self.num_layers, self.hidden_size) 45 46def predict_action(net: Q_RNN_model, interval: MusicInterval): 47 q_values = net(interval.getStateVector()) 48 q_probs = F.softmax(q_values) 49 sampled_action = torch.multinomial(q_probs, 1).item() 50 prob = q_probs[sampled_action] 51 return sampled_action, prob 52 53def train(epoch, net: Q_RNN_model, device, optimizer, songs: List[Song]): 54 print('\nEpoch: %d' % epoch) 55 net.train() 56 display_loss = 0 57 n_loss = 0 58 with tqdm(total=len(songs)) as progress_bar: 59 order = [i for i in range(len(songs))] 60 random.shuffle(order) 61 for i in order: 62 song = songs[i] 63 loss = 0.0 64 song_len = song.len_song 65 optimizer.zero_grad() 66 x = song.getInterval(0).getStateVector() 67 x = x.to(device) 68 q_values = net(x) 69 action = song.getInterval(0).getAction() 70 reward = 0 71 q_value = q_values[action] 72 for i in range(1,song_len): 73 x = song.getInterval(i).getStateVector() 74 x = x.to(device) 75 next_q_values = net(x) 76 next_action = song.getInterval(i).getAction() 77 next_reward, reward_types = song.getInterval(i).getReward(song=song) 78 q_value = next_q_values[action] 79 80 actual_q = reward + K.discount * max(next_q_values) 81 local_loss = F.mse_loss(q_value, actual_q) 82 loss += local_loss 83 display_loss = (display_loss * n_loss + local_loss.item()) / (n_loss + 1) 84 n_loss += 1 85 86 q_values = next_q_values 87 action = next_action 88 reward = next_reward 89 90 avg_loss = loss / song_len 91 avg_loss.backward() 92 optimizer.step() 93 net.clear_hidden_state() 94 95class MusicInterval(object): 96 97 def __init__(self, index=None): 98 self.action = K.num_possible_actions - 2 99 self.action_velocity = 0 100 self.backup = torch.zeros(K.num_possible_notes) 101 self.index = index 102 103 def setAction(self, note, velocity=127, force_overwrite=False): 104 assert velocity in range(K.num_possible_notes), "velocity should be in range [0,127]" 105 assert note in range(K.num_possible_actions), "note should be integer in range [0,129] (128 means no note) (129 means continue note)" 106 velocity /= (K.num_possible_velocities - 1) 107 if force_overwrite or velocity > self.action_velocity or note == K.num_possible_actions - 1: 108 self.action = note 109 self.action_velocity = velocity if note != K.num_possible_actions - 1 else 0.1 110 111 def setBackupNote(self, note, velocity=127): 112 assert velocity in range(K.num_possible_notes), "velocity should be in range [0,127]" 113 assert note in range(K.num_possible_notes), "note should be integer in range [0,127]" 114 115 velocity /= (K.num_possible_velocities - 1) 116 117 self.backup[note] = max(velocity, self.backup[note]) 118 119 def getBackup(self): 120 return self.backup 121 122 def getAction(self): 123 return self.action 124 125 def getVelocity(self): 126 return self.action_velocity 127 128 def getReward(self, song): 129 curr_note = self.action 130 curr_velocity = 0 131 prev_note = 60 132 notes_played_by_backup = self.backup.clone() 133 i = self.index 134 if curr_note == 128: 135 return K.reward_no_note, (K.reward_no_note, 0,0,0) 136 if curr_note == 129: 137 while i >= 0 and song.getInterval(i).getAction() == 129: 138 i -= 1 139 curr_note = 60 if i < 0 else song.getInterval(i).getAction() 140 curr_velocity = 0.5 if i < 0 else song.getInterval(i).getVelocity() 141 i -= 1 142 while i >= 0 and song.getInterval(i).getAction() >= K.num_possible_notes: 143 i -= 1 144 prev_note = 60 if i < 0 else song.getInterval(i).getAction() 145 146 for j in reversed(range(max(0,self.index - K.lookback_len), self.index)): 147 notes_played_by_backup += song.getInterval(j).getBackup() / (self.index - j) 148 149 max_backup_val = max(notes_played_by_backup) 150 notes_played_by_backup = notes_played_by_backup /max_backup_val if max_backup_val > 0 else torch.zeros_like(self.backup) 151 152 153 # print(curr_note) 154 155 velocity_reward = curr_velocity * K.reward_velocity_coef 156 scale_reward = song.getScale()[curr_note] * K.reward_scale_coef 157 chord_reward = notes_played_by_backup[curr_note] * K.reward_chord_coef 158 closeness_reward = K.reward_closeness_coef / (abs(curr_note - prev_note) + 1) if curr_note != prev_note else 0 159 # print(velocity_reward, scale_reward, closeness_reward, chord_reward) 160 total_reward = velocity_reward + scale_reward + closeness_reward + chord_reward 161 return total_reward, (velocity_reward, scale_reward.item(), closeness_reward, chord_reward.item()) 162 163 def getStateVector(self): 164 action_vector = torch.zeros(K.num_possible_actions) 165 action_vector[self.action] = self.action_velocity 166 return torch.cat(( 167 action_vector, 168 self.backup, 169 )) 170 171 172 173 174class Song(object): 175 def __init__(self, title, key): 176 self.title = title 177 self.notes: List[MusicInterval] = [] 178 self.key = key # tensor of length K.numpossiblenotes, 1 hot, only notes in scale 179 self.len_song = 0 180 181 def getInterval(self, index): 182 if index >= self.len_song: 183 self.notes += [MusicInterval(index=i) for i in range(self.len_song, index+1)] 184 self.len_song = len(self.notes) 185 return self.notes[index] 186 187 def getScale(self): 188 return self.key 189 190class DataLoader(object): 191 192 def __init__(self, test_or_train): 193 assert test_or_train in ["test", "train"], "test_or_train shoud be test or train" 194 self.test_or_train = test_or_train 195 self.songsInfo = [] 196 self.songs: List[Song] = [] 197 self.getMidiHeaders() 198 self.populateSongObjects() 199 200 def getMidiHeaders(self): 201 df = pd.read_csv("data/" + self.test_or_train + ".csv") 202 self.vars = df.columns.tolist() 203 self.data = df.to_numpy() 204 print("loading midi data from " + str(self.data.shape[0]) + " files") 205 for row in tqdm(self.data): 206 try: 207 midi_data = pretty_midi.PrettyMIDI("data/" + self.test_or_train + "/" + row[1]) 208 key_vector = self.getSongKeyVector(midi_data) 209 tempo_change_times, tempos = midi_data.get_tempo_changes() 210 song_info = { 211 "name": row[1], 212 "notes": row[2], 213 "len": row[3], 214 "uni_notes": row[4], 215 "len_uni_notes": row[5], 216 "midi_data": midi_data, 217 "key_vector": key_vector, 218 "tempo_info": list(zip(tempo_change_times, tempos)) 219 } 220 self.songsInfo.append(song_info) 221 # self.getNotes(midi_data) 222 except Exception as e: 223 print("Error reading file \"" + row[1] + "\" -> " + str(e)) 224 225 def getSongKeyVector(self, midi_data): 226 # Define major keys and their corresponding pitches 227 major_keys = { 228 'C': [0, 2, 4, 5, 7, 9, 11], 229 'C#': [1, 3, 5, 6, 8, 10, 0], 230 'D': [2, 4, 6, 7, 9, 11, 1], 231 'Eb': [3, 5, 7, 8, 10, 0, 2], 232 'E': [4, 6, 8, 9, 11, 1, 3], 233 'F': [5, 7, 9, 10, 0, 2, 4], 234 'F#': [6, 8, 10, 11, 1, 3, 5], 235 'G': [7, 9, 11, 0, 2, 4, 6], 236 'G#': [8, 10, 0, 1, 3, 5, 7], 237 'A': [9, 11, 1, 2, 4, 6, 8], 238 'Bb': [10, 0, 2, 3, 5, 7, 9], 239 'B': [11, 1, 3, 4, 6, 8, 10], 240 } 241 242 def midi_data_to_notes(midi_data): 243 midi_notes = [] 244 # Iterate through each instrument in the PrettyMIDI object 245 for instrument in midi_data.instruments: 246 # Iterate through each note in the instrument 247 for note in instrument.notes: 248 midi_notes.append(note.pitch) 249 250 return midi_notes 251 252 def midi_notes_to_key_counts(midi_notes): 253 def calculate_matching_notes(pitch_distribution, key_pitches): 254 """ 255 Calculate the sum of matching notes for a given major key. 256 """ 257 matching_notes = 0 258 for pitch in key_pitches: 259 matching_notes += pitch_distribution[pitch] 260 261 return matching_notes 262 263 # Convert MIDI notes to pitches 264 pitches = [note % 12 for note in midi_notes] 265 266 # Calculate pitch distribution 267 pitch_distribution = [0] * 12 268 for pitch in pitches: 269 pitch_distribution[pitch] += 1 270 271 # Get counts of matching notes for each key 272 key_counts = [(key, calculate_matching_notes(pitch_distribution, major_keys[key])) for key in major_keys] 273 274 return key_counts 275 276 def key_counts_to_key_probabilities(key_counts): 277 def min_max_normalize(data): 278 min_val = min(data) 279 max_val = max(data) 280 normalized = [(x - min_val) / (max_val - min_val) for x in data] 281 return normalized 282 283 def softmax(x, temperature=1.0): 284 e_x = np.exp((x - np.max(x)) / temperature) 285 return e_x / e_x.sum(axis=0) 286 287 # Calculate the softmax probabilities for each key 288 temperature = 0.1 289 counts = [count for key, count in key_counts] 290 normalized_counts = min_max_normalize(counts) 291 key_probabilities = dict(zip(major_keys.keys(), softmax(normalized_counts, temperature))) 292 293 return key_probabilities 294 295 def key_probabilities_to_key_vector(key_probabilities): 296 # Populate each pitch with its expected value - Sum_keys(p(pitch | key) * p(key)) 297 pitch_vector = np.zeros(12) 298 for pitch in range(12): 299 pitch_probability = 0 300 for key, key_pitches in major_keys.items(): 301 if pitch in key_pitches: 302 pitch_probability += key_probabilities[key] 303 pitch_vector[pitch] = pitch_probability 304 305 # Extend the pitch vector to all 128 notes 306 key_vector = np.tile(pitch_vector, 10) 307 key_vector = np.concatenate((key_vector, pitch_vector[:8])) 308 309 return key_vector 310 311 midi_notes = midi_data_to_notes(midi_data) 312 key_counts = midi_notes_to_key_counts(midi_notes) 313 # print(f"Key Counts: {key_counts}") 314 key_probabilities = key_counts_to_key_probabilities(key_counts) 315 # print(f"Key Probabilities: {key_probabilities}") 316 key_vector = key_probabilities_to_key_vector(key_probabilities) 317 # print(f"Key Vector: {key_vector}") 318 319 return key_vector 320 321 322 def populateSongObjects(self): 323 self.songs = [Song(self.songsInfo[i]["name"], self.songsInfo[i]["key_vector"]) for i in range(len(self.songsInfo))] 324 for songInfo, song in zip(self.songsInfo, self.songs): 325 midi_data, bpm_changes = songInfo["midi_data"], songInfo["tempo_info"] 326 lead_notes, backup_notes = midi_data.instruments[0].notes, midi_data.instruments[1].notes 327 328 lead_notes_to_beats = self.get_notes_to_beats(lead_notes, bpm_changes) 329 for lead_note in lead_notes: 330 lead_beats = lead_notes_to_beats[lead_note] 331 for i in range(len(lead_beats)): 332 beat = lead_beats[i] 333 if i == 0: 334 song.getInterval(beat).setAction(lead_note.pitch, lead_note.velocity) 335 else: 336 song.getInterval(beat).setAction(129, lead_note.velocity) 337 338 339 backup_notes_to_beats = self.get_notes_to_beats(backup_notes, bpm_changes) 340 for backup_note in backup_notes: 341 backup_beats = backup_notes_to_beats[backup_note] 342 for beat in backup_beats: 343 song.getInterval(beat).setBackupNote(backup_note.pitch, backup_note.velocity) 344 345 346 def get_notes_to_beats(self, notes, bpm_changes): 347 bpm_changes_copy = bpm_changes.copy() 348 # add ending window 349 last_note_end_time = notes[len(notes) - 1].end 350 bpm_changes_copy.append((last_note_end_time, 0)) 351 352 notes_to_beats = {} 353 beats = 0 354 window_start_time = 0.0 355 previous_epm = bpm_changes_copy[0][1] * 2 # Initial EPM 356 357 # populating notes_to_beats with the previous window 358 for change_time, bpm in bpm_changes_copy[1:]: 359 for note in notes: 360 note_start, note_end = note.start, note.end 361 if note_end <= change_time: 362 beat_position_start = math.floor(beats + (note_start - window_start_time) / 60 * previous_epm) 363 beat_position_end = math.floor(beats + (note_end - window_start_time) / 60 * previous_epm) 364 beat_positions = list(range(beat_position_start, beat_position_end + 1)) 365 notes_to_beats[note] = beat_positions 366 367 # adding the number of beats for previous interval 368 beats += (change_time - window_start_time) / 60 * previous_epm 369 370 # shift window 371 window_start_time = change_time 372 previous_epm = bpm * 2 373 374 return notes_to_beats 375 376 def getSongObjects(self): 377 return self.songs