import argparse
import gzip
import os
import numpy as np
import random
from PIL import Image
import torch
import torchvision.transforms.v2 as transforms
import sklearn
# This is for saving the trained SVMs. We could use onnx for SVMs and DNNs, but that is slightly more work.
import pickle

def load_mnist_ubyte(image_path):
    """
    Loads MNIST images from the raw ubyte files.

    Args:
        image_path (str): Path to the image file (e.g., 'train-images-idx3-ubyte.gz').

    Returns:
        images (numpy.ndarray)
    """
    with gzip.open(image_path, 'rb') as f:
        # Read the header: magic number (4 bytes) + num images (4 bytes) +
        # num rows (4 bytes) + num cols (4 bytes) = 16 bytes.

        # Read the entire file content into a buffer
        image_data = f.read()

        # The image data starts at byte 16.
        images = np.frombuffer(image_data, dtype=np.uint8, offset=16)

        # We need the dimensions to reshape. We can extract them from the header bytes,
        # which are big-endian ('>'). We use struct.unpack if we were being strict,
        # but here we'll assume the standard MNIST format and calculate the dimensions
        # for a clean numpy approach.

        # The number of images is in the 5th to 8th byte (4 bytes)
        num_images = np.frombuffer(image_data, dtype='>i4', offset=4, count=1)[0]
        # Rows and columns are 28x28 for MNIST, stored in bytes 9-12 and 13-16.
        # num_rows = np.frombuffer(image_data, dtype='>i4', offset=8, count=1)[0]
        # num_cols = np.frombuffer(image_data, dtype='>i4', offset=12, count=1)[0]
        num_rows = 28
        num_cols = 28

        # Reshape the 1D array into a 3D array (num_images, rows, columns)
        images = images.reshape(num_images, num_rows, num_cols)
    return images

def load_mnist_labels(label_path):
    """
    Loads MNIST labels from the raw ubyte files.

    Args:
        label_path (str): Path to the label file (e.g., 'train-labels-idx1-ubyte.gz').

    Returns:
        labels (numpy.ndarray)
    """
    with gzip.open(label_path, 'rb') as f:
        # Read the header: magic number (4 bytes) + num items (4 bytes) = 8 bytes.
        # Skip these 8 bytes.

        # Read the entire file content into a buffer
        label_data = f.read()

        # The label data starts at byte 8. The data type is unsigned byte ('B' or np.uint8).
        # Labels are a 1D vector.
        labels = np.frombuffer(label_data, dtype=np.uint8, offset=8)

    return labels


class SquashingFunction(torch.nn.Module):
    def __init__(self):
        super(SquashingFunction, self).__init__()
        self.squash = torch.nn.Tanh()
        self.const = 1.7159

    def forward(self, x):
        return self.const * self.squash(x)


class ResidualBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, nonlinearity=torch.nn.ReLU, stride=1):
        super(ResidualBlock, self).__init__()
        self.residual = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(out_channels),
            nonlinearity(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            torch.nn.BatchNorm2d(out_channels),
        )
        if stride != 1 or in_channels != out_channels:
            self.shortcut = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                torch.nn.BatchNorm2d(out_channels))
        else:
            self.shortcut = torch.nn.Sequential()
        # The nonlinearity after summing the residual and shortcut
        self.nonlinearity = nonlinearity()

    def forward(self, x):
        out = self.residual(x)
        x = self.shortcut(x)
        return self.nonlinearity(out + x)

