# %%
import onnxruntime as ort
import numpy as np
import torch
from pathlib import Path
from viscy.transforms import NormalizeSampled


def test_onnx_model():
    """
    Test the ONNX model by loading it and running inference with a sample input.

    The model is expected to be a DynaCLR model for microglia analysis.
    Input shape: (batch_size, channels, z_slice, height, width)
    Based on training configuration:
    - Patch size: 256x256
    - z_slice: 1
    - Channel: Phase3D
    """
    # Load ONNX model
    model_path = Path(__file__).parent / "dynaclr_microglia.onnx"
    session = ort.InferenceSession(str(model_path))

    # Get model input details
    input_name = session.get_inputs()[0].name
    input_shape = session.get_inputs()[0].shape

    # Print model input specifications
    print("Model input specifications:")
    print(f"Input name: {input_name}")
    print(f"Input shape: {input_shape}")
    print(f"Input type: {session.get_inputs()[0].type}")

    # Create a sample input tensor with proper shape and normalization
    # Using training configuration parameters
    batch_size = 1
    channels = 1  # Phase3D channel
    z_slice = 1  # Typical sequence length for microglia tracking
    height = 256  # Final patch size from training
    width = 256  # Final patch size from training

    # Create random input with proper shape
    # Note: Changed order to (batch, channels, z_slice, height, width)
    sample_input = np.random.randn(batch_size, channels, z_slice, height, width).astype(
        np.float32
    )

    # Apply normalization similar to training
    # Using mean and std normalization as in the training config
    mean = sample_input.mean()
    std = sample_input.std()
    sample_input = (sample_input - mean) / (std + 1e-8)

    print(f"\nCreated input tensor with shape: {sample_input.shape}")
    print(f"Input statistics after normalization:")
    print(f"  Mean: {sample_input.mean():.4f}")
    print(f"  Std: {sample_input.std():.4f}")

    # Run inference
    outputs = session.run(None, {input_name: sample_input})

    # Print model information
    print(f"\nNumber of outputs: {len(outputs)}")
    for i, output in enumerate(outputs):
        print(f"Output {i} shape: {output.shape}")
        print(f"Output {i} statistics:")
        print(f"  Mean: {output.mean():.4f}")
        print(f"  Std: {output.std():.4f}")
        print(f"  Min: {output.min():.4f}")
        print(f"  Max: {output.max():.4f}")


if __name__ == "__main__":
    test_onnx_model()

# %%
