using Flux
calc_device = cpu
(width, height, channels, records) = size(X_train_aug)
model = Flux.Chain(Flux.Conv((3, 3), channels=>32, pad=(1,1), selu),
Flux.MaxPool((2, 2)),
Flux.Conv((3, 3), 32=>64,pad=(1,1), selu),
Flux.MaxPool((2, 2)),
Flux.Conv((3, 3), 64=>64,pad=(1,1), selu),
Flux.MaxPool((2, 2)),
Flux.GlobalMeanPool(),
Flux.flatten,
Flux.Dense(64,64, selu),
Flux.Dense(64,3)) |> calc_device