Skip to main content
# requires-python = ">=3.11"
# dependencies = [
#   "ucimlrepo",
#   "pandas",
#   "requests",
# ]
# ///
import argparse
import os
import sys
import time
import json
import pandas as pd
import requests
from ucimlrepo import fetch_ucirepo
# Defaults
DEFAULT_BASE_URL = "https://beta.woodwide.ai"
UCI_DATASET_ID = 45  # Heart Disease dataset
def setup_args():
    parser = argparse.ArgumentParser(
        description="Test Woodwide API with UCI Heart Disease dataset for Clustering and Anomaly Detection"
    )
    parser.add_argument("-k", "--api-key", required=True, help="Woodwide API Key")
    parser.add_argument(
        "-m", "--model-name", default="heart_disease", help="Base name for the models"
    )
    parser.add_argument(
        "-d", "--dataset-name", default="heart_disease", help="Name for the dataset"
    )
    parser.add_argument(
        "-o", "--output-file", default="results.json", help="File path to save results"
    )
    parser.add_argument(
        "--base-url", default=DEFAULT_BASE_URL, help="Base URL for API"
    )
    return parser.parse_args()
def fetch_and_prepare_data():
    print(f"Fetching UCI dataset ID={UCI_DATASET_ID} (Heart Disease)...")
    dataset = fetch_ucirepo(id=UCI_DATASET_ID)
    X = dataset.data.features
    y = dataset.data.targets
    # Combine features and targets
    df = pd.concat([X, y], axis=1)
    # Save to temporary CSV file
    dataset_path = "heart_disease.csv"
    df.to_csv(dataset_path, index=False)
    print(f"Data prepared and saved to {dataset_path}. Shape: {df.shape}")
    return dataset_path
