Python 测训练并保存模型的代码如下,
import os
import tensorflow as tf
from matplotlib import pyplot as plt
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Model save的save format应该是tf格式
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=10, validation_data=(x_test, y_test),
validation_freq=1, callbacks=[cp_callback])
model.summary()
model.save("mnist_model")
# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
C++侧的工程结构如下,
image.png
conanfile.txt文件如下,
[requires]
gtest/1.10.0
glog/0.4.0
protobuf/3.9.1
eigen/3.4.0
dataframe/1.20.0
opencv/3.4.17
[generators]
cmake
CMakeLists.txt文件如下,
cmake_minimum_required(VERSION 3.3)
project(test_mnist_predict)
set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)
set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})
set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})
set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)
include_directories(${INCLUDE_DIRS})
file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)
add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})
foreach( test_file ${test_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${test_file})
target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})
inspect_model.sh文件如下,用来探测模型的输入输出参数
# <yourenv>/bin/saved_model_cli
saved_model_cli show --dir "$1" --tag_set serve --signature_def serving_default
使用方法如下,
./inspect_model.sh ./mnist_model
image.png
tf_mnist_model_test.cpp 文件代码如下,
#include <tensorflow/c/c_api.h>
#include "death_handler/death_handler.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include <opencv2/imgproc.hpp>
#include "img_util/img_util.h"
using namespace tensorflow;
using BatchDef = std::initializer_list<tensorflow::int64>;
int const img_width = 28;
int const img_height = 28;
int main(int argc, char** argv) {
Debug::DeathHandler dh;
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
std::vector<float> GetInputData(std::string const& file_name) {
auto raw_data = cv::imread(file_name);
cv::Mat scaled_data {};
// 1. C++ 先resize 参考这个 https://learnopencv.com/image-resizing-with-opencv/
cv::resize(raw_data, scaled_data, cv::Size(img_width, img_height), cv::INTER_LANCZOS4);
// 2. 再 BGR 2 GRAY
cv::Mat gray_data {};
cv::cvtColor(scaled_data, gray_data, cv::COLOR_BGR2GRAY);
// 3. 再转换成 std::vector 参考 Mat to vector https://stackoverflow.com/questions/26681713/convert-mat-to-array-vector-in-opencv
auto gray_conti = gray_data.isContinuous()? gray_data: gray_data.clone();
cv::Mat flat = gray_data.reshape(1, gray_data.total()*gray_data.channels());
std::vector<uint8_t> vec = flat;
// 4. 图像反色 255 - src Array的颜色, 这个就用std::transform算子 https://blog.csdn.net/weixin_42100963/article/details/106446563
// 5. 再除以255.0f
std::vector<float> res_tensor_data {};
res_tensor_data.resize(vec.size());
std::transform(vec.begin(), vec.end(), res_tensor_data.begin(), [](auto ele) {
return ((float)(255 - ele))/ 255.0f;
});
return res_tensor_data;
}
std::vector<int> ConvertTensorToIndexValue(Tensor const& tensor_) {
auto tensor_res = test::GetTensorValue<float>(tensor_);
std::vector<int> predict_res{};
for(int i=0; i<tensor_res.size(); ++i) {
if(i!=0 && (i+1)%10==0) {
auto max_idx = std::max_element(tensor_res.begin() + (i-9), tensor_res.begin() + (i+1)) - (tensor_res.begin() + (i-9));
predict_res.emplace_back((int)max_idx);
}
}
return predict_res;
}
Tensor MakeTensor(std::vector<float> const& batch, BatchDef const& batch_def) {
Tensor t(DT_FLOAT,
TensorShape(batch_def));
for (int i = 0; i < batch.size(); ++i) {
t.flat<float>()(i) = batch[i];
}
return t;
}
TEST(TfMnistModelTest, LoadAndPredict) {
SavedModelBundleLite bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir = "../mnist_model";
TF_CHECK_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
std::cout << "input the number of test pictures: \n";
int pre_num{0};
std::cin >> pre_num;
for(int i=0; i<pre_num; ++i) {
// 1. 获取输入图片路径
std::cout << "the path of test picture: \n";
std::string img_path{};
std::cin >> img_path;
// 2. 获取输入图片数据
auto input_batches = GetInputData(img_path);
// 3. 转换成神经网络需要的tensor 形状
auto input_tensor = MakeTensor(input_batches, {1, 28, 28, 1});
// 4. 运行预测过程
std::vector<tensorflow::Tensor> out_tensors;
TF_CHECK_OK(bundle.GetSession()->Run({{"serving_default_flatten_input:0", input_tensor}},
{"StatefulPartitionedCall:0"}, {}, &out_tensors));
// 5. 打印tensor 值
std::cout << "Print Tensor Value\n";
test::PrintTensorValue<float>(std::cout, out_tensors[0], 10);
std::cout << "\n";
// 6. 将tensor值转换成数字值,打印预测结果
auto predict_res = ConvertTensorToIndexValue(out_tensors[0]);
std::cout << "Predict result: " << predict_res[0] << "\n";
}
}
程序输出如下,
image.png











网友评论