1 回答

TA貢獻1821條經驗 獲得超6個贊
我假設您想要一個tf.Variable僅使用的實例assign。但是,在使用 時tf.function,您應該始終從外部提供變量,并在內部使用內置的 TensorFlow 數據結構。
例如,您的代碼更改最少,沒有tf.Variable對象將是:
import tensorflow as tf
import numpy as np
from typeguard import typechecked
from typing import Union
@tf.function
def train_kmeans(X: Union[tf.Tensor, np.ndarray],
k: Union[int, tf.Tensor],
n_iter: Union[int, tf.Tensor] = 10) -> (tf.Tensor, tf.Tensor):
X = tf.convert_to_tensor(X)
X = tf.cast(X, tf.float32)
# Required as an int later
num_centers = k
assert len(tf.shape(X)) == 2, "Training data X must be represented as 2D array only"
m = tf.shape(X)[0]
k = tf.convert_to_tensor(k, dtype=tf.int64)
random_select = tf.random.shuffle(X)
init_centroids = random_select[:k]
centroids = init_centroids
clusters = tf.zeros([m, ], dtype=tf.int64)
for _ in tf.range(n_iter):
squared_diffs = tf.square(X[None, :, :] - centroids[:, None, :])
euclidean_dists = tf.reduce_sum(squared_diffs, axis=-1) ** 0.5
clusters = tf.argmin(euclidean_dists, axis=0)
selector = tf.range(k)[:, None] == clusters[None, :]
# TF data structure
new_centroids = tf.TensorArray(tf.float32, num_centers, element_shape=[1, 2])
for c in range(k):
select = selector[c]
points = X[select]
centroid = tf.reduce_mean(points, axis=0)
centroid = tf.reshape(centroid, [1, 2])
new_centroids.write(tf.cast(c, tf.int32), centroid)
centroids = new_centroids.concat()
centroids = tf.reshape(centroids, [num_centers, 2])
return centroids, clusters
添加回答
舉報