class ResNet(torch.nn.Module):
    """A mostly faithful recreation of LeNet 5."""

    def __init__(self, nonlinearity = torch.nn.ReLU):
        super(ResNet, self).__init__()
        self.net = torch.nn.Sequential(
                # 5x5 convolution with 8 output feature maps
                torch.nn.Conv2d(1, 16, kernel_size=5),
                torch.nn.BatchNorm2d(16),
                nonlinearity(),
                ## Now we are working with 28x28 feature maps
                ## 3 blocks per downscale, to 14x14, 7x7, 
                ResidualBlock(16, 16),
                ResidualBlock(16, 16),
                ResidualBlock(16, 32, stride=2),
                ResidualBlock(32, 32),
                ResidualBlock(32, 32),
                ResidualBlock(32, 64, stride=2),
                ResidualBlock(64, 64),
                ResidualBlock(64, 64),
                ResidualBlock(64, 128, stride=2),
                # A single average pool to reduce all feature channels to 1x1
                torch.nn.AdaptiveAvgPool2d((1, 1)),
                torch.nn.Flatten(),
                torch.nn.Linear(128, 84),
                # We are not going to try to recreate the original exemplar-based function in LeNet5
                #euclidean_rbf(84, 12)
                torch.nn.Linear(84, 10),
                )
        self.decision = torch.nn.Softmax(dim=1)

        torch.nn.init.uniform_(self.net[0].weight.data, a=-1, b=1)

    def features(self, x):
        # Go through the first 14 layers to extract a feature vector of size 128
        for i in range(14):
            x = self.net[i](x)
        return x

    def forward(self, x):
        """Forward through the network."""
        y_hat = self.decision(self.net(x))
        return y_hat


class LeNet5(torch.nn.Module):
    """A mostly faithful recreation of LeNet 5."""

    def __init__(self, nonlinearity = SquashingFunction):
        super(LeNet5, self).__init__()
        self.net = torch.nn.Sequential(
                # 5x5 convolution with 6 output feature maps
                torch.nn.Conv2d(1, 6, 5),
                # 2x2 subsampling learned bias and weight, called S2 in the paper.
                # We'll use an average pool and then a 1x1 conv with 6 groups to emulate that.
                torch.nn.AvgPool2d(kernel_size=2, stride=2),
                torch.nn.Conv2d(6, 6, kernel_size=1, groups=6),
                nonlinearity(),
                # 5x5 convolution with 6 output feature maps of size 5x5
                torch.nn.Conv2d(6, 16, kernel_size=5),
                # This again, emulating layer S4 from the paper.
                torch.nn.AvgPool2d(kernel_size=2, stride=2),
                torch.nn.Conv2d(16, 16, kernel_size=1, groups=16),
                nonlinearity(),
                # The final convolution reduces features to 1x1
                torch.nn.Conv2d(16, 120, kernel_size=5, stride=1),
                torch.nn.Flatten(),
                torch.nn.Linear(120, 84),
                # We are not going to try to recreate the original exemplar-based function in LeNet5
                #euclidean_rbf(84, 12)
                torch.nn.Linear(84, 10),
                )
        self.decision = torch.nn.Softmax(dim=1)

        torch.nn.init.uniform_(self.net[0].weight.data, a=-1, b=1)

    def features(self, x):
        # Go through the first 10 layers to extract a feature vector of size 120
        for i in range(10):
            x = self.net[i](x)
        return x

    def forward(self, x):
        """Forward through the network."""
        y_hat = self.decision(self.net(x))
        return y_hat


class Linear(torch.nn.Module):
    """A linear neural network."""

    def __init__(self, nonlinearity = torch.nn.ReLU):
        super(Linear, self).__init__()
        self.net = torch.nn.Sequential(
                torch.nn.Flatten(),
                torch.nn.Linear(1024, 2048),
                nonlinearity(),
                torch.nn.Linear(2048, 120),
                nonlinearity(),
                torch.nn.Linear(120, 84),
                torch.nn.Linear(84, 10)
                )
        self.decision = torch.nn.Softmax(dim=1)

        torch.nn.init.uniform_(self.net[1].weight.data, a=-1, b=1)

    def forward(self, x):
        """Forward through the network."""
        y_hat = self.decision(self.net(x))
        return y_hat


