Image API

Overview

This document summarizes supporting functions and iterators to read and process images provided in

mxnet.image Image Iterators and image augmentation functions

Image processing functions

image.imdecode Decode an image to an NDArray.
image.scale_down Scales down crop size if it’s larger than image size.
image.resize_short Resizes shorter edge to size.
image.fixed_crop Crop src at fixed location, and (optionally) resize it to size.
image.random_crop Randomly crop src with size (width, height).
image.center_crop Crops the image src to the given size by trimming on all four sides and preserving the center of the image.
image.color_normalize Normalize src with mean and std.
image.random_size_crop Randomly crop src with size.

Image iterators

Iterators support loading image from binary Record IO and raw image files.

image.ImageIter Image data iterator with a large number of augmentation choices.
>>> data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 224, 224), label_width=1,
                                   path_imglist='data/custom.lst')
>>> data_iter.reset()
>>> for data in data_iter:
...     d = data.data[0]
...     print(d.shape)
>>> # we can apply lots of augmentations as well
>>> data_iter = mx.image.ImageIter(4, (3, 224, 224), path_imglist='data/custom.lst',
                                   rand_crop=True, rand_resize=True, rand_mirror=True, mean=True,
                                   brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1,
                                   pca_noise=0.1, rand_gray=0.05)
>>> data = data_iter.next()
>>> # specify augmenters manually is also supported
>>> data_iter = mx.image.ImageIter(32, (3, 224, 224), path_rec='data/caltech.rec',
                                   path_imgidx='data/caltech.idx', shuffle=True,
                                   aug_list=[mx.image.HorizontalFlipAug(0.5),
                                   mx.image.ColorJitterAug(0.1, 0.1, 0.1)])

We use helper function to initialize augmenters

image.CreateAugmenter Creates an augmenter list.

A list of supporting augmenters

image.Augmenter Image Augmenter base class
image.SequentialAug Composing a sequential augmenter list.
image.RandomOrderAug Apply list of augmenters in random order
image.ResizeAug Make resize shorter edge to size augmenter.
image.ForceResizeAug Force resize to size regardless of aspect ratio
image.RandomCropAug Make random crop augmenter
image.RandomSizedCropAug Make random crop with random resizing and random aspect ratio jitter augmenter.
image.CenterCropAug Make center crop augmenter.
image.BrightnessJitterAug Random brightness jitter augmentation.
image.ContrastJitterAug Random contrast jitter augmentation.
image.SaturationJitterAug Random saturation jitter augmentation.
image.HueJitterAug Random hue jitter augmentation.
image.ColorJitterAug Apply random brightness, contrast and saturation jitter in random order.
image.LightingAug Add PCA based noise.
image.ColorNormalizeAug Mean and std normalization.
image.RandomGrayAug Randomly convert to gray image.
image.HorizontalFlipAug Random horizontal flip.
image.CastAug Cast to float32

Similar to ImageIter, ImageDetIter is designed for Object Detection tasks.

image.ImageDetIter Image iterator with a large number of augmentation choices for detection.
>>> data_iter = mx.image.ImageDetIter(batch_size=4, data_shape=(3, 224, 224),
                                      path_imglist='data/train.lst')
>>> data_iter.reset()
>>> for data in data_iter:
...     d = data.data[0]
...     l = data.label[0]
...     print(d.shape)
...     print(l.shape)

Unlike object classification with fixed label_width, object count may vary from image to image. Thus we have special format for object detection labels. Usually the lst file generated by tools/im2rec.py is a list of

index_0  label_0  image_path_0
index_1  label_1  image_path_1

Where label_N is a number a of fixed-width vector. The format of label used in object detection is a variable length vector

A  B  [extra header]  [(object0), (object1), ... (objectN)]

Where A is the width of header (2 + length of extra header), B is the width of each object. Extra header is optional and used for inserting helper information such as (width, height). Each object is usually 5 or 6 numbers describing the object properties, for example: [id, xmin, ymin, xmax, ymax, difficulty] Putting all together, we have a lst file for object detection:

0  4  5  640  480  1  0.1  0.2  0.8  0.9  2  0.5  0.3  0.6  0.8  data/xxx.jpg
1  4  5  480  640  3  0.05  0.16  0.75  0.9  data/yyy.jpg
2  4  5  500  600  2  0.6  0.1  0.7  0.5  0  0.1  0.3  0.2  0.4  3  0.25  0.25  0.3  0.3 data/zzz.jpg
...

A helper function to initialize Augmenters for Object detection task

image.CreateDetAugmenter Create augmenters for detection.

Since Detection task is sensitive to object localization, any modification to image that introduced localization shift will require correction to label, and a list of augmenters specific for Object detection is provided

image.DetBorrowAug Borrow standard augmenter from image classification.
image.DetRandomSelectAug Randomly select one augmenter to apply, with chance to skip all.
image.DetHorizontalFlipAug Random horizontal flipping.
image.DetRandomCropAug Random cropping with constraints
image.DetRandomPadAug Random padding augmenter.

API Reference

Image Iterators and image augmentation functions

class mxnet.image.ImageIter(batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, data_name='data', label_name='softmax_label', **kwargs)[source]

Image data iterator with a large number of augmentation choices. This iterator supports reading from both .rec files and raw image files.

To load input images from .rec files, use path_imgrec parameter and to load from raw image files, use path_imglist and path_root parameters.

To use data partition (for distributed training) or shuffling, specify path_imgidx parameter.

Parameters:
  • batch_size (int) – Number of examples per batch.
  • data_shape (tuple) – Data shape in (channels, height, width) format. For now, only RGB image with 3 channels is supported.
  • label_width (int, optional) – Number of labels per example. The default label width is 1.
  • path_imgrec (str) – Path to image record file (.rec). Created with tools/im2rec.py or bin/im2rec.
  • path_imglist (str) – Path to image list (.lst). Created with tools/im2rec.py or with custom script. Format: Tab separated record of index, one or more labels and relative_path_from_root.
  • imglist (list) – A list of images with the label(s). Each item is a list [imagelabel: float or list of float, imgpath].
  • path_root (str) – Root folder of image files.
  • path_imgidx (str) – Path to image index file. Needed for partition and shuffling when using .rec source.
  • shuffle (bool) – Whether to shuffle all images at the start of each iteration or not. Can be slow for HDD.
  • part_index (int) – Partition index.
  • num_parts (int) – Total number of partitions.
  • data_name (str) – Data name for provided symbols.
  • label_name (str) – Label name for provided symbols.
  • kwargs – More arguments for creating augmenter. See mx.image.CreateAugmenter.
reset()[source]

Resets the iterator to the beginning of the data.

next_sample()[source]

Helper function for reading in next sample.

next()[source]

Returns the next batch of data.

check_data_shape(data_shape)[source]

Checks if the input data shape is valid

check_valid_image(data)[source]

Checks if the input data is valid

imdecode(s)[source]

Decodes a string or byte string to an NDArray. See mx.img.imdecode for more details.

read_image(fname)[source]

Reads an input image fname and returns the decoded raw bytes.

>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
augmentation_transform(data)[source]

Transforms input data with specified augmentation.

postprocess_data(datum)[source]

Final postprocessing step before image is loaded into the batch.

image.imdecode(buf, *args, **kwargs)

Decode an image to an NDArray.

Note: imdecode uses OpenCV (not the CV2 Python library). MXNet must have been built with USE_OPENCV=1 for imdecode to work.

Parameters:
  • buf (str/bytes or numpy.ndarray) – Binary image data as string or numpy ndarray.
  • flag (int, optional, default=1) – 1 for three channel color output. 0 for grayscale output.
  • to_rgb (int, optional, default=1) – 1 for RGB formatted output (MXNet default). 0 for BGR formatted output (OpenCV default).
  • out (NDArray, optional) – Output buffer. Use None for automatic allocation.
