﻿using Microsoft.ML;
using SpamClassifier.Core.Abstractions;
using SpamClassifier.Core.Models;

namespace SpamClassifier.Core.Services;

public class SmsSpamClassifierService : ISpamClassifier
{
    private readonly MLContext _mlContext;
    private ITransformer? _model;
    private PredictionEngine<SmsMessageData, TextMessagePrediction>? _predictionEngine;

    public SmsSpamClassifierService()
    {
        _mlContext = new MLContext(seed: 1); // hardcoded seed for reproducibility
    }

    public void Train(string dataPath)
    {
        if (string.IsNullOrWhiteSpace(dataPath))
            throw new ArgumentException("Data path is required.", nameof(dataPath));

        var data = _mlContext.Data.LoadFromTextFile<SmsMessageData>(
            path: dataPath,
            hasHeader: true,
            separatorChar: ',');

        var split = _mlContext.Data.TrainTestSplit(data, testFraction: 0.2);

        var pipeline =
            _mlContext.Transforms.Conversion
                .MapValue(
                    outputColumnName: "Label",
                    keyValuePairs: new[]
                    {
                        new KeyValuePair<string, bool>("ham", false),
                        new KeyValuePair<string, bool>("spam", true)
                    },
                    inputColumnName: nameof(SmsMessageData.LabelText))
            .Append(
                _mlContext.Transforms.Text.FeaturizeText(
                    outputColumnName: "Features",
                    inputColumnName: nameof(SmsMessageData.Text)))
            .Append(
                _mlContext.BinaryClassification.Trainers
                    .SdcaLogisticRegression());

        _model = pipeline.Fit(split.TrainSet);

        var predictions = _model.Transform(split.TestSet);
        var metrics = _mlContext.BinaryClassification.Evaluate(predictions);

        Console.WriteLine("=== SMS model evaluation ===");
        Console.WriteLine($"Accuracy: {metrics.Accuracy:P2}");
        Console.WriteLine($"F1 Score: {metrics.F1Score:P2}");
        Console.WriteLine();

        _predictionEngine =
            _mlContext.Model.CreatePredictionEngine<SmsMessageData, TextMessagePrediction>(_model);
    }

    public TextMessagePrediction Predict(string text)
    {
        if (_predictionEngine == null)
            throw new InvalidOperationException("Model not trained. Call Train() first.");

        var input = new SmsMessageData
        {
            Text = text,
        };

        return _predictionEngine.Predict(input);
    }
}
