Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Science prizes #85

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
185 changes: 85 additions & 100 deletions metrics_evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# import string\n",
Expand All @@ -36,8 +38,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import matplotlib as mpl\n",
Expand All @@ -64,8 +68,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"metricslist = ['Brier', 'LogLoss']\n",
Expand All @@ -90,8 +96,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"mystery = {}\n",
Expand All @@ -108,8 +116,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"snphotcc = {}\n",
Expand All @@ -131,8 +141,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"plasticc = {}\n",
Expand All @@ -141,38 +153,16 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"# old_snphotcc_names = []\n",
"# for prefix in ['templates_', 'wavelets_']:\n",
"# for suffix in ['boost_forest', 'knn', 'nb', 'neural_network', 'svm']:\n",
"# old_snphotcc_names.append(prefix+suffix+'.dat')\n",
"\n",
"# for i in range(len(snphotcc_names)):\n",
"# name = old_snphotcc_names[i]\n",
"# fileloc = dirname+'classifications/'+name\n",
"# snphotcc_info = pd.read_csv(fileloc, sep=' ')\n",
"# full = snphotcc_info.set_index('Object').join(truth_snphotcc.set_index('Object'))\n",
"# name = snphotcc_names[i]\n",
" \n",
"# truth = full['Type'] - 1\n",
"# snphotcc_truth_table = proclam.metrics.util.det_to_prob(truth)\n",
"# fileloc = 'examples/'+name+'/truth_table_'+name+'.csv'\n",
"# with open(fileloc, 'wb') as truth_place:\n",
"# np.savetxt(fileloc, snphotcc_truth_table, delimiter=' ')\n",
" \n",
"# probs = full[['1', '2', '3']]\n",
"# fileloc = 'examples/'+name+'/predicted_prob_'+name+'.csv'\n",
"# probs.to_csv(fileloc, sep=' ', index=False, header=True)"
]
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# more_names = snphotcc_names\n",
Expand All @@ -183,8 +173,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def make_class_pairs(data_info_dict):\n",
Expand All @@ -201,11 +193,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'label': 'ProClaM', 'names': ['Idealized', 'Guess', 'Tunnel', 'Broadbrush', 'Cruise', 'SubsumedTo', 'SubsumedFrom'], 'dirname': 'examples/ProClaM/', 'classifications': ['Idealized/predicted_prob_Idealized.csv', 'Guess/predicted_prob_Guess.csv', 'Tunnel/predicted_prob_Tunnel.csv', 'Broadbrush/predicted_prob_Broadbrush.csv', 'Cruise/predicted_prob_Cruise.csv', 'SubsumedTo/predicted_prob_SubsumedTo.csv', 'SubsumedFrom/predicted_prob_SubsumedFrom.csv'], 'truth_tables': ['Idealized/truth_table_Idealized.csv', 'Guess/truth_table_Guess.csv', 'Tunnel/truth_table_Tunnel.csv', 'Broadbrush/truth_table_Broadbrush.csv', 'Cruise/truth_table_Cruise.csv', 'SubsumedTo/truth_table_SubsumedTo.csv', 'SubsumedFrom/truth_table_SubsumedFrom.csv']}\n"
]
}
],
"source": [
"for dataset in [mystery, snphotcc, plasticc]:\n",
"for dataset in [ plasticc]: #mystery, snphotcc,\n",
" dataset = make_file_locs(dataset)\n",
" dataset['class_pairs'] = make_class_pairs(dataset)"
]
Expand All @@ -221,8 +223,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def plot_cm(probs, truth, name, loc=''):\n",
Expand All @@ -234,14 +238,17 @@
" plt.ylabel('true class')\n",
" plt.colorbar()\n",
" plt.title(name)\n",
" plt.savefig(loc+name+'_cm.png')\n",
" plt.close()"
" #plt.savefig(loc+name+'_cm.png')\n",
" plt.show()\n",
" #plt.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": []
},
Expand All @@ -263,6 +270,7 @@
" nobj_truth = np.shape(truth_values)[0]\n",
" nclass_truth = np.shape(truth_values)[1]\n",
" tvec = np.where(truth_values==1)[1]\n",
" print(tvec)\n",
"# if nclass_truth!= nclass:\n",
"# print('Truth table of size %i x %i and prob matrix of size %i x %i do not match up in size'%(nobj,nclass,nobj_truth,nclass_truth))\n",
"# else:\n",
Expand All @@ -274,8 +282,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def make_patch_spines_invisible(ax):\n",
Expand Down Expand Up @@ -324,6 +334,7 @@
" plt.legend(handles, metric_names)\n",
" plt.suptitle(title)\n",
" plt.savefig(fileloc)\n",
" plt.show()\n",
" return"
]
},
Expand All @@ -336,13 +347,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"ename": "KeyError",
"evalue": "'class_pairs'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-13-15f56f172c88>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mmystery\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msnphotcc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplasticc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetricslist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'names'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mcc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpair\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'class_pairs'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpair\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprobm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtruthv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_class_pairs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpair\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m#loc=dataset['dirname'], title=dataset['label']+' '+dataset['names'][cc])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyError\u001b[0m: 'class_pairs'"
]
}
],
"source": [
"for dataset in [mystery, snphotcc, plasticc]:\n",
" data = np.empty((len(metricslist), len(dataset['names'])))\n",
" for cc, pair in enumerate(dataset['class_pairs']):\n",
" print(pair)\n",
" probm, truthv = read_class_pairs(pair, dataset, cc)#loc=dataset['dirname'], title=dataset['label']+' '+dataset['names'][cc])\n",
"# plot_cm(probm, truthv, str(cc), loc='./sandbox/')\n",
" det = proclam.metrics.util.prob_to_det(probm)\n",
Expand All @@ -361,55 +387,14 @@
"# metric_plot(dataset, metricslist, markerlist, colors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# more_data = np.empty((len(metricslist), len(more_names)))\n",
"# for cc, pair in enumerate(more_class_pairs):\n",
"# probm, truthv = read_class_pairs(pair, dirname)\n",
"# for count, metric in enumerate(metricslist):\n",
"# D = getattr(proclam.metrics, metric)()\n",
"# hm = D.evaluate(probm, truthv)\n",
"# more_data[count][cc] = hm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# metric_plot(more_names, metricslist, more_data, markerlist, colors, title='SNPhotCC', fileloc=dirname+'snphotccdata.png')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# data = np.empty((len(metricslist), len(names)))\n",
"# for cc, pair in enumerate(class_pairs):\n",
"# probm, truthv = read_class_pairs(pair, dirname)\n",
"# for count, metric in enumerate(metricslist):\n",
"# D = getattr(proclam.metrics, metric)()\n",
"# hm = D.evaluate(probm, truthv)\n",
"# data[count][cc] = hm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# metric_plot(names, metricslist, data, markerlist, colors, title='Mystery Dataset', fileloc=dirname+'mysterydata.png')"
]
"source": []
},
{
"cell_type": "code",
Expand Down Expand Up @@ -443,7 +428,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.6.8"
}
},
"nbformat": 4,
Expand Down
Loading