Skip to content

Commit

Permalink
Automated sync from github.com/tensorflow/tensorflow (#2226)
Browse files Browse the repository at this point in the history
BUG=automated sync from upstream
NO_CHECK_TFLITE_FILES=automated sync from upstream
  • Loading branch information
TFLM-bot authored Sep 15, 2023
1 parent 7d1dee8 commit 2f2c744
Show file tree
Hide file tree
Showing 7 changed files with 349 additions and 11 deletions.
1 change: 1 addition & 0 deletions tensorflow/lite/builtin_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ typedef enum {
kTfLiteBuiltinStablehloGather = 201,
kTfLiteBuiltinStablehloTranspose = 202,
kTfLiteBuiltinDilate = 203,
kTfLiteBuiltinStablehloRngBitGenerator = 204,
} TfLiteBuiltinOperator;

#ifdef __cplusplus
Expand Down
42 changes: 42 additions & 0 deletions tensorflow/lite/core/api/flatbuffer_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,18 @@ TfLiteMirrorPaddingMode ConvertMirrorPadding(MirrorPadMode padding) {
return kTfLiteMirrorPaddingUnknown;
}

TfLiteRngAlgorithm ConvertRngAlgorithm(RngAlgorithm algorithm) {
switch (algorithm) {
case RngAlgorithm_THREEFRY:
return kTfLiteRngAlgorithmThreefry;
case RngAlgorithm_PHILOX:
return kTfLiteRngAlgorithmPhilox;
case RngAlgorithm_DEFAULT:
return kTfLiteRngAlgorithmDefault;
}
return kTfLiteRngAlgorithmUnknown;
}

#ifndef TF_LITE_STATIC_MEMORY
TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
Expand Down Expand Up @@ -899,6 +911,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR: {
return ParseStablehloRngBitGenerator(op, error_reporter, allocator,
builtin_data);
}

// TODO: skip param parsing for now since ops below don't have kernels
case BuiltinOperator_STABLEHLO_SLICE:
Expand Down Expand Up @@ -2084,6 +2100,32 @@ TfLiteStatus ParseResizeNearestNeighbor(const Operator* op,
return kTfLiteOk;
}

TfLiteStatus ParseStablehloRngBitGenerator(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);

SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteStablehloRngBitGeneratorParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteStablehloRngBitGeneratorParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);

const StablehloRngBitGeneratorOptions* schema_params =
op->builtin_options_2_as_StablehloRngBitGeneratorOptions();
if (schema_params != nullptr) {
params->algorithm = ConvertRngAlgorithm(schema_params->algorithm());
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}

*builtin_data = params.release();
return kTfLiteOk;
}

// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/lite/core/api/flatbuffer_conversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ TfLiteStatus ParseRightShift(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);

TfLiteStatus ParseStablehloRngBitGenerator(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);

} // namespace tflite

#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
18 changes: 18 additions & 0 deletions tensorflow/lite/core/c/builtin_op_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,24 @@ typedef struct {
int update_computation_subgraph_index;
} TfLiteStablehloScatterParams;

typedef enum {
kTfLiteRngAlgorithmUnknown = 0,
// An algorithm auto-selected by the system according to device type.
kTfLiteRngAlgorithmDefault,
// The Philox algorithm, as described in paper
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
kTfLiteRngAlgorithmPhilox,
// The ThreeFry algorithm, as described in paper
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
kTfLiteRngAlgorithmThreefry,
} TfLiteRngAlgorithm;

typedef struct {
TfLiteRngAlgorithm algorithm;
} TfLiteStablehloRngBitGeneratorParams;

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
Expand Down
90 changes: 89 additions & 1 deletion tensorflow/lite/python/schema_py_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,7 @@ class BuiltinOperator(object):
STABLEHLO_GATHER = 201
STABLEHLO_TRANSPOSE = 202
DILATE = 203
STABLEHLO_RNG_BIT_GENERATOR = 204
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: tflite
Expand Down Expand Up @@ -2007,6 +2008,7 @@ class BuiltinOptions2(object):
StablehloGatherOptions = 16
StablehloTransposeOptions = 17
DilateOptions = 18
StablehloRngBitGeneratorOptions = 19

def BuiltinOptions2Creator(unionType, table):
from flatbuffers.table import Table
Expand Down Expand Up @@ -2048,6 +2050,8 @@ def BuiltinOptions2Creator(unionType, table):
return StablehloTransposeOptionsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == BuiltinOptions2().DilateOptions:
return DilateOptionsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == BuiltinOptions2().StablehloRngBitGeneratorOptions:
return StablehloRngBitGeneratorOptionsT.InitFromBuf(table.Bytes, table.Pos)
return None
# automatically generated by the FlatBuffers compiler, do not modify

