我試圖確定一個句子和其他句子之間的語義相似性,如下所示:import tensorflow as tfimport tensorflow_hub as hubimport numpy as npimport os, sysfrom sklearn.metrics.pairwise import cosine_similarity# get cosine similairty matrixdef cos_sim(input_vectors): similarity = cosine_similarity(input_vectors) return similarity# get topN similar sentencesdef get_top_similar(sentence, sentence_list, similarity_matrix, topN): # find the index of sentence in list index = sentence_list.index(sentence) # get the corresponding row in similarity matrix similarity_row = np.array(similarity_matrix[index, :]) # get the indices of top similar indices = similarity_row.argsort()[-topN:][::-1] return [sentence_list[i] for i in indices]module_url = "https://tfhub.dev/google/universal-sentence-encoder/2" #@param ["https://tfhub.dev/google/universal-sentence-encoder/2", "https://tfhub.dev/google/universal-sentence-encoder-large/3"]# Import the Universal Sentence Encoder's TF Hub moduleembed = hub.Module(module_url)# Reduce logging output.tf.logging.set_verbosity(tf.logging.ERROR)sentences_list = [ # phone related 'My phone is slow', 'My phone is not good', 'I need to change my phone. It does not work well', 'How is your phone?', # age related 'What is your age?', 'How old are you?', 'I am 10 years old', # weather related 'It is raining today', 'Would it be sunny tomorrow?', 'The summers are here.']with tf.Session() as session: session.run([tf.global_variables_initializer(), tf.tables_initializer()]) sentences_embeddings = session.run(embed(sentences_list))similarity_matrix = cos_sim(np.array(sentences_embeddings))sentence = "It is raining today"top_similar = get_top_similar(sentence, sentences_list, similarity_matrix, 3)# printing the list using loop for x in range(len(top_similar)): print(top_similar[x])#view raw
1 回答

慕雪6442864
TA貢獻1812條經驗 獲得超5個贊
問題的原因似乎是 TF2 不支持 hub 型號。
這很簡單,但是您是否嘗試過禁用tensorflow版本2的行為?
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
此命令將禁用 TensorFlow 2 行為,但仍然可能會出現一些與導入模塊和圖形相關的錯誤。
然后嘗試下面的命令。
!pip install --upgrade tensorflow==1.15
import tensorflow as tf
print(tf.__version__)
添加回答
舉報
0/150
提交
取消