| | import os |
| | import sys |
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | from utils import suppress_warnings |
| |
|
| | import numpy as np |
| | import tensorflow as tf |
| | import torch |
| | import torch.nn as nn |
| | import logging |
| | from pathlib import Path |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class ModelLoader: |
| | def __init__(self, weights_dir="models/weights"): |
| | """Initialize model loader with paths to weight files""" |
| | self.weights_dir = Path(weights_dir) |
| | |
| | |
| | self.STGCN_WEIGHTS = self.weights_dir / "best_stgcn.weights.h5" |
| | self.TRANSFORMER_MODEL_PATH = self.weights_dir / "Transformer_12rel_4_bs16_sl32.keras" |
| | self.ANGLE_MODEL_PATH = self.weights_dir / "Transformer_12rel_4_angle3_branch_bs16_sl32.keras" |
| | self.SWIN3D_WEIGHTS = self.weights_dir / "best_swin3d_b_22k.pth" |
| | |
| | |
| | self.SEQ_LEN = 32 |
| | self.NUM_CLASSES = 22 |
| | self.ACTIONS = [ |
| | "barbell biceps curl","lateral raise","push-up","bench press", |
| | "chest fly machine","deadlift","decline bench press","hammer curl", |
| | "hip thrust","incline bench press","lat pulldown","leg extension", |
| | "leg raises","plank","pull Up","romanian deadlift","russian twist", |
| | "shoulder press","squat","t bar row","tricep Pushdown","tricep dips" |
| | ] |
| | |
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | |
| | self.models = {} |
| | self._load_all_models() |
| | |
| | def _build_stgcn_model(self): |
| | """Build ST-GCN model architecture""" |
| | from tensorflow.keras import layers, Model, regularizers |
| | |
| | |
| | class STGCNBlock(layers.Layer): |
| | def __init__(self, in_ch, out_ch, A_norm, stride=1, dropout=0.3, **kwargs): |
| | super().__init__(**kwargs) |
| | self.A_const = tf.constant(A_norm.astype(np.float32)) |
| | self.B = self.add_weight(shape=A_norm.shape, initializer="zeros", trainable=True) |
| | self.conv1x1 = layers.Conv2D(out_ch, 1, use_bias=False, |
| | kernel_regularizer=regularizers.l2(1e-5)) |
| | self.branches = [ |
| | layers.Conv2D(out_ch, (k,1), strides=(stride,1), padding="same", |
| | use_bias=False, kernel_regularizer=regularizers.l2(1e-5)) |
| | for k in (3,5,9) |
| | ] |
| | self.bn = layers.BatchNormalization() |
| | if in_ch == out_ch and stride == 1: |
| | self.res = lambda x, training: x |
| | else: |
| | self.res = tf.keras.Sequential([ |
| | layers.Conv2D(out_ch,1,strides=(stride,1),padding="same", |
| | use_bias=False, kernel_regularizer=regularizers.l2(1e-5)), |
| | layers.BatchNormalization() |
| | ]) |
| | self.act = layers.Activation("relu") |
| | self.drop = layers.SpatialDropout2D(dropout) |
| |
|
| | def call(self, x, training=False): |
| | A_mat = self.A_const + tf.nn.softmax(self.B, axis=1) |
| | x_sp = tf.einsum('ij,btjk->btik', A_mat, x) |
| | x_sp = self.conv1x1(x_sp) |
| | out = sum(branch(x_sp) for branch in self.branches) / len(self.branches) |
| | out = self.bn(out, training=training) |
| | r = self.res(x, training=training) if callable(self.res) else self.res(x) |
| | y = self.act(out + r) |
| | return self.drop(y, training=training) |
| | |
| | |
| | V = 31 |
| | connections = [ |
| | (0,1),(1,2),(2,3),(3,7),(0,4),(4,5),(5,6),(6,8),(9,10), |
| | (11,12),(11,13),(13,15),(15,17),(17,19),(19,21), |
| | (12,14),(14,16),(16,18),(18,20),(20,22), |
| | (11,23),(12,24),(23,24),(23,25),(25,27),(27,29),(27,31), |
| | (24,26),(26,28),(28,30) |
| | ] |
| | |
| | A = np.zeros((V, V), dtype=np.float32) |
| | for u, v in connections: |
| | if u < V and v < V: |
| | A[u, v] = A[v, u] = 1.0 |
| | |
| | |
| | D = A.sum(axis=1) |
| | D_inv = np.diag(1.0 / np.sqrt(D + 1e-6)) |
| | A_norm = D_inv @ A @ D_inv |
| | |
| | |
| | C = 4 |
| | inp = layers.Input((self.SEQ_LEN, V, C)) |
| | x = STGCNBlock(C, 64, A_norm, stride=1)(inp) |
| | x = STGCNBlock(64, 64, A_norm, stride=2)(x) |
| | x = STGCNBlock(64, 128, A_norm, stride=2)(x) |
| | x = STGCNBlock(128,256, A_norm, stride=2)(x) |
| | x = layers.GlobalAveragePooling2D()(x) |
| | out = layers.Dense(self.NUM_CLASSES, activation="softmax", |
| | kernel_regularizer=regularizers.l2(1e-5))(x) |
| | |
| | return Model(inp, out) |
| | |
| | def _build_swin3d_model(self): |
| | """Build Swin3D model""" |
| | from torchvision.models.video import swin3d_b, Swin3D_B_Weights |
| | |
| | model = swin3d_b(weights=Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1) |
| | model.head = nn.Linear(model.head.in_features, self.NUM_CLASSES) |
| | model = model.to(self.device) |
| | |
| | return model |
| | |
| | @tf.keras.utils.register_keras_serializable() |
| | class PositionalEncoding(tf.keras.layers.Layer): |
| | """Positional encoding for Transformer models""" |
| | def __init__(self, maxlen, dm, **kwargs): |
| | super().__init__(**kwargs) |
| | pos = np.arange(maxlen)[:, None] |
| | i = np.arange(dm)[None, :] |
| | angle = pos / np.power(10000, (2*(i//2))/dm) |
| | pe = np.zeros((maxlen, dm), dtype=np.float32) |
| | pe[:,0::2] = np.sin(angle[:,0::2]) |
| | pe[:,1::2] = np.cos(angle[:,1::2]) |
| | self.pe = tf.constant(pe[None,...]) |
| | |
| | def call(self, x): |
| | return x + self.pe[:, :tf.shape(x)[1],:] |
| | |
| | def get_config(self): |
| | cfg = super().get_config() |
| | cfg.update({"maxlen": int(self.pe.shape[1]), "dm": int(self.pe.shape[2])}) |
| | return cfg |
| | |
| | def _load_all_models(self): |
| | """Load all four models""" |
| | logger.info("Loading all models...") |
| | |
| | try: |
| | |
| | logger.info("Loading ST-GCN...") |
| | if self.STGCN_WEIGHTS.exists(): |
| | model_stgcn = self._build_stgcn_model() |
| | try: |
| | model_stgcn.load_weights(str(self.STGCN_WEIGHTS)) |
| | logger.info("ST-GCN weights loaded successfully.") |
| | except Exception as e: |
| | logger.warning(f"ST-GCN weight load error: {e}. Using skip_mismatch.") |
| | model_stgcn.load_weights(str(self.STGCN_WEIGHTS), skip_mismatch=True) |
| | self.models['stgcn'] = model_stgcn |
| | else: |
| | logger.warning(f"ST-GCN weights not found: {self.STGCN_WEIGHTS}") |
| | self.models['stgcn'] = None |
| | |
| | |
| | logger.info("Loading Transformer 12rel...") |
| | if self.TRANSFORMER_MODEL_PATH.exists(): |
| | model_transformer = tf.keras.models.load_model( |
| | str(self.TRANSFORMER_MODEL_PATH), |
| | custom_objects={'PositionalEncoding': self.PositionalEncoding} |
| | ) |
| | self.models['transformer_12rel'] = model_transformer |
| | logger.info("Transformer 12rel loaded successfully.") |
| | else: |
| | logger.warning(f"Transformer 12rel not found: {self.TRANSFORMER_MODEL_PATH}") |
| | self.models['transformer_12rel'] = None |
| | |
| | |
| | logger.info("Loading Transformer angle branch...") |
| | if self.ANGLE_MODEL_PATH.exists(): |
| | model_angle = tf.keras.models.load_model(str(self.ANGLE_MODEL_PATH)) |
| | self.models['transformer_angle'] = model_angle |
| | logger.info("Transformer angle branch loaded successfully.") |
| | else: |
| | logger.warning(f"Transformer angle not found: {self.ANGLE_MODEL_PATH}") |
| | self.models['transformer_angle'] = None |
| | |
| | |
| | logger.info("Loading Swin3D...") |
| | if self.SWIN3D_WEIGHTS.exists(): |
| | model_swin3d = self._build_swin3d_model() |
| | state = torch.load(str(self.SWIN3D_WEIGHTS), map_location=self.device) |
| | model_swin3d.load_state_dict(state) |
| | model_swin3d.eval() |
| | self.models['swin3d'] = model_swin3d |
| | logger.info("Swin3D loaded successfully.") |
| | else: |
| | logger.warning(f"Swin3D weights not found: {self.SWIN3D_WEIGHTS}") |
| | self.models['swin3d'] = None |
| | |
| | |
| | loaded_models = [name for name, model in self.models.items() if model is not None] |
| | if not loaded_models: |
| | logger.warning("No models could be loaded. App will run in demo mode with mock predictions.") |
| | |
| | else: |
| | logger.info(f"Successfully loaded models: {loaded_models}") |
| | |
| | except Exception as e: |
| | logger.error(f"Error loading models: {str(e)}") |
| | raise |
| | |
| | def get_model(self, model_name): |
| | """Get a specific model by name""" |
| | return self.models.get(model_name) |
| | |
| | def get_available_models(self): |
| | """Get list of available (loaded) models""" |
| | return [name for name, model in self.models.items() if model is not None] |
| | |
| | def predict_stgcn(self, X): |
| | """Predict using ST-GCN model""" |
| | if self.models['stgcn'] is None: |
| | return None |
| | return self.models['stgcn'].predict(X, batch_size=32, verbose=0) |
| | |
| | def predict_transformer_12rel(self, X): |
| | """Predict using Transformer 12rel model""" |
| | if self.models['transformer_12rel'] is None: |
| | return None |
| | return self.models['transformer_12rel'].predict(X, batch_size=32, verbose=0) |
| | |
| | def predict_transformer_angle(self, X_rel, X_ang): |
| | """Predict using Transformer angle branch model""" |
| | if self.models['transformer_angle'] is None: |
| | return None |
| | return self.models['transformer_angle'].predict([X_rel, X_ang], batch_size=32, verbose=0) |
| | |
| | def predict_swin3d(self, X): |
| | """Predict using Swin3D model""" |
| | if self.models['swin3d'] is None: |
| | return None |
| | |
| | with torch.no_grad(): |
| | probas = [] |
| | for x in X: |
| | x_batch = x.unsqueeze(0).to(self.device) |
| | logits = self.models['swin3d'](x_batch) |
| | proba = torch.softmax(logits, dim=1).cpu().numpy() |
| | probas.append(proba) |
| | |
| | return np.vstack(probas) |
| | |
| | def cleanup(self): |
| | """Clean up model resources""" |
| | logger.info("Cleaning up model resources...") |
| | for name, model in self.models.items(): |
| | if model is not None: |
| | if name == 'swin3d': |
| | |
| | del model |
| | torch.cuda.empty_cache() |
| | else: |
| | |
| | del model |
| | |
| | |
| | tf.keras.backend.clear_session() |