Returns:

An NDArray containing the image.

Return type:

NDArray

Example

>>> with open("flower.jpg", 'rb') as fp:
...     str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image)
>>> image

Set flag parameter to 0 to get grayscale output

>>> with open("flower.jpg", 'rb') as fp:
...     str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image, flag=0)
>>> image

Set to_rgb parameter to 0 to get output in OpenCV format (BGR)

>>> with open("flower.jpg", 'rb') as fp:
...     str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image, to_rgb=0)
>>> image

image.scale_down(src_size, size)

Scales down crop size if it’s larger than image size.

If width/height of the crop is larger than the width/height of the image, sets the width/height to the width/height of the image.

Parameters:
  • src_size (tuple of int) – Size of the image in (width, height) format.
  • size (tuple of int) – Size of the crop in (width, height) format.
Returns:

A tuple containing the scaled crop size in (width, height) format.

Return type:

tuple of int

Example

>>> src_size = (640,480)
>>> size = (720,120)
>>> new_size = mx.img.scale_down(src_size, size)
>>> new_size
(640,106)
image.resize_short(src, size, interp=2)

Resizes shorter edge to size.

Note: resize_short uses OpenCV (not the CV2 Python library). MXNet must have been built with OpenCV for resize_short to work.

Resizes the original image by setting the shorter edge to size and setting the longer edge accordingly. Resizing function is called from OpenCV.

Parameters:
  • src (NDArray) – The original image.
  • size (int) – The length to be set for the shorter edge.
  • interp (int, optional, default=2) – Interpolation method used for resizing the image. Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). More details can be found in the documentation of OpenCV, please refer to http://docs.opencv.org/master/da/d54/group__imgproc__transform.html.
Returns:

An ‘NDArray’ containing the resized image.

Return type:

NDArray

Example

>>> with open("flower.jpeg", 'rb') as fp:
...     str_image = fp.read()
...
>>> image = mx.img.imdecode(str_image)
>>> image

>>> size = 640
>>> new_image = mx.img.resize_short(image, size)
>>> new_image

image.fixed_crop(src, x0, y0, w, h, size=None, interp=2)

Crop src at fixed location, and (optionally) resize it to size.

Parameters:
  • src (NDArray) – Input image
  • x0 (int) – Left boundary of the cropping area
  • y0 (int) – Top boundary of the cropping area
  • w (int) – Width of the cropping area
  • h (int) – Height of the cropping area
  • size (tuple of (w, h)) – Optional, resize to new size after cropping
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
Returns:

An NDArray containing the cropped image.

Return type:

NDArray

image.random_crop(src, size, interp=2)

Randomly crop src with size (width, height). Upsample result if src is smaller than size.

Parameters:
  • src (Source image NDArray) –
  • size (Size of the crop formatted as (width, height). If the size is larger) – than the image, then the source image is upsampled to size and returned.
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
Returns:

  • NDArray – An NDArray containing the cropped image.
  • Tuple – A tuple (x, y, width, height) where (x, y) is top-left position of the crop in the original image and (width, height) are the dimensions of the cropped image.

Example

>>> im = mx.nd.array(cv2.imread("flower.jpg"))
>>> cropped_im, rect  = mx.image.random_crop(im, (100, 100))
>>> print cropped_im

>>> print rect
(20, 21, 100, 100)
image.center_crop(src, size, interp=2)

Crops the image src to the given size by trimming on all four sides and preserving the center of the image. Upsamples if src is smaller than size.

Note

This requires MXNet to be compiled with USE_OPENCV.

Parameters:
  • src (NDArray) – Binary source image data.
  • size (list or tuple of int) – The desired output image size.
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
Returns:

  • NDArray – The cropped image.
  • Tuple – (x, y, width, height) where x, y are the positions of the crop in the original image and width, height the dimensions of the crop.

Example

