Skip to content

Commit

Permalink
Merge pull request #1 from cdb17006/indentation_problem
Browse files Browse the repository at this point in the history
Indentation problem
  • Loading branch information
cjz18001 authored Jun 21, 2020
2 parents 6fb1309 + 998e41c commit 6db359e
Showing 1 changed file with 42 additions and 42 deletions.
84 changes: 42 additions & 42 deletions ann_benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,50 +69,50 @@ def write_output(train, test, fn, distance, point_type='float', count=1000, SMIL
f.close()

print('Write Dataset %s' % fn)
f = h5sparse.File(fn, 'w')
f.attrs['distance'] = distance
f.attrs['point_type'] = point_type
print('train size: %9d * %4d' % train.shape)
print('test size: %9d * %4d' % test.shape)
if issparse(train):
f.create_dataset('train',data=train)
else:
f.create_dataset('train', train.shape, dtype=train.dtype)[:] = train
if issparse(test):
f.create_dataset('test',data=test)
else:
f.create_dataset('test', test.shape, dtype=test.dtype)[:] = test
neighbors = f.create_dataset('neighbors', (test.shape[0], count), dtype='i')
distances = f.create_dataset('distances', (test.shape[0], count), dtype='f')

# use which method to compute the groundtruth
if issparse(train):
train = train.toarray()
method = 'bruteforce'
f = h5sparse.File(fn, 'w')
f.attrs['distance'] = distance
f.attrs['point_type'] = point_type
print('train size: %9d * %4d' % train.shape)
print('test size: %9d * %4d' % test.shape)
if issparse(train):
f.create_dataset('train',data=train)
else:
f.create_dataset('train', train.shape, dtype=train.dtype)[:] = train
if issparse(test):
f.create_dataset('test',data=test)
else:
f.create_dataset('test', test.shape, dtype=test.dtype)[:] = test
neighbors = f.create_dataset('neighbors', (test.shape[0], count), dtype='i')
distances = f.create_dataset('distances', (test.shape[0], count), dtype='f')

# use which method to compute the groundtruth
if issparse(train):
train = train.toarray()
method = 'bruteforce'
if method == 'balltree':
tree = sklearn.neighbors.BallTree(train, leaf_size=1000000, metric=distance)
else:
bf = BruteForceBLAS(metric=distance, precision=train.dtype)
bf.fit(train)

print(test)
for i, x in enumerate(test):
if i % 1 == 0:
print('%d/%d...' % (i, test.shape[0]))
if method == 'balltree':
tree = sklearn.neighbors.BallTree(train, leaf_size=1000000, metric=distance)
dist, ind = tree.query([x], k=count)
neighbors[i] = ind[0]
distances[i] = dist[0]
else:
bf = BruteForceBLAS(metric=distance, precision=train.dtype)
bf.fit(train)

print(test)
for i, x in enumerate(test):
if i % 1 == 0:
print('%d/%d...' % (i, test.shape[0]))
if method == 'balltree':
dist, ind = tree.query([x], k=count)
neighbors[i] = ind[0]
distances[i] = dist[0]
else:
res = list(bf.query_with_distances(x, count))
print(len(res))
res.sort(key=lambda t: t[-1])
neighbors[i] = [j for j, _ in res]
distances[i] = [d for _, d in res]
print(neighbors[i])
print(distances[i])
f.close()
print('Finish.')
res = list(bf.query_with_distances(x, count))
print(len(res))
res.sort(key=lambda t: t[-1])
neighbors[i] = [j for j, _ in res]
distances[i] = [d for _, d in res]
print(neighbors[i])
print(distances[i])
f.close()
print('Finish.')


def train_test_split(X, test_size=10000):
Expand Down

0 comments on commit 6db359e

Please sign in to comment.