001/*-
002 * Copyright 2016 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.List;
013
014/**
015 * Class to run over a pair of datasets in parallel with NumPy broadcasting of second dataset
016 */
017public class BroadcastSingleIterator extends BroadcastSelfIterator {
018        private int[] bShape;
019        private int[] aStride;
020        private int[] bStride;
021
022        final private int endrank;
023
024        private final int[] aDelta, bDelta;
025        private final int aStep, bStep;
026        private int aMax, bMax;
027        private int aStart, bStart;
028
029        /**
030         * @param a dataset to iterate over
031         * @param b dataset to iterate over (will broadcast to first)
032         */
033        public BroadcastSingleIterator(Dataset a, Dataset b) {
034                super(a, b);
035
036                int[] aShape = a.getShapeRef();
037                maxShape = aShape;
038                List<int[]> fullShapes = BroadcastUtils.broadcastShapesToMax(maxShape, b.getShapeRef());
039                bShape = fullShapes.remove(0);
040
041                int rank = maxShape.length;
042                endrank = rank - 1;
043
044                bDataset = b.reshape(bShape);
045                int[] aOffset = new int[1];
046                aStride = AbstractDataset.createStrides(aDataset, aOffset);
047                bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape);
048
049                pos = new int[rank];
050                aDelta = new int[rank];
051                aStep = aDataset.getElementsPerItem();
052                bDelta = new int[rank];
053                bStep = bDataset.getElementsPerItem();
054                for (int j = endrank; j >= 0; j--) {
055                        aDelta[j] = aStride[j] * aShape[j];
056                        bDelta[j] = bStride[j] * bShape[j];
057                }
058                aStart = aOffset[0];
059                bStart = bDataset.getOffset();
060                aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE;
061                bMax = endrank < 0 ? bStep + bStart: Integer.MIN_VALUE;
062                reset();
063        }
064
065        @Override
066        public boolean hasNext() {
067                int j = endrank;
068                int oldB = bIndex;
069                for (; j >= 0; j--) {
070                        pos[j]++;
071                        aIndex += aStride[j];
072                        bIndex += bStride[j];
073                        if (pos[j] >= maxShape[j]) {
074                                pos[j] = 0;
075                                aIndex -= aDelta[j]; // reset these dimensions
076                                bIndex -= bDelta[j];
077                        } else {
078                                break;
079                        }
080                }
081                if (j == -1) {
082                        if (endrank >= 0) {
083                                return false;
084                        }
085                        aIndex += aStep;
086                        bIndex += bStep;
087                }
088
089                if (aIndex == aMax || bIndex == bMax) {
090                        return false;
091                }
092
093                if (read) {
094                        if (oldB != bIndex) {
095                                if (asDouble) {
096                                        bDouble = bDataset.getElementDoubleAbs(bIndex);
097                                } else {
098                                        bLong = bDataset.getElementLongAbs(bIndex);
099                                }
100                        }
101                }
102
103                return true;
104        }
105
106        /**
107         * @return shape of first broadcasted dataset
108         */
109        public int[] getFirstShape() {
110                return maxShape;
111        }
112
113        /**
114         * @return shape of second broadcasted dataset
115         */
116        public int[] getSecondShape() {
117                return bShape;
118        }
119
120        @Override
121        public void reset() {
122                for (int i = 0; i <= endrank; i++) {
123                        pos[i] = 0;
124                }
125
126                if (endrank >= 0) {
127                        pos[endrank] = -1;
128                        aIndex = aStart - aStride[endrank];
129                        bIndex = bStart - bStride[endrank];
130                } else {
131                        aIndex = aStart - aStep;
132                        bIndex = bStart - bStep;
133                }
134
135                if (aIndex == 0 || bIndex == 0 || (endrank >= 0 && bStride[endrank] == 0)) { // for zero-ranked datasets or extended shape
136                        if (read) {
137                                storeCurrentValues();
138                        }
139                }
140        }
141}