>>> with open("flower.jpg", 'rb') as fp:
...     str_image = fp.read()
...
>>> image = mx.image.imdecode(str_image)
>>> image

>>> cropped_image, (x, y, width, height) = mx.image.center_crop(image, (1000, 500))
>>> cropped_image

>>> x, y, width, height
(1241, 910, 1000, 500)
image.color_normalize(src, mean, std=None)

Normalize src with mean and std.

Parameters:
  • src (NDArray) – Input image
  • mean (NDArray) – RGB mean to be subtracted
  • std (NDArray) – RGB standard deviation to be divided
Returns:

An NDArray containing the normalized image.

Return type:

NDArray

image.random_size_crop(src, size, min_area, ratio, interp=2)

Randomly crop src with size. Randomize area and aspect ratio.

Parameters:
  • src (NDArray) – Input image
  • size (tuple of (int, int)) – Size of the crop formatted as (width, height).
  • min_area (int) – Minimum area to be maintained after cropping
  • ratio (tuple of (float, float)) – Aspect ratio range as (min_aspect_ratio, max_aspect_ratio)
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
Returns:

  • NDArray – An NDArray containing the cropped image.
  • Tuple – A tuple (x, y, width, height) where (x, y) is top-left position of the crop in the original image and (width, height) are the dimensions of the cropped image.

class mxnet.image.Augmenter(**kwargs)[source]

Image Augmenter base class

dumps()[source]

Saves the Augmenter to string

Returns:JSON formatted string that describes the Augmenter.
Return type:str
class mxnet.image.ResizeAug(size, interp=2)[source]

Make resize shorter edge to size augmenter.

Parameters:
  • size (int) – The length to be set for the shorter edge.
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
class mxnet.image.ForceResizeAug(size, interp=2)[source]

Force resize to size regardless of aspect ratio

Parameters:
  • size (tuple of (int, int)) – The desired size as in (width, height)
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
class mxnet.image.RandomCropAug(size, interp=2)[source]

Make random crop augmenter

Parameters:
  • size (int) – The length to be set for the shorter edge.
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
class mxnet.image.RandomSizedCropAug(size, min_area, ratio, interp=2)[source]

Make random crop with random resizing and random aspect ratio jitter augmenter.

Parameters:
  • size (tuple of (int, int)) – Size of the crop formatted as (width, height).
  • min_area (int) – Minimum area to be maintained after cropping
  • ratio (tuple of (float, float)) – Aspect ratio range as (min_aspect_ratio, max_aspect_ratio)
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
class mxnet.image.CenterCropAug(size, interp=2)[source]

Make center crop augmenter.

Parameters:
  • size (list or tuple of int) – The desired output image size.
  • interp (int, optional, default=2) – Interpolation method. See resize_short for details.
class mxnet.image.RandomOrderAug(ts)[source]

Apply list of augmenters in random order

Parameters:ts (list of augmenters) – A series of augmenters to be applied in random order
class mxnet.image.BrightnessJitterAug(brightness)[source]

Random brightness jitter augmentation.

Parameters:brightness (float) – The brightness jitter ratio range, [0, 1]
class mxnet.image.ContrastJitterAug(contrast)[source]

Random contrast jitter augmentation.

Parameters:contrast (float) – The contrast jitter ratio range, [0, 1]
class mxnet.image.SaturationJitterAug(saturation)[source]

Random saturation jitter augmentation.

Parameters:saturation (float) – The saturation jitter ratio range, [0, 1]
class mxnet.image.HueJitterAug(hue)[source]

Random hue jitter augmentation.

Parameters:hue (float) – The hue jitter ratio range, [0, 1]
class mxnet.image.ColorJitterAug(brightness, contrast, saturation)[source]

Apply random brightness, contrast and saturation jitter in random order.

Parameters:
  • brightness (float) – The brightness jitter ratio range, [0, 1]
  • contrast (float) – The contrast jitter ratio range, [0, 1]
  • saturation (float) – The saturation jitter ratio range, [0, 1]
