-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improving Python notebooks (no ticket)
- Loading branch information
searchivairus
committed
Feb 7, 2018
1 parent
077cbea
commit d515b32
Showing
3 changed files
with
463 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,376 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy \n", | ||
"import sys \n", | ||
"import nmslib \n", | ||
"import time \n", | ||
"import math \n", | ||
"from sklearn.neighbors import NearestNeighbors\n", | ||
"from sklearn.model_selection import train_test_split" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Just read the data\n", | ||
"all_data_matrix = numpy.loadtxt('../../sample_data/final128_10K.txt')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create a held-out query data set\n", | ||
"(data_matrix, query_matrix) = train_test_split(all_data_matrix, test_size = 0.1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"# of queries 1000, # of data points 9000\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"# of queries %d, # of data points %d\" % (query_matrix.shape[0], data_matrix.shape[0]) )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Set index parameters\n", | ||
"# These are the most important onese\n", | ||
"M = 15\n", | ||
"efC = 100\n", | ||
"\n", | ||
"num_threads = 4\n", | ||
"index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC, 'post' : 0,\n", | ||
" 'skip_optimized_index' : 1 # using non-optimized index!\n", | ||
" }" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Number of neighbors \n", | ||
"K=100" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Space name should correspond to the space name \n", | ||
"# used for brute-force search\n", | ||
"space_name='l2'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"9000" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# Intitialize the library, specify the space, the type of the vector and add data points \n", | ||
"index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) \n", | ||
"index.addDataPointBatch(data_matrix) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Index-time parameters {'M': 15, 'indexThreadQty': 4, 'efConstruction': 100, 'skip_optimized_index': 1, 'post': 0}\n", | ||
"Indexing time = 0.291947\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Create an index\n", | ||
"start = time.time()\n", | ||
"index.createIndex(index_time_params) \n", | ||
"end = time.time() \n", | ||
"print('Index-time parameters', index_time_params)\n", | ||
"print('Indexing time = %f' % (end-start))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Setting query-time parameters {'efSearch': 100}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Setting query-time parameters\n", | ||
"efS = 100\n", | ||
"query_time_params = {'efSearch': efS}\n", | ||
"print('Setting query-time parameters', query_time_params)\n", | ||
"index.setQueryTimeParams(query_time_params)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"kNN time total=0.037669 (sec), per query=0.000038 (sec), per query adjusted for thread number=0.000151 (sec)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Querying\n", | ||
"query_qty = query_matrix.shape[0]\n", | ||
"start = time.time() \n", | ||
"nbrs = index.knnQueryBatch(query_matrix, k = K, num_threads = num_threads)\n", | ||
"end = time.time() \n", | ||
"print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' % \n", | ||
" (end-start, float(end-start)/query_qty, num_threads*float(end-start)/query_qty)) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Computing gold-standard data\n", | ||
"Brute-force preparation time 0.001139\n", | ||
"brute-force kNN time total=0.319399 (sec), per query=0.000319 (sec)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Computing gold-standard data \n", | ||
"print('Computing gold-standard data')\n", | ||
"\n", | ||
"start = time.time()\n", | ||
"sindx = NearestNeighbors(n_neighbors=K, metric='l2', algorithm='brute').fit(data_matrix)\n", | ||
"end = time.time()\n", | ||
"\n", | ||
"print('Brute-force preparation time %f' % (end - start))\n", | ||
"\n", | ||
"start = time.time() \n", | ||
"gs = sindx.kneighbors(query_matrix)\n", | ||
"end = time.time()\n", | ||
"\n", | ||
"print('brute-force kNN time total=%f (sec), per query=%f (sec)' % \n", | ||
" (end-start, float(end-start)/query_qty) )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"kNN recall 0.993040\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Finally computing recall\n", | ||
"recall=0.0\n", | ||
"for i in range(0, query_qty):\n", | ||
" correct_set = set(gs[1][i])\n", | ||
" ret_set = set(nbrs[i][0])\n", | ||
" recall = recall + float(len(correct_set.intersection(ret_set))) / len(correct_set)\n", | ||
"recall = recall / query_qty\n", | ||
"print('kNN recall %f' % recall)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Save a meta index\n", | ||
"index.saveIndex('dense_index_nonoptim.bin')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Re-intitialize the library, specify the space, the type of the vector.\n", | ||
"newIndex = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"9000" | ||
] | ||
}, | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# For non-optimized indices or methods different from HNSW we DO need to re-add data points\n", | ||
"newIndex.addDataPointBatch(data_matrix) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Re-load the index and re-run queries\n", | ||
"newIndex.loadIndex('dense_index_nonoptim.bin')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Setting query-time parameters {'efSearch': 100}\n", | ||
"kNN time total=0.031991 (sec), per query=0.000032 (sec), per query adjusted for thread number=0.000128 (sec)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Setting query-time parameters and querying\n", | ||
"print('Setting query-time parameters', query_time_params)\n", | ||
"newIndex.setQueryTimeParams(query_time_params)\n", | ||
"\n", | ||
"query_qty = query_matrix.shape[0]\n", | ||
"start = time.time() \n", | ||
"new_nbrs = newIndex.knnQueryBatch(query_matrix, k = K, num_threads = num_threads)\n", | ||
"end = time.time() \n", | ||
"print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' % \n", | ||
" (end-start, float(end-start)/query_qty, num_threads*float(end-start)/query_qty)) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"kNN recall 0.993040\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Finally computing recall for the new result set\n", | ||
"recall=0.0\n", | ||
"for i in range(0, query_qty):\n", | ||
" correct_set = set(gs[1][i])\n", | ||
" ret_set = set(new_nbrs[i][0])\n", | ||
" recall = recall + float(len(correct_set.intersection(ret_set))) / len(correct_set)\n", | ||
"recall = recall / query_qty\n", | ||
"print('kNN recall %f' % recall)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.