MPSCNNHelloWorld/ViewController.swift
/* |
Copyright (C) 2016 Apple Inc. All Rights Reserved. |
See LICENSE.txt for this sample’s licensing information |
Abstract: |
View Controller for Metal Performance Shaders Sample Code. |
*/ |
import UIKit |
import MetalPerformanceShaders |
class ViewController: UIViewController{ |
// some properties used to control the app and store appropriate values |
// we will start with the simple 1 layer |
var deep = false |
var commandQueue: MTLCommandQueue! |
var device: MTLDevice! |
// Networks we have |
var neuralNetwork: MNIST_Full_LayerNN? = nil |
var neuralNetworkDeep: MNIST_Deep_ConvNN? = nil |
var runningNet: MNIST_Full_LayerNN? = nil |
// loading MNIST Test Set here |
let MNISTdata = GetMNISTData() |
// MNIST dataset image parameters |
let mnistInputWidth = 28 |
let mnistInputHeight = 28 |
let mnistInputNumPixels = 784 |
// Outlets to labels and view |
@IBOutlet weak var digitView: DrawView! |
@IBOutlet weak var predictionLabel: UILabel! |
@IBOutlet weak var accuracyLabel: UILabel! |
override func viewDidLoad() { |
super.viewDidLoad() |
// Load default device. |
device = MTLCreateSystemDefaultDevice() |
// Make sure the current device supports MetalPerformanceShaders. |
guard MPSSupportsMTLDevice(device) else { |
print("Metal Performance Shaders not Supported on current Device") |
return |
} |
// Create new command queue. |
commandQueue = device!.makeCommandQueue() |
// initialize the networks we shall use to detect digits |
neuralNetwork = MNIST_Full_LayerNN(withCommandQueue: commandQueue) |
neuralNetworkDeep = MNIST_Deep_ConvNN(withCommandQueue: commandQueue) |
runningNet = neuralNetwork |
} |
@IBAction func tappedDeepButton(_ sender: UIButton) { |
// switch network to be used between the deep and the single layered |
if deep { |
sender.setTitle("Use Deep Net", for: UIControlState.normal) |
runningNet = neuralNetwork |
} |
else{ |
sender.setTitle("Use Single Layer", for: UIControlState.normal) |
runningNet = neuralNetworkDeep |
} |
deep = !deep |
} |
@IBAction func tappedClear(_ sender: UIButton) { |
// clear the digitview |
digitView.lines = [] |
digitView.setNeedsDisplay() |
predictionLabel.isHidden = true |
} |
@IBAction func tappedTestSet(_ sender: UIButton) { |
// placeholder to count number of correct detections on the test set |
var correctDetections = Int32(0) |
let total = Float(10000) |
accuracyLabel.isHidden = false |
__atomic_reset() |
// validate NeuralNetwork was initialized properly |
assert(runningNet != nil) |
for i in 0..<Int(total){ |
inference(imageNum: i, correctLabel: UInt(MNISTdata.labels[i])) |
if i % 100 == 0 { |
accuracyLabel.text = "\(i/100)% Done" |
// this command helps update the UI in the loop regularly |
RunLoop.current.run(mode: RunLoopMode.defaultRunLoopMode, before: Date.distantPast) |
} |
} |
// display accuracy of the network on the MNIST test set |
correctDetections = __get_atomic_count() |
accuracyLabel.isHidden = false |
accuracyLabel.text = "Accuracy = \(Float(correctDetections * 100)/total)%" |
} |
@IBAction func tappedDetectDigit(_ sender: UIButton) { |
// get the digitView context so we can get the pixel values from it to intput to network |
let context = digitView.getViewContext() |
// validate NeuralNetwork was initialized properly |
assert(runningNet != nil) |
// putting input into MTLTexture in the MPSImage |
runningNet?.srcImage.texture.replace(region: MTLRegion( origin: MTLOrigin(x: 0, y: 0, z: 0), |
size: MTLSize(width: mnistInputWidth, height: mnistInputHeight, depth: 1)), |
mipmapLevel: 0, |
slice: 0, |
withBytes: context!.data!, |
bytesPerRow: mnistInputWidth, |
bytesPerImage: 0) |
// run the network forward pass |
let label = (runningNet?.forward())! |
// show the prediction |
predictionLabel.text = "\(label)" |
predictionLabel.isHidden = false |
} |
/** |
This function runs the inference network on the test set |
- Parameters: |
- imageNum: If the test set is being used we will get a value between 0 and 9999 for which of the 10,000 images is being evaluated |
- correctLabel: The correct label for the inputImage while testing |
- Returns: |
Void |
*/ |
func inference(imageNum: Int, correctLabel: UInt){ |
// get the correct image pixels from the test set |
var mnist_input_image = [UInt8]() |
mnist_input_image += MNISTdata.images[(imageNum*mnistInputNumPixels)..<((imageNum+1)*mnistInputNumPixels)] |
// create a source image for the network to forward |
let inputImage = MPSImage(device: device, imageDescriptor: (runningNet?.sid)!) |
// put image in source texture (input layer) |
inputImage.texture.replace(region: MTLRegion(origin: MTLOrigin(x: 0, y: 0, z: 0), |
size: MTLSize(width: mnistInputWidth, height: mnistInputHeight, depth: 1)), |
mipmapLevel: 0, |
slice: 0, |
withBytes: mnist_input_image, |
bytesPerRow: mnistInputWidth, |
bytesPerImage: 0) |
// run the network forward pass |
_ = runningNet!.forward(inputImage: inputImage, imageNum : imageNum, correctLabel: correctLabel) |
} |
} |
Copyright © 2016 Apple Inc. All Rights Reserved. Terms of Use | Privacy Policy | Updated: 2016-11-17