Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Example/Example/Tests/Helpers/CompletionSerializer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//

import Foundation

class CompletionSerializer {

typealias CompletableFunction = (@escaping () -> ()) -> ()
Expand Down
16 changes: 13 additions & 3 deletions Example/Example/Tests/Helpers/Float.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@

import Foundation

extension Array where Element == Float {
public protocol Randomizable {
static func random() -> Self
}

extension Float : Randomizable {
public static func random() -> Float {
return Float(arc4random()) / Float(UInt32.max)
}
}

extension Array where Element : Randomizable {

static func random(count: Int) -> [Float] {
return (0..<count).map { _ in Float(arc4random()) / Float(UINT32_MAX) }
static func random(count: Int) -> [Element] {
return (0..<count).map { _ in Element.random() }
}

}
8 changes: 4 additions & 4 deletions Example/Example/Tests/Helpers/Texture.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ public class Texture {
data[y * width + x][z] = newValue
}
}

var width: Int {
return size.w
}

var height: Int {
return size.h
}

var depth: Int {
return size.f
}
Expand Down Expand Up @@ -164,5 +164,5 @@ extension Texture {
}
return texture
}

}
6 changes: 3 additions & 3 deletions Example/Example/Tests/InstanceNormTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class InstanceNormTest: BenderTest {

func test(texture: Texture, completion: @escaping () -> ()) {
let styleNet = Network(inputSize: texture.size)
let weights = [Float].init(repeating: Float(arc4random()) / Float(UINT32_MAX), count: texture.depth)
let bias = [Float].init(repeating: Float(arc4random()) / Float(UINT32_MAX), count: texture.depth)
let weights = [Float].init(repeating: Float.random(), count: texture.depth)
let bias = [Float].init(repeating: Float.random(), count: texture.depth)
let scale = Data.init(bytes: weights, count: texture.totalCount * MemoryLayout<Float>.stride)
let shift = Data.init(bytes: bias, count: texture.totalCount * MemoryLayout<Float>.stride)
styleNet.start ->> InstanceNorm(scale: scale, shift: shift)
Expand Down Expand Up @@ -100,5 +100,5 @@ class InstanceNormTest: BenderTest {
// }
return output
}

}
2 changes: 1 addition & 1 deletion Example/Example/ViewController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import UIKit
class TestViewController: UIViewController {

let testRunner = BenderTestRunner()

override func viewDidLoad() {
super.viewDidLoad()
testRunner.run()
Expand Down
6 changes: 3 additions & 3 deletions Example/MNISTTestController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class MNISTTestController: UIViewController, ExampleViewController {

// If you want to create it from scratch comment the line above and uncomment:
// createMNISTNetwork()

var me = self
me.setPixelBufferPool()
setupMetalView()
Expand Down Expand Up @@ -197,7 +197,7 @@ class MNISTTestController: UIViewController, ExampleViewController {
}

}

}

// MARK: - AVCaptureVideoDataOutputSampleBufferDelegate
Expand Down Expand Up @@ -231,7 +231,7 @@ extension MNISTTestController: AVCaptureVideoDataOutputSampleBufferDelegate {
} else {
debugPrint("samplebuffer is nil \(sampleBuffer)")
}

return nil
}

Expand Down
7 changes: 1 addition & 6 deletions Example/RandomLoader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// Created by Mathias Claassen on 5/24/17.
//
//

import Foundation
import MetalBender

Expand All @@ -24,13 +23,9 @@ class RandomParameterLoader: ParameterLoader {
free(pointer)
}

func random() -> Float {
return Float(Double(arc4random()) / Double(UINT32_MAX))
}

func uniformRandom(_ x: UnsafeMutablePointer<Float>, count: Int, scale: Float) {
for i in 0..<count {
x[i] = (random()*2 - 1) * scale
x[i] = (Float.random()*2 - 1) * scale
}
}

Expand Down
6 changes: 2 additions & 4 deletions Sources/Adapters/TFOptimizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
//
//

import Foundation

/// Processes a grpah imported from TensorFlow applying some optimizations/simplifications
public protocol TFOptimizer {

/// Optimize a grsph imported from TensorFlow. Nodes that are to be removed should be left without adjacencies
func optimize(graph: TFGraph)

}

public extension TFOptimizer {
Expand All @@ -22,7 +20,7 @@ public extension TFOptimizer {
/// This information can later be used by the 'activationNeuron' function
func addNeuronIfThere(node: TFNode) {
let outgoing = node.outgoingNodes()
if outgoing.count == 1, let next = (outgoing.first as? TFNode),
if let next = (outgoing.first as? TFNode),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't want to remove this check. There could be more than one outgoingNodes

next.nodeDef.isTFReLuOp || next.nodeDef.isTFTanhOp || next.nodeDef.isTFSigmoidOp {
var neuron = Tensorflow_AttrValue()
neuron.value = Tensorflow_AttrValue.OneOf_Value.s(next.nodeDef.op.data(using: .utf8)!)
Expand Down
2 changes: 1 addition & 1 deletion Sources/Adapters/Tensorflow/ProtoExtensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ extension Tensorflow_NodeDef {
let formatString = String(data: dataFormat, encoding: .utf8) else {
return (Int(strides[1]), Int(strides[2]))
}

let strideX = formatString == "NHWC" ? strides[2] : strides[3]
let strideY = formatString == "NHWC" ? strides[1] : strides[2]
return (Int(strideX), Int(strideY))
Expand Down
2 changes: 0 additions & 2 deletions Sources/Adapters/Tensorflow/String+TFParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//

import Foundation

extension Tensorflow_NodeDef {

var isTFAddOp: Bool { return op == Constants.Ops.Add }
Expand Down
4 changes: 1 addition & 3 deletions Sources/Adapters/Tensorflow/TFConvOptimizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//

import Foundation

/// Combines Conv2d with BiasAdd. Should be executed after Variable Processor
class TFConvOptimizer: TFOptimizer {

Expand Down Expand Up @@ -39,5 +37,5 @@ class TFConvOptimizer: TFOptimizer {
}
}
}

}
2 changes: 1 addition & 1 deletion Sources/Adapters/Tensorflow/TFConverter+Mappers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,5 @@ public extension TFConverter {
}
mappers[Constants.Ops.BatchNormGlobal] = batchnormMapper
}

}
4 changes: 2 additions & 2 deletions Sources/Adapters/Tensorflow/TFConverter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public enum ProtoFileType {

case binary
case text

}

/// Converts a TFNode to a NetworkLayer of Bender
Expand Down Expand Up @@ -105,7 +105,7 @@ open class TFConverter: Converter {
}

/// Runs the mappers through all the nodes in the `graph`.
/// Ops that cannot be mapped are discarded. If these ops are in the main path of the graph then the resulting graph will be disconnected.
/// Ops that cannot be mapped are discarded. If these ops are in the main path of the graph then the resulting graph will be disconnected.
///
/// - Parameter graph: The TFGraph to be mapped
/// - Returns: An array of mapped oprations as NetworkLayer's
Expand Down
9 changes: 4 additions & 5 deletions Sources/Adapters/Tensorflow/TFDeleteOptimizers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
//
//

import Foundation

/// Strips common nodes that are used in training but not in evaluating/testing

public class TFStripTrainingOps: TFOptimizer {

public var regexes: [Regex] = [TFDeleteSave().regex, TFDeleteRegularizer().regex, TFDeleteInitializer().regex]
Expand All @@ -27,7 +26,7 @@ public class TFIgnoredOpsDeleter: TFOptimizer {

let ops = ["NoOp", "ExpandDims", "Cast", "Squeeze", "StopGradient", "CheckNumerics", "Assert", "Equal", "All",
"Dequantize", "RequantizationRange", "Requantize", "PlaceholderWithDefault", "Identity"]

public func optimize(graph: TFGraph) {
for node in graph.nodes {
if ops.contains(node.nodeDef.op) {
Expand All @@ -42,7 +41,7 @@ public class TFIgnoredOpsDeleter: TFOptimizer {
public class TFDeleteSave: TFDeleteSubgraphOptimizer {

public var regex: Regex = try! Regex("save(_\\d+)?/")

}

/// Deletes 'Initializer' subgraphs
Expand Down Expand Up @@ -80,5 +79,5 @@ fileprivate extension String {
let regex = try! Regex("dropout(_\\d+)?/mul")
return regex.test(self)
}

}
8 changes: 3 additions & 5 deletions Sources/Adapters/Tensorflow/TFDeleteSubgraph.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
//
//

import Foundation

/// Deletes a specific subgraph from a TFGraph
public protocol TFDeleteSubgraphOptimizer: TFOptimizer {

/// This Regex tells if a node is in a subgraph to be deleted or not.
/// This Regex tells if a node is in a subgraph to be deleted or not.
/// If the node's name has a match for this regex then it will be considered as belonging to a subgraph
var regex: Regex { get set }

Expand All @@ -37,7 +35,7 @@ public extension TFDeleteSubgraphOptimizer {
/// Returns an identifier for a node in this graph
func id(for node: TFNode) -> String {
let match = regex.match(node.nodeDef.name)
return (node.nodeDef.name as NSString).substring(to: match.location + match.length)
return String(node.nodeDef.name.prefix(match.location + match.length))
}

/// Returns if the node has incoming connections to nodes outside of the subgraph
Expand Down Expand Up @@ -95,5 +93,5 @@ public extension TFDeleteSubgraphOptimizer {
// wire together
rewire(mappings: mappings)
}

}
4 changes: 1 addition & 3 deletions Sources/Adapters/Tensorflow/TFDenseSubstitution.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//

import Foundation

/// Transforms a MatMul and a BiasAdd into a FullyConnected. Should be executed after Variable Processor
/// Does not work with embedded weights. Transposing of weights must be done previously on Python side.
public class TFDenseSubstitution: TFOptimizer {
Expand All @@ -16,7 +14,7 @@ public class TFDenseSubstitution: TFOptimizer {
MatMul --> BiasAdd [--> Neuron]
^ ^
Variable Variable

Returns:
Variable -> BiasAdd(+add-ons) <- Variable

Expand Down
6 changes: 2 additions & 4 deletions Sources/Adapters/Tensorflow/TFGraph.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//

import Foundation

/// Represents a graph that is imported from a TensorFlow model
public class TFGraph: GraphProtocol {

Expand All @@ -34,7 +32,7 @@ public class TFGraph: GraphProtocol {
for node in nodes {
// Filter TF control inputs
let filtered = node.nodeDef.input.filter { $0.first != "^" }

for input in filtered {
if let inputNode = nodesByName[input] {
node.addIncomingEdge(from: inputNode)
Expand All @@ -48,5 +46,5 @@ public class TFGraph: GraphProtocol {
var me = self
me.sortNodes()
}

}
6 changes: 2 additions & 4 deletions Sources/Adapters/Tensorflow/TFInstanceNormOptimizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//

import Foundation

/// Creates an Instance norm from a series of nodes. Must be executed after Variable processor
/// This implementation works with the one presented in https://github.com/lengstrom/fast-style-transfer/blob/master/src/transform.py#L49 (29/05/2017)
/// If you implement InstanceNorm differently you migth have to create your own parser.
Expand All @@ -22,7 +20,7 @@ public class TFInstanceNormOptimizer: TFDeleteSubgraphOptimizer {
Input --> InstanceNormAdd --> Output
^ ^
Variable -> InstanceNormMul | Variable

Set_of_nodes is ([Add -> Pow, Sub] -> RealDiv)

*/
Expand Down Expand Up @@ -81,5 +79,5 @@ public class TFInstanceNormOptimizer: TFDeleteSubgraphOptimizer {
}
}
}

}
9 changes: 4 additions & 5 deletions Sources/Adapters/Tensorflow/TFNode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//

import Foundation

/// A node of a graph imported from TensorFlow
public class TFNode: Node {

Expand All @@ -24,9 +22,10 @@ public class TFNode: Node {
public func isEqual(to other: Node) -> Bool {
return self.nodeDef == (other as? TFNode)?.nodeDef
}

}

public func ==(lhs: Tensorflow_NodeDef, rhs: Tensorflow_NodeDef) -> Bool {
return lhs.name == rhs.name
extension Tensorflow_NodeDef : Equatable {
public static func ==(lhs: Tensorflow_NodeDef, rhs: Tensorflow_NodeDef) -> Bool {
return lhs.name == rhs.name
}
}
2 changes: 0 additions & 2 deletions Sources/Adapters/Tensorflow/TFReshapeOptimizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//

import Foundation

/// Removes Reshape nodes
public class TFReshapeOptimizer: TFOptimizer {

Expand Down
Loading