import nibabel as nib
import numpy as np
import pmcxcl as pmcx
import napari
import time

# --- Load MINC volume ---
minc_file = r"C:\Users\subha\Downloads\phantom_1.0mm_normal_crisp.mnc\phantom_1.0mm_normal_crisp.mnc"
minc_img = nib.load(minc_file)
vol_data = minc_img.get_fdata()
print(f"Original volume shape: {vol_data.shape}")

# =====================================================
# CREATE SEGMENTATION FIRST, THEN PAD
# =====================================================
layer_names = ['Air', 'Scalp', 'Skull', 'CSF', 'Gray Matter', 'White Matter']

# CREATE ORIGINAL SEGMENTATION
volume_original = np.zeros(vol_data.shape, dtype=np.uint8)
volume_original[vol_data < 0.5] = 0
volume_original[(vol_data >= 0.5) & (vol_data < 2.0)] = 1
volume_original[(vol_data >= 2.0) & (vol_data < 4.5)] = 2
volume_original[(vol_data >= 4.5) & (vol_data < 5.5)] = 3
volume_original[(vol_data >= 5.5) & (vol_data < 7.5)] = 4
volume_original[vol_data >= 7.5] = 5

# PAD WITH AIR
pad_size = 30
volume = np.zeros((vol_data.shape[0] + 2*pad_size, 
                   vol_data.shape[1] + 2*pad_size, 
                   vol_data.shape[2] + 2*pad_size), dtype=np.uint8)
volume[pad_size:-pad_size, pad_size:-pad_size, pad_size:-pad_size] = volume_original

print(f"✅ Padded volume shape: {volume.shape}")

optprop_master = [
    [0.0, 0.0, 1.0, 1.0],       # Air
    [0.019, 7.8, 0.89, 1.37],   # Scalp
    [0.06, 4.98, 0.80, 1.37],   # Skull
    [0.004, 0.009, 0.89, 1.37], # CSF
    [0.02, 9.0, 0.89, 1.37],    # Gray Matter
    [0.08, 40.9, 0.84, 1.37]    # White Matter
]

# =====================================================
# SOURCE & NEAR-SOURCE DETECTORS
# =====================================================
srcpos_air = [-50, 120, volume.shape[2] - 10]  # z=231 (AIR)
print(f"✅ Source in AIR: {srcpos_air} (layer {volume[tuple(srcpos_air)]})")

# ✅ DETECTORS RIGHT NEXT TO SOURCE (5-15 voxels away)
detectors = [
    [130, 120, 225, 1.0],      # Directly below (6 voxels down)
    [135, 120, 228, 1.0],      # Right, slightly down (3 voxels)
    [125, 120, 228, 1.0],      # Left, slightly down
    [130, 125, 228, 1.0],      # Forward, slightly down
    [130, 115, 228, 1.0],      # Backward, slightly down
    [132, 122, 230, 1.0],      # Diagonal close
    [128, 118, 230, 1.0],      # Diagonal close
    [130, 120, 230, 1.0],      # 1 voxel down
    [133, 123, 229, 1.0]       # Very close diagonal
]

# Verify positions
print("✅ NEAR-SOURCE DETECTOR POSITIONS:")
for i, det in enumerate(detectors):
    y, x, z = det[:3]
    layer_idx = volume[y, x, z]
    layer_name = layer_names[layer_idx] if layer_idx < len(layer_names) else f"Layer{layer_idx}"
    dist = np.sqrt((y-130)**2 + (x-120)**2 + (z-231)**2)
    print(f"  DET{i+1}: [{y:3d},{x:3d},{z:3d}] → {layer_name} (dist={dist:.1f} voxels)")

# =====================================================
# TIME-DOMAIN SIMULATION
# =====================================================
cfg = {
    'nphoton': 1000000,
    'vol': volume,
    'srcpos': srcpos_air,
    'srcdir': [0, -1, 0],
    'srctype': 'isotropic',
    'prop': optprop_master,
    'detpos': detectors,
    'seed': 12345,
    'tstart': 0,
    'tend': 2e-9,        # Short time for near-surface
    'tstep': 1e-10,      
    'maxgate': 20,
    'detectorradius': 2
}

print("\n=== RUNNING: AIR SOURCE → NEAR-SOURCE DETECTORS ===")
start_time = time.time()

#result = pmcx.run(cfg)
sim_time = time.time() - start_time

