1 回答
TA貢獻1835條經驗 獲得超7個贊
您可以創建一個用戶定義的函數來獲取最大值的索引
from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType
max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
df = df.withColumn("topicID", max_index("topicDistribution"))
例子
>>> from pyspark.sql import functions as f
>>> from pyspark.sql.types import IntegerType
>>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])
>>> df.show()
+-----------------+
|topicDistribution|
+-----------------+
| [0.2, 0.3, 0.5]|
+-----------------+
>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
>>> df.withColumn("topicID", max_index("topicDistribution")).show()
+-----------------+-------+
|topicDistribution|topicID|
+-----------------+-------+
| [0.2, 0.3, 0.5]| 2|
+-----------------+-------+
編輯:
由于您提到其中的列表topicDistribution是 numpy 數組,因此您可以更新max_index udf如下:
max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())
添加回答
舉報
