mx.symbol.amp_multicast

Description

Cast function used by AMP, that casts its inputs to the common widest type.

It casts only between low precision float/FP32 and does not do anything for other types.

Usage

mx.symbol.amp_multicast(...)

Arguments

Argument

Description

data

NDArray-or-Symbol[].

Weights

num.outputs

int, required.

Number of input/output pairs to be casted to the widest type.

cast.narrow

boolean, optional, default=0.

Whether to cast to the narrowest type

name

string, optional.

Name of the resulting symbol.