def main():
    args = setup_args()
    base_url = args.base_url.rstrip("/")
    headers = {
        "Authorization": f"Bearer {args.api_key}",
    }
    # 1. Fetch Data
    dataset_path = fetch_and_prepare_data()
    try:
        # 2. Upload Dataset
        print(f"Uploading {dataset_path} as '{args.dataset_name}'...")
        start_time = time.time()
        with open(dataset_path, "rb") as f:
            files = {"file": (os.path.basename(dataset_path), f, "text/csv")}
            data = {"dataset_name": args.dataset_name, "overrides": "true"}
            response = requests.post(f"{base_url}/api/datasets", headers=headers, files=files, data=data)
        
        if response.status_code not in (200, 201):
            print(f"Error uploading dataset: {response.status_code}\n{response.text}")
            sys.exit(1)
        
        dataset_id = response.json().get("dataset_id")
        print(f"Upload took {time.time() - start_time:.2f}s. ID: {dataset_id}\n")
        # 3. Train Clustering Model
        clustering_model_name = f"{args.model_name}_clustering"
        print(f"Step 1: Training Clustering Model '{clustering_model_name}' to identify patient groups...")
        
        start_time = time.time()
        train_payload = {
            "model_name": clustering_model_name,
            "model_type": "clustering",
            "dataset_id": dataset_id
        }
        response = requests.post(
            f"{base_url}/api/models/train",
            json=train_payload,
            headers=headers,
        )
        if response.status_code not in (200, 202):
            print(f"Error starting clustering training: {response.status_code}\n{response.text}")
            sys.exit(1)
        clustering_model_id = response.json().get("model_id")
        print(f"Clustering training started. ID: {clustering_model_id}")
        # Wait for Clustering
        while True:
            response = requests.get(f"{base_url}/api/models/{clustering_model_id}", headers=headers)
            model_data = response.json()
            status = model_data.get("status")
            if status == "ready":
                print(f"Clustering complete! (Took {time.time() - start_time:.2f}s)\n")
                break
            elif status == "failed":
                print(f"Error: Clustering failed.\n{response.text}")
                sys.exit(1)
            time.sleep(2)
        # 4. Run Clustering Inference (Synchronous)
        print(f"Running clustering inference to segment patients...")
        with open(dataset_path, "rb") as f:
            files = {"file": (os.path.basename(dataset_path), f, "text/csv")}
            data = {"output_type": "json"}
            response = requests.post(
                f"{base_url}/api/models/{clustering_model_id}/infer",
                headers=headers,
                files=files,
                data=data
            )
        
        if response.status_code != 200:
            print(f"Error running clustering inference: {response.status_code}\n{response.text}")
            sys.exit(1)
        
        infer_result = response.json()
        clusters_data = infer_result.get("output", {})
        
        # Extract cluster labels
        cluster_labels = clusters_data.get("cluster_label", {})
        cluster_descriptions = clusters_data.get("cluster_descriptions", {})
        
        if cluster_labels:
            sorted_indices = sorted([int(k) for k in cluster_labels.keys()])
            clusters_list = [cluster_labels[str(i)] for i in sorted_indices]
        else:
            print("Error: Could not find cluster labels in response.")
            sys.exit(1)
        # 5. Filter for the largest cluster for targeted anomaly detection
        from collections import Counter
        cluster_counts = Counter(clusters_list)
        target_cluster = max(cluster_counts, key=cluster_counts.get)
        print(f"Largest cluster identified: Cluster {target_cluster} with {cluster_counts[target_cluster]} patients.")
        target_cluster_desc = cluster_descriptions.get(str(target_cluster), "No description available.")
        print(f"Filtering patients in Cluster {target_cluster} for targeted anomaly detection...")
        print(f"Cluster Description: {target_cluster_desc}")
        
        # Load original data to filter
        df = pd.read_csv(dataset_path)
        df['cluster'] = clusters_list
        cluster_df = df[df['cluster'] == target_cluster].drop(columns=['cluster'])
        
        cluster_dataset_path = f"heart_disease_cluster_{target_cluster}.csv"
        cluster_df.to_csv(cluster_dataset_path, index=False)
        cluster_dataset_name = f"{args.dataset_name}_cluster_{target_cluster}"
        
        print(f"Uploading Cluster {target_cluster} data ({len(cluster_df)} patients)...")
        with open(cluster_dataset_path, "rb") as f:
            files = {"file": (os.path.basename(cluster_dataset_path), f, "text/csv")}
            data = {"dataset_name": cluster_dataset_name, "overrides": "true"}
            response = requests.post(f"{base_url}/api/datasets", headers=headers, files=files, data=data)
        
        cluster_dataset_id = response.json().get("dataset_id")
        # 6. Train Anomaly Detection Model on the specific cluster
        anomaly_model_name = f"{args.model_name}_cluster_{target_cluster}_anomaly"
        print(f"Step 2: Training Anomaly Detection Model '{anomaly_model_name}' for Cluster {target_cluster}...")
        
        start_time = time.time()
        train_payload = {
            "model_name": anomaly_model_name,
            "model_type": "anomaly",
            "dataset_id": cluster_dataset_id
        }
        response = requests.post(
            f"{base_url}/api/models/train",
            json=train_payload,
            headers=headers,
        )
        if response.status_code not in (200, 202):
            print(f"Error starting anomaly training: {response.status_code}\n{response.text}")
            sys.exit(1)
        anomaly_model_id = response.json().get("model_id")
        print(f"Anomaly detection training started. ID: {anomaly_model_id}")
        # Wait for Anomaly Detection
        while True:
            response = requests.get(f"{base_url}/api/models/{anomaly_model_id}", headers=headers)
            status = response.json().get("status")
            if status == "ready":
                print(f"Anomaly detection complete! (Took {time.time() - start_time:.2f}s)\n")
                break
            elif status == "failed":
                print(f"Error: Anomaly detection failed.\n{response.text}")
                sys.exit(1)
            time.sleep(2)
        # 7. Run Anomaly Inference (Synchronous)
        print(f"Running anomaly detection inference on Cluster {target_cluster}...")
        start_time = time.time()
        
        with open(cluster_dataset_path, "rb") as f:
            files = {"file": (os.path.basename(cluster_dataset_path), f, "text/csv")}
            data = {"output_type": "json"}
            response = requests.post(
                f"{base_url}/api/models/{anomaly_model_id}/infer",
                headers=headers,
                files=files,
                data=data
            )
        
        if response.status_code != 200:
            print(f"Error running anomaly inference: {response.status_code}\n{response.text}")
            sys.exit(1)
        anomaly_result = response.json()
        anomalies_data = anomaly_result.get("output", {})
        anomalous_ids = anomalies_data.get("anomalous_ids", [])
        
        print(f"Anomaly detection took {time.time() - start_time:.2f}s")
        # 8. Combine and Save Results
        final_results = {
            "target_cluster": target_cluster,
            "target_cluster_description": target_cluster_desc,
            "clustering_model_id": clustering_model_id,
            "anomaly_model_id": anomaly_model_id,
            "cluster_size": len(cluster_df),
            "anomalous_ids": anomalous_ids
        }
        # 9. Extract details for anomalous patients
        anomalous_details = []
        if anomalous_ids:
            relevant_cols = ['age', 'sex', 'cp', 'trestbps', 'chol', 'thalach', 'num']
            display_cols = [c for c in relevant_cols if c in cluster_df.columns]
            
            for aid in anomalous_ids[:5]:
                idx = int(aid)
                if idx < len(cluster_df):
                    patient_data = cluster_df.iloc[idx][display_cols].to_dict()
                    patient_data['id'] = aid
                    anomalous_details.append(patient_data)
        final_results["anomalous_details_sample"] = anomalous_details
        with open(args.output_file, "w") as f:
            json.dump(final_results, f, indent=2)
        print(f"Results saved to {args.output_file}")
        print("\n" + "="*120)
        print(f"ANALYSIS SUMMARY FOR CLUSTER {target_cluster}")
        print(f"Cluster Description: {target_cluster_desc}")
        print("="*120)
        print(f"Identified {final_results['cluster_size']} patients in this group.")
        print(f"Found {len(anomalous_ids)} anomalous cases.")
        
        if anomalous_details:
            print("\nSample of Anomalous Patients:")
            cols_to_print = ['id', 'age', 'sex', 'cp', 'trestbps', 'chol', 'thalach', 'num']
            header = " | ".join([f"{c.upper():<7}" for c in cols_to_print])
            print(header)
            print("-" * len(header))
            for p in anomalous_details:
                row = [f"{str(p.get(c, '')):<7}" for c in cols_to_print]
                print(" | ".join(row))
        print("="*120 + "\n")
    finally:
        if os.path.exists(dataset_path): os.remove(dataset_path)
        if 'cluster_dataset_path' in locals() and os.path.exists(cluster_dataset_path):
            os.remove(cluster_dataset_path)
if __name__ == "__main__":
    main()