print(f"✅ Simulation completed in {sim_time:.2f} seconds")
print(f"⚡ Speed: {cfg['nphoton']/1e6/sim_time:.1f} M photons/second")

# =====================================================
# PROCESS RESULTS
# =====================================================
print("\n🔍 RESULT KEYS:", list(result.keys()))
flux = result['flux']
if flux.ndim == 4:
    flux_total = flux.sum(axis=-1)
else:
    flux_total = flux

print(f"Flux stats: min={flux_total.min():.2e}, max={flux_total.max():.2e}")

# =====================================================
# PRINT FLUENCE DATA AT NEAR-SOURCE DETECTORS
# =====================================================
print("\n" + "="*100)
print("📊 NEAR-SOURCE DETECTOR FLUENCE & COUNTS")
print("="*100)
print(f"{'DET':<4} {'POS':<12} {'LAYER':<12} {'DIST':>6} {'PHOTONS':>10} {'FLUENCE':>14} {'RATE':>7}")
print("-"*100)

fluence_data = []
total_detected = 0

for i, det in enumerate(detectors):
    y, x, z = det[:3]
    
    # Safe indexing
    if (0 <= y < volume.shape[0] and 0 <= x < volume.shape[1] and 0 <= z < volume.shape[2]):
        layer_idx = volume[y, x, z]
        layer_name = layer_names[layer_idx] if layer_idx < len(layer_names) else f"Layer{layer_idx}"
        fluence_val = flux_total[y, x, z]
    else:
        layer_name = "INVALID"
        fluence_val = 0.0
    
    photons = 0
    if 'ncount' in result:
        ncount = np.array(result['ncount'])
        if i < len(ncount):
            photons = int(ncount[i])
            total_detected += photons
    
    dist = np.sqrt((y-130)**2 + (x-120)**2 + (z-231)**2)
    rate = (photons / cfg['nphoton'] * 100) if cfg['nphoton'] > 0 else 0
    
    fluence_data.append({
        'detector': i+1, 'position': f"[{y},{x},{z}]", 
        'layer': layer_name, 'photons': photons, 'fluence': float(fluence_val)
    })
    
    print(f"{i+1:<4} {f'[{y},{x},{z}]':<12} {layer_name:<12} {dist:>5.1f} "
          f"{photons:>10,} {fluence_val:>13.2e} {rate:>6.2f}%")

# Summary
detection_rate = total_detected / cfg['nphoton'] * 100
print(f"\n{'='*100}")
print(f"TOTAL: {cfg['nphoton']:,} launched | {total_detected:,} detected ({detection_rate:.2f}%)")

# =====================================================
# LAYER SUMMARY
# =====================================================
print("\n📈 NEAR-SOURCE LAYER SUMMARY:")
print(f"{'LAYER':<12} {'N_DET':<6} {'PHOTONS':>12} {'RATE%':>8} {'AVG_FLUENCE':>14}")
print("-"*65)

layer_summary = {}
for d in fluence_data:
    layer = d['layer']
    if layer not in layer_summary:
        layer_summary[layer] = {'count': 0, 'photons': 0, 'fluence': []}
    layer_summary[layer]['count'] += 1
    layer_summary[layer]['photons'] += d['photons']
    if d['fluence'] > 0:
        layer_summary[layer]['fluence'].append(d['fluence'])

for layer, stats in layer_summary.items():
    avg_fluence = np.mean(stats['fluence']) if stats['fluence'] else 0
    rate = stats['photons'] / cfg['nphoton'] * 100
    print(f"{layer:<12} {stats['count']:<6} {stats['photons']:>12,} {rate:>7.2f}% {avg_fluence:>13.2e}")

# =====================================================
# VISUALIZATION
# =====================================================
viewer = napari.Viewer(ndisplay=3)
viewer.add_image(flux_total, name=f'Near-Source Fluence ({sim_time:.1f}s)', colormap='plasma', opacity=0.9)
viewer.add_labels(volume, name='Brain Layers', opacity=0.7)
viewer.add_points([srcpos_air], name='AIR Source', size=30, face_color='red')
viewer.add_points([d[:3] for d in detectors], name='NEAR DETECTORS', size=20, face_color='lime')

print(f"\n🎉 NEAR-SOURCE SIMULATION COMPLETE! ({sim_time:.2f}s)")
print("✅ Detectors 3-6 voxels from source (scalp/air boundary)")
print("✅ Expect 1-10% detection rate!")
print("✅ Full fluence & count data above!")




