Revista Informática

Cómo funciona k-means e implementación en Python

Publicado el 07 octubre 2022 por Daniel Rodríguez @analyticslane
Cómo funciona k-means e implementación en Python

El algoritmo de k-means o k-medias es uno de los más utilizados dentro del análisis de clúster. Algo que se puede explicar porque este es un algoritmo sencillo, fácil de interpretar y generalmente ofrece buenos resultados en la mayoría de los conjuntos de datos. Por lo que suele estar implementado en la mayoría de las librerías estadísticas y de aprendizaje automático como stats de R o Scikit-learn de Python. Dado que una de las mejores formas para comprender un algoritmo es implementarlo, veamos cómo crear una versión básica de k-means en Python.

El algoritmo de k-means

Al ejecutar el algoritmo de k-means sobre un conjunto de datos el resultado son k centroides. Los cuales son puntos en el espacio del conjunto de datos cuya posición es la media de los registros que pertenecen a cada uno de los grupos en los que se separan, de ahí en nombre de k-means o k-medias. Para asignar un nuevo registro a uno de los grupos simplemente se debe calcular la distancia a cada uno de los centroides, seleccionado aquel cuya distancia sea menor.

La separación entre los puntos y cada uno de los centroides se suele calcular mediante la distancia euclídea. Por esto, k-means solamente se puede usar sobre conjunto de datos numéricos, nunca sobre categóricos. Para trabajar con datos categóricos se debe usar otro algoritmo tales como k-modes.

Por ejemplo, en un conjunto de datos con dos características cada uno de los centroides es un punto en el plano. Indicando el valor medio de los elementos de cada uno de los grupos en los que se dividen los datos. Pudiéndose interpretar así los centroides como los valores "típicos" de los registros que se pueden encontrar en cada uno de los grupos. Los que facilita identificar los datos de cada uno de los grupos. Aunque, en la mayoría de los casos, posiblemente no exista ningún registro en los datos exactamente igual al centroide.

Implementación de Python de k-means

La búsqueda de los centroides en k-means usa una técnica de refinamiento iterativo. Inicialmente se parte con un conjunto de centroides aleatorios, por lo que el algoritmo no converge siempre a los mismos resultados, y se va afinando la posición de estos mediante un proceso iterativo.

El algoritmo de k-means se puede resumir en los siguientes pasos:

  1. Generar aleatoriamente k centroides en espacios del conjunto de datos.
  2. Obtener la distancia de cada uno de los datos a todos los centroides y asignar al grupo cuya distancia sea menor.
  3. Calcular la posición media de cada uno de los grupos para actualizar los centroides.
  4. Comprobar si los centroides se han desplazado por debajo de un valor límite, si es así esta es la solución, en caso contrario volver al punto 2.

Una implementación de este algoritmo se puede ver en el siguiente código.

import numpy as np
from scipy.spatial.distance import cdist

def next_step(centroids, previous, iterations, max_iter=10, stop_limit=0):
    """Comprueba si se debe calcular la siguiente iteración.
    
    Parameters
    ----------
    centroids : ndarray
        La posición actual de los centroides.
        
    previous : ndarray
        La posición previa de los centroides.
        
    iterations : integer
        La iteración actual.
        
    max_iter : integer
        El número máximo de iteraciones permitidas.
        
    stop_limit : real
        La diferencia entre a partir de la cual se considera que el
        algoritmo ha convergido.
        
    Returns
    -------
    mode : boolean
        Verdadero si el algoritmo debe continuar, falso en el resto de
        los casos.
    """   
    if iterations == 0 or previous is None:
        return True
    elif iterations > max_iter:
        return False
    elif np.sum(np.abs(centroids - previous)) <= stop_limit:
        return False
    else:
        return True

    
