using Flux
function train_model(dataloader, num_epochs::Int, model, optimiser, val_data, val_labels, device)
num_batches = length(dataloader)
for epoch in 1:num_epochs
batch_loss = 0.0
batch_acc = 0.0
for (count, (batch_data, batch_labels)) in enumerate(dataloader)
x, y = device(batch_data), device(batch_labels)
gradients = Flux.gradient(() -> loss(x, y), Flux.params(model))
Flux.Optimise.update!(optimiser, Flux.params(model), gradients)
batch_loss += loss(x, y)
batch_acc += accuracy(x, y)
if num_batches == count
println("Epoch: $(epoch)")
print("Training Loss: $(round((batch_loss / num_batches), digits=2)) ")
print("Training Acc: $(round((batch_acc / num_batches), digits=2)) ")
print("Validation Loss: $(round(loss(device(val_data), device(val_labels)), digits=2)) ")
print("Validation Acc: $(round(accuracy(device(val_data), device(val_labels)), digits=2))")
println()
end
end
end
end