Expand Down Expand Up @@ -7869,7 +7873,7 @@ def __init__(self):
self.largeCustomOptionsOffset = 0 # type: int
self.largeCustomOptionsSize = 0 # type: int
self.builtinOptions2Type = 0 # type: int
self.builtinOptions2 = None # type: Union[None, StablehloConcatenateOptionsT, StablehloBroadcastInDimOptionsT, StablehloSliceOptionsT, StablehloConvolutionOptionsT, StablehloCustomCallOptionsT, StablehloReduceOptionsT, StablehloScatterOptionsT, StablehloCompareOptionsT, StablehloDynamicSliceOptionsT, StablehloPadOptionsT, StablehloIotaOptionsT, StablehloDotGeneralOptionsT, StablehloReduceWindowOptionsT, StablehloSortOptionsT, StablehloWhileOptionsT, StablehloGatherOptionsT, StablehloTransposeOptionsT, DilateOptionsT]
self.builtinOptions2 = None # type: Union[None, StablehloConcatenateOptionsT, StablehloBroadcastInDimOptionsT, StablehloSliceOptionsT, StablehloConvolutionOptionsT, StablehloCustomCallOptionsT, StablehloReduceOptionsT, StablehloScatterOptionsT, StablehloCompareOptionsT, StablehloDynamicSliceOptionsT, StablehloPadOptionsT, StablehloIotaOptionsT, StablehloDotGeneralOptionsT, StablehloReduceWindowOptionsT, StablehloSortOptionsT, StablehloWhileOptionsT, StablehloGatherOptionsT, StablehloTransposeOptionsT, DilateOptionsT, StablehloRngBitGeneratorOptionsT]

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down Expand Up @@ -9969,6 +9973,14 @@ def Pack(self, builder):

# namespace: tflite

class RngAlgorithm(object):
DEFAULT = 0
PHILOX = 1
THREEFRY = 2
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: tflite

from flatbuffers.compat import import_numpy
np = import_numpy()

Expand Down Expand Up @@ -14529,6 +14541,82 @@ def Pack(self, builder):
from flatbuffers.compat import import_numpy
np = import_numpy()

class StablehloRngBitGeneratorOptions(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = StablehloRngBitGeneratorOptions()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsStablehloRngBitGeneratorOptions(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def StablehloRngBitGeneratorOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)

# StablehloRngBitGeneratorOptions
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# StablehloRngBitGeneratorOptions
def Algorithm(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0

def StablehloRngBitGeneratorOptionsStart(builder): builder.StartObject(1)
def Start(builder):
return StablehloRngBitGeneratorOptionsStart(builder)
def StablehloRngBitGeneratorOptionsAddAlgorithm(builder, algorithm): builder.PrependInt8Slot(0, algorithm, 0)
def AddAlgorithm(builder, algorithm):
return StablehloRngBitGeneratorOptionsAddAlgorithm(builder, algorithm)
def StablehloRngBitGeneratorOptionsEnd(builder): return builder.EndObject()
def End(builder):
return StablehloRngBitGeneratorOptionsEnd(builder)

class StablehloRngBitGeneratorOptionsT(object):

# StablehloRngBitGeneratorOptionsT
def __init__(self):
self.algorithm = 0 # type: int

@classmethod
def InitFromBuf(cls, buf, pos):
stablehloRngBitGeneratorOptions = StablehloRngBitGeneratorOptions()
stablehloRngBitGeneratorOptions.Init(buf, pos)
return cls.InitFromObj(stablehloRngBitGeneratorOptions)

@classmethod
def InitFromObj(cls, stablehloRngBitGeneratorOptions):
x = StablehloRngBitGeneratorOptionsT()
x._UnPack(stablehloRngBitGeneratorOptions)
return x

# StablehloRngBitGeneratorOptionsT
def _UnPack(self, stablehloRngBitGeneratorOptions):
if stablehloRngBitGeneratorOptions is None:
return
self.algorithm = stablehloRngBitGeneratorOptions.Algorithm()

# StablehloRngBitGeneratorOptionsT
def Pack(self, builder):
StablehloRngBitGeneratorOptionsStart(builder)
StablehloRngBitGeneratorOptionsAddAlgorithm(builder, self.algorithm)
stablehloRngBitGeneratorOptions = StablehloRngBitGeneratorOptionsEnd(builder)
return stablehloRngBitGeneratorOptions
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: tflite

from flatbuffers.compat import import_numpy
np = import_numpy()

class StablehloScatterOptions(object):
__slots__ = ['_tab']

Expand Down
19 changes: 19 additions & 0 deletions tensorflow/lite/schema/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ enum BuiltinOperator : int32 {
STABLEHLO_GATHER = 201, // WARNING: No runtime support
STABLEHLO_TRANSPOSE = 202, // WARNING: No runtime support
DILATE = 203,
STABLEHLO_RNG_BIT_GENERATOR = 204,
}
// LINT.ThenChange(nnapi_linter/linter.proto)

Expand Down Expand Up @@ -623,6 +624,7 @@ union BuiltinOptions2{
StablehloGatherOptions,
StablehloTransposeOptions,
DilateOptions,
StablehloRngBitGeneratorOptions,
}

table StablehloGatherOptions{
Expand Down Expand Up @@ -767,6 +769,23 @@ table StablehloScatterOptions {
update_computation_subgraph_index: int;
}

enum RngAlgorithm : byte {
// An algorithm auto-selected by the system according to device type.
DEFAULT = 0,
// The Philox algorithm, as described in paper
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
PHILOX = 1,
// The ThreeFry algorithm, as described in paper
// ['Parallel Random Numbers: As Easy as 1, 2, 3']
// (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
THREEFRY = 2,
}

table StablehloRngBitGeneratorOptions {
algorithm:RngAlgorithm;
}

// LINT.IfChange
enum Padding : byte { SAME, VALID }
// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td)
Expand Down
Loading

0 comments on commit 2f2c744

Please sign in to comment.