diff --git a/cellx/networks/classifier.py b/cellx/networks/classifier.py index da745d1..7079051 100644 --- a/cellx/networks/classifier.py +++ b/cellx/networks/classifier.py @@ -1 +1,13 @@ -# this will be the classifier +# Building the CNN classifier: + +import tensorflow as tf +from tensorflow import keras as K + +from ..layers import Encoder2D + +input = K.Input(shape=(2,)) +encoder = layers.Encoder2D()(input) +dense = K.layers.Dense(512, activation="relu")(encoder) +output = K.layers.Dense(4, activation=tf.nn.softmax)(dense) + +classifier_model = K.Model(inputs=input, outputs=output)