import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torchvision.transforms as T
import torchvision.models as models
from torchvision.datasets import CocoDetection
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.models import ResNet18_Weights
import sys
import atexit

# === Setup for Multi-GPU & Multi-Node ===
def setup():
    """Initialize the distributed training environment using torchrun."""
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    local_rank = int(os.environ["LOCAL_RANK"])  # torchrun provides LOCAL_RANK automatically
    torch.cuda.set_device(local_rank) #Bind the process to the correct GPU.

    print(f"[GPU {local_rank}] Process group initialized. Rank {rank}/{world_size}")
    sys.stdout.flush()
    return rank, world_size, local_rank

# === Model Definition ===
#create the model
class ResNetForClassification(nn.Module):
    def __init__(self, num_classes):
        super(ResNetForClassification, self).__init__()
        base_model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.base_model = nn.Sequential(*list(base_model.children())[:-1])  # Remove FC layer
        self.classifier = nn.Linear(base_model.fc.in_features, num_classes)

    def forward(self, x):
        features = self.base_model(x).view(x.size(0), -1)
        return self.classifier(features)

# === Custom Dataset ===
class COCODatasetForClassification(CocoDetection):
    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        category_ids = [ann["category_id"] for ann in target]
        label = category_ids[0] if category_ids else 0

        transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        return transform(img), torch.tensor(label, dtype=torch.long)

# === DataLoader Helper ===
#create Dataset and DistributedSampler
def get_dataloader(image_path, annotation_path):
    dataset = COCODatasetForClassification(root=image_path, annFile=annotation_path)
    sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
    #batch sise =256
    loader = DataLoader(dataset, batch_size=256, sampler=sampler, num_workers=8,
                        pin_memory=True, prefetch_factor=4, persistent_workers=True, drop_last=False)
    return loader

# === Cleanup Function ===
def cleanup():
    if dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()

# Register cleanup at exit
atexit.register(cleanup)

# === Training Loop ===
def train(args):
    rank, world_size, local_rank = setup()
    device = torch.device(f"cuda:{local_rank}")

    train_loader = get_dataloader(args.coco_image_path, args.annotation_path)
    
    model = ResNetForClassification(num_classes=91).to(device)
    #wrap the model with DDP.  1 copy of same model for each GPU
    ddp_model = DDP(model, device_ids=[local_rank])
    #optimizer Adam with learning rate=1e-5
    optimizer = optim.Adam(ddp_model.parameters(), lr=1e-5)
    #Loss
    loss_fn = nn.CrossEntropyLoss()

    try:
        start_time = time.time()
        
        for epoch in range(args.num_epochs):
            #shuffle
            train_loader.sampler.set_epoch(epoch)
            ddp_model.train()
            epoch_loss = 0.0
            start_epoch = time.time()
            seen_samples = 0
            for batch_idx, (images, labels) in enumerate(train_loader):
                images, labels = images.to(device), labels.to(device)

                optimizer.zero_grad()
                #forward
                outputs = ddp_model(images)
                #loss
                loss = loss_fn(outputs, labels)
                #backward
                loss.backward()
                #update
                optimizer.step()

                epoch_loss += loss.item()
                
                seen_samples += images.size(0)
                
                #prints
                if batch_idx % 100 == 0:
                    print(f"[GPU {local_rank}] Epoch [{epoch+1}/{args.num_epochs}], "
                          f"Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
                    sys.stdout.flush()

            print(f"[GPU {local_rank}] Epoch {epoch+1} completed. Loss: {epoch_loss:.4f}")
            sys.stdout.flush()
            
            torch.cuda.synchronize(device)
            epoch_time = time.time() - start_epoch
            samples_t = torch.tensor([seen_samples], device=device, dtype=torch.float64)
            time_t    = torch.tensor([epoch_time],    device=device, dtype=torch.float64)
            if dist.is_initialized():
                dist.all_reduce(samples_t, op=dist.ReduceOp.SUM)
                dist.all_reduce(time_t,    op=dist.ReduceOp.MAX)    
            if rank == 0:
                ips = samples_t.item() / time_t.item()
                print(f"[THROUGHPUT] Epoch {epoch+1}: {ips:.1f} img/s (global, {world_size} proc)")

        #save checkpoint    
        if rank == 0:
            torch.save(ddp_model.state_dict(), "demo_coco_resnet18_ddp.pth")

        print(f"[GPU {local_rank}] Training finished. Total time: {(time.time() - start_time) / 60:.2f} minutes")

    #cleanup    
    finally:
        dist.barrier(device_ids=[local_rank])
        if dist.is_initialized():
            del ddp_model  # Release DDP model
            dist.destroy_process_group()
        time.sleep(2)
        os._exit(0)

# === Main Function ===
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs")
    parser.add_argument("--coco_image_path", type=str, required=True, help="Path to COCO images")
    parser.add_argument("--annotation_path", type=str, required=True, help="Path to COCO annotations")
    
    args = parser.parse_args()
    print("resnet18 started")
    train(args)

if __name__ == "__main__":
    print("Hello from rank", __import__("os").getenv("RANK", "unknown"))
    main()