def k_means(data, n_clusters, max_iter=10, stop_limit=0, random_state=None):
    """Implementación básica del algoritmo de Kmeans.
    
    Nota: esta función es solamente una implementación básica con fines
    pedagógicos, no usar en producción.
    
    Parameters
    ----------
    data : ndarray
        El conjunto de datos sobre el que se desea aplicar el algoritmo
        de Kmeans.
        
    n_clusters : integer
        El número de clústeres.
        
    max_iter : integer
        El número máximo de iteraciones permitidas.
        
    stop_limit : real
        La diferencia entre a partir de la cual se considera que el
        algoritmo ha convergido.
    
    random_state : number or None
        La semilla con la que se generan los números aleatorios.
        
    Returns
    -------
    centroids : ndarray
        Los centroides obtenidos.
    """ 
    if random_state is not None:
        np.random.seed(random_state)
        
    centroids = np.random.randn(n_clusters * data.shape[1]).reshape(n_clusters, data.shape[1])
    
    iteration = 0
    previous = None
   
    while next_step(centroids, previous, iteration, max_iter = max_iter, stop_limit = stop_limit):
        iteration += 1
        previous = np.empty_like(centroids)

        distance = cdist(data, centroids)
        clusters = np.argmin(distance, axis=1)

        for num in range(n_clusters):
            if np.any(clusters == num):
                centroids[num] = np.mean(data[clusters == num, :], axis = 0)
        
    return centroids

Explicación del código

En el código anterior se han implementado dos funciones: next_step(), en la cual se comprueba si el algoritmo debe seguir, y k_means(), donde se implementa el algoritmo de k-means. La función next_step() devele verdadero cuando se debe calcular el siguiente paso y falso en caso contrario. Las condiciones para detener el algoritmo son que se alcance el límite de iteraciones indicado o la distancia entre los centroides actuales y del paso anterior sean inferiores a un límite. En el resto de los casos se continuará iterando.

El primer paso de la función k_means() es fijar una semilla en el caso de que se hubiese indicado en los parámetros, algo recomendable para garantizar que los resultados se puedan repetir. Una vez hecho esto se seleccionan de forma aleatoria la posición de los k centroides iniciales. Procediendo a continuación con el refinamiento iterativo. Calculando la distancia de todos los puntos a los centroides, mediante la función cdist(), asignando cada uno de los registros a un clúster, identificando el índice más cercando con np.argmin(), y recalcando la posición de los centroides como la media de cada uno de los grupos. Lo que solamente se hace cuando exista algún registro asignado a cada uno de los clústeres.

Al terminar el proceso iterativo la función devuelve el último valor de los centroides.

Comparación con Scikit-learn

Una forma para saber si el algoritmo está correctamente implementado es compararlo con una implementación estándar como puede ser la de Scikit-learn. Algo que se puede realizar con el siguiente código.

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans

data, _ = make_blobs(300, centers=3, random_state=1)

centroids = k_means(data, 3, random_state=1)

kmeans = KMeans(n_clusters=3, random_state=1).fit(data)

print('Implementación')
print(centroids)
print('Scikit-learn')
print(kmeans.cluster_centers_)
Implementación
[[ -1.4531567    4.40756967]
 [-10.07499139  -3.8699274 ]
 [ -7.05318146  -8.00168371]]
Scikit-learn
[[ -1.4531567    4.40756967]
 [ -7.05318146  -8.00168371]
 [-10.07499139  -3.8699274 ]]

Como se puede ver los resultados son los mismos, aunque el orden de aparición de los centroides no sea exactamente el mismo en ambos casos. Posición que dependerá de la semilla.

Conclusiones

En esta entrada se ha visto los fundamentos de k-means y una como implementar este en Python. La implementación que se ha creado ofrece resultados similares a los de Scikit-learn en conjunto de datos sencillos.

Aunque los resultados de la implementación de la entrada son correctos, no recomiendo su uso en producción. La implementación de Scikit-learn, o cualquier otra librería estándar, estará más probada y será más robusta y eficiente que la desarrollada aquí.

Imagen de WikiImages en Pixabay


Volver a la Portada de Logo Paperblog