En muchas ocasiones, mientras trabaja con la biblioteca scikit-learn , deberá guardar sus modelos de predicción en un archivo y luego restaurarlos para reutilizar su trabajo anterior para: probar su modelo con datos nuevos, comparar varios modelos o Algo más. Este procedimiento de guardado también se conoce como serialización de objetos: representa un objeto con un flujo de bytes para almacenarlo en el disco, enviarlo a través de una red o guardarlo en una base de datos, mientras que el procedimiento de restauración se conoce como deserialización. En este artículo, analizamos tres formas posibles de hacer esto en Python y scikit-learn, cada una presentada con sus pros y sus contras.
Herramientas para guardar y restaurar modelos
La primera herramienta que describimos es Pickle , la herramienta estándar de Python para la (des) serialización de objetos. Luego, miramos la biblioteca Joblib que ofrece (des) serialización fácil de objetos que contienen matrices de datos grandes, y finalmente presentamos un enfoque manual para guardar y restaurar objetos hacia / desde JSON (JavaScript Object Notation). Ninguno de estos enfoques representa una solución óptima, pero se debe elegir el ajuste correcto de acuerdo con las necesidades de su proyecto.
Inicialización del modelo
Inicialmente, creemos un modelo de scikit-learn. En nuestro ejemplo usaremos un modelo de regresión logística y el conjunto de datos Iris .
Vamos a importar las bibliotecas necesarias, cargar los datos y dividirlos en conjuntos de prueba y entrenamiento.
from sklearn.linear_model import LogisticRegression from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split # Load and split data data = load_iris() Xtrain, Xtest, Ytrain, Ytest = train_test_split(data.data, data.target, test_size=0.3, random_state=4)
Ahora creemos el modelo con algunos parámetros no predeterminados y ajustémoslo a los datos de entrenamiento. Suponemos que ha encontrado previamente los parámetros óptimos del modelo, es decir, los que producen la mayor precisión estimada.
# Create a model model = LogisticRegression(C=0.1, max_iter=20, fit_intercept=True, n_jobs=3, solver='liblinear') model.fit(Xtrain, Ytrain)
Y nuestro modelo resultante:
LogisticRegression(C=0.1, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, max_iter=20, multi_class='ovr', n_jobs=3, penalty='l2', random_state=None, solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
Usando el fit
método, el modelo ha aprendido sus coeficientes que están almacenados en model.coef_
. El objetivo es guardar los parámetros y los coeficientes del modelo en un archivo, por lo que no es necesario repetir el entrenamiento del modelo y los pasos de optimización de los parámetros nuevamente con datos nuevos.
Módulo pickle
En las siguientes líneas de código, el modelo que creamos en el paso anterior se guarda en un archivo y luego se carga como un nuevo objeto llamado pickled_model
. A continuación, el modelo cargado se utiliza para calcular la puntuación de precisión y predecir los resultados sobre nuevos datos no vistos (de prueba).
import pickle # # Create your model here (same as above) # # Save to file in the current working directory pkl_filename = "pickle_model.pkl" with open(pkl_filename, 'wb') as file: pickle.dump(model, file) # Load from file with open(pkl_filename, 'rb') as file: pickle_model = pickle.load(file) # Calculate the accuracy score and predict target values score = pickle_model.score(Xtest, Ytest) print("Test score: {0:.2f} %".format(100 * score)) Ypredict = pickle_model.predict(Xtest)
La ejecución de este código debería generar su puntuación y guardar el modelo a través de Pickle:
$ python save_model_pickle.py Test score: 91.11 %
Lo mejor de usar Pickle para guardar y restaurar nuestros modelos de aprendizaje es que es rápido: puede hacerlo en dos líneas de código. Es útil si ha optimizado los parámetros del modelo en los datos de entrenamiento, por lo que no necesita repetir este paso nuevamente. De todos modos, no guarda los resultados de la prueba ni ningún dato. Aún así, puede hacer esto guardando una tupla, o una lista, de varios objetos (y recuerde qué objeto va a dónde), de la siguiente manera:
tuple_objects = (model, Xtrain, Ytrain, score) # Save tuple pickle.dump(tuple_objects, open("tuple_model.pkl", 'wb')) # Restore tuple pickled_model, pickled_Xtrain, pickled_Ytrain, pickled_score = pickle.load(open("tuple_model.pkl", 'rb'))
Módulo Joblib
La biblioteca Joblib está destinada a ser un reemplazo de Pickle, para objetos que contienen datos grandes. Repetiremos el procedimiento de guardar y restaurar como con Pickle.
from sklearn.externals import joblib # Save to file in the current working directory joblib_file = "joblib_model.pkl" joblib.dump(model, joblib_file) # Load from file joblib_model = joblib.load(joblib_file) # Calculate the accuracy and predictions score = joblib_model.score(Xtest, Ytest) print("Test score: {0:.2f} %".format(100 * score)) Ypredict = pickle_model.predict(Xtest)
$ python save_model_joblib.py Test score: 91.11 %
Como se ve en el ejemplo, la biblioteca Joblib ofrece un flujo de trabajo un poco más simple en comparación con Pickle. Si bien Pickle requiere que se pase un objeto de archivo como argumento, Joblib funciona tanto con objetos de archivo como con nombres de archivo de cadena. En caso de que su modelo contenga grandes conjuntos de datos, cada conjunto se almacenará en un archivo separado, pero el procedimiento de guardar y restaurar seguirá siendo el mismo. Joblib también permite diferentes métodos de compresión, como 'zlib', 'gzip', 'bz2' y diferentes niveles de compresión.
Guardar y restaurar manualmente a JSON
Dependiendo de su proyecto, muchas veces encontrará Pickle y Joblib como soluciones inadecuadas. De todos modos, siempre que desee tener un control total sobre el proceso de guardar y restaurar, la mejor manera es crear sus propias funciones manualmente.
A continuación, se muestra un ejemplo de cómo guardar y restaurar objetos manualmente mediante JSON. Este enfoque nos permite seleccionar los datos que deben guardarse, como los parámetros del modelo, los coeficientes, los datos de entrenamiento y cualquier otra cosa que necesitemos.
Dado que queremos guardar todos estos datos en un solo objeto, una forma posible de hacerlo es crear una nueva clase que herede de la clase modelo, que en nuestro ejemplo es LogisticRegression
. La nueva clase, llamada MyLogReg
, implementa los métodos save_json
y load_json
para guardar y restaurar a / desde un archivo JSON, respectivamente.
Para simplificar, guardaremos solo tres parámetros del modelo y los datos de entrenamiento. Algunos datos adicionales que podríamos almacenar con este enfoque son, por ejemplo, una puntuación de validación cruzada en el conjunto de entrenamiento, datos de prueba, puntuación de precisión en los datos de prueba, etc.
import json import numpy as np class MyLogReg(LogisticRegression): # Override the class constructor def __init__(self, C=1.0, solver='liblinear', max_iter=100, X_train=None, Y_train=None): LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter) self.X_train = X_train self.Y_train = Y_train # A method for saving object data to JSON file def save_json(self, filepath): dict_ = {} dict_['C'] = self.C dict_['max_iter'] = self.max_iter dict_['solver'] = self.solver dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None' dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None' # Creat json and save to file json_txt = json.dumps(dict_, indent=4) with open(filepath, 'w') as file: file.write(json_txt) # A method for loading data from JSON file def load_json(self, filepath): with open(filepath, 'r') as file: dict_ = json.load(file) self.C = dict_['C'] self.max_iter = dict_['max_iter'] self.solver = dict_['solver'] self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None
Ahora probemos la MyLogReg
clase. Primero creamos un objeto mylogreg
, le pasamos los datos de entrenamiento y lo guardamos en un archivo. Luego creamos un nuevo objeto json_mylogreg
y llamamos al métod load_json
para cargar los datos del archivo.
filepath = "mylogreg.json" # Create a model and train it mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain) mylogreg.save_json(filepath) # Create a new object and load its data from JSON file json_mylogreg = MyLogReg() json_mylogreg.load_json(filepath) json_mylogreg
Al imprimir el nuevo objeto, podemos ver nuestros parámetros y datos de entrenamiento según sea necesario.
MyLogReg(C=1.0, X_train=array([[ 4.3, 3. , 1.1, 0.1], [ 5.7, 4.4, 1.5, 0.4], ..., [ 7.2, 3. , 5.8, 1.6], [ 7.7, 2.8, 6.7, 2. ]]), Y_train=array([0, 0, ..., 2, 2]), class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1, penalty='l2', random_state=None, solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
Dado que la serialización de datos usando JSON realmente guarda el objeto en un formato de cadena, en lugar de un flujo de bytes, el archivo 'mylogreg.json' podría abrirse y modificarse con un editor de texto. Aunque este enfoque sería conveniente para el desarrollador, es menos seguro ya que un intruso puede ver y modificar el contenido del archivo JSON. Además, este enfoque es más adecuado para objetos con una pequeña cantidad de variables de instancia, como los modelos scikit-learn, porque cualquier adición de nuevas variables requiere cambios en los métodos de guardar y restaurar.
Problemas de compatibilidad
Si bien algunos de los pros y los contras de cada herramienta se cubrieron en el texto hasta ahora, probablemente el mayor inconveniente de las herramientas Pickle y Joblib es su compatibilidad con diferentes modelos y versiones de Python.
Compatibilidad con la versión de Python : la documentación de ambas herramientas indica que no se recomienda (des) serializar objetos en diferentes versiones de Python, aunque podría funcionar con cambios menores de versión.
Compatibilidad del modelo : uno de los errores más frecuentes es guardar su modelo con Pickle o Joblib y luego cambiar el modelo antes de intentar restaurar desde un archivo. La estructura interna del modelo debe permanecer sin cambios entre guardar y recargar.
Un último problema con Pickle y Joblib está relacionado con la seguridad. Ambas herramientas pueden contener código malicioso, por lo que no se recomienda restaurar datos de fuentes no confiables o no autenticadas.
Conclusiones
En esta publicación describimos tres herramientas para guardar y restaurar modelos de scikit-learn. Las bibliotecas Pickle y Joblib son rápidas y fáciles de usar, pero tienen problemas de compatibilidad en diferentes versiones de Python y cambios en el modelo de aprendizaje. Por otro lado, el enfoque manual es más difícil de implementar y debe modificarse con cualquier cambio en la estructura del modelo, pero en el lado positivo podría adaptarse fácilmente a varias necesidades y no tiene problemas de compatibilidad.