#!/bin/bash

# NVIDIA Container Toolkit Installation Script
# Source: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html

set -e

echo "=== NVIDIA Container Toolkit Installation Script ==="
echo ""

# Check if running as root or with sudo
if [ "$EUID" -ne 0 ]; then
    echo "Please run this script with sudo: sudo ./cuda-container.sh"
    exit 1
fi

# ============================================
# Pre-check: Verify NVIDIA driver is installed
# ============================================
echo "[Pre-check] Verifying NVIDIA driver..."

# Handle WSL2: nvidia-smi is in /usr/lib/wsl/lib/
if [ -f /usr/lib/wsl/lib/nvidia-smi ]; then
    export PATH="/usr/lib/wsl/lib:$PATH"
    echo "WSL2 detected, added /usr/lib/wsl/lib to PATH"
fi

if ! command -v nvidia-smi &> /dev/null; then
    echo "ERROR: nvidia-smi not found. Please install NVIDIA drivers first."
    echo "       Visit: https://www.nvidia.com/Download/index.aspx"
    exit 1
fi

nvidia-smi --query-gpu=driver_version,name --format=csv,noheader
echo ""

# ============================================
# Pre-check: Verify Docker is installed
# ============================================
echo "[Pre-check] Verifying Docker installation..."
if ! command -v docker &> /dev/null; then
    echo "ERROR: Docker is not installed. Please run install-docker.sh first."
    exit 1
fi

docker --version
echo ""

# ============================================
# STEP 1: Install prerequisites
# ============================================
echo "[1/4] Installing prerequisites..."
apt-get update
apt-get install -y curl gnupg2

# ============================================
# STEP 2: Add NVIDIA Container Toolkit repository
# ============================================
echo ""
echo "[2/4] Adding NVIDIA Container Toolkit repository..."

# Add GPG key
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | \
    gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg --yes

# Add repository
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
    sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
    tee /etc/apt/sources.list.d/nvidia-container-toolkit.list > /dev/null

apt-get update

# ============================================
# STEP 3: Install NVIDIA Container Toolkit
# ============================================
echo ""
echo "[3/4] Installing NVIDIA Container Toolkit..."
apt-get install -y nvidia-container-toolkit

# ============================================
# STEP 4: Configure Docker to use NVIDIA runtime
# ============================================
echo ""
echo "[4/4] Configuring Docker to use NVIDIA runtime..."
nvidia-ctk runtime configure --runtime=docker
systemctl restart docker

# ============================================
# DONE
# ============================================
echo ""
echo "=========================================="
echo "  NVIDIA Container Toolkit Installation Complete!"
echo "=========================================="
echo ""

# Detect max supported CUDA version from driver
DRIVER_VER=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader 2>/dev/null | head -1 || echo "unknown")
CUDA_VER=$(nvidia-smi 2>/dev/null | grep "CUDA Version" | awk '{print $9}' || echo "unknown")

echo "Detected NVIDIA Driver: $DRIVER_VER"
echo "Max supported CUDA Version: $CUDA_VER"
echo ""
echo "To verify GPU support with CUDA 12.9, run:"
echo "  docker run --rm --gpus all nvidia/cuda:12.9.0-base-ubuntu22.04 nvidia-smi"
echo ""
echo "Alternative CUDA 12.9 images:"
echo "  nvidia/cuda:12.9.0-base-ubuntu24.04"
echo "  nvidia/cuda:12.9.0-runtime-ubuntu22.04"
echo "  nvidia/cuda:12.9.0-devel-ubuntu22.04"
echo ""
echo "Browse all available tags: https://hub.docker.com/r/nvidia/cuda/tags"
echo ""
