using Flux function train_model(dataloader, num_epochs::Int, model, optimiser, val_data, val_labels, device) num_batches = length(dataloader) for epoch in 1:num_epochs batch_loss = 0.0 batch_acc = 0.0 for (count, (batch_data, batch_labels)) in enumerate(dataloader) # assign each batch to the device x, y = device(batch_data), device(batch_labels) gradients = Flux.gradient(() -> loss(x, y), Flux.params(model)) Flux.Optimise.update!(optimiser, Flux.params(model), gradients) # collect the loss and accuracy from each batch batch_loss += loss(x, y) batch_acc += accuracy(x, y) # output the loss and accuracy at the end of each epoch if num_batches == count println("Epoch: $(epoch)") # calculate the average loss/accuracy for each epoch from the collected batches print("Training Loss: $(round((batch_loss / num_batches), digits=2)) ") print("Training Acc: $(round((batch_acc / num_batches), digits=2)) ") # calculate the validation loss/accuracy on the whole validation set print("Validation Loss: $(round(loss(device(val_data), device(val_labels)), digits=2)) ") print("Validation Acc: $(round(accuracy(device(val_data), device(val_labels)), digits=2))") println() end end end end
Hosted onDeepnote