Loading weights into Metal Performance Shaders(MPS) Convolution Layer from PyTorch.

Metal Performance Shaders(MPS) is hi-performance library based on Metal with ready to use Neural Networks Layers using direct API to GPU.
It’s give ability skip or significant improve performance for pre/post processing your’s data.
For my task I got problem with inconsistency Convoluting Networks format with Pytorch layers and MPS convolution layers.
In MPS convolution layer implemented by MPSCNNConvolution or MPSCNNConvolutionNode classes.
These two classes on init use other object which should implement MPSCNNConvolutionDataSource protocol. It’s needed for loading/managing weights and biases.
Main Problem
Weights format/ordering in MPSCNNConvolutionDataSource:
weight[Output Channels][Kernel Height][Kernel Width][Input Channels]
Weights format/ordering in PyTorch torch.nn.Conv2d:
weight[Output Channels][Input Channels][Kernel Height][Kernel Width]
Export Weighs Data from PyTorch:
For Metal Performance Shaders you should unload weights directly to file:
model.conv1.weight.flatten().detach().numpy().tofile("conv1_weights.bin")
These script convert pytorch ternsors into numpy float32 array and unload to raw binary file with array of 32 bit floats.
Where conv1 is object of class Conv2d which is 2D Convolution layer.
Solution
For correct fixing this problem need reorder 4D Tensor.
var data = Bundle.main.url(forResource: name, withExtension: "bin").flatMap { try? Data(contentsOf: $0) }let floatData = data?.withUnsafeMutableBytes { return Array($0.bindMemory(to: Float32.self)) } ?? [Float32]()var newFloatData = Array<Float32>(repeating: 0.0, count: floatData.count)let Cf = self.inputChannels
let M = self.outputChannels
let kH = self.kernelHeight
let kW = self.kernelWidthfor m in 0..<M {
for c in 0..<Cf {
for kh in 0..<kH {
for kw in 0..<kW {
newFloatData[m * kH * kW * Cf + kh * kW * Cf + kw * Cf + c] = floatData[m * Cf * kH * kW + c * kH * kW + kh * kW + kw] }
}
}
}
PS: Thanks for Caffe2 leftovers in PyTorch sources, them really help to me understand problem. (link)
Full Example of MPSCNNConvolutionDataSource you can see at: https://github.com/dhrebeniuk/RealTimeFastStyleTransfer/blob/master/RealTimeFastStyleTransfer/Sources/Rendering/Filters/StyleTransfer/Convolution2dDataSource.swift