def preprocess(X_train, order, device):
    # normalize and then pad to 32x32
    # Images are 0 to 1.
    # Change so the background (white) became -0.1, and the foreground (black) became 1.175
    # Multiply by 1.275 to shift expand the range, and subtract from 1.175
    preprocessed = torch.tensor(1.175 - (1.275*X_train[order])).float()
    
    # Pad 2 on every side, changing the 28x28 to 32x32
    preprocessed = torch.nn.functional.pad(preprocessed, pad=(2,2,2,2))
    # Add a channel dimension
    return preprocessed.reshape((-1, 1, 32, 32)).to(device)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
        required=True,
        help="gzip file with mnist data.")
    parser.add_argument(
        "--test",
        required=True,
        help="gzip file with mnist data.")
    parser.add_argument(
        "--train_labels",
        required=True,
        help="gzip file with mnist labels.")
    parser.add_argument(
        "--test_labels",
        required=True,
        help="gzip file with mnist labels.")
    parser.add_argument(
        "--epochs",
        required=False,
        type=int,
        default=5,
        help="Number of epochs to train.")
    parser.add_argument(
        "--train_samples",
        required=False,
        type=int,
        default=60000,
        help="Number of samples to use for training.")
    parser.add_argument(
        "--save_mismatch",
        required=False,
        type=int,
        default=0,
        help="The number of mismatches to save.")
    parser.add_argument(
        "--batch_size",
        required=False,
        type=int,
        default=32,
        help="The batch size.")
    parser.add_argument(
        "--error_rate",
        required=False,
        type=float,
        default=0.0,
        help="The training label error rate.")
    parser.add_argument(
        "--model",
        required=False,
        default="lenet",
        type=str,
        help="Model type")
    parser.add_argument(
        "--save",
        required=False,
        default=None,
        type=str,
        help="Path to save the trained model")
    parser.add_argument(
        "--load",
        required=False,
        default=None,
        type=str,
        help="Path to load the trained model")
    parser.add_argument(
        "--random_seed",
        required=False,
        type=int,
        default=112,
        help="The random seed.")
    ####
    # These are the SVM Options
    parser.add_argument(
        "--use_svm",
        default=False,
        action='store_true',
        help="Use an SVM for final classification after training or model loading.")
    parser.add_argument(
        "--kernel",
        required=False,
        type=str,
        default='rbf',
        help="linear or rbf or poly")
    parser.add_argument(
        "--C",
        required=False,
        type=int,
        default=None,
        help="C value for svm soft margin. Defaults to 1 within scikit's implementation")
    parser.add_argument(
        "--gamma",
        required=False,
        type=float,
        default=0.1,
        help="Gamma for the rbf kernel")
    parser.add_argument(
        "--degree",
        required=False,
        type=int,
        default=None,
        help="Degree for the polynomial kernel (try 2)")
    parser.add_argument(
        "--coef0",
        required=False,
        type=float,
        default=None,
        help="Offset for the polynomial kernel (try 1)")
    parser.add_argument(
        "--save_svm",
        required=False,
        default=None,
        type=str,
        help="Path to save pickle of trained scikit svm.")
    parser.add_argument(
        "--load_svm",
        required=False,
        default=None,
        type=str,
        help="Path to load the pickle of the trained scikit svm.")

    args = parser.parse_args()

    np.random.default_rng(args.random_seed)
    torch.manual_seed(args.random_seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print("Loading data")
    X_train_digits = (load_mnist_ubyte(args.train)/255)[:args.train_samples]
    X_test_digits = load_mnist_ubyte(args.test)/255
    Y_train_digits = load_mnist_labels(args.train_labels)[:args.train_samples]
    Y_test_digits = load_mnist_labels(args.test_labels)
    # Save some digits for the homework
    for i in range (20):
        Image.fromarray((255*X_test_digits[i]).reshape((28, 28)).astype(np.uint8)).save(f"example_test_digit_{i}.png")
    print(f"Example classes test are {Y_test_digits[:20]}")
    # Save some digits for the homework
    for i in range (20):
        Image.fromarray((255*X_train_digits[i]).reshape((28, 28)).astype(np.uint8)).save(f"example_train_digit_{i}.png")
    print(f"Example classes are {Y_train_digits[:20]}")

    # Create the model
    if args.model == "lenet":
        model = LeNet5()
    elif args.model == "lenet_relu":
        model = LeNet5(nonlinearity=torch.nn.ReLU)
    elif args.model == "resnet":
        model = ResNet()
    elif args.model == "linear":
        model = Linear()

    # Don't shuffle the test data, but otherwise treat it the same as the training data.
    X_test = preprocess(X_test_digits, np.arange(X_test_digits.shape[0]), device)
    Y_test = torch.tensor(Y_test_digits).long().to(device)
    test_batch_size = 1000

    if args.error_rate > 0.0:
        # Insert errors into the training data at the given error rate
        total_errors = int(args.error_rate * len(Y_train_digits))
        to_change = random.choice(np.arange(len(Y_train_digits)), k=total_errors)
        possible_labels = []
        for original in np.arange(10):
            # The possible wrong labels are every value but the correct one
            possible_labels.append(list(numpy.arange(original)) + list(numpy.arange(original+1, 10)))
        for idx in to_change:
            original = Y_train_digits[idx]
            Y_train_digits[idx] = random.choice(possible_labels[original])

    # Shuffle and preprocess the training data
    order = np.arange(X_train_digits.shape[0])
    np.random.shuffle(order)

    X_train = preprocess(X_train_digits, order, device)
    Y_train = torch.tensor(Y_train_digits[order]).long().to(device)

    # Are we doing training, or just reloading?
    if args.load is not None:
        model.load_state_dict(torch.load(args.load, map_location=torch.device("cpu"), weights_only=True))
        model.to(device)
    else:
        model.to(device)
        
        # Authors used 0.0005 for two epochs, 0.0002 for the next 2, 0.0001 for the
        # next 3, 0.00005 for the next 4, and 0.00001 after.
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
        lr_steps = [2, 5, 9, 13]
        lr_schedule = [0.0002, 0.0001, 0.00005, 0.00001]

        criterion = torch.nn.CrossEntropyLoss()

        # See how many batches we'll use per epoch
        batches = int(np.ceil(X_train_digits.shape[0]/float(args.batch_size)))
        # We could just do this in one step, but let's assume that memory is finite
        test_batches = int(np.ceil(X_test_digits.shape[0]/float(test_batch_size)))

        print(f"Training on {X_train_digits.shape[0]} examples over {batches} batches")

        for epoch in range(args.epochs):

            # Update the learning rate (as described in the original paper)
            if epoch in lr_steps:
                lr_idx = lr_steps.index(epoch)
                optimizer.lr = lr_schedule[lr_idx]
            total_loss = 0.0
            model.train()
            for batch in range(batches):
                begin = batch*args.batch_size
                end = (batch+1)*args.batch_size

                X_batch = X_train[begin:end]
                Y_batch = Y_train[begin:end]

                # Zero gradients before gradient calculation
                optimizer.zero_grad()

                y_hat = model(X_batch)
                loss = criterion(y_hat, Y_batch)
                total_loss += loss.item() * X_batch.size(0)

                # Gradient calculation
                loss.backward()
                # Update weights
                optimizer.step()

            epoch_loss = total_loss / X_train.size(0)
            print(f"{args.train_samples} Epoch {epoch} training loss {epoch_loss}")

            # Evaluation
            # Don't calculate gradients during these steps
            model.eval()
            with torch.no_grad():
                total_loss = 0.0
                for batch in range(test_batches):
                    begin = batch*test_batch_size
                    end = (batch+1)*test_batch_size

                    X_batch = X_test[begin:end]
                    Y_batch = Y_test[begin:end]

                    y_hat = model(X_batch)
                    loss = criterion(y_hat, Y_batch)
                    total_loss += loss.item() * X_batch.size(0)
                epoch_loss = total_loss / X_test.size(0)
                print(f"{args.train_samples} Epoch {epoch} testing loss {epoch_loss}")
                ## Accuracy values
                # We can't just run over everything, that takes too much memory. Chop it up.
                matches = 0
                mismatches = 0
                for testbatch in range(int(np.ceil(X_train_digits.shape[0]/float(test_batch_size)))):
                    begin = testbatch*test_batch_size
                    end = (testbatch+1)*test_batch_size
                    y_hat = model(X_train[begin:end])
                    classes = torch.argmax(y_hat, dim=1)
                    matches += (classes == Y_train[begin:end])
                    mismatches += (classes != Y_train[begin:end])
                train_accuracy = torch.sum(matches)/X_train.size(0)
                matches = 0
                mismatches = 0
                for testbatch in range(int(np.ceil(X_test_digits.shape[0]/float(test_batch_size)))):
                    begin = testbatch*test_batch_size
                    end = (testbatch+1)*test_batch_size
                    y_hat = model(X_test[begin:end])
                    classes = torch.argmax(y_hat, dim=1)
                    matches += (classes == Y_test[begin:end])
                    mismatches += (classes != Y_test[begin:end])
                test_accuracy = torch.sum(matches)/X_test.size(0)
                print(f"{args.train_samples} Epoch {epoch} accuracies are {train_accuracy} {test_accuracy}")

    if args.save is not None:
        torch.save(model.state_dict(), args.save)

    # Final evaluation
    model.eval()
    with torch.no_grad():
        if args.use_svm:
            if args.load_svm:
                with open(args.load_svm, 'rb') as infile:
                    svm = pickle.load(infile)
            else:
                svm_args = {}
                for arg in ['kernel', 'gamma', 'degree', 'coef0', 'C']:
                    if None != getattr(args, arg):
                        svm_args[arg] = getattr(args, arg)
                svm = sklearn.svm.SVC(**svm_args)
                # Create feature vectors for training
                print("Building SVM inputs.")
                features = None
                for testbatch in range(int(np.ceil(X_train_digits.shape[0]/float(test_batch_size)))):
                    begin = testbatch*test_batch_size
                    end = (testbatch+1)*test_batch_size
                    vectors = model.features(X_train[begin:end]).cpu().numpy()
                    if features is None:
                        features = vectors
                    else:
                        features = np.concatenate((features, vectors))

                print("Training the SVM.")
                svm.fit(features, Y_train.cpu().numpy())

            if args.save_svm:
                with open(args.save_svm, 'wb') as out:
                    pickle.dump(svm, out)

            print("Building test inputs.")
            features = None
            for testbatch in range(int(np.ceil(X_test_digits.shape[0]/float(test_batch_size)))):
                begin = testbatch*test_batch_size
                end = (testbatch+1)*test_batch_size
                vectors = model.features(X_test[begin:end]).cpu().numpy()
                if features is None:
                    features = vectors
                else:
                    features = np.concatenate((features, vectors))

            print("Inference with the SVM.")
            results = svm.predict(features)
            Y_test = Y_test.cpu().numpy()
            matches = (results == Y_test)
            sum_matches = np.sum(matches)
            test_accuracy = sum_matches / len(Y_test)
        else:
            # DNN classification
            matches = 0
            mismatches = 0
            for testbatch in range(int(np.ceil(X_test_digits.shape[0]/float(test_batch_size)))):
                begin = testbatch*test_batch_size
                end = (testbatch+1)*test_batch_size
                y_hat = model(X_test[begin:end])
                classes = torch.argmax(y_hat, dim=1)
                matches += (classes == Y_test[begin:end])
                mismatches += (classes != Y_test[begin:end])
            sum_matches = torch.sum(matches)
            test_accuracy = sum_matches/X_test.size(0)

        print(f"{args.train_samples} Final accuracy {sum_matches}/{X_test.size(0)} ({test_accuracy})")



