Commit 05f86f05 authored by Guo, Yejun's avatar Guo, Yejun Committed by Pedro Arthur

libavfilter/dnn: remove limit for the name of DNN model input/output

remove the requirment that the name of DNN model input/output
should be "x"/"y",
Signed-off-by: 's avatarGuo, Yejun <yejun.guo@intel.com>
Signed-off-by: 's avatarPedro Arthur <bygrandao@gmail.com>
parent 05aec8bb
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include "dnn_backend_native.h" #include "dnn_backend_native.h"
static DNNReturnType set_input_output_native(void *model, DNNData *input, DNNData *output) static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
{ {
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model; ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
InputParams *input_params; InputParams *input_params;
......
...@@ -76,7 +76,7 @@ static TF_Buffer *read_graph(const char *model_filename) ...@@ -76,7 +76,7 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf; return graph_buf;
} }
static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *output) static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
{ {
TFModel *tf_model = (TFModel *)model; TFModel *tf_model = (TFModel *)model;
int64_t input_dims[] = {1, input->height, input->width, input->channels}; int64_t input_dims[] = {1, input->height, input->width, input->channels};
...@@ -84,8 +84,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o ...@@ -84,8 +84,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
TF_Tensor *output_tensor; TF_Tensor *output_tensor;
// Input operation should be named 'x' // Input operation
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x"); tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
if (!tf_model->input.oper){ if (!tf_model->input.oper){
return DNN_ERROR; return DNN_ERROR;
} }
...@@ -100,8 +100,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o ...@@ -100,8 +100,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
} }
input->data = (float *)TF_TensorData(tf_model->input_tensor); input->data = (float *)TF_TensorData(tf_model->input_tensor);
// Output operation should be named 'y' // Output operation
tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y"); tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, output_name);
if (!tf_model->output.oper){ if (!tf_model->output.oper){
return DNN_ERROR; return DNN_ERROR;
} }
......
...@@ -40,7 +40,7 @@ typedef struct DNNModel{ ...@@ -40,7 +40,7 @@ typedef struct DNNModel{
void *model; void *model;
// Sets model input and output, while allocating additional memory for intermediate calculations. // Sets model input and output, while allocating additional memory for intermediate calculations.
// Should be called at least once before model execution. // Should be called at least once before model execution.
DNNReturnType (*set_input_output)(void *model, DNNData *input, DNNData *output); DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name);
} DNNModel; } DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends. // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
......
...@@ -121,7 +121,7 @@ static int config_props(AVFilterLink *inlink) ...@@ -121,7 +121,7 @@ static int config_props(AVFilterLink *inlink)
sr_context->input.height = inlink->h * sr_context->scale_factor; sr_context->input.height = inlink->h * sr_context->scale_factor;
sr_context->input.channels = 1; sr_context->input.channels = 1;
result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, &sr_context->output); result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
if (result != DNN_SUCCESS){ if (result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n"); av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
return AVERROR(EIO); return AVERROR(EIO);
...@@ -130,7 +130,7 @@ static int config_props(AVFilterLink *inlink) ...@@ -130,7 +130,7 @@ static int config_props(AVFilterLink *inlink)
if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){ if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){
sr_context->input.width = inlink->w; sr_context->input.width = inlink->w;
sr_context->input.height = inlink->h; sr_context->input.height = inlink->h;
result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, &sr_context->output); result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
if (result != DNN_SUCCESS){ if (result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n"); av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
return AVERROR(EIO); return AVERROR(EIO);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment