Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 6b72b44

Browse files
authored
added inference model interface for sharing between layer model and frozen model (#1053)
* added inference model interface for sharing between layer model and frozen model * updated api according to comments * update the comments * updated the doc for inference model interface
1 parent fbc498b commit 6b72b44

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

src/index.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
4343
export {SGDOptimizer} from './optimizers/sgd_optimizer';
4444
// tslint:disable-next-line:max-line-length
4545
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor';
46-
export {DataType, Rank, ShapeMap} from './types';
46+
// tslint:disable-next-line:max-line-length
47+
export {DataType, InferenceModel, ModelPredictConfig, NamedTensorMap, Rank, ShapeMap} from './types';
4748

4849
export * from './ops/ops';
4950
export {LSTMCellFunc} from './ops/lstm';

src/types.ts

+49-1
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,53 @@ export function sumOutType(type: DataType) {
123123
*/
124124
export type TensorContainer = void|Tensor|string|number|boolean|
125125
TensorContainerObject|TensorContainerArray;
126-
export interface TensorContainerObject { [x: string]: TensorContainer; }
126+
export interface TensorContainerObject {
127+
[x: string]: TensorContainer;
128+
}
127129
export interface TensorContainerArray extends Array<TensorContainer> {}
130+
131+
export interface ModelPredictConfig {
132+
/**
133+
* Optional. Batch size (Integer). If unspecified, it will default to 32.
134+
*/
135+
batchSize?: number;
136+
137+
/**
138+
* Optional. Verbosity mode. Defaults to false.
139+
*/
140+
verbose?: boolean;
141+
142+
/**
143+
* Optional. List of output node names to evaluate when running predict().
144+
* Defaults to the model's default output.
145+
*/
146+
outputs?: string|string[];
147+
}
148+
149+
/**
150+
* Common interface for a machine learning model that can do inference.
151+
*/
152+
export interface InferenceModel {
153+
/**
154+
* Execute the inference for the input tensors.
155+
*
156+
* @param input The input tensors, when there is single input for the model,
157+
* inputs param should be a Tensor. For models with mutliple inputs, inputs
158+
* params should be in either Tensor[] if the input order is fixed, or
159+
* otherwise NamedTensorMap format.
160+
* For batch inference execution, the tensors for each input need to be
161+
* concatenated together. For example with mobilenet, the required input shape
162+
* is [1, 244, 244, 3], which represents the [batch, height, width, channel].
163+
* If we are provide a batched data of 100 images, the input tensor should be
164+
* in the shape of [100, 244, 244, 3].
165+
*
166+
* @param config Prediction configuration for specifying the batch size and
167+
* output node names.
168+
*
169+
* @returns Inference result tensors. The output would be single Tensor if
170+
* model has single output node, otherwise Tensor[] or NamedTensorMap[] will
171+
* be returned for model with multiple outputs.
172+
*/
173+
predict(inputs: Tensor|Tensor[]|NamedTensorMap, config: ModelPredictConfig):
174+
Tensor|Tensor[]|NamedTensorMap;
175+
}

0 commit comments

Comments
 (0)