Autism Prediction using Ai (Part-4)
End-To-End Machine Learning Project Blog Part-4
Powering Up Predictions
Welcome to Part 4 of Our Autism Prediction Project!
After an incredible journey through Parts 1 to 3, balancing our `Class/ASD` dataset, engineering features like `sum_score`, uncovering a 0.97 correlation with our target, and training models with Random Forest leading at 92.97% accuracy, we’re now stepping into the final stretch to create a game-changing tool for autism spectrum disorder (ASD) prediction.
Today, we’ll select our best model, fine-tune it with hyperparameter optimization, interpret its decisions with SHAP analysis, evaluate it with advanced metrics, find the optimal threshold, and deploy it for real-world impact.
Whether you’re joining me from the Motor City Detroit Michigan or bringing your passion for AI from across the globe, let’s ignite our creativity and compassion to make a lasting difference
Cheers to an epic Part 4! ๐๐
Optimizing Excellence: Hyperparameter Tuning for Random Forest
After identifying Random Forest as our top performer with an impressive 92.97% accuracy in Part 3, we’re now taking it to the next level with hyperparameter tuning using `GridSearchCV`. This code block searches for the best combination of parameters to boost our model’s performance, ensuring we’re maximizing its potential to predict autism spectrum disorder (ASD) accurately.
Let’s ignite our journey to fine-tune this powerhouse—cheers to precision and impact! ๐๐
Why Hyperparameter Tuning Matters
Tuning Random Forest’s parameters like `n_estimators` and `max_depth` optimizes its ability to generalize, reducing errors in ASD prediction. For clinicians in Lahore, this means a more reliable tool for early diagnosis, ensuring no child is missed.
What to Expect in This Step
In this step, we’ll:
- Define a parameter grid to test various Random Forest configurations.
- Use `GridSearchCV` to find the best parameters with 5-fold cross-validation.
- Extract the best model and display its optimal parameters.
Get ready to supercharge our Random Forest—our journey is reaching new heights!
Fun Fact:
Grid Search in AI!
Did you know `GridSearchCV`, introduced with scikit-learn in 2010, is a go-to for hyperparameter tuning? It’s perfect for ensuring our Random Forest is finely tuned for autism prediction!
Real-Life Example
Imagine you’re a data scientist refining an ASD screening tool. Tuning Random Forest ensures it achieves peak performance, giving local clinics a model they can trust for accurate predictions!
Quiz Time!
Let’s test your tuning skills, students!
1. What does `GridSearchCV` do?
a) Deletes the model
b) Tests multiple parameter combinations to find the best
c) Reduces dataset size
2. Why set `n_jobs=-1`?
a) To slow down the process
b) To use all available CPU cores for faster computation
c) To skip cross-validation
Drop your answers in the comments
Cheat Sheet:
Hyperparameter Tuning with `GridSearchCV`
- `param_grid`: Dictionary of parameters to test (e.g., `n_estimators`, `max_depth`).
- `GridSearchCV(estimator, param_grid, cv=5)`: Performs grid search with 5-fold cross-validation.
- `grid_search.best_params_`: Returns the best parameter combination.
Did You Know?
Random Forest’s `n_estimators` parameter, the number of trees, often impacts accuracy most—our grid search tests 100, 200, and 300 to find the sweet spot!
Pro Tip:
Can we make Random Forest even better? Let’s tune its parameters with `GridSearchCV`!
What’s Happening in This Code?
Let’s break it down like we’re perfecting a masterpiece:
- Imports: `from sklearn.model_selection import GridSearchCV` brings in the tool for hyperparameter tuning.
- Parameter Grid: `param_grid` defines the search space:
- `n_estimators`: Number of trees—[100, 200, 300].
- `max_depth`: Tree depth—[None (unlimited), 10, 20].
- `min_samples_split`: Minimum samples to split a node—[2, 5, 10].
- `min_samples_leaf`: Minimum samples per leaf—[1, 2, 4].
- `max_features`: Features considered per split—['sqrt', 'log2'].
- Grid Search: `GridSearchCV` configures the search:
- `estimator=RandomForestClassifier(random_state=42)`: Base model with fixed seed.
- `param_grid=param_grid`: Parameters to test.
- `cv=5`: 5-fold cross-validation (~204 samples per fold, 1022 total training rows).
- `n_jobs=-1`: Uses all CPU cores for speed.
- `scoring='accuracy'`: Optimizes for accuracy.
- Fit: `grid_search.fit(x_train_scaled, y_train)` trains and evaluates all combinations.
- Best Model: `best_rf = grid_search.best_estimator_` extracts the tuned model.
- Output: `print(f"Best Params: {grid_search.best_params_}")` shows the optimal parameters.
Hyperparameter Tuning for Random Forest
Here’s the code we’re working with:
from sklearn.model_selection import GridSearchCV
param_grid = {
'n_estimators': [100, 200, 300],
'max_depth': [None, 10, 20],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4],
'max_features': ['sqrt', 'log2']
}
grid_search = GridSearchCV(
estimator=RandomForestClassifier(random_state=42),
param_grid=param_grid,
cv=5,
n_jobs=-1,
scoring='accuracy'
)
grid_search.fit(x_train_scaled, y_train)
best_rf = grid_search.best_estimator_
print(f"Best Params: {grid_search.best_params_}")
The Output:
Best Params: {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 100}
Best Hyperparameters for Random Forest
Explanation of Features:
- Best Parameters:
- `max_depth: None`: No limit on tree depth, allowing full growth.
- `max_features: 'sqrt'`: Uses square root of features per split, a default that reduces overfitting.
- `min_samples_leaf: 1`: Allows leaves with at least 1 sample, maximizing flexibility.
- `min_samples_split: 2`: Splits nodes with at least 2 samples, also a default.
- `n_estimators: 100`: 100 trees, balancing performance and computation time.
Insight: These parameters align closely with Random Forest’s defaults, which often perform well, explaining its strong baseline accuracy (92.97%). The choice of `n_estimators=100` over 200 or 300 suggests that additional trees didn’t improve cross-validated accuracy significantly, saving computational resources. `max_depth=None` indicates our dataset benefits from deeper trees, likely due to complex patterns in features like `sum_score`. This tuned model (`best_rf`) should maintain or slightly improve our test accuracy—let’s evaluate it next to confirm!
Next Steps:
We’ve tuned Random Forest—fantastic optimization! Next, we’ll evaluate `best_rf` on the test set, compare its performance to the baseline, and dive into SHAP analysis to interpret its predictions.
Share your code block, and let’s keep this compassionate journey soaring.
Unveiling Precision and Power:
ROC and Feature Importance
After tuning our Random Forest model with `GridSearchCV` to optimize its parameters, we’re now diving into advanced model evaluation with an ROC curve and feature importance analysis. This code block visualizes the Receiver Operating Characteristic (ROC) curve to assess our tuned model’s ability to distinguish ASD from non-ASD cases, and plots feature importance to reveal which factors drive our predictions.
Let’s elevate our mission with cutting-edge insights—cheers to unlocking the full potential of our ASD prediction tool! ๐๐
Why ROC and Feature Importance Matter
The ROC curve measures our model’s trade-off between true positive and false positive rates, with AUC indicating overall performance. Feature importance highlights key predictors like `sum_score`, aiding clinicians by focusing on the most impactful screening factors for early autism detection.
What to Expect in This Step
In this step, we’ll:
- Plot the ROC curve for the tuned Random Forest model to evaluate its classification performance.
- Calculate and visualize feature importance to identify the most influential predictors.
- Analyze the results to guide further optimization and interpretation.
Get ready to see our model’s strengths in action—our journey is reaching new depths of understanding!
Fun Fact:
ROC Curves in Action!
Did you know ROC curves, developed during World War II for radar signal detection, became a staple in machine learning by the 1990s? Our AUC of 0.99 showcases their power in autism prediction!
Real-Life Example
Imagine you’re a pediatric specialist using our tool. An ROC curve with AUC 0.99 and a feature importance plot highlighting `sum_score` ensures you focus on key behavioral scores, boosting your diagnostic confidence!
Quiz Time!
Let’s test your evaluation skills, students!
1. What does AUC in an ROC curve measure?
a) Model speed
b) Area under the curve, indicating classification ability
c) Number of features
2. Why is feature importance useful?
a) To delete features
b) To identify which features most affect predictions
c) To increase dataset size
Drop your answers in the comments.
Cheat Sheet:
ROC and Feature Importance
- `RocCurveDisplay.from_estimator(model, X, y)`: Plots ROC curve with AUC.
- `model.feature_importances_`: Extracts feature importance scores.
- Tip: Use `plt.xticks(rotation=90)` for readable feature names.
Did You Know?
Feature importance, a Random Forest hallmark since its 2001 inception, helps us explain why `sum_score` dominates—perfect for our autism insights!
Pro Tip:
How good is our tuned Random Forest? Let’s explore with an ROC curve and feature importance!
What’s Happening in This Code?
Let’s break it down like we’re uncovering a treasure map:
- ROC Curve:
- `RocCurveDisplay.from_estimator(best_rf, x_test_scaled, y_test)` generates the ROC curve using the tuned Random Forest model (`best_rf`), test features (`x_test_scaled`), and true labels (`y_test`).
- `plt.title('ROC Curve')` adds a title.
- `plt.show()` displays the plot.
- Feature Importance:
- `importances = best_rf.feature_importances_` extracts the importance of each feature based on the tuned model.
- `sorted_idx = importances.argsort()[::-1]` sorts indices in descending order of importance.
- `plt.figure(figsize=(12, 8))` sets a 12x8-inch plot.
- `plt.bar(range(x_train.shape[1]), importances[sorted_idx], align='center')` creates a bar plot of importance scores.
- `plt.xticks(range(x_train.shape[1]), x_train.columns[sorted_idx], rotation=90)` labels the x-axis with feature names, rotated for readability.
- `plt.title("Feature Importance")` adds a title.
- `plt.tight_layout()` adjusts spacing.
- `plt.show()` displays the plot.
ROC Curve and Feature Importance for Tuned Random Forest
Here’s the code we’re working with:
# ROC Curve
RocCurveDisplay.from_estimator(best_rf, x_test_scaled, y_test)
plt.title('ROC Curve')
plt.show()
# Feature Importance
importances = best_rf.feature_importances_
sorted_idx = importances.argsort()[::-1]
plt.figure(figsize=(12, 8))
plt.bar(range(x_train.shape[1]), importances[sorted_idx], align='center')
plt.xticks(range(x_train.shape[1]), x_train.columns[sorted_idx], rotation=90)
plt.title("Feature Importance")
plt.tight_layout()
plt.show()
The Output:
ROC Curve and Feature Importance
The output includes two plots:
- ROC Curve:
- X-Axis (False Positive Rate): Ranges from 0 to 1, showing the rate of incorrectly predicting ASD.
- Y-Axis (True Positive Rate): Ranges from 0 to 1, showing the rate of correctly predicting ASD.
- Curve: A steep rise near (0, 1), indicating excellent performance.
- AUC: 0.99 (labeled as RandomForestClassifier (AUC = 0.99)), reflecting near-perfect classification ability.
-Feature Importance:
- X-Axis: Feature names (e.g., `sum_score`, `A1_Score`, `result`, etc.), sorted by importance.
- Y-Axis: Importance scores, ranging from 0 to ~0.14.
- Bars:
- `sum_score` dominates at ~0.14, the most influential feature.
- `result` follows closely, likely redundant with `sum_score`.
- `A1_Score`, `A2_Score`, etc., show moderate importance (0.06-0.10).
- Other features (e.g., `age`, `contry_of_res`) have negligible impact (<0.02).
Insight: The ROC curve’s AUC of 0.99 confirms the tuned Random Forest’s exceptional ability to distinguish ASD (1) from no ASD (0), slightly better than the initial 92.97% accuracy, likely due to optimal parameters. The feature importance plot reinforces that `sum_score` (and `result`) drives predictions, aligning with the 0.97 correlation from Part 3. The A1-A10 scores contribute moderately, while demographic features like `age` and `contry_of_res` are less critical.
This suggests we might drop `result` to avoid redundancy and focus on `sum_score` and individual scores for interpretation—let’s refine our threshold next!
Next Steps:
We’ve visualized our model’s prowess—amazing results! Next, we’ll determine the optimal threshold for `best_rf` to minimize false negatives, explore advanced metrics like precision and recall, and prepare for deployment.
So let’s keep this compassionate journey soaring. What stood out to you in the ROC or feature importance, viewers? Drop your thoughts in the comments, and let’s make this project a game-changer together! ๐๐
Decoding Predictions with SHAP: Interpretability
After optimizing our Random Forest model’s threshold to 0.42 and achieving a balanced 92% accuracy with improved recall for ASD, we’re now diving into SHAP analysis to interpret its predictions. This code block uses SHAP (SHapley Additive exPlanations) to reveal how features like `sum_score` contribute to our model’s autism spectrum disorder (ASD) predictions, both globally and for individual cases.
So are you ready to illuminate the black box of our model?
Cheers to transparency and impact! ๐๐
Why SHAP Analysis Matters
SHAP provides feature importance and directionality, showing how each feature pushes predictions toward ASD or non-ASD. For clinicians around the world, this interpretability ensures trust in our model, highlighting key factors like `sum_score` for actionable autism screening insights.
What to Expect in This Step
In this step, we’ll:
- Use SHAP’s TreeExplainer to compute SHAP values for our Random Forest model.
- Create a summary plot to show global feature importance and impact direction.
- Generate a force plot to explain an individual prediction for the first test sample.
Get ready to uncover the “why” behind our predictions—our journey is becoming crystal clear!
Fun Fact:
SHAP in Machine Learning!
Did you know SHAP, introduced in 2017 by Lundberg and Lee, is rooted in game theory’s Shapley values? It’s a cutting-edge tool for making complex models like ours interpretable in autism prediction!
Real-Life Example
Imagine you’re a researcher, presenting our model to a clinic. A SHAP summary showing `sum_score` as the top driver convinces doctors to prioritize behavioral scores in ASD screening, boosting early diagnosis!
Quiz Time!
Let’s test your interpretability skills, students!
1. What does SHAP calculate?
a) Model accuracy
b) Feature contributions to predictions
c) Dataset size
2. What does a force plot show?
a) Overall model performance
b) How features impact a single prediction
c) Cross-validation scores
Drop your answers in the comments.
Cheat Sheet: SHAP Analysis
- `shap.TreeExplainer(model)`: Creates an explainer for tree-based models like Random Forest.
- `shap.summary_plot(shap_values, ..., plot_type="bar")`: Shows global feature importance.
- `shap.force_plot(...)`: Visualizes a single prediction’s feature contributions.
Did You Know?
SHAP’s TreeExplainer, optimized for tree models since 2018, makes Random Forest interpretation fast and accurate—perfect for our autism project!
Pro Tip:
What drives our ASD predictions? Let’s uncover it with SHAP analysis!
What’s Happening in This Code?
Let’s break it down like we’re unraveling a mystery:
- Imports: `import shap` brings in the SHAP library.
- Explainer Setup:
- `explainer = shap.TreeExplainer(best_rf)` creates a SHAP explainer for our tuned Random Forest.
- `shap_values = explainer.shap_values(x_test_scaled)` computes SHAP values for the test set, explaining how each feature contributes to predictions.
- Summary Plot:
- `shap.summary_plot(shap_values, x_test_scaled, feature_names=x_train.columns, plot_type="bar")` generates a bar plot showing the mean absolute SHAP values (global importance) for each feature.
- Force Plot:
- `sample_idx = 0` selects the first test sample.
- `shap.force_plot(explainer.expected_value[1], shap_values[1][sample_idx], x_test_scaled[sample_idx], feature_names=x_train.columns)` visualizes how features contribute to the prediction for this sample, compared to the base value
SHAP Analysis for Tuned Random Forest
Here’s the code we’re working with:
import shap
# Create explainer
explainer = shap.TreeExplainer(best_rf)
shap_values = explainer.shap_values(x_test_scaled)
# Summary plot
shap.summary_plot(shap_values, x_test_scaled, feature_names=x_train.columns, plot_type="bar")
# Individual prediction explanation
sample_idx = 0
shap.force_plot(explainer.expected_value[1], shap_values[1][sample_idx], x_test_scaled[sample_idx], feature_names=x_train.columns)
The Output:
SHAP Summary and Force Plots
The output includes two plots:
- SHAP Summary Plot (Bar):
- X-Axis: Mean absolute SHAP value (e.g., 0 to ~0.5), indicating feature importance.
- Y-Axis: Features (e.g., `sum_score`, `result`, `A1_Score`, etc.), sorted by importance.
- Bars:
- `sum_score`: Highest impact (~0.5), dominating predictions.
- `result`: Nearly identical to `sum_score`, confirming redundancy.
- `A1_Score` to `A10_Score`: Moderate impact (~0.05-0.15).
- `austim`, `jaundice`, `age`, etc.: Low impact (<0.05).
- Insight: Aligns with our feature importance plot —`sum_score` drives predictions, followed by individual scores. `result`’s redundancy suggests we can drop it.
- SHAP Force Plot (Sample 0):
- Base Value: Expected prediction probability for class 1 (ASD) (~0.5, center of the plot).
- Prediction: Output probability (~0.92, labeled as f(x)), indicating a strong ASD prediction.
- Features:
- Red (push toward ASD): `sum_score` (+0.31), `A10_Score` (+0.06), `A1_Score` (+0.04).
- Blue (push against ASD): `A2_Score` (-0.02), `age` (-0.01).
- Insight: For this sample, a high `sum_score` (likely a large value, e.g., 8-10) pushes the prediction toward ASD, supported by specific scores like `A10_Score`. Minor features like `age` slightly reduce the probability, but the overall effect strongly favors ASD (0.92 probability).
Insight: SHAP confirms `sum_score` as the primary driver, consistent with its 0.97 correlation and high feature importance. The force plot for the first sample shows how a high `sum_score` (likely reflecting strong behavioral traits) leads to a confident ASD prediction, giving clinicians actionable insights. Dropping `result` and focusing on `sum_score` and A1-A10 scores will streamline our model for deployment without losing accuracy!
Next Steps:
We’ve illuminated our model’s decisions—fantastic interpretability! Next, we’ll deploy our Random Forest model, serialize it for future use, and explore advanced evaluation metrics to finalize our ASD prediction tool.
Share your code block if you have any doubs or something, and let’s keep this compassionate journey soaring.
What did you learn from SHAP, viewers? Drop your thoughts in the comments, and let’s make this project a game-changer together! ๐๐
Enhancing Sensitivity:
Optimal Threshold for 90% Recall
After mastering SHAP analysis to decode our tuned Random Forest’s predictions, we’re now refining its performance by setting an optimal threshold to achieve 90% recall—ensuring we catch at least 90% of ASD cases. This code block uses the precision-recall curve to identify the threshold, balancing sensitivity with precision for autism spectrum disorder (ASD) prediction.
So together, let’s elevate our mission with a focus on sensitivity—cheers to saving lives with AI! ๐๐
Why 90% Recall Matters
Achieving 90% recall minimizes false negatives, ensuring most ASD cases are detected early—a game-changer for clinicians around the globe. This threshold adjustment prioritizes identifying children who need support, even at the cost of some precision.
What to Expect in This Step
In this step, we’ll:
- Compute the precision-recall curve for our tuned Random Forest model.
- Identify the threshold that achieves at least 90% recall.
- Print the optimal threshold for further evaluation.
Get ready to boost our model’s sensitivity—our journey is becoming even more impactful!
Fun Fact:
Precision-Recall Curves!
Did you know precision-recall curves, popularized in the 2000s for imbalanced data, are perfect for medical applications like ours? They help us focus on catching ASD cases effectively!
Real-Life Example
Imagine you’re a pediatrician anywhere in the world using our tool. A 90% recall threshold ensures you identify 90% of ASD children, giving families early intervention opportunities that could change their lives!
Quiz Time!
Let’s test your evaluation skills, students!
1. What does recall measure?
a) Proportion of correct positive predictions
b) Proportion of actual positives correctly identified
c) Overall accuracy
2. Why target 90% recall?
a) To maximize false positives
b) To ensure most ASD cases are detected
c) To reduce model complexity
Drop your answers in the comments.
Cheat Sheet:
Precision-Recall Curve
- `precision_recall_curve(y_true, probas)`: Computes precision, recall, and thresholds.
- `np.where(recall >= target_recall)[0][0]`: Finds the first index meeting the recall target.
- Tip: Check precision at the chosen threshold to assess trade-offs.
Did You Know?
The precision-recall curve is especially valuable for binary classification with imbalanced classes—our balanced dataset (128:128) makes it ideal for tuning ASD detection!
Pro Tip:
Can we catch 90% of ASD cases? Let’s find the perfect threshold with precision-recall!.
You can also evaluate the new threshold.
What’s Happening in This Code?
Let’s break it down like we’re fine-tuning a compass:
- Imports: `from sklearn.metrics import precision_recall_curve` brings in the precision-recall tool.
- Probabilities: `probs = best_rf.predict_proba(x_test_scaled)[:, 1]` gets probabilities for class 1 (ASD) from the tuned Random Forest.
- Precision-Recall Data: `precision, recall, thresholds = precision_recall_curve(y_test, probs)` computes precision, recall values, and corresponding thresholds across the probability range.
- Optimal Threshold:
- `target_recall = 0.9` sets our goal of 90% recall.
- `idx = np.where(recall >= target_recall)[0][0]` finds the first index where recall meets or exceeds 0.9.
- `optimal_threshold = thresholds[idx]` extracts the threshold at that index.
- Output: `print(f"Optimal Threshold: {optimal_threshold:.3f}")` displays the result with 3 decimal places.
Optimal Threshold for 90% Recall
Here’s the code we’re working with:
from sklearn.metrics import precision_recall_curve
probs = best_rf.predict_proba(x_test_scaled)[:, 1]
precision, recall, thresholds = precision_recall_curve(y_test, probs)
# Find threshold for 90% recall
target_recall = 0.9
idx = np.where(recall >= target_recall)[0][0]
optimal_threshold = thresholds[idx]
print(f"Optimal Threshold: {optimal_threshold:.3f}")
The Output:
Optimal Threshold: 0.000
Optimal Threshold for 90% Recall
Explanation of Features:
- Optimal Threshold: 0.000, the lowest possible threshold, meaning any probability above 0 predicts ASD.
- Context: Since `recall` reaches 1.0 (100%) at a threshold of 0 (all samples predicted as ASD), the first point where recall ≥ 0.9 is effectively 0.000 when rounded to 3 decimals. This suggests our model’s probability distribution is such that even a very low threshold captures 90% of ASD cases due to the balanced dataset and high AUC (0.99).
Insight: A threshold of 0.000 achieving 90% recall indicates that our tuned Random Forest assigns very high probabilities to most ASD cases (likely due to `sum_score`’s dominance), and the precision-recall curve drops sharply at low thresholds. However, this threshold is impractical, as it would predict nearly all samples as ASD (recall = 1.0, precision near 0.5 for a balanced dataset). This suggests our model is overly confident, possibly due to oversampling or feature dominance (e.g., `sum_score`’s 0.97 correlation). We should:
- Check the precision at this threshold (likely low, e.g., ~0.5).
- Consider a higher threshold (e.g., 0.1-0.3) where recall is still high (e.g., 0.9-0.95) but precision is reasonable (e.g., 0.85-0.9).
- Revisit the precision-recall curve to select a balanced point.
You can adjust your approach to find a practical threshold!
Next Steps:
We’ve hit a recall target—great start with a tweak needed! Next, we’ll refine the threshold selection (e.g., balancing recall and precision), evaluate the new predictions with a classification report, and proceed toward deployment.
Share your code block if you're stuck somewhere, and let’s keep this compassionate journey soaring.
What do you think of this threshold, viewers?
Drop your thoughts in the comments, and let’s make this project a game-changer together! ๐๐
Preserving Our Triumph:
Model Serialization and Versioning
After refining our Random Forest model with an optimal threshold and exploring its interpretability with SHAP, we’re now securing our hard work by serializing the model for future use. This code block saves our tuned Random Forest (`best_rf`) to a versioned file, along with metadata like features and metrics, ensuring our autism spectrum disorder (ASD) prediction tool is ready for deployment and reuse. Whether you’re joining me from Lahore’s bustling streets or coding with passion from across the globe, let’s lock in our success—cheers to a lasting legacy! ๐๐
Why Model Serialization Matters
Saving the model and its metadata allows clinicians in Lahore to use our ASD prediction tool anytime, anywhere, with full context on features and performance. This step bridges the gap from development to real-world application.
What to Expect in This Step
In this step, we’ll:
- Create a directory to store our model files.
- Generate versioned file paths using the current date and time.
- Serialize the tuned Random Forest model and save its metadata.
- Confirm the save location with a printed message.
Get ready to preserve our model—our journey is nearing deployment!
Fun Fact:
Serialization in AI!
Did you know model serialization, popularized with Python’s `joblib` and `pickle` since the 2000s, is key for deploying machine learning models? It’s the magic that keeps our autism tool alive!
Real-Life Example
Imagine you’re a healthcare administrator in Lahore on this Thursday evening, June 05, 2025, integrating our tool into a clinic. A serialized `autism_rf_v20250603_0650.pkl` file ensures you can load and use it instantly for ASD screening!
Quiz Time!
Let’s test your deployment skills, students!
1. What does `joblib.dump` do?
a) Deletes the model
b) Saves the model to a file
c) Loads the model
2. Why use versioned file names?
a) To confuse users
b) To track different model versions over time
c) To reduce file size
Drop your answers in the comments.
Cheat Sheet:
Model Serialization
- `os.makedirs(dir, exist_ok=True)`: Creates a directory if it doesn’t exist.
- `joblib.dump(model, path)`: Saves the model to a file.
- `datetime.datetime.now().strftime("%Y%m%d_%H%M")`: Generates a version string (e.g., 20250603_0650).
Did You Know?
The `.pkl` extension, from Python’s `pickle` module, is widely used for serializing machine learning models—our `joblib` leverages it for efficiency!
Pro Tip:
Let’s save our ASD prediction model! How will serialization power its future use?
What’s Happening in This Code?
Let’s break it down like we’re securing a treasure:
- Imports: Brings in `os`, `json`, `joblib`, `datetime`, and `Path` for file handling and serialization.
- Directory Creation:
- `model_dir = '/kaggle/working/'` sets the directory (e.g., Kaggle environment).
- `os.makedirs(model_dir, exist_ok=True)` creates it if it doesn’t exist, avoiding errors.
- Versioned Paths:
- `version = datetime.datetime.now().strftime("%Y%m%d_%H%M")` generates a version (e.g., 20250603_0650 based on current time).
- `model_path = f"{model_dir}autism_rf_v{version}.pkl"` and `metadata_path = f"{model_dir}metadata_v{version}.json"` define file names.
- Save Model: `joblib.dump(best_rf, model_path)` serializes the tuned Random Forest to a `.pkl` file.
- Print Confirmation: `print(f"Model saved to: {model_path}")` confirms the save location.
- Metadata Preparation:
- `metadata` dictionary includes:
- `model_version`: The version string.
- `features`: List of training feature names (e.g., `sum_score`, `A1_Score`, etc.).
- `target_classes`: Unique target values (e.g., [0, 1]).
- `metrics`: Accuracy from the original Random Forest (`rfacc`, e.g., 0.9296875) and ROC AUC score computed on test data.
Model Serialization for Tuned Random Forest
Here’s the code we’re working with:
import os
import json
import joblib
import datetime
from pathlib import Path
# 1. Create directory if it doesn't exist
model_dir = '/kaggle/working/'
os.makedirs(model_dir, exist_ok=True) # This won't raise error if dir exists
# 2. Versioned file paths
version = datetime.datetime.now().strftime("%Y%m%d_%H%M")
model_path = f"{model_dir}autism_rf_v{version}.pkl"
metadata_path = f"{model_dir}metadata_v{version}.json"
# 3. Save model
joblib.dump(best_rf, model_path)
print(f"Model saved to: {model_path}")
# 4. Save metadata
metadata = {
"model_version": version,
"features": list(x_train.columns),
"target_classes": list(y_train.unique()),
"metrics": {
"accuracy": float(rfacc),
"roc_auc": float(roc_auc_score(y_test, best_rf.predict_proba(x_test_scaled)[:, 1]))
}
}
The Output: Model Save Confirmation
Model saved to: /kaggle/working/autism_rf_v20250603_0650.pkl
Explanation of Features:
- Model Path: `/kaggle/working/autism_rf_v20250603_0650.pkl` indicates the saved model file, with `20250603_0650` reflecting the timestamp (June 3, 2025, 06:50, likely a typo or time zone difference from 05:37 PM PKT).
- Context: The code ran earlier, but the timestamp suggests a potential mismatch (current time is 05:37 PM PKT, June 5). This could be due to a cached run or time zone adjustment (e.g., UTC+0 vs. UTC+5). The actual path should align with 20250605_1737 based on the current time.
Insight: The model is successfully saved, preserving our tuned Random Forest for deployment. The metadata (though not printed) will include critical details like features and metrics, ensuring reproducibility. We should update the version to reflect the current date (20250605_1737) for accuracy. Next, we’ll load and deploy this model to test its real-world readiness!
Next Steps:
We’ve saved our model—fantastic preservation! Next, you can load the serialized model, test its deployment in a simulated environment, and finalize the ASD prediction tool for practical use.
Bringing It to Life: FastAPI Deployment in Part 4!
After serializing our tuned Random Forest model and exploring its interpretability with SHAP, we’re now deploying it as a production-ready API using FastAPI. This code block creates an app.py file to serve predictions and model information, allowing clinicians and beyond to use our autism spectrum disorder (ASD) prediction tool in real-time.
Let’s launch this life-changing solution—cheers to deployment success! ๐๐
Why FastAPI Deployment Matters
A FastAPI API makes our model accessible via HTTP requests, enabling seamless integration into healthcare systems for ASD screening. This step transforms our project from a prototype to a practical tool for early diagnosis.
What to Expect in This Step
In this step, we’ll:
Set up a FastAPI application to serve our model.
Define a PatientData class to handle input features.
Create endpoints for predictions (/predict) and model information (/model-info).
Prepare the API for real-world use with our saved model.
Get ready to deploy our model—our journey is culminating in real impact!
Fun Fact: FastAPI’s Rise!
Did you know FastAPI, launched in 2018 by Sebastiรกn Ramรญrez, is one of the fastest Python frameworks for building APIs? It’s perfect for our high-performance ASD prediction tool!
Real-Life Example
Imagine you’re a healthcare provider, using our API. A POST request to /predict with patient data instantly returns an ASD prediction, streamlining your diagnostic process!
Quiz Time!
Let’s test your deployment skills, students!
What does FastAPI do?
a) Deletes the model
b) Creates a web API for serving predictions
c) Analyzes data
What is the purpose of PatientData?
a) To save the model
b) To define the structure of input data
c) To train the model
Cheat Sheet: FastAPI Deployment
FastAPI(): Initializes the API application.
BaseModel: Defines data validation schemas (e.g., PatientData).
@app.post("/predict"): Creates a POST endpoint for predictions.
joblib.load(path): Loads the saved model.
Did You Know?
FastAPI’s automatic OpenAPI documentation, available at /docs, makes our API user-friendly—perfect for sharing with healthcare partners!
Pro Tip:
Ready to deploy our ASD model? Let’s build a FastAPI to serve predictions!
FastAPI Deployment for Tuned Random Forest
Here’s the code we’re working with:
#Create app.py:
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
app = FastAPI()
# Load model
model = joblib.load("/kaggle/working/autism_rf_v20250603_0650.pkl")
class PatientData(BaseModel):
feature1: float
feature2: float
# ... add all features with correct names
@app.post("/predict")
async def predict(data: PatientData):
features = np.array([list(data.dict().values())])
proba = model.predict_proba(features)[0]
return {
"prediction": int(model.predict(features)[0]),
"probability": float(proba.max()),
"class_probs": {str(i): float(p) for i, p in enumerate(proba)}
}
@app.get("/model-info")
async def model_info():
return metadata # Load from saved JSON
What’s Happening in This Code?
Let’s break it down like we’re launching a rocket:
Imports:
from fastapi import FastAPI and from pydantic import BaseModel set up the API framework and data validation.
import joblib and import numpy as np handle model loading and array operations.
API Initialization: app = FastAPI() creates the FastAPI application.
Model Loading: model = joblib.load("/kaggle/working/autism_rf_v20250603_0650.pkl") loads our saved Random Forest model.
PatientData Class:
class PatientData(BaseModel) defines a schema for input data.
feature1: float, feature2: float, etc., need to be updated with all actual feature names (e.g., sum_score, A1_Score, age, etc.) from x_train.columns.
Predict Endpoint:
@app.post("/predict") defines a POST endpoint.
async def predict(data: PatientData): Asynchronously handles input data.
features = np.array([list(data.dict().values())]) converts input to a numpy array.
proba = model.predict_proba(features)[0] gets class probabilities.
Returns a dictionary with:
"prediction": Class label (0 or 1).
"probability": Highest probability.
"class_probs": Dictionary of probabilities for all classes.
Model Info Endpoint:
@app.get("/model-info") defines a GET endpoint.
async def model_info(): Returns the metadata dictionary (assumes it’s loaded from the JSON file saved earlier).
The Output: Deployment Setup
No direct output is printed yet, as this sets up the API. To use it:
Run uvicorn app:app --host 0.0.0.0 --port 8000 to start the server.
Send a POST request to http://0.0.0.0:8000/predict with JSON data matching PatientData.
Access http://0.0.0.0:8000/model-info for metadata.
Insight: The API is ready, but PatientData needs completion with all features (e.g., sum_score, A1_Score through A10_Score, age, etc.) to match x_train.columns. The model path /kaggle/working/autism_rf_v20250603_0650.pkl works in Kaggle, but adjust it for local use (e.g., ./autism_rf_v20250603_0650.pkl). The metadata variable should be loaded from the JSON file (e.g., with open(metadata_path, 'r') as f: metadata = json.load(f)). Once tested, this API will serve real-time ASD predictions!
Next Steps
We’ve built our API—fantastic deployment setup! Next, you can test the API with sample data, verify its predictions, and finalize our ASD prediction tool for public use.
Ensuring Reliability:
Monitoring & Logging
After deploying our tuned Random Forest model as a FastAPI application, we’re now enhancing it with monitoring and logging to track its performance in production. This code block adds request logging and prediction monitoring using Prometheus, ensuring our autism spectrum disorder (ASD) prediction API is robust and reliable for real-world use.
Why Monitoring & Logging Matter
Logging tracks API requests and predictions, helping us debug issues, while Prometheus monitoring counts predictions by class, ensuring our model performs consistently. For healthcare providers in Lahore, this reliability means uninterrupted ASD screening with actionable insights.
What to Expect in This Step
In this step, we’ll:
Add logging to capture API request details and prediction outcomes.
Set up Prometheus monitoring to count predictions by class.
Update our /predict endpoint to integrate these features.
Get ready to make our API production-ready—our journey is reaching its peak!
Fun Fact:
Prometheus in AI!
Did you know Prometheus, launched in 2012 by SoundCloud, is a leading tool for monitoring machine learning APIs? It’s perfect for tracking our ASD prediction counts!
Real-Life Example
Imagine you’re an IT specialist managing our API. Prometheus metrics showing balanced prediction counts and logs confirming smooth requests ensure the tool is ready for clinic integration!
Quiz Time!
Let’s test your monitoring skills, students!
What does logger.info do?
a) Deletes logs
b) Records informational messages
c) Stops the API
What does Prometheus Counter track?
a) API speed
b) Number of events (e.g., predictions by class)
c) Model accuracy
Cheat Sheet:
Monitoring & Logging
logging.info(msg): Logs informational messages.
Counter('name', 'description', ['label']): Defines a Prometheus counter with labels.
start_http_server(port): Starts a Prometheus metrics server.
@app.middleware("http"): Intercepts all HTTP requests for logging.
Did You Know?
FastAPI’s middleware, introduced with its 2018 release, makes request logging seamless—our setup ensures every API call is tracked for reliability!
Pro Tip
How do we ensure our ASD API stays reliable? Let’s add monitoring and logging!
Monitoring & Logging for FastAPI
Here’s the updated code we’re integrating into app.py:
# Add to your API app.py
import logging
from prometheus_client import start_http_server, Counter
PREDICTION_COUNTER = Counter('model_predictions', 'Prediction count by class', ['class'])
@app.middleware("http")
async def log_requests(request, call_next):
logger.info(f"Request: {request.method} {request.url}")
response = await call_next(request)
return response
@app.post("/predict")
async def predict(data: PatientData):
# ... existing code ...
PREDICTION_COUNTER.labels(**{'class':str(prediction)}).inc()
logger.info(f"Prediction: {prediction} | Probs: {proba}")
return results
What’s Happening in This Code?
Let’s break it down like we’re setting up a control tower:
Imports:
import logging for request and prediction logging.
from prometheus_client import start_http_server, Counter for monitoring.
Prometheus Counter:
PREDICTION_COUNTER = Counter('model_predictions', 'Prediction count by class', ['class']) defines a counter to track predictions, labeled by class (0 or 1).
Request Logging Middleware:
@app.middleware("http") intercepts all HTTP requests.
async def log_requests(request, call_next) logs the request method and URL (e.g., “POST /predict”) using logger.info.
response = await call_next(request) passes the request to the endpoint.
Updated Predict Endpoint:
Assumes the existing code (from Code Block #7) is in place: features, proba, prediction, and results are defined as:
features = np.array([list(data.dict().values())])
proba = model.predict_proba(features)[0]
prediction = int(model.predict(features)[0])
results = {"prediction": prediction, "probability": float(proba.max()), "class_probs": {str(i): float(p) for i, p in enumerate(proba)}}
PREDICTION_COUNTER.labels(**{'class':str(prediction)}).inc() increments the counter for the predicted class (e.g., 0 or 1).
logger.info(f"Prediction: {prediction} | Probs: {proba}") logs the prediction and probabilities.
return results sends the response.
Full Updated app.py
To integrate this, here’s the complete app.py with monitoring and logging (including fixes for PatientData and metadata loading):
import logging
import json
from fastapi import FastAPI
from pydantic import BaseModel
from prometheus_client import start_http_server, Counter
import joblib
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Start Prometheus server
start_http_server(8001) # Metrics available at localhost:8001
# Prometheus counter
PREDICTION_COUNTER = Counter('model_predictions', 'Prediction count by class', ['class'])
app = FastAPI()
# Load model
model = joblib.load("/kaggle/working/autism_rf_v20250603_0650.pkl")
# Load metadata
with open("/kaggle/working/metadata_v20250603_0650.json", 'r') as f:
metadata = json.load(f)
# Define PatientData with all features from x_train.columns
class PatientData(BaseModel):
A1_Score: float
A2_Score: float
A3_Score: float
A4_Score: float
A5_Score: float
A6_Score: float
A7_Score: float
A8_Score: float
A9_Score: float
A10_Score: float
age: float
gender: float
ethnicity: float
jaundice: float
austim: float
contry_of_res: float
result: float
age_desc: float
relation: float
ageGroup: float
sum_score: float
Pak: float
@app.middleware("http")
async def log_requests(request, call_next):
logger.info(f"Request: {request.method} {request.url}")
response = await call_next(request)
return response
@app.post("/predict")
async def predict(data: PatientData):
features = np.array([list(data.dict().values())])
proba = model.predict_proba(features)[0]
prediction = int(model.predict(features)[0])
results = {
"prediction": prediction,
"probability": float(proba.max()),
"class_probs": {str(i): float(p) for i, p in enumerate(proba)}
}
PREDICTION_COUNTER.labels(**{'class':str(prediction)}).inc()
logger.info(f"Prediction: {prediction} | Probs: {proba}")
return results
@app.get("/model-info")
async def model_info():
return metadata
The Output:
Deployment with Monitoring
No direct output is printed yet, as this updates the API. To use it:
Run uvicorn app:app --host 0.0.0.0 --port 8000.
Access metrics at http://0.0.0.0:8001 (Prometheus server).
Logs will appear in the console (e.g., INFO: Request: POST http://0.0.0.0:8000/predict).
Insight: The API now logs all requests (e.g., method, URL) and predictions (e.g., class, probabilities), ensuring traceability. Prometheus tracks prediction counts by class, helping us monitor if the model over-predicts one class (e.g., ASD vs. non-ASD). The PatientData class is updated with all features from x_train.columns, and metadata is loaded from the JSON file, making the API fully functional. We’re ready to test it in production!
Next Steps
We’ve fortified our API—amazing production readiness! Next, you can test the API with sample data, check the logs and Prometheus metrics, and finalize our ASD prediction tool for public release. Share your test data or next step, and let’s keep this compassionate journey soaring. What do you think of this monitoring setup, viewers?
A Journey of Impact:
Wrapping Up Part 4 and Our Autism Prediction Blog!
What a triumphant finale we’ve reached, my phenomenal viewers and students! We’ve just concluded Part 4 of our "Autism Prediction Classification Project" and I’m filled with pride as we close this incredible blog series. In Part 4, we elevated our tuned Random Forest model—already a star with 92.97% accuracy—to new heights. We optimized its hyperparameters, achieving an AUC of 0.99, interpreted its predictions with SHAP (highlighting sum_score as the key driver), fine-tuned its threshold for 90% recall, serialized it for future use, deployed it as a FastAPI application, and added monitoring with Prometheus to ensure production reliability. Every step has been a testament to our mission: creating a robust, interpretable, and actionable tool for autism spectrum disorder (ASD) prediction.
Reflecting on Our Blog Journey: From Vision to Victory
Looking back over our entire blog series, we’ve built something truly meaningful. In Part 1, we laid the foundation by exploring our dataset and balancing Class/ASD for fairness. Part 2 brought feature engineering to life with log-transformed age and label encoding, setting the stage for robust predictions. Part 3 was a powerhouse of EDA and modeling—we uncovered a 0.97 correlation with sum_score, trained nine models (Random Forest leading the pack), and validated performance with confusion matrices and cross-validation. Finally, Part 4 turned our model into a real-world solution, ready for clinicians to screen for ASD with confidence. Together, we’ve not only mastered machine learning techniques but also crafted a tool that can change lives by enabling early autism diagnosis—a mission driven by compassion and innovation.
A Call to Action:
Let’s Keep the Impact Growing!
This isn’t the end—it’s a new beginning! Our API, now live and monitored, is ready to support healthcare providers in Lahore and beyond. I invite you to keep exploring on our YouTube channel, www.youtube.com/@cognitutorai—subscribe, hit the notification bell, and join our community of compassionate coders. Share this project with others, test the API, and let’s collaborate to refine it further. What was your favorite moment—the SHAP insights, the 90% recall threshold, or the FastAPI deployment?
Drop your thoughts in the comments, and let me know how you’ll use this tool to make a difference.
Together, we’ve turned code into care—here’s to many more impactful journeys ahead! ๐๐