class mxnet.image.LightingAug(alphastd, eigval, eigvec)[source]

Add PCA based noise.

Parameters:
  • alphastd (float) – Noise level
  • eigval (3x1 np.array) – Eigen values
  • eigvec (3x3 np.array) – Eigen vectors
class mxnet.image.ColorNormalizeAug(mean, std)[source]

Mean and std normalization.

Parameters:
  • mean (NDArray) – RGB mean to be subtracted
  • std (NDArray) – RGB standard deviation to be divided
class mxnet.image.RandomGrayAug(p)[source]

Randomly convert to gray image.

Parameters:p (float) – Probability to convert to grayscale
class mxnet.image.HorizontalFlipAug(p)[source]

Random horizontal flip.

Parameters:p (float) – Probability to flip image horizontally
class mxnet.image.CastAug(typ='float32')[source]

Cast to float32

image.CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, rand_mirror=False, mean=None, std=None, brightness=0, contrast=0, saturation=0, hue=0, pca_noise=0, rand_gray=0, inter_method=2)

Creates an augmenter list.

Parameters:
  • data_shape (tuple of int) – Shape for output data
  • resize (int) – Resize shorter edge if larger than 0 at the begining
  • rand_crop (bool) – Whether to enable random cropping other than center crop
  • rand_resize (bool) – Whether to enable random sized cropping, require rand_crop to be enabled
  • rand_gray (float) – [0, 1], probability to convert to grayscale for all channels, the number of channels will not be reduced to 1
  • rand_mirror (bool) – Whether to apply horizontal flip to image with probability 0.5
  • mean (np.ndarray or None) – Mean pixel values for [r, g, b]
  • std (np.ndarray or None) – Standard deviations for [r, g, b]
  • brightness (float) – Brightness jittering range (percent)
  • contrast (float) – Contrast jittering range (percent)
  • saturation (float) – Saturation jittering range (percent)
  • hue (float) – Hue jittering range (percent)
  • pca_noise (float) – Pca noise level (percent)
  • inter_method (int, default=2(Area-based)) –

    Interpolation method for all resizing operations

    Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK).

Examples

>>> # An example of creating multiple augmenters
>>> augs = mx.image.CreateAugmenter(data_shape=(3, 300, 300), rand_mirror=True,
...    mean=True, brightness=0.125, contrast=0.125, rand_gray=0.05,
...    saturation=0.125, pca_noise=0.05, inter_method=10)
>>> # dump the details
>>> for aug in augs:
...    aug.dumps()
class mxnet.image.ImageDetIter(batch_size, data_shape, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, data_name='data', label_name='label', **kwargs)[source]

Image iterator with a large number of augmentation choices for detection.

Parameters:
  • aug_list (list or None) – Augmenter list for generating distorted images
  • batch_size (int) – Number of examples per batch.
  • data_shape (tuple) – Data shape in (channels, height, width) format. For now, only RGB image with 3 channels is supported.
  • path_imgrec (str) – Path to image record file (.rec). Created with tools/im2rec.py or bin/im2rec.
  • path_imglist (str) – Path to image list (.lst). Created with tools/im2rec.py or with custom script. Format: Tab separated record of index, one or more labels and relative_path_from_root.
  • imglist (list) – A list of images with the label(s). Each item is a list [imagelabel: float or list of float, imgpath].
  • path_root (str) – Root folder of image files.
  • path_imgidx (str) – Path to image index file. Needed for partition and shuffling when using .rec source.
  • shuffle (bool) – Whether to shuffle all images at the start of each iteration or not. Can be slow for HDD.
  • part_index (int) – Partition index.
  • num_parts (int) – Total number of partitions.
  • data_name (str) – Data name for provided symbols.
  • label_name (str) – Name for detection labels
  • kwargs – More arguments for creating augmenter. See mx.image.CreateDetAugmenter.
reshape(data_shape=None, label_shape=None)[source]

