001/*- 002 * Copyright 2016 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.ArrayList; 013import java.util.Arrays; 014import java.util.List; 015 016public final class BroadcastUtils { 017 018 /** 019 * Calculate shapes for broadcasting 020 * @param oldShape 021 * @param size 022 * @param newShape 023 * @return broadcasted shape and full new shape or null if it cannot be done 024 */ 025 public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int... newShape) { 026 if (newShape == null) 027 return null; 028 029 int brank = newShape.length; 030 if (brank == 0) { 031 if (size == 1) 032 return new int[][] {oldShape, newShape}; 033 return null; 034 } 035 036 if (Arrays.equals(oldShape, newShape)) 037 return new int[][] {oldShape, newShape}; 038 039 int offset = brank - oldShape.length; 040 if (offset < 0) { // when new shape is incomplete 041 newShape = padShape(newShape, -offset); 042 offset = 0; 043 } 044 045 int[] bshape; 046 if (offset > 0) { // new shape has extra dimensions 047 bshape = padShape(oldShape, offset); 048 } else { 049 bshape = oldShape; 050 } 051 052 for (int i = 0; i < brank; i++) { 053 if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) { 054 return null; 055 } 056 } 057 058 return new int[][] {bshape, newShape}; 059 } 060 061 /** 062 * Pad shape by prefixing with ones 063 * @param shape 064 * @param padding 065 * @return new shape or old shape if padding is zero 066 */ 067 public static int[] padShape(final int[] shape, final int padding) { 068 if (padding < 0) 069 throw new IllegalArgumentException("Padding must be zero or greater"); 070 071 if (padding == 0) 072 return shape; 073 074 final int[] nshape = new int[shape.length + padding]; 075 Arrays.fill(nshape, 1); 076 System.arraycopy(shape, 0, nshape, padding, shape.length); 077 return nshape; 078 } 079 080 /** 081 * Take in shapes and broadcast them to same rank 082 * @param shapes 083 * @return list of broadcasted shapes plus the first entry is the maximum shape 084 */ 085 public static List<int[]> broadcastShapes(int[]... shapes) { 086 int maxRank = -1; 087 for (int[] s : shapes) { 088 if (s == null) 089 continue; 090 091 int r = s.length; 092 if (r > maxRank) { 093 maxRank = r; 094 } 095 } 096 097 List<int[]> newShapes = new ArrayList<int[]>(); 098 for (int[] s : shapes) { 099 if (s == null) 100 continue; 101 newShapes.add(padShape(s, maxRank - s.length)); 102 } 103 104 int[] maxShape = new int[maxRank]; 105 for (int i = 0; i < maxRank; i++) { 106 int m = -1; 107 for (int[] s : newShapes) { 108 int l = s[i]; 109 if (l > m) { 110 if (m > 1) { 111 throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum"); 112 } 113 m = l; 114 } 115 } 116 maxShape[i] = m; 117 } 118 119 checkShapes(maxShape, newShapes); 120 newShapes.add(0, maxShape); 121 return newShapes; 122 } 123 124 /** 125 * Take in shapes and broadcast them to maximum shape 126 * @param maxShape 127 * @param shapes 128 * @return list of broadcasted shapes 129 */ 130 public static List<int[]> broadcastShapesToMax(int[] maxShape, int[]... shapes) { 131 int maxRank = maxShape.length; 132 for (int[] s : shapes) { 133 if (s == null) 134 continue; 135 136 int r = s.length; 137 if (r > maxRank) { 138 throw new IllegalArgumentException("A shape exceeds given rank of maximum shape"); 139 } 140 } 141 142 List<int[]> newShapes = new ArrayList<int[]>(); 143 for (int[] s : shapes) { 144 if (s == null) 145 continue; 146 newShapes.add(padShape(s, maxRank - s.length)); 147 } 148 149 checkShapes(maxShape, newShapes); 150 return newShapes; 151 } 152 153 private static void checkShapes(int[] maxShape, List<int[]> newShapes) { 154 for (int i = 0; i < maxShape.length; i++) { 155 int m = maxShape[i]; 156 for (int[] s : newShapes) { 157 int l = s[i]; 158 if (l != 1 && l != m) { 159 throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum"); 160 } 161 } 162 } 163 } 164 165 static Dataset createDataset(final Dataset a, final Dataset b, final int[] shape) { 166 final Class<? extends Dataset> rc; 167 final int ar = a.getRank(); 168 final int br = b.getRank(); 169 Class<? extends Dataset> tc = InterfaceUtils.getBestInterface(a.getClass(), b.getClass()); 170 if (ar == 0 ^ br == 0) { // ignore type of zero-rank dataset unless it's floating point 171 if (ar == 0) { 172 rc = a.hasFloatingPointElements() ? tc : b.getClass(); 173 } else { 174 rc = b.hasFloatingPointElements() ? tc : a.getClass(); 175 } 176 } else { 177 rc = tc; 178 } 179 final int ia = a.getElementsPerItem(); 180 final int ib = b.getElementsPerItem(); 181 182 return DatasetFactory.zeros(ia > ib ? ia : ib, rc, shape); 183 } 184 185 /** 186 * Check if dataset item sizes are compatible 187 * <p> 188 * Dataset a is considered compatible with the output dataset if any of the 189 * conditions are true: 190 * <ul> 191 * <li>o is undefined</li> 192 * <li>a has item size equal to o's</li> 193 * <li>a has item size equal to 1</li> 194 * <li>o has item size equal to 1</li> 195 * </ul> 196 * @param a input dataset a 197 * @param o output dataset (can be null) 198 */ 199 static void checkItemSize(Dataset a, Dataset o) { 200 final int isa = a.getElementsPerItem(); 201 if (o != null) { 202 final int iso = o.getElementsPerItem(); 203 if (isa != iso && isa != 1 && iso != 1) { 204 throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'"); 205 } 206 } 207 } 208 209 /** 210 * Check if dataset item sizes are compatible 211 * <p> 212 * Dataset a is considered compatible with the output dataset if any of the 213 * conditions are true: 214 * <ul> 215 * <li>a has item size equal to b's</li> 216 * <li>a has item size equal to 1</li> 217 * <li>b has item size equal to 1</li> 218 * <li>a or b are single-valued</li> 219 * </ul> 220 * and, o is undefined, or any of the following are true: 221 * <ul> 222 * <li>o has item size equal to maximum of a and b's</li> 223 * <li>o has item size equal to 1</li> 224 * <li>a and b have item sizes of 1</li> 225 * </ul> 226 * @param a input dataset a 227 * @param b input dataset b 228 * @param o output dataset 229 */ 230 static void checkItemSize(Dataset a, Dataset b, Dataset o) { 231 final int isa = a.getElementsPerItem(); 232 final int isb = b.getElementsPerItem(); 233 if (isa != isb && isa != 1 && isb != 1) { 234 // exempt single-value dataset case too 235 if ((isa == 1 || b.getSize() != 1) && (isb == 1 || a.getSize() != 1) ) { 236 throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another"); 237 } 238 } 239 if (o != null && o.getDType() != Dataset.BOOL) { 240 final int ism = Math.max(isa, isb); 241 final int iso = o.getElementsPerItem(); 242 if (iso != ism && iso != 1 && ism != 1) { 243 throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'"); 244 } 245 } 246 } 247 248 /** 249 * Create a stride array from a dataset to a broadcast shape 250 * @param a dataset 251 * @param broadcastShape 252 * @return stride array 253 */ 254 public static int[] createBroadcastStrides(Dataset a, final int[] broadcastShape) { 255 return createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape); 256 } 257 258 /** 259 * Create a stride array from a dataset to a broadcast shape 260 * @param isize 261 * @param oShape original shape 262 * @param oStride original stride 263 * @param broadcastShape 264 * @return stride array 265 */ 266 public static int[] createBroadcastStrides(final int isize, final int[] oShape, final int[] oStride, final int[] broadcastShape) { 267 int rank = oShape.length; 268 if (broadcastShape.length != rank) { 269 throw new IllegalArgumentException("Dataset must have same rank as broadcast shape"); 270 } 271 272 int[] stride = new int[rank]; 273 if (oStride == null) { 274 int s = isize; 275 for (int j = rank - 1; j >= 0; j--) { 276 if (broadcastShape[j] == oShape[j]) { 277 stride[j] = s; 278 s *= oShape[j]; 279 } else { 280 stride[j] = 0; 281 } 282 } 283 } else { 284 for (int j = 0; j < rank; j++) { 285 if (broadcastShape[j] == oShape[j]) { 286 stride[j] = oStride[j]; 287 } else { 288 stride[j] = 0; 289 } 290 } 291 } 292 293 return stride; 294 } 295 296 /** 297 * Converts and broadcast all objects as datasets of same shape 298 * @param objects 299 * @return all as broadcasted to same shape 300 */ 301 public static Dataset[] convertAndBroadcast(Object... objects) { 302 final int n = objects.length; 303 304 Dataset[] datasets = new Dataset[n]; 305 int[][] shapes = new int[n][]; 306 for (int i = 0; i < n; i++) { 307 Dataset d = DatasetFactory.createFromObject(objects[i]); 308 datasets[i] = d; 309 shapes[i] = d.getShapeRef(); 310 } 311 312 List<int[]> nShapes = BroadcastUtils.broadcastShapes(shapes); 313 int[] mshape = nShapes.get(0); 314 for (int i = 0; i < n; i++) { 315 datasets[i] = datasets[i].getBroadcastView(mshape); 316 } 317 318 return datasets; 319 } 320}