#!/bin/bash
# GPU-Optimized Installation Script for FoundationsAI
# Supports: RTX 4090, RTX 5090, RTX PRO 6000 Blackwell, and other NVIDIA GPUs

set -e  # Exit on error

echo "========================================"
echo "GPU-Optimized Setup for FoundationsAI"
echo "========================================"
echo ""

# Check Python version
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
echo "Python version: $PYTHON_VERSION"

# Check if running in virtual environment
if [ -z "$VIRTUAL_ENV" ]; then
    echo "ERROR: Please activate your virtual environment first!"
    echo "Run: source venv/bin/activate"
    exit 1
fi

echo ""
echo "Step 1: Detecting GPU..."
if command -v nvidia-smi &> /dev/null; then
    GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -n 1)
    GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader | head -n 1 | awk '{print $1}')
    COMPUTE_CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1)

    echo "Detected GPU: $GPU_NAME"
    echo "VRAM: ${GPU_MEMORY}MB"
    echo "Compute Capability: $COMPUTE_CAP"
    echo ""

    # Determine GPU type and installation strategy
    if [[ "$GPU_NAME" == *"4090"* ]]; then
        GPU_TYPE="RTX4090"
        ARCHITECTURE="Ada Lovelace (sm_89)"
        CUDA_VERSION="cu121"
        PYTORCH_INDEX="https://download.pytorch.org/whl/cu121"
        RECOMMENDED_MODEL="llama3.1:8b"
        echo "🎮 Detected: RTX 4090"
        echo "Architecture: $ARCHITECTURE"
        echo "Recommended model: $RECOMMENDED_MODEL"
    elif [[ "$GPU_NAME" == *"RTX PRO 6000"* ]] || [[ "$COMPUTE_CAP" == "12.0" ]] || [[ "$COMPUTE_CAP" == "12."* ]]; then
        GPU_TYPE="BLACKWELL"
        ARCHITECTURE="Blackwell (sm_120)"
        CUDA_VERSION="cu128"
        PYTORCH_INDEX="https://download.pytorch.org/whl/cu128"
        RECOMMENDED_MODEL="llama3.1:8b or llama3.3:70b"
        echo "🚀 Detected: Blackwell GPU"
        echo "Architecture: $ARCHITECTURE"
        echo "Recommended models: $RECOMMENDED_MODEL"
    else
        GPU_TYPE="GENERIC"
        ARCHITECTURE="Unknown (sm_$COMPUTE_CAP)"
        CUDA_VERSION="cu121"
        PYTORCH_INDEX="https://download.pytorch.org/whl/cu121"
        RECOMMENDED_MODEL="llama3.1:8b or llama3.2:3b"
        echo "✓ Generic NVIDIA GPU"
        echo "Architecture: $ARCHITECTURE"
        echo "Recommended models: $RECOMMENDED_MODEL"
    fi
else
    echo "ERROR: nvidia-smi not found"
    echo "Please install NVIDIA drivers first"
    exit 1
fi

echo ""
echo "Step 2: Checking CUDA version..."
if command -v nvcc &> /dev/null; then
    CUDA_INSTALLED=$(nvcc --version | grep "release" | awk '{print $5}' | cut -d',' -f1)
    echo "CUDA toolkit version: $CUDA_INSTALLED"

    # Extract major.minor version
    CUDA_MAJOR=$(echo $CUDA_INSTALLED | cut -d'.' -f1)
    CUDA_MINOR=$(echo $CUDA_INSTALLED | cut -d'.' -f2)

    if [ "$GPU_TYPE" == "BLACKWELL" ]; then
        if [ "$CUDA_MAJOR" -lt 12 ] || ([ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]); then
            echo "WARNING: CUDA 12.8+ recommended for Blackwell GPUs"
            echo "Current version: $CUDA_INSTALLED"
            echo "Consider upgrading CUDA toolkit for best performance"
        fi
    elif [ "$GPU_TYPE" == "RTX4090" ]; then
        if [ "$CUDA_MAJOR" -lt 12 ] || ([ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 1 ]); then
            echo "WARNING: CUDA 12.1+ recommended for RTX 4090"
            echo "Current version: $CUDA_INSTALLED"
        fi
    fi
