using MLUtils, Augmentor, OneHotArrays function augment_images_one_hot_labels(images, labels, image_height::Int, image_width::Int, augmentor, categories) # input to the model requires WHCN format (width, height, channels, batchs/samples) aug_images = Array{Float32}(undef, (image_width, image_height, 3, MLUtils.numobs(images))) # augment the images for (index, image) in enumerate(images) aug_images[:,:,:,index] = Augmentor.augment(image, augmentor) end # one hot encode the labels one_hot_labels = OneHotArrays.onehotbatch(labels, categories) return aug_images, one_hot_labels end
Hosted onDeepnote