Reshape iterator for data_shape or label_shape.

Parameters:
  • data_shape (tuple or None) – Reshape the data_shape to the new shape if not None
  • label_shape (tuple or None) – Reshape label shape to new shape if not None
next()[source]

Override the function for returning next batch.

augmentation_transform(data, label)[source]

Override Transforms input data with specified augmentations.

check_label_shape(label_shape)[source]

Checks if the new label shape is valid

draw_next(color=None, thickness=2, mean=None, std=None, clip=True, waitKey=None, window_name='draw_next')[source]

Display next image with bounding boxes drawn.

Parameters:
  • color (tuple) – Bounding box color in RGB, use None for random color
  • thickness (int) – Bounding box border thickness
  • mean (True or numpy.ndarray) – Compensate for the mean to have better visual effect
  • std (True or numpy.ndarray) – Revert standard deviations
  • clip (bool) – If true, clip to [0, 255] for better visual effect
  • waitKey (None or int) – Hold the window for waitKey milliseconds if set, skip ploting if None
  • window_name (str) – Plot window name if waitKey is set.
Returns:

Return type:

numpy.ndarray

Examples

>>> # use draw_next to get images with bounding boxes drawn
>>> iterator = mx.image.ImageDetIter(1, (3, 600, 600), path_imgrec='train.rec')
>>> for image in iterator.draw_next(waitKey=None):
...     # display image
>>> # or let draw_next display using cv2 module
>>> for image in iterator.draw_next(waitKey=0, window_name='disp'):
...     pass
sync_label_shape(it, verbose=False)[source]

Synchronize label shape with the input iterator. This is useful when train/validation iterators have different label padding.

Parameters:
  • it (ImageDetIter) – The other iterator to synchronize
  • verbose (bool) – Print verbose log if true
Returns:

The synchronized other iterator, the internal label shape is updated as well.

Return type:

ImageDetIter

Examples

>>> train_iter = mx.image.ImageDetIter(32, (3, 300, 300), path_imgrec='train.rec')
>>> val_iter = mx.image.ImageDetIter(32, (3, 300, 300), path.imgrec='val.rec')
>>> train_iter.label_shape
(30, 6)
>>> val_iter.label_shape
(25, 6)
>>> val_iter = train_iter.sync_label_shape(val_iter, verbose=False)
>>> train_iter.label_shape
(30, 6)
>>> val_iter.label_shape
(30, 6)
class mxnet.image.DetAugmenter(**kwargs)[source]

Detection base augmenter

dumps()[source]

Saves the Augmenter to string

Returns:JSON formatted string that describes the Augmenter.
Return type:str
class mxnet.image.DetBorrowAug(augmenter)[source]

Borrow standard augmenter from image classification. Which is good once you know label won’t be affected after this augmenter.

Parameters:augmenter (mx.image.Augmenter) – The borrowed standard augmenter which has no effect on label
class mxnet.image.DetRandomSelectAug(aug_list, skip_prob=0)[source]

Randomly select one augmenter to apply, with chance to skip all.

Parameters:
  • aug_list (list of DetAugmenter) – The random selection will be applied to one of the augmenters
  • skip_prob (float) – The probability to skip all augmenters and return input directly
class mxnet.image.DetHorizontalFlipAug(p)[source]

Random horizontal flipping.

Parameters:p (float) – chance [0, 1] to flip
class mxnet.image.DetRandomCropAug(min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 1.0), min_eject_coverage=0.3, max_attempts=50)[source]

Random cropping with constraints