else
    echo "WARNING: nvcc not found - CUDA toolkit may not be installed"
    echo "Install CUDA toolkit from: https://developer.nvidia.com/cuda-downloads"
fi

echo ""
echo "Step 3: Uninstalling old PyTorch (if exists)..."
pip uninstall -y torch torchvision torchaudio 2>/dev/null || true

echo ""
echo "Step 4: Installing PyTorch with $CUDA_VERSION for $GPU_TYPE..."
echo "Using PyTorch index: $PYTORCH_INDEX"
echo "This may take several minutes..."
pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX

echo ""
echo "Step 5: Verifying PyTorch installation..."
python3 << EOF
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU 0: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    print(f"Compute capability: {props.major}.{props.minor}")
    print(f"Memory: {props.total_memory / 1e9:.2f} GB")

    # Test tensor on GPU
    try:
        x = torch.randn(1000, 1000).cuda()
        y = x @ x.T
        print("✓ GPU tensor operations working!")
    except Exception as e:
        print(f"✗ GPU test failed: {e}")
else:
    print("✗ CUDA not available!")
    exit(1)
EOF

if [ $? -ne 0 ]; then
    echo ""
    echo "ERROR: PyTorch verification failed!"
    echo ""
    echo "Troubleshooting:"

    if [ "$GPU_TYPE" == "BLACKWELL" ]; then
        echo "1. Ensure CUDA 12.8+ is installed for Blackwell:"
        echo "   https://developer.nvidia.com/cuda-downloads"
        echo ""
        echo "2. Check NVIDIA driver version:"
        echo "   nvidia-smi (should be 545+ for Blackwell)"
        echo ""
        echo "3. Try PyTorch nightly build:"
        echo "   pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128"
    elif [ "$GPU_TYPE" == "RTX4090" ]; then
        echo "1. Ensure CUDA 12.1+ is installed for RTX 4090:"
        echo "   https://developer.nvidia.com/cuda-downloads"
        echo ""
        echo "2. Check NVIDIA driver version:"
        echo "   nvidia-smi (should be 525+ for RTX 4090)"
        echo ""
        echo "3. Try PyTorch stable release:"
        echo "   pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121"
    else
        echo "1. Ensure CUDA 12.1+ is installed:"
        echo "   https://developer.nvidia.com/cuda-downloads"
        echo ""
        echo "2. Check NVIDIA driver version:"
        echo "   nvidia-smi"
        echo ""
        echo "3. Try PyTorch stable release:"
        echo "   pip install torch torchvision torchaudio --index-url $PYTORCH_INDEX"
    fi

    exit 1
fi

echo ""
echo "Step 6: Installing remaining dependencies..."
pip install -r requirements.txt --no-deps || true
pip install -r requirements.txt

echo ""
echo "Step 7: Final verification with check_gpu.py..."
python3 check_gpu.py

echo ""
echo "========================================"
echo "✓ Installation complete!"
echo "========================================"
echo ""

if [ "$GPU_TYPE" == "BLACKWELL" ]; then
    echo "Your RTX PRO 6000 Blackwell GPU is ready to use!"
    echo ""
    echo "Recommended model: llama3.1:8b or llama3.3:70b"
elif [ "$GPU_TYPE" == "RTX4090" ]; then
    echo "Your RTX 4090 GPU is ready to use!"
    echo ""
    echo "Recommended model: llama3.1:8b (llama3.3:70b requires 40GB VRAM)"
else
    echo "Your $GPU_NAME is ready to use!"
    echo ""
    echo "Recommended model: $RECOMMENDED_MODEL"
fi

echo ""
echo "To start the application:"
echo "  python3 LLMAPI.py"
echo ""
echo "To monitor GPU usage:"
echo "  nvidia-smi -l 1"
echo ""
