Los árboles de decisión son un de la familia de modelos de aprendizaje automático más utilizados. Se pueden utilizar tanto para resolver problemas de clasificación como de regresión. Una de sus principales ventajas es la facilidad con la que se puede interpretar los resultados en base a reglas. Permitiendo no solo obtener un resultado, sino que inspeccionar los motivos por los que se llega a una predicción dada. Por ejemplo, en un modelo que permita predecir la aparición de fallos en una maquinaria se puede explorar el proceso lógico del algoritmo, identificando los valores de las características que llevar a una conclusión dada. Esto permite nos solo utilizar el modelo para predecir un valor, sino actuar sobre las causas para evitar la aparición de resultados no deseados. Siendo la visualización de árboles de decisión una forma de facilitar esto.
Para la visualización de los árboles de decisión en Python se puede utilizar la librería PyDotPlus. Esta es una versión mejorada del antiguo proyecto pydot que proporciona una interfaz Python al lenguaje de Graphviz.
Instalación de PyDotPlus en Anaconda
En el entorno de trabajo Anaconda no se encuentra instalada por defecto la librería PyDotPlus
, por lo que es necesario instalarla antes de poder utilizarla. Lo que se puede realizar mediante el gestor de paquetes conda
. Para ello solamente se ha de abrir una terminal y lanzar la siguiente línea de comando
conda install pydotplus
Al lanzar este comando conda
buscará los paquetes necesarios y pedirá confirmación para su instalación. Una vez finalizado el proceso la librería se podrá utilizar sin problemas.
Alternativamente se puede realizar la instalación del paquete mediante pip
. Aunque en este caso puede ser necesario resolver manualmente dependencias. Para ello se ha de ejecutar el siguiente comando en una terminal:
pip install pydotplus
Creación de un árbol de decisión
Antes de poder representar un árbol de decisión en primer lugar se ha de crear un modelo. Para ello se puede simular un conjunto de datos con el método make_blobs
de scikit-learn.
from sklearn.datasets.samples_generator import make_blobs X, y = make_blobs(n_samples=50, centers=3, random_state=0, cluster_std=0.60) feature_names = ['X', 'Y']
En este conjunto de datos se han creado tres burbujas en un espacio bidimensional. A la primera característica se la ha llamado X
e Y
a la segunda. El resultado se puede representar gráficamente con el siguiente código.
import matplotlib.pyplot as plt plt.scatter(X[:, 0], X[:, 1], c=y) plt.xlabel(feature_names[0]) plt.ylabel(feature_names[1]) plt.show()
Ahora se puede emplear la clase DecisionTreeClassifier
para el entrenamiento de un árbol de decisión.
from sklearn.tree import DecisionTreeClassifier tree = DecisionTreeClassifier(random_state=0).fit(X, y)
Visualización de árboles de decisión con PyDotPlus
La visualización del árbol se puede realizar con el siguiente código.
from sklearn.tree import export_graphviz from pydotplus import graph_from_dot_data dot_data = export_graphviz(tree, feature_names=feature_names) graph = graph_from_dot_data(dot_data) graph.write_png('tree.png')
En primer lugar, se importan los métodos necesarios. Desde scikit-learn se importa export_graphviz
, un método que permite exportar los resultados de un árbol de decisión al formato DOT de Graphviz. A este método se le pasado el árbol e indicado el nombre de las características. Posteriormente se utiliza el método graph_from_dot_data
para la creación del gráfico y explotar este a un archivo PNG.
En la representación del árbol de decisión se pude observar fácilmente el proceso que emplear el modelo a la hora de tomar una decisión. Incluso se pude reproducir este sin demasiado problema.
Opciones en la representación de PyDotPlus
El método export_graphviz
cuenta con múltiples opciones a la hora de crear representaciones de los árboles de decisión. Por ejemplo, se pueden redondear los cuadros y utilizar colores. Además, es posible exportar la figura en otros formatos gráficos, como en JPG.
Conclusiones
Se ha visto una forma de representar gráficamente árboles de decisión creados en Python con scikit-learn. Esta visualización facilita la interpretación de los resultados y la explicación de estos. Permitiendo no solo conocer la predicción, sino que explicar el origen de la misma.
Imágenes: Pixabay (pixel2013)