Parameters:
  • min_object_covered (float, default=0.1) – The cropped area of the image must contain at least this fraction of any bounding box supplied. The value of this parameter should be non-negative. In the case of 0, the cropped area does not need to overlap any of the bounding boxes supplied.
  • min_eject_coverage (float, default=0.3) – The minimum coverage of cropped sample w.r.t its original size. With this constraint, objects that have marginal area after crop will be discarded.
  • aspect_ratio_range (tuple of floats, default=(0.75, 1.33)) – The cropped area of the image must have an aspect ratio = width / height within this range.
  • area_range (tuple of floats, default=(0.05, 1.0)) – The cropped area of the image must contain a fraction of the supplied image within in this range.
  • max_attempts (int, default=50) – Number of attempts at generating a cropped/padded region of the image of the specified constraints. After max_attempts failures, return the original image.
class mxnet.image.DetRandomPadAug(aspect_ratio_range=(0.75, 1.33), area_range=(1.0, 3.0), max_attempts=50, pad_val=(128, 128, 128))[source]

Random padding augmenter.

Parameters:
  • aspect_ratio_range (tuple of floats, default=(0.75, 1.33)) – The padded area of the image must have an aspect ratio = width / height within this range.
  • area_range (tuple of floats, default=(1.0, 3.0)) – The padded area of the image must be larger than the original area
  • max_attempts (int, default=50) – Number of attempts at generating a padded region of the image of the specified constraints. After max_attempts failures, return the original image.
  • pad_val (float or tuple of float, default=(128, 128, 128)) – pixel value to be filled when padding is enabled.
image.CreateDetAugmenter(data_shape, resize=0, rand_crop=0, rand_pad=0, rand_gray=0, rand_mirror=False, mean=None, std=None, brightness=0, contrast=0, saturation=0, pca_noise=0, hue=0, inter_method=2, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 3.0), min_eject_coverage=0.3, max_attempts=50, pad_val=(127, 127, 127))

Create augmenters for detection.

Parameters:
  • data_shape (tuple of int) – Shape for output data
  • resize (int) – Resize shorter edge if larger than 0 at the begining
  • rand_crop (float) – [0, 1], probability to apply random cropping
  • rand_pad (float) – [0, 1], probability to apply random padding
  • rand_gray (float) – [0, 1], probability to convert to grayscale for all channels
  • rand_mirror (bool) – Whether to apply horizontal flip to image with probability 0.5
  • mean (np.ndarray or None) – Mean pixel values for [r, g, b]
  • std (np.ndarray or None) – Standard deviations for [r, g, b]
  • brightness (float) – Brightness jittering range (percent)
  • contrast (float) – Contrast jittering range (percent)
  • saturation (float) – Saturation jittering range (percent)
  • hue (float) – Hue jittering range (percent)
  • pca_noise (float) – Pca noise level (percent)
  • inter_method (int, default=2(Area-based)) –

    Interpolation method for all resizing operations

    Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK).

  • min_object_covered (float) – The cropped area of the image must contain at least this fraction of any bounding box supplied. The value of this parameter should be non-negative. In the case of 0, the cropped area does not need to overlap any of the bounding boxes supplied.
  • min_eject_coverage (float) – The minimum coverage of cropped sample w.r.t its original size. With this constraint, objects that have marginal area after crop will be discarded.
  • aspect_ratio_range (tuple of floats) – The cropped area of the image must have an aspect ratio = width / height within this range.
  • area_range (tuple of floats) – The cropped area of the image must contain a fraction of the supplied image within in this range.
  • max_attempts (int) – Number of attempts at generating a cropped/padded region of the image of the specified constraints. After max_attempts failures, return the original image.
  • pad_val (float) – Pixel value to be filled when padding is enabled. pad_val will automatically be subtracted by mean and divided by std if applicable.

Examples

>>> # An example of creating multiple augmenters
>>> augs = mx.image.CreateDetAugmenter(data_shape=(3, 300, 300), rand_crop=0.5,
...    rand_pad=0.5, rand_mirror=True, mean=True, brightness=0.125, contrast=0.125,
...    saturation=0.125, pca_noise=0.05, inter_method=10, min_object_covered=[0.3, 0.5, 0.9],
...    area_range=(0.3, 3.0))
>>> # dump the details
>>> for aug in augs:
...    aug.dumps()