-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathCollatable.swift
76 lines (68 loc) · 2.77 KB
/
Collatable.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import TensorFlow
// Private protocol used to derive conformance to Collatable using KeyPathIterable
public protocol _Collatable {
static func _collateLeaf<Root>(
_ rootOut: inout Root, _ rootKeyPath: PartialKeyPath<Root>, _ rootIn: [Root])
}
/// Types whose elements can be collated in some higher-rank element of the
/// same type (example: tensors, tuple of tensors)
public protocol Collatable: _Collatable {
init(collating: [Self])
}
// For derived conformance
extension Collatable {
public static func _collateLeaf<Root>(
_ rootOut: inout Root, _ rootKeyPath: PartialKeyPath<Root>, _ rootIn: [Root]
) {
guard let keyPath = rootKeyPath as? WritableKeyPath<Root, Self> else {
fatalError("Failed conversion from \(rootKeyPath) to 'WritableKeyPath<\(Root.self), \(Self.self)>'")
}
rootOut[keyPath: keyPath] = Self.init(collating: rootIn.map { $0[keyPath: keyPath] })
}
}
// For derived conformance
extension _KeyPathIterableBase {
public func _collateAll<Root>(
_ rootOut: inout Root, _ rootKeyPath: PartialKeyPath<Root>, _ rootIn: [Root]) {
for kp in _allKeyPathsTypeErased {
let joinedKeyPath = rootKeyPath.appending(path: kp)!
if let valueType = type(of: joinedKeyPath).valueType as? _Collatable.Type {
valueType._collateLeaf(&rootOut, joinedKeyPath, rootIn)
} else if let nested = self[keyPath: kp] as? _KeyPathIterableBase {
nested._collateAll(&rootOut, joinedKeyPath, rootIn)
} else {
fatalError("Key path \(kp) is not Collatable")
}
}
}
}
// For derived conformance
extension KeyPathIterable {
public init(collating roots: [Self]) {
self = roots[0]
_collateAll(&self, \.self, roots)
}
}
// Tensor are collated using stacking
extension Tensor: Collatable {
public init(collating: [Self]) { self.init(stacking: collating) }
}
// Example: you can derive conformance to Collatable directly if a struct has only tensors
// struct Pair : Collatable, KeyPathIterable {
// var first: Tensor
// var second: Tensor
// var third: Tensor = Tensor(5.0)
// }