-
Notifications
You must be signed in to change notification settings - Fork 693
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding KMeans PyTorch Implementation to cfa model (#998)
* adding KMeans PyTorch Implementation to cfa model * Required modifications and implementation comparison * Update torch_model.py * Update torch_model.py * Update kmeans.py * required changes and Google docstring format * remove test notebook * Update torch_model.py * 'unindent does not match any outer indentation level (<unknown>, line 34)' (syntax-error) * Update torch_model.py * Fix pre-commit issues * Revert torch model to sklearn --------- Co-authored-by: Samet Akcay <samet.akcay@intel.com> Co-authored-by: Ashwin Vaidya <ashwin.vaidya@intel.com>
- Loading branch information
1 parent
a56121d
commit ed4d1a1
Showing
1 changed file
with
72 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""KMeans clustering algorithm implementation using PyTorch.""" | ||
|
||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
|
||
|
||
class KMeans: | ||
def __init__(self, n_clusters: int, max_iter: int = 10): | ||
""" | ||
Initializes the KMeans object. | ||
Args: | ||
n_clusters (int): The number of clusters to create. | ||
max_iter (int, optional)): The maximum number of iterations to run the algorithm. Defaults to 10. | ||
""" | ||
self.n_clusters = n_clusters | ||
self.max_iter = max_iter | ||
|
||
def fit(self, inputs): | ||
""" | ||
Fits the K-means algorithm to the input data. | ||
Args: | ||
inputs (torch.Tensor): Input data of shape (batch_size, n_features). | ||
Returns: | ||
tuple: A tuple containing the labels of the input data with respect to the identified clusters | ||
and the cluster centers themselves. The labels have a shape of (batch_size,) and the | ||
cluster centers have a shape of (n_clusters, n_features). | ||
Raises: | ||
ValueError: If the number of clusters is less than or equal to 0. | ||
""" | ||
batch_size, _ = inputs.shape | ||
|
||
# Initialize centroids randomly from the data points | ||
centroid_indices = torch.randint(0, batch_size, (self.n_clusters,)) | ||
self.cluster_centers_ = inputs[centroid_indices] | ||
|
||
# Run the k-means algorithm for max_iter iterations | ||
for _ in range(self.max_iter): | ||
# Compute the distance between each data point and each centroid | ||
distances = torch.cdist(inputs, self.cluster_centers_) | ||
|
||
# Assign each data point to the closest centroid | ||
self.labels_ = torch.argmin(distances, dim=1) | ||
|
||
# Update the centroids to be the mean of the data points assigned to them | ||
for j in range(self.n_clusters): | ||
mask = self.labels_ == j | ||
if mask.any(): | ||
self.cluster_centers_[j] = inputs[mask].mean(dim=0) | ||
# this line returns labels and centoids of the results | ||
return self.labels_, self.cluster_centers_ | ||
|
||
def predict(self, inputs): | ||
""" | ||
Predicts the labels of input data based on the fitted model. | ||
Args: | ||
inputs (torch.Tensor): Input data of shape (batch_size, n_features). | ||
Returns: | ||
torch.Tensor: The predicted labels of the input data with respect to the identified clusters. | ||
Raises: | ||
AttributeError: If the KMeans object has not been fitted to input data. | ||
""" | ||
distances = torch.cdist(inputs, self.cluster_centers_) | ||
return torch.argmin(distances, dim=1) |