Description
import os
os.environ['KERAS_BACKEND'] = 'torch'
os.environ['OPS_KERNAL'] = '1'
import keras
keras.config.set_floatx('bfloat16')
from keras import ops
import numpy as np
initial_dim = 2048
finally_dim = 64
z = ops.convert_to_tensor(np.random.random([1,36,initial_dim]))
dense = keras.layers.Dense(finally_dim)
z1 = dense(z)
z2 = dense(z[:,:8])
print(ops.isclose(z1[:,:8],z2).all())
Example code is as above. In some cases, when the above z1 and z2 are found to not pass isclose, theoretically, and under fp32, they should be able to pass isclose in any situation. What is the problem, and how can it be solved?
This bug also be found at tf and jax backend,but not found at numpy backend
pass case:initial_dim = 2048 finally_dim =2048 ;initial_dim = 2048 finally_dim =4096 ;initial_dim = 1024 finally_dim =2048 ;
fail case: initial_dim = 2048 finally_dim =64;initial_dim = 2048 finally_dim =1024 ;initial_dim = 1024 finally_dim =2047 ;
However, similarly, we did not find a similar issue in pure torch.
import torch
import numpy as np
initial_dim = 4096
finally_dim = 32
z = torch.tensor(np.random.random([1,36,initial_dim]),dtype=torch.bfloat16)
linear = torch.nn.Linear(initial_dim,finally_dim).bfloat16()
z1 = linear(z)
z2 = linear(z[:,:8])
print(torch.isclose(z1[:,:8],z2).all())