This repository was archived by the owner on Jul 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathTensor.swift
880 lines (798 loc) · 28.5 KB
/
Tensor.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import CTensorFlow
import Foundation
import _Differentiation
infix operator .==: ComparisonPrecedence
infix operator .!=: ComparisonPrecedence
/// Special protocol for calling tensorflow operations that take heterogeneous arrays as input.
public protocol AnyTensor {
var _rawTensorHandle: CTensorHandle { get }
var _tensorFlowDataType: TensorDataType { get }
}
/// A multidimensional array of elements that is a generalization of vectors and matrices to
/// potentially higher dimensions.
///
/// The generic parameter `Scalar` describes the type of scalars in the tensor (such as `Int32`,
/// `Float`, etc).
@frozen
public struct Tensor<Scalar: TensorFlowScalar> {
/// The underlying `TensorHandle`.
/// - Note: `handle` is public to allow user defined ops, but should not normally be used.
public let handle: TensorHandle<Scalar>
@inlinable
public init(handle: TensorHandle<Scalar>) {
self.handle = handle
}
}
extension Tensor: AnyTensor {
public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle }
public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType }
}
//===------------------------------------------------------------------------------------------===//
// Tensor Properties
//===------------------------------------------------------------------------------------------===//
extension Tensor {
/// The number of dimensions of the `Tensor`.
public var rank: Int {
@_semantics("autodiff.nonvarying")
get { handle.rank }
}
/// The shape of the `Tensor`.
public var shape: TensorShape {
@_semantics("autodiff.nonvarying")
get { handle.shape }
}
/// The number of scalars in the `Tensor`.
#if USING_X10_BACKEND
@inlinable
public var scalarCount: Int {
@_semantics("autodiff.nonvarying")
get { shape.contiguousSize }
}
#else
@inlinable
public var scalarCount: Int {
@_semantics("autodiff.nonvarying")
get {
let status = _ExecutionContext.global.status
let size = TFE_TensorHandleNumElements(handle._cTensorHandle, status)
checkOk(status)
return Int(size)
}
}
#endif
/// The rank of the tensor, represented as a `Tensor<Int32>`.
@inlinable
public var rankTensor: Tensor<Int32> {
@_semantics("autodiff.nonvarying")
get {
return _Raw.rank(self)
}
}
/// The dimensions of the tensor, represented as a `Tensor<Int32>`.
@inlinable
public var shapeTensor: Tensor<Int32> {
@_semantics("autodiff.nonvarying")
get {
return _Raw.shape(self)
}
}
/// The number of scalars in the tensor, represented as a `Tensor<Int32>`.
@inlinable
public var scalarCountTensor: Tensor<Int32> {
@_semantics("autodiff.nonvarying")
get {
return _Raw.size(self)
}
}
}
//===------------------------------------------------------------------------------------------===//
// Scalar Conversion
//===------------------------------------------------------------------------------------------===//
extension Tensor {
/// Returns `true` if `rank` is equal to 0 and `false` otherwise.
@inlinable
public var isScalar: Bool {
return rank == 0
}
/// Returns the single scalar element if `rank` is equal to 0 and `nil`
/// otherwise.
@inlinable
public var scalar: Scalar? {
isScalar ? scalars[0] : nil
}
/// Reshape to scalar.
/// - Precondition: The tensor has exactly one scalar.
@inlinable
@differentiable(where Scalar: TensorFlowFloatingPoint)
public func scalarized() -> Scalar {
precondition(
shape.contiguousSize == 1,
"This tensor must have exactly one scalar but contains \(shape.contiguousSize).")
return scalars[0]
}
}
extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: scalarized)
func _vjpScalarized() -> (value: Scalar, pullback: (Scalar) -> Tensor) {
let device = self.device
return (scalarized(), { v in Tensor(v, on: device) })
}
}
extension TensorFlowScalar {
@inlinable
public init?(_ tensor: Tensor<Self>) {
guard let scalar = tensor.scalar else {
return nil
}
self = scalar
}
}
//===------------------------------------------------------------------------------------------===//
// Array Conversion
//===------------------------------------------------------------------------------------------===//
extension Tensor {
@inlinable
public var array: ShapedArray<Scalar> {
debugLog("Returning a host copy of array.")
#if USING_X10_BACKEND
if handle.backend == .XLA {
return ShapedArray<Scalar>(shape: shape.dimensions, scalars: scalars)
}
#endif
return handle.makeHostCopy()
}
@differentiable(where Scalar: TensorFlowFloatingPoint)
public var scalars: [Scalar] {
#if USING_X10_BACKEND
if handle.backend == .XLA {
let (storage, _) = xlaTensor.fetchTensorValues(Scalar.self)
return storage
}
#endif
return array.scalars
}
}
extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: scalars)
func _vjpScalars() -> (value: [Scalar], pullback: (Array<Scalar>.TangentVector) -> Tensor) {
(
value: scalars,
pullback: { [shape = self.shape, device = self.device] v in
Tensor(shape: shape, scalars: v.base, on: device)
}
)
}
}
//===------------------------------------------------------------------------------------------===//
// Initialization
//===------------------------------------------------------------------------------------------===//
extension Tensor {
/// Creates a 0-D tensor from a scalar value.
@differentiable(where Scalar: TensorFlowFloatingPoint)
public init(_ value: Scalar, on device: Device = .default) {
#if USING_X10_BACKEND
switch device.backend {
case .XLA:
self.init(_xla: XLATensor.make(value, on: device))
case .TF_EAGER:
self.init(shape: [], scalars: [value], on: device)
}
#else
self.init(shape: [], scalars: [value], on: device)
#endif
}
}
extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: init(_:on:))
static func _vjpScalarInit(_ value: __owned Scalar, on device: Device = .default) -> (
value: Tensor, pullback: (Tensor) -> Scalar
) {
return (Tensor(value, on: device), { $0.scalarized() })
}
}
extension Tensor {
/// Creates a 1D tensor from scalars.
@inlinable
@differentiable(where Scalar: TensorFlowFloatingPoint)
public init(_ scalars: [Scalar], on device: Device = .default) {
self.init(shape: [scalars.count], scalars: scalars, on: device)
}
/// Creates a 1D tensor from scalars.
@inlinable
public init<C: Collection>(
_ vector: C, on device: Device = .default
) where C.Element == Scalar {
#if USING_X10_BACKEND
self.init([Scalar](vector), on: device)
#else
let handle = TensorHandle<Scalar>(
shape: [vector.count],
scalarsInitializer: { addr in
var currentAddr = addr
for scalar in vector {
currentAddr.initialize(to: scalar)
currentAddr = currentAddr.advanced(by: 1)
}
})
self.init(handle: handle)
#endif
}
/// Creates a tensor with the specified shape and contiguous scalars in row-major order.
///
/// - Parameters:
/// - shape: The shape of the tensor.
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
@inlinable
@differentiable(where Scalar: TensorFlowFloatingPoint)
public init(shape: TensorShape, scalars: [Scalar], on device: Device = .default) {
precondition(
shape.contiguousSize == scalars.count,
"""
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
provided.
""")
self = scalars.withUnsafeBufferPointer { bufferPointer in
Tensor(shape: shape, scalars: bufferPointer, on: device)
}
}
/// Creates a tensor with the specified shape and contiguous scalars in row-major order.
///
/// - Parameters:
/// - shape: The shape of the tensor.
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
public init(
shape: TensorShape,
scalars: UnsafeBufferPointer<Scalar>,
on device: Device = .default
) {
precondition(
shape.contiguousSize == scalars.count,
"""
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
provided.
""")
#if USING_X10_BACKEND
switch device.backend {
case .XLA:
self.init(_xla: XLATensor.make(scalars, shape.dimensions, on: device))
case .TF_EAGER:
let handle = TensorHandle<Scalar>(
shape: shape.dimensions,
scalarsInitializer: { address in
address.initialize(from: scalars.baseAddress!, count: shape.contiguousSize)
})
self.init(handle: handle)
}
#else
let handle = TensorHandle<Scalar>(
shape: shape.dimensions,
scalarsInitializer: { address in
address.initialize(from: scalars.baseAddress!, count: shape.contiguousSize)
})
self.init(handle: handle)
#endif
}
#if USING_X10_BACKEND
/// Creates a tensor with the specified shape and contiguous scalars in row-major order.
///
/// - Parameters:
/// - shape: The shape of the tensor.
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
@inlinable
public init(
shape: TensorShape,
scalars: [Scalar],
toReducedPrecision: Bool,
directlyOn device: Device
) {
precondition(
shape.contiguousSize == scalars.count,
"""
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
provided.
""")
self = scalars.withUnsafeBufferPointer { bufferPointer in
Tensor(
shape: shape, scalars: bufferPointer, toReducedPrecision: toReducedPrecision,
directlyOn: device)
}
}
/// Creates a tensor with the specified shape and contiguous scalars in row-major order.
///
/// - Parameters:
/// - shape: The shape of the tensor.
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
public init(
shape: TensorShape,
scalars: UnsafeBufferPointer<Scalar>,
toReducedPrecision: Bool,
directlyOn device: Device
) {
precondition(
shape.contiguousSize == scalars.count,
"""
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
provided.
""")
switch device.backend {
case .XLA:
self.init(
_xla: XLATensor.make(
scalars, shape.dimensions, toReducedPrecision: toReducedPrecision,
directlyOn: device))
case .TF_EAGER:
precondition(!toReducedPrecision)
self = .init(shape: shape, scalars: scalars, on: device)
}
}
#endif
/// Creates a tensor with the specified shape and contiguous scalars in row-major order.
///
/// - Parameters:
/// - shape: The shape of the tensor.
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
public init<C: Collection>(
shape: TensorShape, scalars: C, on device: Device = .default
) where C.Element == Scalar {
precondition(
shape.contiguousSize == scalars.count,
"""
The shape requires \(shape.contiguousSize) scalars but \(scalars.count) were \
provided.
""")
#if USING_X10_BACKEND
self.init(shape: shape, scalars: [Scalar](scalars), on: device)
#else
let handle = TensorHandle<Scalar>(
shape: shape.dimensions,
scalarsInitializer: { addr in
var currentAddr = addr
for scalar in scalars {
currentAddr.initialize(to: scalar)
currentAddr = currentAddr.advanced(by: 1)
}
})
self.init(handle: handle)
#endif
}
}
extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: init(_:on:))
static func _vjpInit(_ scalars: [Scalar], on device: Device = .default) -> (
value: Tensor, pullback: (Tensor) -> Array<Scalar>.TangentVector
) {
(
value: Tensor(scalars, on: device),
pullback: { v in
Array<Scalar>.TangentVector(v.scalars)
}
)
}
@inlinable
@derivative(of: init(shape:scalars:on:))
static func _vjpInit(
shape: TensorShape, scalars: [Scalar], on device: Device = .default
) -> (value: Tensor, pullback: (Tensor) -> Array<Scalar>.TangentVector) {
(
value: Tensor(shape: shape, scalars: scalars, on: device),
pullback: { v in
Array<Scalar>.TangentVector(v.scalars)
}
)
}
}
// Background story on `TensorElementLiteral` and why it's necessary:
//
// Very importantly, we want users to be able to implicitly convert an array
// literal to a tensor. At first glance, a straightforward implementation would
// be conforming `Tensor` to `ExpressibleByArrayLiteral` with
// `ExpressibleBy(Float|Int|Bool)Literal` as a base case. However, it is not
// that simple. We have binary operators that take `(Tensor, Scalar)`, `(Scalar,
// Tensor)` as well as `(Tensor, Tensor)`. When `Tensor`s are convertible from
// both a scalar and an array literal, a scalar-tensor binary operator like `+`
// will not type check.
//
// One way to work around it is to define all tensor-tensor operators in a
// protocol extension, and all tensor-scalar and scalar-tensor operators on
// concrete `Tensor`. Protocol extensions are less favorable than concrete
// implementations, so the compiler will prefer the concrete implementation for
// a scalar-tensor operation. However, this would cause enormous code bloat and
// is entirely a hack.
//
// To resolve ambiguity, `Tensor` should not be expressible by scalar literal.
// There's already a lightweight syntax for converting a scalar to a tensor:
// `Tensor(x)`, so there is no strong need for implicit conversion. But we need
// to find a way to give `ExpressibleByArrayLiteral` a base case: what would the
// `ArrayLiteralElement` be if we want to support both `[1,2,3]` and `[[[1,2],
// [1,2]]]`? In the first case the array literal element is an integer, while
// in the second case the array literal itself should be a tensor. Based on this
// observation, we come up with an intermediate type: `TensorElementLiteral` as
// the `ArrayLiteralElement` of `Tensor`. By making `TensorElementLiteral`
// expressible by both array literal and scalar literal, `Tensor` can now be
// converted from an arbitrary-dimensional array literal.
//
// Due to protocol requirements, `TensorElementLiteral` has to be
// public. It is never supposed to be used directly by any user, so the library
// convention is to prepend an underscore to its name, making it
// `_TensorElementLiteral`.
//
// It would be nice to be able to remove this type when we can systematically
// resolve tensor-scalar/scalar-tensor op ambiguity someday, either through an
// improved `Expressible` model, or by introducing an attribute to tell the type
// checker which function to prefer when ambiguity occurs.
/// Represents a literal element for conversion to a `Tensor`.
///
/// - Note: Do not ever use this API directly. This is implicitly created
/// during the conversion from an array literal to a `Tensor`, and is purely
/// for implementation purposes.
@frozen
public struct _TensorElementLiteral<Scalar> where Scalar: TensorFlowScalar {
@usableFromInline let tensor: Tensor<Scalar>
}
extension _TensorElementLiteral: ExpressibleByBooleanLiteral
where Scalar: ExpressibleByBooleanLiteral {
public typealias BooleanLiteralType = Scalar.BooleanLiteralType
@inlinable
public init(booleanLiteral: BooleanLiteralType) {
tensor = Tensor(Scalar(booleanLiteral: booleanLiteral))
}
}
extension _TensorElementLiteral: ExpressibleByIntegerLiteral
where Scalar: ExpressibleByIntegerLiteral {
public typealias IntegerLiteralType = Scalar.IntegerLiteralType
@inlinable
public init(integerLiteral: IntegerLiteralType) {
tensor = Tensor(Scalar(integerLiteral: integerLiteral))
}
}
extension _TensorElementLiteral: ExpressibleByFloatLiteral
where Scalar: ExpressibleByFloatLiteral {
public typealias FloatLiteralType = Scalar.FloatLiteralType
@inlinable
public init(floatLiteral: FloatLiteralType) {
tensor = Tensor(Scalar(floatLiteral: floatLiteral))
}
}
extension _TensorElementLiteral: ExpressibleByArrayLiteral {
public typealias ArrayLiteralElement = _TensorElementLiteral<Scalar>
@inlinable
public init(arrayLiteral elements: _TensorElementLiteral<Scalar>...) {
tensor = _Raw.pack(elements.map { $0.tensor })
}
}
extension Tensor: ExpressibleByArrayLiteral {
/// The type of the elements of an array literal.
public typealias ArrayLiteralElement = _TensorElementLiteral<Scalar>
/// Creates a tensor initialized with the given elements.
/// - Note: This is for conversion from tensor element literals. This is a
/// separate method because `ShapedArray` initializers need to call it.
@inlinable
internal init(_tensorElementLiterals elements: [_TensorElementLiteral<Scalar>]) {
self = _Raw.pack(elements.map { $0.tensor })
}
/// Creates a tensor initialized with the given elements.
@inlinable
public init(arrayLiteral elements: _TensorElementLiteral<Scalar>...) {
precondition(!elements.isEmpty, "Cannot create a 'Tensor' with no elements.")
self.init(_tensorElementLiterals: elements)
}
}
//===------------------------------------------------------------------------------------------===//
// Equatable
//===------------------------------------------------------------------------------------------===//
extension Tensor: Equatable where Scalar: Equatable {
@inlinable
public static func == (lhs: Tensor, rhs: Tensor) -> Bool {
guard lhs.shape == rhs.shape else {
return false
}
return (lhs .== rhs).all()
}
@inlinable
public static func != (lhs: Tensor, rhs: Tensor) -> Bool {
guard lhs.shape == rhs.shape else {
return true
}
return (lhs .!= rhs).any()
}
}
//===------------------------------------------------------------------------------------------===//
// Description and Visualization
//===------------------------------------------------------------------------------------------===//
// String conversion.
extension Tensor: CustomStringConvertible {
/// A textual representation of the tensor.
///
/// - Note: use `fullDescription` for a non-pretty-printed description showing all scalars.
public var description: String {
@_semantics("autodiff.nonvarying")
get {
return array.description
}
}
}
extension Tensor {
/// A textual representation of the tensor. Returns a summarized description if `summarize` is
/// true and the element count exceeds twice the `edgeElementCount`.
///
/// - Parameters:
/// - lineWidth: The max line width for printing. Used to determine number of scalars to print
/// per line.
/// - edgeElementCount: The maximum number of elements to print before and after summarization
/// via ellipses (`...`).
/// - summarizing: If true, summarize description if element count exceeds twice
/// `edgeElementCount`.
public func description(
lineWidth: Int = 80,
edgeElementCount: Int = 3,
summarizing: Bool = false
) -> String {
return array.description(
lineWidth: lineWidth,
edgeElementCount: edgeElementCount,
summarizing: summarizing)
}
/// A full, non-pretty-printed textual representation of the tensor, showing
/// all scalars.
public var fullDescription: String {
@_semantics("autodiff.nonvarying")
get {
return array.fullDescription
}
}
#if USING_X10_BACKEND
public var irText: String { XLATensor.irText(xlaTensor) }
#endif
}
// Xcode Playground display conversion.
extension Tensor: CustomPlaygroundDisplayConvertible {
public var playgroundDescription: Any {
@_semantics("autodiff.nonvarying")
get {
return description
}
}
}
// Mirror representation, used by debugger/REPL.
extension Tensor: CustomReflectable {
public var customMirror: Mirror {
@_semantics("autodiff.nonvarying")
get {
return Mirror(self, children: [], displayStyle: .struct)
}
}
}
//===------------------------------------------------------------------------------------------===//
// Codable Conformance
//===------------------------------------------------------------------------------------------===//
extension Tensor: Codable where Scalar: Codable {
@inlinable
public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
try container.encode(array)
}
@inlinable
public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
let array = try container.decode(ShapedArray<Scalar>.self)
self.init(array)
}
}
//===------------------------------------------------------------------------------------------===//
// Additive Group
//===------------------------------------------------------------------------------------------===//
extension Tensor: AdditiveArithmetic where Scalar: Numeric {
/// The scalar zero tensor.
#if USING_X10_BACKEND
public static var zero: Tensor {
var zero = Tensor(0, on: _DeviceThreadLocalState.local.currentDevice)
if _DeviceThreadLocalState.local.isReducedPrecision {
zero = zero.toReducedPrecision
}
return zero
}
#else
@inlinable
public static var zero: Tensor { Tensor(0) }
#endif
/// Adds two tensors and produces their sum.
/// - Note: `+` supports broadcasting.
@inlinable
@differentiable(where Scalar: TensorFlowFloatingPoint)
public static func + (lhs: Tensor, rhs: Tensor) -> Tensor {
_Raw.addV2(lhs, rhs)
}
/// Subtracts one tensor from another and produces their difference.
/// - Note: `-` supports broadcasting.
@inlinable
@differentiable(where Scalar: TensorFlowFloatingPoint)
public static func - (lhs: Tensor, rhs: Tensor) -> Tensor {
_Raw.sub(lhs, rhs)
}
}
extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: +)
static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (
value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)
) {
(
lhs + rhs,
{ [broadcastPb = BroadcastingPullback(lhs, rhs)] v in
return broadcastPb(v, v)
}
)
}
@inlinable
@derivative(of: -)
static func _vjpSubtract(lhs: Tensor, rhs: Tensor) -> (
value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)
) {
(
lhs - rhs,
{ [broadcastPb = BroadcastingPullback(lhs, rhs)] v in
return broadcastPb(v, -v)
}
)
}
}
//===------------------------------------------------------------------------------------------===//
// Multiplicative Group
//===------------------------------------------------------------------------------------------===//
extension Tensor: PointwiseMultiplicative where Scalar: Numeric {
/// The scalar one tensor.
@inlinable
public static var one: Tensor { Tensor(1) }
/// Returns the element-wise reciprocal of `self`.
@inlinable
public var reciprocal: Tensor { 1 / self }
/// Multiplies two tensors element-wise and produces their product.
/// - Note: `.*` supports broadcasting.
public static func .* (lhs: Tensor, rhs: Tensor) -> Tensor {
return lhs * rhs
}
}
//===------------------------------------------------------------------------------------------===//
// Differentiable
//===------------------------------------------------------------------------------------------===//
extension Tensor: Differentiable & EuclideanDifferentiable where Scalar: TensorFlowFloatingPoint {
public typealias TangentVector = Tensor
public var zeroTangentVectorInitializer: () -> TangentVector {
let shape = self.shape
return { Tensor(zeros: shape) }
}
}
//===------------------------------------------------------------------------------------------===//
// Multi-device support
//===------------------------------------------------------------------------------------------===//
#if USING_X10_BACKEND
extension Tensor {
/// The device on which `self` is allocated.
public var device: Device {
@_semantics("autodiff.nonvarying")
get {
switch handle.backend {
case .XLA:
return xlaTensor.device
case .TF_EAGER:
var kind: Device.Kind = .CPU
var ordinal = 0
let status = _ExecutionContext.global.status
// Find out what the underlying libraries think the default is.
if let cString = TFE_TensorHandleDeviceName(handle._cTensorHandle, status) {
checkOk(status)
let tfDeviceName = String(cString: cString)
// Parse type and ordinal from a string with the expected syntax:
// /job:localhost/replica:0/task:0/device:CPU:0
let pattern = ".+device:(.+):(\\d+)$"
let regex = try! NSRegularExpression(pattern: pattern)
let nsrange = NSRange(tfDeviceName.startIndex..., in: tfDeviceName)
if let match = regex.firstMatch(in: tfDeviceName, range: nsrange) {
if let kindRange = Range(match.range(at: 1), in: tfDeviceName) {
switch String(tfDeviceName[kindRange]).uppercased() {
case "CPU":
kind = .CPU
case "GPU":
kind = .GPU
case "TPU":
kind = .TPU
default:
kind = .CPU
}
}
if let ordinalRange = Range(match.range(at: 2), in: tfDeviceName) {
ordinal = Int(tfDeviceName[ordinalRange]) ?? 0
}
}
}
return Device(kind: kind, ordinal: ordinal, backend: .TF_EAGER)
}
}
}
}
#endif
//===------------------------------------------------------------------------------------------===//
// Annotations
//===------------------------------------------------------------------------------------------===//
public protocol TensorProtocol {
associatedtype Scalar: TensorFlowScalar
init(repeating repeatedValue: Scalar, shape: TensorShape, on device: Device)
var annotations: String { get }
var shape: TensorShape { get }
var summary: String { get }
}
public protocol DifferentiableTensorProtocol:
TensorProtocol & Differentiable & EuclideanDifferentiable
where Scalar: TensorFlowFloatingPoint {
@differentiable(wrt: self)
func annotate(_ annotation: String) -> Self
}
extension Tensor: TensorProtocol {
/// The annotations describing this tensor.
public var annotations: String {
#if USING_X10_BACKEND
switch handle.backend {
case .XLA:
return XLATensor.annotations(xlaTensor)
case .TF_EAGER:
return Device.defaultTFEager.annotationsAvailable
}
#else
return "Annotations not available in TF_EAGER."
#endif
}
/// An alias for annotations.
public var summary: String { annotations }
}
extension Tensor: DifferentiableTensorProtocol
where Scalar: TensorFlowFloatingPoint {
/// Adds an annotation.
///
/// Note: Only X10 is supported. For other backends, umodified `self` is
/// returned.
///
/// - Parameter annotation: The annotation to be added.
/// - Returns: The annotated tensor.
@differentiable(wrt: self)
public func annotate(_ annotation: String) -> Tensor<Scalar> {
#if USING_X10_BACKEND
switch handle.backend {
case .XLA:
return Tensor<Scalar>(_xla: XLATensor.annotate(xlaTensor, annotation))
case .TF_EAGER:
return self
}
#else
return self
#endif
}
@derivative(of: annotate)
@usableFromInline
func vjpAnnotate(_ annotation: String) -> (
value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>
) {
(annotate(annotation), { $0 })
}
}