亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

如何從給定模型中獲取 Graph(或 GraphDef)?

如何從給定模型中獲取 Graph(或 GraphDef)?

qq_花開花謝_0 2023-04-18 17:15:55
我有一個使用 Tensorflow 2 和 Keras 定義的大模型。該模型在 Python 中運行良好?,F在,我想將它導入到 C++ 項目中。在我的 C++ 項目中,我使用TF_GraphImportGraphDef函數。*.pb如果我使用以下代碼準備文件,效果很好:    with open('load_model.pb', 'wb') as f:         f.write(tf.compat.v1.get_default_graph().as_graph_def().SerializeToString())我已經在使用 Tensorflow 1(使用 tf.compat.v1.* 函數)編寫的簡單網絡上嘗試了這段代碼。它運作良好?,F在我想將我的大模型(開頭提到的,使用Tensorflow 2編寫)導出到C++項目中。為此,我需要從我的模型中獲取一個Graph或GraphDef對象。問題是:如何做到這一點?我沒有找到任何屬性或函數來獲取它。我也試過用它tf.saved_model.save(model, 'model')來保存整個模型。它生成一個包含不同文件的目錄,包括saved_model.pb文件。不幸的是,當我嘗試使用TF_GraphImportGraphDef函數在 C++ 中加載此文件時,程序拋出異常。
查看完整描述

2 回答

?
海綿寶寶撒

TA貢獻1809條經驗 獲得超8個贊

生成的協議緩沖區文件tf.saved_model.save不包含GraphDef消息,而是包含一個SavedModel.?您可以在 Python 中遍歷它SavedModel以獲取其中的嵌入圖,但這不會立即用作凍結圖,因此正確處理它可能很困難。取而代之的是,C++ API 現在包含一個LoadSavedModel調用,允許您從目錄加載整個保存的模型。它應該看起來像這樣:

#include <iostream>

#include <...>? // Add necessary TF include directives


using namespace std;

using namespace tensorflow;


int main()

{

? ? // Path to saved model directory

? ? const string export_dir = "...";

? ? // Load model

? ? Status s;

? ? SavedModelBundle bundle;

? ? SessionOptions session_options;

? ? RunOptions run_options;

? ? s = LoadSavedModel(session_options, run_options, export_dir,

? ? ? ? ? ? ? ? ? ? ? ?// default "serve" tag set by tf.saved_model.save

? ? ? ? ? ? ? ? ? ? ? ?{"serve"}, &bundle));

? ? if (!.ok())

? ? {

? ? ? ? cerr << "Could not load model: " << s.error_message() << endl;

? ? ? ? return -1;

? ? }

? ? // Model is loaded

? ? // ...

? ? return 0;

}

從這里開始,您可以做不同的事情。也許您最愿意使用 將保存的模型轉換為凍結圖FreezeSavedModel,這應該讓您可以像以前一樣做事:


GraphDef frozen_graph_def;

std::unordered_set<string> inputs;

std::unordered_set<string> outputs;

s = FreezeSavedModel(bundle, &frozen_graph_def,

? ? ? ? ? ? ? ? ? ? ?&inputs, &outputs));

if (!s.ok())

{

? ? cerr << "Could not freeze model: " << s.error_message() << endl;

? ? return -1;

}

否則,您可以直接使用保存的模型對象:


// Default "serving_default" signature name set by tf.saved_model_save

const SignatureDef& signature_def = bundle.GetSignatures().at("serving_default");

// Get input and output names (different from layer names)

// Key is input and output layer names

const string input_name = signature_def.inputs().at("my_input").name();

const string output_name = signature_def.inputs().at("my_output").name();

// Run model

Tensor input = ...;

std::vector<Tensor> outputs;

s = bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));

if (!s.ok())

{

? ? cerr << "Error running model: " << s.error_message() << endl;

? ? return -1;

}

// Get result

Tensor& output = outputs[0];


查看完整回答
反對 回復 2023-04-18
?
寶慕林4294392

TA貢獻2021條經驗 獲得超8個贊

我找到了以下問題的解決方案:


g = tf.Graph()

with g.as_default():


    # Create model

    inputs = tf.keras.Input(...) 

    x = tf.keras.layers.Conv2D(1, (1,1), padding='same')(inputs)

    # Done creating model


    # Optionally get graph operations

    ops = g.get_operations()

    for op in ops:

        print(op.name, op.type)


    # Save graph

    tf.io.write_graph(g.as_graph_def(), 'path', 'filename.pb', as_text=False)


查看完整回答
反對 回復 2023-04-18
  • 2 回答
  • 0 關注
  • 173 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號