Skip to main content
# requires-python = ">=3.11"
# dependencies = [
#   "ucimlrepo",
#   "pandas",
#   "requests",
# ]
# ///
import argparse
import os
import sys
import time
import json
import pandas as pd
import requests
from ucimlrepo import fetch_ucirepo

# Defaults
DEFAULT_BASE_URL = "https://beta.woodwide.ai/"
UCI_DATASET_ID = 45  # Heart Disease dataset


def setup_args():
    parser = argparse.ArgumentParser(
        description="Test Woodwide API with UCI Heart Disease dataset for Clustering and Anomaly Detection"
    )
    parser.add_argument("-k", "--api-key", required=True, help="Woodwide API Key")
    parser.add_argument(
        "-m", "--model-name", default="heart_disease", help="Base name for the models"
    )
    parser.add_argument(
        "-d", "--dataset-name", default="heart_disease", help="Name for the dataset"
    )
    parser.add_argument(
        "-o", "--output-file", default="results.json", help="File path to save results"
    )
    parser.add_argument(
        "--base-url", default=DEFAULT_BASE_URL, help="Base URL for API"
    )
    return parser.parse_args()


def fetch_and_prepare_data():
    print(f"Fetching UCI dataset ID={UCI_DATASET_ID} (Heart Disease)...")
    dataset = fetch_ucirepo(id=UCI_DATASET_ID)

    X = dataset.data.features
    y = dataset.data.targets

    # Truncate for file size limitations
    #N = len(X)
    #X = X[:N]
    #y = y[:N]

    # Combine features and targets
    df = pd.concat([X, y], axis=1)

    # Save to temporary CSV file
    dataset_path = "heart_disease.csv"
    df.to_csv(dataset_path, index=False)

    print(f"Data prepared and saved to {dataset_path}. Shape: {df.shape}")
    return dataset_path


