Nah_kagz1092
commited on
Update model.py
Browse files
model.py
CHANGED
|
@@ -3,14 +3,10 @@ import tensorflow as tf
|
|
| 3 |
class TransformerModel(tf.keras.Model):
|
| 4 |
def __init__(self, config):
|
| 5 |
super(TransformerModel, self).__init__()
|
| 6 |
-
self.encoder = tf.keras.layers.
|
| 7 |
-
self.decoder = tf.keras.layers.
|
| 8 |
|
| 9 |
def call(self, inputs, targets):
|
| 10 |
encoder_output = self.encoder(inputs)
|
| 11 |
decoder_output = self.decoder(targets, encoder_output)
|
| 12 |
return decoder_output
|
| 13 |
-
|
| 14 |
-
def predict(self, input_data):
|
| 15 |
-
# Code để dự đoán đầu ra dựa trên đầu vào
|
| 16 |
-
pass
|
|
|
|
| 3 |
class TransformerModel(tf.keras.Model):
|
| 4 |
def __init__(self, config):
|
| 5 |
super(TransformerModel, self).__init__()
|
| 6 |
+
self.encoder = tf.keras.layers.Transformer(**config["encoder_params"])
|
| 7 |
+
self.decoder = tf.keras.layers.Transformer(**config["decoder_params"])
|
| 8 |
|
| 9 |
def call(self, inputs, targets):
|
| 10 |
encoder_output = self.encoder(inputs)
|
| 11 |
decoder_output = self.decoder(targets, encoder_output)
|
| 12 |
return decoder_output
|
|
|
|
|
|
|
|
|
|
|
|