Skip to content

Commit

Permalink
Fix wrong operator "is" in string comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
moehanabi committed Nov 2, 2024
1 parent 74ad9c3 commit db55f72
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tensorlayerx/backend/ops/paddle_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,10 @@ class Conv2D(object):

def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None):
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
if self.data_format is 'NHWC':
if self.data_format == 'NHWC':
self._stride = (strides[1], strides[2])
self._dilation = (dilations[1], dilations[2])
elif self.data_format is 'NCHW':
elif self.data_format == 'NCHW':
self._stride = (strides[2], strides[3])
self._dilation = (dilations[2], dilations[3])

Expand Down Expand Up @@ -537,10 +537,10 @@ def conv2d(input, filters, strides, padding, data_format='NCHW', dilations=None)
A Tensor. Has the same type as input.
"""
data_format, padding = preprocess_2d_format(data_format, padding)
if data_format is 'NHWC':
if data_format == 'NHWC':
_stride = (strides[1], strides[2])
_dilation = (dilations[1], dilations[2])
elif data_format is 'NCHW':
elif data_format == 'NCHW':
_stride = (strides[2], strides[3])
_dilation = (dilations[2], dilations[3])
outputs = F.conv2d(
Expand All @@ -553,10 +553,10 @@ class Conv3D(object):

def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None):
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
if self.data_format is 'NDHWC':
if self.data_format == 'NDHWC':
self._strides = (strides[1], strides[2], strides[3])
self._dilations = (dilations[1], dilations[2], dilations[3])
elif self.data_format is 'NCDHW':
elif self.data_format == 'NCDHW':
self._strides = (strides[2], strides[3], strides[4])
self._dilations = (dilations[2], dilations[3], dilations[4])

Expand Down Expand Up @@ -603,10 +603,10 @@ def conv3d(input, filters, strides, padding, data_format='NDHWC', dilations=None
A Tensor. Has the same type as input.
"""
data_format, padding = preprocess_3d_format(data_format, padding)
if data_format is 'NDHWC':
if data_format == 'NDHWC':
_strides = (strides[1], strides[2], strides[3])
_dilations = (dilations[1], dilations[2], dilations[3])
elif data_format is 'NCDHW':
elif data_format == 'NCDHW':
_strides = (strides[2], strides[3], strides[4])
_dilations = (dilations[2], dilations[3], dilations[4])
outputs = F.conv3d(
Expand Down Expand Up @@ -1195,10 +1195,10 @@ def __init__(self, strides, padding, data_format, dilations, out_channel, k_size
self.k_size = k_size
self.groups = groups
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
if self.data_format is 'NHWC':
if self.data_format == 'NHWC':
self.strides = (strides[1], strides[2])
self.dilations = (dilations[1], dilations[2])
elif self.data_format is 'NCHW':
elif self.data_format == 'NCHW':
self.strides = (strides[2], strides[3])
self.dilations = (dilations[2], dilations[3])

Expand Down Expand Up @@ -1241,10 +1241,10 @@ def __init__(self, strides, padding, data_format, dilations, out_channel, k_size
self.in_channel = int(in_channel)
self.depth_multiplier = depth_multiplier
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
if self.data_format is 'NHWC':
if self.data_format == 'NHWC':
self.strides = (strides[1], strides[2])
self.dilations = (dilations[1], dilations[2])
elif self.data_format is 'NCHW':
elif self.data_format == 'NCHW':
self.strides = (strides[2], strides[3])
self.dilations = (dilations[2], dilations[3])

Expand Down

0 comments on commit db55f72

Please sign in to comment.