﻿using Microsoft.ML;
using SpamClassifier.Core.Abstractions;
using SpamClassifier.Core.Models;

namespace SpamClassifier.Core.Services;

public class MailSpamClassifierService : ISpamClassifier
{
    private readonly MLContext _mlContext;
    private ITransformer? _model;
    private PredictionEngine<TextMessageData, TextMessagePrediction>? _predictionEngine;

    public MailSpamClassifierService()
    {
        _mlContext = new MLContext(seed: 1); // hardcoded seed for reproducibility
    }

    public void Train(string dataPath)
    {
        if (string.IsNullOrWhiteSpace(dataPath))
            throw new ArgumentException("Data path is required.", nameof(dataPath));

        IDataView data = _mlContext.Data.LoadFromTextFile<TextMessageData>(
            path: dataPath,
            hasHeader: true,
            separatorChar: ',');

        var split = _mlContext.Data.TrainTestSplit(data, testFraction: 0.2);

        var pipeline =
            _mlContext.Transforms.Text.FeaturizeText(
                outputColumnName: "Features",
                inputColumnName: nameof(TextMessageData.Text))
            .Append(
                _mlContext.BinaryClassification.Trainers
                    .SdcaLogisticRegression());

        _model = pipeline.Fit(split.TrainSet);

        var predictions = _model.Transform(split.TestSet);
        var metrics = _mlContext.BinaryClassification.Evaluate(predictions);

        Console.WriteLine("=== Model evaluation ===");
        Console.WriteLine($"Accuracy: {metrics.Accuracy:P2}");
        Console.WriteLine($"F1 Score: {metrics.F1Score:P2}");
        Console.WriteLine();

        _predictionEngine =
            _mlContext.Model.CreatePredictionEngine<TextMessageData, TextMessagePrediction>(_model);
    }

    public TextMessagePrediction Predict(string text)
    {
        if (_predictionEngine == null)
            throw new InvalidOperationException("Model not trained. Call Train() first.");

        var input = new TextMessageData
        {
            Text = text
        };

        var result = _predictionEngine.Predict(input);

        return new TextMessagePrediction
        {
            PredictedLabel = result.PredictedLabel,
            Probability = result.Probability
        };
    }
}
