Animal Image Classification (TensorFlow & CNN)
"A complete end‑to‑end pipeline for building, cleaning, preprocessing, training, evaluating, and deploying a deep CNN model for multi‑class animal image classification."
This project is designed to be clean, organized, and human-friendly, showing the full machine‑learning workflow — from data validation to model evaluation & ROC curves.
Project Structure
| Component | Description |
|---|---|
| Data Loading | Reads and extracts the ZIP dataset from Google Drive |
| EDA | Class distribution, file integrity, image sizes, brightness, contrast, samples display |
| Preprocessing | Resizing, normalization, augmentation, hashing, cleaning corrupted files |
| Model | Deep custom CNN with BatchNorm, Dropout & L2 Regularization |
| Training | Adam optimizer, LR scheduler, Early stopping |
| Evaluation | Confusion matrix, classification report, ROC curves |
| Export | Saves final .h5 model |
How to Run
1. Upload your dataset to Google Drive
Your dataset must be structured as:
Animals/
├── Cats/
├── Dogs/
├── Snakes/
2. Update the ZIP path
zip_path = '/content/drive/MyDrive/Animals.zip'
extract_to = '/content/my_data'
3. Run the Notebook
Once executed, the script will:
- Mount Google Drive
- Extract images
- Build a DataFrame of paths
- Run EDA checks
- Clean and prepare images
- Train the CNN model
- Export results
Data Preparation & EDA
This project performs deep dataset validation including:
Class Distribution
class_count = df['class'].value_counts()
class_count.plot(kind='bar')
Image Size Properties
image_df['Channels'].value_counts().plot(kind='bar')
Duplicate Image Detection
Using MD5 hashing:
def get_hash(file_path):
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
Brightness & Contrast Issues
stat = ImageStat.Stat(img.convert("L"))
brightness = stat.mean[0]
contrast = stat.stddev[0]
Auto‑fixing poor images
Brightness/contrast enhanced using:
img = ImageEnhance.Brightness(img).enhance(1.5)
img = ImageEnhance.Contrast(img).enhance(1.5)
Image Preprocessing
All images are resized to 256×256 and normalized to [0,1].
def preprocess_image(path, target_size=(256, 256)):
img = tf.io.read_file(path)
img = tf.image.decode_image(img, channels=3)
img = tf.image.resize(img, target_size)
return tf.cast(img, tf.float32) / 255.0
Data Augmentation
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.1),
tf.keras.layers.RandomZoom(0.1),
])
CNN Model Architecture
Below is a simplified view of the model:
Conv2D (32) → BatchNorm → Conv2D (32) → BatchNorm → MaxPool → Dropout
Conv2D (64) → BatchNorm → Conv2D (64) → BatchNorm → MaxPool → Dropout
Conv2D (128) → BatchNorm → Conv2D (128) → BatchNorm → MaxPool → Dropout
Conv2D (256) → BatchNorm → Conv2D (256) → BatchNorm → MaxPool → Dropout
Flatten → Dense (softmax)
Example code:
model.add(Conv2D(32, (3,3), activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling2D((2,2)))
Training
epochs = 50
optimizer = Adam(learning_rate=0.0005)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
Callbacks
| Callback | Purpose |
|---|---|
| ReduceLROnPlateau | Auto‑reduce LR when val_loss stagnates |
| EarlyStopping | Stop training when no improvement |
Model Evaluation
Accuracy
test_loss, test_accuracy = model.evaluate(test_ds)
Classification Report
print(classification_report(y_true, y_pred, target_names=le.classes_))
Confusion Matrix
sns.heatmap(cm, annot=True, cmap='Blues')
ROC Curve (One-vs-Rest)
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
Saving the Model
model.save("Animal_Classification.h5")
Full Code Organization (High-Level)
| Step | Description |
|---|---|
| 1 | Import libraries, mount drive |
| 2 | Extract ZIP |
| 3 | Build DataFrame |
| 4 | EDA & cleaning |
| 5 | Preprocessing & augmentation |
| 6 | Dataset pipeline (train/val/test) |
| 7 | CNN architecture |
| 8 | Training |
| 9 | Evaluation |
| 10 | Save model |
Final Notes
This README is crafted to feel human, clean, and attractive — not autogenerated. It can be directly used in any GitHub repository.
If you want, I can also:
- Generate a short version
- Add badges (TensorFlow, Python, etc.)
- Write an installation section
- Turn it into a Hugging Face Space README
Animal Image Classification – Complete Pipeline (README)
"A clean dataset is half the model’s accuracy. The rest is just engineering."
This project presents a complete end-to-end deep learning pipeline for multi-class animal image classification using TensorFlow/Keras. It includes everything from data extraction, cleaning, and analysis, to model training, evaluation, and exporting.
Table of Contents
| Section | Description |
|---|---|
| 1. Project Overview | What this project does & architecture overview |
| 2. Features | Key capabilities of this pipeline |
| 3. Directory Structure | Recommended project layout |
| 4. Installation | How to install and run this project |
| 5. Dataset Processing | Extraction, cleaning, inspections |
| 6. Exploratory Data Analysis | Visualizations & summary statistics |
| 7. Preprocessing & Augmentation | Data preparation logic |
| 8. CNN Model Architecture | Layers, blocks, hyperparameters |
| 9. Training & Callbacks | How the model is trained |
| 10. Evaluation Metrics | Reports, ROC curve, confusion matrix |
| 11. Model Export | Saving and downloading the model |
| 12. Code Examples | Important snippets explained |
1. Project Overview
This project builds a Convolutional Neural Network (CNN) to classify images of animals into multiple categories. The process includes:
- Dataset extraction from Google Drive
- Data validation (duplicates, corrupt files, mislabeled images)
- Image enhancement & cleaning
- Class distribution analysis
- Image size analysis and outlier detection
- Data augmentation
- CNN model training with regularization
- Performance evaluation using multiple metrics
- Model export to
.h5
The pipeline is designed to be robust, explainable, and production-friendly.
2. Features
| Feature | Description |
|---|---|
| Automatic Dataset Extraction | Unzips and loads images from Google Drive |
| Image Validation | Detects duplicates, corrupted images, and mislabeled files |
| Data Cleaning | Brightness/contrast enhancements for dark or overexposed samples |
| EDA Visualizations | Class distribution, size, color modes, outliers |
| TensorFlow Dataset Pipeline | Efficient TFRecords-like batching & prefetching |
| Deep CNN Model | 32 → 64 → 128 → 256 filters with batch norm and dropout |
| Model Evaluation Dashboard | Confusion matrix, ROC curves, precision/recall/F1 |
| Model Export | Saves final model as Animal_Classification.h5 |
3. Recommended Directory Structure
Animal-Classification
┣ data
┃ ┗ Animals (extracted folders)
┣ notebooks
┣ src
┃ ┣ preprocessing.py
┃ ┣ model.py
┃ ┗ utils.py
┣ README.md
┗ requirements.txt
4. Installation
pip install tensorflow pandas matplotlib seaborn scikit-learn pillow tqdm
If using Google Colab, the project already supports:
google.colab.drivegoogle.colab.files
5. Dataset Extraction & Loading
Example snippet:
zip_path = '/content/drive/MyDrive/Animals.zip'
extract_to = '/content/my_data'
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
Images are collected into a DataFrame:
paths = [(path.parts[-2], path.name, str(path)) for path in Path(extract_to).rglob('*.*')]
df = pd.DataFrame(paths, columns=['class','image','full_path'])
6. Exploratory Data Analysis
Examples of generated visualizations:
- Barplot of class distribution
- Pie chart of percentage per class
- Scatter plots of image width and height
- Image mode (RGB/Gray) distribution
Example:
plt.figure(figsize=(32,16))
class_count.plot(kind='bar')
7. Preprocessing & Augmentation
Preprocessing function
def preprocess_image(path, target_size=(256,256)):
img = tf.io.read_file(path)
img = tf.image.decode_image(img, channels=3)
img = tf.image.resize(img, target_size)
return tf.cast(img, tf.float32)/255.0
Augmentation
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.1),
tf.keras.layers.RandomZoom(0.1),
])
8. CNN Model Architecture
| Block | Layers |
|---|---|
| Block 1 | Conv2D(32) → BN → Conv2D(32) → BN → MaxPool → Dropout(0.2) |
| Block 2 | Conv2D(64) → BN → Conv2D(64) → BN → MaxPool → Dropout(0.3) |
| Block 3 | Conv2D(128) → BN → Conv2D(128) → BN → MaxPool → Dropout(0.4) |
| Block 4 | Conv2D(256) → BN → Conv2D(256) → BN → MaxPool → Dropout(0.5) |
| Output | Flatten → Dense(num_classes, softmax) |
Example snippet:
model.add(Conv2D(64,(3,3),activation='relu',padding='same'))
model.add(BatchNormalization())
9. Training
optimizer = Adam(learning_rate=0.0005)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy', metrics=['accuracy'])
Using callbacks:
ReduceLROnPlateau(...)
EarlyStopping(...)
10. Evaluation Metrics
This project computes:
- Precision, Recall, F1 (macro & per class)
- Confusion matrix (heatmap)
- ROC curves (one-vs-rest)
- Macro-average ROC
Example:
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True)
11. Model Export
model.save("Animal_Classification.h5")
files.download("Animal_Classification.h5")
12. Example Snippets
Checking corrupted files
try:
with Image.open(path) as img:
img.verify()
except:
corrupted.append(path)
Filtering duplicate images
df['file_hash'] = df['full_path'].apply(get_hash)
df_unique = df.drop_duplicates(subset='file_hash')