def main():
    args = setup_args()
    base_url = args.base_url.rstrip("/")
    headers = {
        "Authorization": f"Bearer {args.api_key}",
        "accept": "application/json",
    }

    # 1. Fetch Data
    dataset_path = fetch_and_prepare_data()

    try:
        # 2. Upload Dataset
        print(f"Uploading {dataset_path} as '{args.dataset_name}'...")
        start_time = time.time()
        with open(dataset_path, "rb") as f:
            files = {"file": (os.path.basename(dataset_path), f, "text/csv")}
            data = {"name": args.dataset_name, "overwrite": "true"}
            response = requests.post(f"{base_url}/api/datasets", headers=headers, files=files, data=data)
        
        if response.status_code != 200:
            print(f"Error uploading dataset: {response.status_code}\n{response.text}")
            sys.exit(1)
        
        dataset_id = response.json().get("id")
        print(f"Upload took {time.time() - start_time:.2f}s. ID: {dataset_id}\n")

        # 3. Train Clustering Model
        clustering_model_name = f"{args.model_name}_clustering"
        print(f"Step 1: Training Clustering Model '{clustering_model_name}' to identify patient groups...")
        
        start_time = time.time()
        response = requests.post(
            f"{base_url}/api/models/clustering/train",
            params={"dataset_name": args.dataset_name},
            data={"model_name": clustering_model_name, "overwrite": "true"},
            headers=headers,
        )

        if response.status_code != 200:
            print(f"Error starting clustering training: {response.status_code}\n{response.text}")
            sys.exit(1)

        clustering_model_id = response.json().get("id")
        print(f"Clustering training started. ID: {clustering_model_id}")

        # Wait for Clustering
        timeout = 3000
        while True:
            response = requests.get(f"{base_url}/api/models/{clustering_model_id}", headers=headers)
            status = response.json().get("training_status")
            if status == "COMPLETE":
                print(f"Clustering complete! (Took {time.time() - start_time:.2f}s)\n")
                break
            elif status == "FAILED":
                print(f"Error: Clustering failed.\n{response.text}")
                sys.exit(1)
            time.sleep(2)

        # 4. Run Clustering Inference to get clusters
        print(f"Running clustering inference to segment patients...")
        response = requests.post(
            f"{base_url}/api/models/clustering/{clustering_model_id}/infer",
            params={"dataset_id": dataset_id},
            headers=headers,
            stream=True,
        )
        if response.status_code != 200:
            print(f"Error running clustering inference: {response.status_code}\n{response.text}")
            sys.exit(1)
        
        # Read the streamed response
        clusters_raw = b""
        for chunk in response.iter_content(chunk_size=None):
            if chunk:
                clusters_raw += chunk
        
        # Sometimes the response is wrapped in a way that needs cleaning, 
        # or we just need to ensure we have the full JSON.
        try:
            clusters = json.loads(clusters_raw)
        except json.JSONDecodeError as e:
            print(f"Error decoding clustering JSON: {e}")
            print(f"Raw response start: {clusters_raw[:100]}")
            sys.exit(1)
        
        # Extract cluster description if available
        cluster_descriptions = {}
        if isinstance(clusters, dict):
            cluster_descriptions = clusters.get("cluster_descriptions", {})
            # If clusters is a dict, the labels are likely in 'cluster_label'
            cluster_labels = clusters.get("cluster_label", {})
            # Convert to list for the rest of the script logic
            if cluster_labels:
                # Assuming cluster_labels is a dict of {index: label}
                sorted_indices = sorted([int(k) for k in cluster_labels.keys()])
                clusters_list = [cluster_labels[str(i)] for i in sorted_indices]
            else:
                clusters_list = clusters
        else:
            clusters_list = clusters

        # 5. Filter for the largest cluster for targeted anomaly detection
        # Count occurrences of each cluster
        if isinstance(clusters_list, list):
            from collections import Counter
            cluster_counts = Counter(clusters_list)
            # Get the cluster with the maximum count
            target_cluster = max(cluster_counts, key=cluster_counts.get)
            print(f"Largest cluster identified: Cluster {target_cluster} with {cluster_counts[target_cluster]} patients.")
        else:
            target_cluster = 0
            print(f"Warning: Could not determine cluster counts. Defaulting to Cluster {target_cluster}.")

        target_cluster_desc = cluster_descriptions.get(str(target_cluster), "No description available.")
        print(f"Filtering patients in Cluster {target_cluster} for targeted anomaly detection...")
        print(f"Cluster Description: {target_cluster_desc}")
        
        # Load the original data to filter it
        df = pd.read_csv(dataset_path)
        
        # We assume 'clusters_list' is a list of integers corresponding to the rows in df
        if isinstance(clusters_list, list) and len(clusters_list) == len(df):
            df['cluster'] = clusters_list
            cluster_df = df[df['cluster'] == target_cluster].drop(columns=['cluster'])
            
            if cluster_df.empty:
                print(f"Warning: Cluster {target_cluster} is empty. Using full dataset instead.")
                cluster_dataset_path = dataset_path
                cluster_dataset_name = args.dataset_name
            else:
                cluster_dataset_path = f"heart_disease_cluster_{target_cluster}.csv"
                cluster_df.to_csv(cluster_dataset_path, index=False)
                cluster_dataset_name = f"{args.dataset_name}_cluster_{target_cluster}"
                
                print(f"Uploading Cluster {target_cluster} data ({len(cluster_df)} patients)...")
                with open(cluster_dataset_path, "rb") as f:
                    files = {"file": (os.path.basename(cluster_dataset_path), f, "text/csv")}
                    data = {"name": cluster_dataset_name, "overwrite": "true"}
                    requests.post(f"{base_url}/api/datasets", headers=headers, files=files, data=data)
        else:
            print("Warning: Could not map clusters to rows. Using full dataset for anomaly detection.")
            cluster_dataset_path = dataset_path
            cluster_dataset_name = args.dataset_name

        # 6. Train Anomaly Detection Model on the specific cluster
        anomaly_model_name = f"{args.model_name}_cluster_{target_cluster}_anomaly"
        print(f"Step 2: Training Anomaly Detection Model '{anomaly_model_name}' for Cluster {target_cluster}...")
        
        start_time = time.time()
        response = requests.post(
            f"{base_url}/api/models/anomaly/train",
            params={"dataset_name": cluster_dataset_name},
            data={"model_name": anomaly_model_name, "overwrite": "true"},
            headers=headers,
        )

        if response.status_code != 200:
            print(f"Error starting anomaly training: {response.status_code}\n{response.text}")
            sys.exit(1)

        anomaly_model_id = response.json().get("id")
        print(f"Anomaly detection training started. ID: {anomaly_model_id}")

        # Wait for Anomaly Detection
        while True:
            response = requests.get(f"{base_url}/api/models/{anomaly_model_id}", headers=headers)
            status = response.json().get("training_status")
            if status == "COMPLETE":
                print(f"Anomaly detection complete! (Took {time.time() - start_time:.2f}s)\n")
                break
            elif status == "FAILED":
                print(f"Error: Anomaly detection failed.\n{response.text}")
                sys.exit(1)
            time.sleep(2)

        # 7. Run Anomaly Inference
        print(f"Running anomaly detection inference on Cluster {target_cluster}...")
        start_time = time.time()
        
        # Get the dataset ID for the cluster dataset
        response = requests.get(f"{base_url}/api/datasets", headers=headers)
        datasets = response.json()
        cluster_dataset_id = next((d['id'] for d in datasets if d['name'] == cluster_dataset_name), None)

        response = requests.post(
            f"{base_url}/api/models/anomaly/{anomaly_model_id}/infer",
            params={"dataset_id": cluster_dataset_id},
            headers=headers,
            stream=True,
        )
        
        if response.status_code != 200:
            print(f"Error running anomaly inference: {response.status_code}\n{response.text}")
            sys.exit(1)

        # Read the streamed response
        anomalies_raw = b""
        for chunk in response.iter_content(chunk_size=None):
            if chunk:
                # The server might send chunks that are just strings, 
                # but iter_content with chunk_size=None on a StreamingResponse 
                # should give us the raw bytes.
                anomalies_raw += chunk
        
        # Clean up the raw response if it's malformed
        # The server yields '{' then '"anomalous_ids": [' which might result in '{"anomalous_ids": ['
        # If it's missing the opening quote for the key, we fix it here, 
        # but the server-side fix is better.
        try:
            anomalies = json.loads(anomalies_raw)
        except json.JSONDecodeError as e:
            # Fallback: try to see if it's just a missing quote after the first '{'
            if anomalies_raw.startswith(b'{"anomalous_ids"'):
                 # This would be valid, so the error must be elsewhere
                 pass
            
            print(f"Error decoding anomaly JSON: {e}")
            print(f"Raw response: {anomalies_raw.decode('utf-8', errors='replace')}")
            sys.exit(1)
        print(f"Anomaly detection took {time.time() - start_time:.2f}s")

        # 8. Combine and Save Results
        # The anomaly endpoint returns a list of anomalous_ids
        anomalous_ids = anomalies.get("anomalous_ids", [])
        
        final_results = {
            "target_cluster": target_cluster,
            "target_cluster_description": target_cluster_desc,
            "clustering_model_id": clustering_model_id,
            "anomaly_model_id": anomaly_model_id,
            "cluster_size": len(cluster_df) if 'cluster_df' in locals() else "unknown",
            "anomalous_ids": anomalous_ids
        }

        # 9. Extract details for anomalous patients
        anomalous_details = []
        if anomalous_ids and 'cluster_df' in locals():
            # The IDs from the API are now integers matching the row index of the uploaded CSV
            try:
                # Relevant columns for Heart Disease dataset
                relevant_cols = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'num']
                # Filter for columns that actually exist in the dataframe
                display_cols = [c for c in relevant_cols if c in cluster_df.columns]
                
                for aid in anomalous_ids[:5]:  # Get details for first 5
                    idx = int(aid)
                    if idx < len(cluster_df):
                        patient_data = cluster_df.iloc[idx][display_cols].to_dict()
                        patient_data['id'] = aid
                        anomalous_details.append(patient_data)
            except (ValueError, IndexError):
                pass

        final_results["anomalous_details_sample"] = anomalous_details

        formatted_result = json.dumps(final_results, indent=2)
        with open(args.output_file, "w") as f:
            f.write(formatted_result)
        print(f"Results saved to {args.output_file}")

        # Print a small snippet explaining the results
        print("\n" + "="*120)
        print(f"ANALYSIS SUMMARY FOR CLUSTER {target_cluster}")
        print(f"Cluster {target_cluster} Description: {target_cluster_desc}")
        print("="*120)
        print(f"We identified a group of {final_results['cluster_size']} patients with similar clinical profiles.")
        print(f"Within this specific group, we found {len(anomalous_ids)} anomalous cases.")
        
        if anomalous_details:
            print("\nSample of Anomalous Patients (Full Clinical Metrics):")
            # Create a table-like header with most columns
            cols_to_print = ['id', 'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'num']
            header = " | ".join([f"{c.upper():<7}" for c in cols_to_print])
            print(header)
            print("-" * len(header))
            for p in anomalous_details:
                row = []
                for c in cols_to_print:
                    val = p.get(c, '')
                    if isinstance(val, float):
                        row.append(f"{val:<7.1f}")
                    else:
                        row.append(f"{str(val):<7}")
                print(" | ".join(row))
            if len(anomalous_ids) > 5:
                print(f"\n... and {len(anomalous_ids) - 5} more anomalous cases.")
        elif anomalous_ids:
            print(f"\nFound {len(anomalous_ids)} anomalous patient IDs: {', '.join(anomalous_ids[:10])}")
            
        print("="*120 + "\n")

    finally:
        if 'cluster_dataset_path' in locals() and os.path.exists(cluster_dataset_path) and cluster_dataset_path != dataset_path:
            os.remove(cluster_dataset_path)



if __name__ == "__main__":
    main()