Commit f4b3c0e5 authored by Guo, Yejun's avatar Guo, Yejun Committed by Pedro Arthur

avfilter/dnn: add a new interface to query dnn model's input info

to support dnn networks more general, we need to know the input info
of the dnn model.

background:
The data type of dnn model's input could be float32, uint8 or fp16, etc.
And the w/h of input image could be fixed or variable.
Signed-off-by: 's avatarGuo, Yejun <yejun.guo@intel.com>
Signed-off-by: 's avatarPedro Arthur <bygrandao@gmail.com>
parent e1b45b85
...@@ -28,6 +28,28 @@ ...@@ -28,6 +28,28 @@
#include "dnn_backend_native_layer_conv2d.h" #include "dnn_backend_native_layer_conv2d.h"
#include "dnn_backend_native_layers.h" #include "dnn_backend_native_layers.h"
static DNNReturnType get_input_native(void *model, DNNData *input, const char *input_name)
{
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
for (int i = 0; i < network->operands_num; ++i) {
DnnOperand *oprd = &network->operands[i];
if (strcmp(oprd->name, input_name) == 0) {
if (oprd->type != DOT_INPUT)
return DNN_ERROR;
input->dt = oprd->data_type;
av_assert0(oprd->dims[0] == 1);
input->height = oprd->dims[1];
input->width = oprd->dims[2];
input->channels = oprd->dims[3];
return DNN_SUCCESS;
}
}
// do not find the input operand
return DNN_ERROR;
}
static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
{ {
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model; ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
...@@ -37,7 +59,6 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const ...@@ -37,7 +59,6 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const
return DNN_ERROR; return DNN_ERROR;
/* inputs */ /* inputs */
av_assert0(input->dt == DNN_FLOAT);
for (int i = 0; i < network->operands_num; ++i) { for (int i = 0; i < network->operands_num; ++i) {
oprd = &network->operands[i]; oprd = &network->operands[i];
if (strcmp(oprd->name, input_name) == 0) { if (strcmp(oprd->name, input_name) == 0) {
...@@ -234,6 +255,7 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename) ...@@ -234,6 +255,7 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename)
} }
model->set_input_output = &set_input_output_native; model->set_input_output = &set_input_output_native;
model->get_input = &get_input_native;
return model; return model;
} }
......
...@@ -105,6 +105,37 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input) ...@@ -105,6 +105,37 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input)
input_dims[1] * input_dims[2] * input_dims[3] * size); input_dims[1] * input_dims[2] * input_dims[3] * size);
} }
static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name)
{
TFModel *tf_model = (TFModel *)model;
TF_Status *status;
int64_t dims[4];
TF_Output tf_output;
tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name);
if (!tf_output.oper)
return DNN_ERROR;
tf_output.index = 0;
input->dt = TF_OperationOutputType(tf_output);
status = TF_NewStatus();
TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status);
if (TF_GetCode(status) != TF_OK){
TF_DeleteStatus(status);
return DNN_ERROR;
}
TF_DeleteStatus(status);
// currently only NHWC is supported
av_assert0(dims[0] == 1);
input->height = dims[1];
input->width = dims[2];
input->channels = dims[3];
return DNN_SUCCESS;
}
static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
{ {
TFModel *tf_model = (TFModel *)model; TFModel *tf_model = (TFModel *)model;
...@@ -568,6 +599,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename) ...@@ -568,6 +599,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
model->model = (void *)tf_model; model->model = (void *)tf_model;
model->set_input_output = &set_input_output_tf; model->set_input_output = &set_input_output_tf;
model->get_input = &get_input_tf;
return model; return model;
} }
......
...@@ -43,6 +43,9 @@ typedef struct DNNData{ ...@@ -43,6 +43,9 @@ typedef struct DNNData{
typedef struct DNNModel{ typedef struct DNNModel{
// Stores model that can be different for different backends. // Stores model that can be different for different backends.
void *model; void *model;
// Gets model input information
// Just reuse struct DNNData here, actually the DNNData.data field is not needed.
DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
// Sets model input and output. // Sets model input and output.
// Should be called at least once before model execution. // Should be called at least once before model execution.
DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output); DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment