Nah_kagz1092 commited on
Commit
e2692aa
·
verified ·
1 Parent(s): 10f2732

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -6
model.py CHANGED
@@ -3,14 +3,10 @@ import tensorflow as tf
3
  class TransformerModel(tf.keras.Model):
4
  def __init__(self, config):
5
  super(TransformerModel, self).__init__()
6
- self.encoder = tf.keras.layers.TransformerEncoder(config["encoder_layers"])
7
- self.decoder = tf.keras.layers.TransformerDecoder(config["decoder_layers"])
8
 
9
  def call(self, inputs, targets):
10
  encoder_output = self.encoder(inputs)
11
  decoder_output = self.decoder(targets, encoder_output)
12
  return decoder_output
13
-
14
- def predict(self, input_data):
15
- # Code để dự đoán đầu ra dựa trên đầu vào
16
- pass
 
3
  class TransformerModel(tf.keras.Model):
4
  def __init__(self, config):
5
  super(TransformerModel, self).__init__()
6
+ self.encoder = tf.keras.layers.Transformer(**config["encoder_params"])
7
+ self.decoder = tf.keras.layers.Transformer(**config["decoder_params"])
8
 
9
  def call(self, inputs, targets):
10
  encoder_output = self.encoder(inputs)
11
  decoder_output = self.decoder(targets, encoder_output)
12
  return decoder_output