diff --git a/ann_benchmarks/datasets.py b/ann_benchmarks/datasets.py index 2ee8551..ad80eca 100644 --- a/ann_benchmarks/datasets.py +++ b/ann_benchmarks/datasets.py @@ -69,50 +69,50 @@ def write_output(train, test, fn, distance, point_type='float', count=1000, SMIL f.close() print('Write Dataset %s' % fn) - f = h5sparse.File(fn, 'w') - f.attrs['distance'] = distance - f.attrs['point_type'] = point_type - print('train size: %9d * %4d' % train.shape) - print('test size: %9d * %4d' % test.shape) - if issparse(train): - f.create_dataset('train',data=train) - else: - f.create_dataset('train', train.shape, dtype=train.dtype)[:] = train - if issparse(test): - f.create_dataset('test',data=test) - else: - f.create_dataset('test', test.shape, dtype=test.dtype)[:] = test - neighbors = f.create_dataset('neighbors', (test.shape[0], count), dtype='i') - distances = f.create_dataset('distances', (test.shape[0], count), dtype='f') - - # use which method to compute the groundtruth - if issparse(train): - train = train.toarray() - method = 'bruteforce' + f = h5sparse.File(fn, 'w') + f.attrs['distance'] = distance + f.attrs['point_type'] = point_type + print('train size: %9d * %4d' % train.shape) + print('test size: %9d * %4d' % test.shape) + if issparse(train): + f.create_dataset('train',data=train) + else: + f.create_dataset('train', train.shape, dtype=train.dtype)[:] = train + if issparse(test): + f.create_dataset('test',data=test) + else: + f.create_dataset('test', test.shape, dtype=test.dtype)[:] = test + neighbors = f.create_dataset('neighbors', (test.shape[0], count), dtype='i') + distances = f.create_dataset('distances', (test.shape[0], count), dtype='f') + + # use which method to compute the groundtruth + if issparse(train): + train = train.toarray() + method = 'bruteforce' + if method == 'balltree': + tree = sklearn.neighbors.BallTree(train, leaf_size=1000000, metric=distance) + else: + bf = BruteForceBLAS(metric=distance, precision=train.dtype) + bf.fit(train) + + print(test) + for i, x in enumerate(test): + if i % 1 == 0: + print('%d/%d...' % (i, test.shape[0])) if method == 'balltree': - tree = sklearn.neighbors.BallTree(train, leaf_size=1000000, metric=distance) + dist, ind = tree.query([x], k=count) + neighbors[i] = ind[0] + distances[i] = dist[0] else: - bf = BruteForceBLAS(metric=distance, precision=train.dtype) - bf.fit(train) - - print(test) - for i, x in enumerate(test): - if i % 1 == 0: - print('%d/%d...' % (i, test.shape[0])) - if method == 'balltree': - dist, ind = tree.query([x], k=count) - neighbors[i] = ind[0] - distances[i] = dist[0] - else: - res = list(bf.query_with_distances(x, count)) - print(len(res)) - res.sort(key=lambda t: t[-1]) - neighbors[i] = [j for j, _ in res] - distances[i] = [d for _, d in res] - print(neighbors[i]) - print(distances[i]) - f.close() - print('Finish.') + res = list(bf.query_with_distances(x, count)) + print(len(res)) + res.sort(key=lambda t: t[-1]) + neighbors[i] = [j for j, _ in res] + distances[i] = [d for _, d in res] + print(neighbors[i]) + print(distances[i]) + f.close() + print('Finish.') def train_test_split(X, test_size=10000):