On device training of text classifier model

I have made a text classifier model but I want to train it on device too. When text is classified wrong, user can make update the model on device.

Code : // // SpamClassifierHelper.swift // LearningML // // Created by Himan Dhawan on 7/1/24. //

import Foundation import CreateMLComponents import CoreML import NaturalLanguage

enum TextClassifier : String { case spam = "spam" case notASpam = "ham" }

class SpamClassifierModel {

// MARK: - Private Type Properties
/// The updated Spam Classifier model.
private static var updatedSpamClassifier: SpamClassifier?
/// The default Spam Classifier model.
private static var defaultSpamClassifier: SpamClassifier {
    do {
        return try SpamClassifier(configuration: .init())
    } catch {
        fatalError("Couldn't load SpamClassifier due to: \(error.localizedDescription)")
    }
}

// The Spam Classifier model currently in use.
static var liveModel: SpamClassifier {
    updatedSpamClassifier ?? defaultSpamClassifier
}

/// The location of the app's Application Support directory for the user.
private static let appDirectory = FileManager.default.urls(for: .applicationSupportDirectory,
                                                           in: .userDomainMask).first!

class var urlOfModelInThisBundle : URL {
    let bundle = Bundle(for: self)
    return bundle.url(forResource: "SpamClassifier", withExtension:"mlmodelc")!
}

/// The default Spam Classifier model's file URL.
private static let defaultModelURL = urlOfModelInThisBundle
/// The permanent location of the updated Spam Classifier model.
private static var updatedModelURL = appDirectory.appendingPathComponent("personalized.mlmodelc")
/// The temporary location of the updated Spam Classifier model.
private static var tempUpdatedModelURL = appDirectory.appendingPathComponent("personalized_tmp.mlmodelc")

// MARK: - Public Type Methods
static func predictLabelFor(_ value: String) throws -> (predication :String?, confidence : String) {
    let spam = try NLModel(mlModel: liveModel.model)
    let result = spam.predictedLabel(for: value)
    let confidence = spam.predictedLabelHypotheses(for: value, maximumCount: 1).first?.value ?? 0
    
    return (result,String(format: "%.2f", confidence * 100))
    
}

static func updateModel(newEntryText : String, spam : TextClassifier) throws {
    
    guard let modelURL = Bundle.main.url(forResource: "SpamClassifier", withExtension: "mlmodelc") else {
        fatalError("Could not find model in bundle")
    }

    // Create feature provider for the new image
    let featureProvider = try MLDictionaryFeatureProvider(dictionary: ["label": MLFeatureValue(string: newEntryText), "text": MLFeatureValue(string: spam.rawValue)])

    let batchProvider = MLArrayBatchProvider(array: [featureProvider])

    let updateTask = try MLUpdateTask(forModelAt: modelURL, trainingData: batchProvider, configuration: nil, completionHandler: { context in
        let updatedModel = context.model
        let fileManager = FileManager.default
        do {
            // Create a directory for the updated model.
            try fileManager.createDirectory(at: tempUpdatedModelURL,
                                            withIntermediateDirectories: true,
                                            attributes: nil)
            
            // Save the updated model to temporary filename.
            try updatedModel.write(to: tempUpdatedModelURL)
            
            // Replace any previously updated model with this one.
            _ = try fileManager.replaceItemAt(updatedModelURL,
                                              withItemAt: tempUpdatedModelURL)
            
            loadUpdatedModel()
            
            print("Updated model saved to:\n\t\(updatedModelURL)")
        } catch let error {
            print("Could not save updated model to the file system: \(error)")
            return
        }
    })

    updateTask.resume()
}

/// Loads the updated Spam Classifier, if available.
/// - Tag: LoadUpdatedModel
private static func loadUpdatedModel() {
    guard FileManager.default.fileExists(atPath: updatedModelURL.path) else {
        // The updated model is not present at its designated path.
        return
    }
    
    // Create an instance of the updated model.
    guard let model = try? SpamClassifier(contentsOf: updatedModelURL) else {
        return
    }
    
    // Use this updated model to make predictions in the future.
    updatedSpamClassifier = model
}

}

You can directly use CreateML's MLTextClassifier API, available on iOS, to create text classifier. MLUpdateTask does not support text classifier.

On device training of text classifier model
 
 
Q