Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
JuliaDT/dt.jl
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
144 lines (124 sloc)
3.58 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Create a decision tree that used the Shannon Entropy to decide how to best branch the tree based | |
on the data provide above. | |
CAN SURVIVE WITHOUT | |
INSTANCE COMING TO SURFACE? HAS FLIPPERS? FISH? | |
1 Yes Yes Yes | |
2 Yes Yes Yes | |
3 Yes No No | |
4 No Yes No | |
5 No Yes No | |
Check if every item in the dataset is in the same class: | |
If so return the class label | |
Else | |
find the best feature to split the data split the dataset | |
create a branch node | |
for each split | |
call createBranch and add the result to the branch node | |
return branch node | |
""" | |
function shannon_entropy(x, y) | |
""" | |
Calculates the shannon entropy of the sub-dataset | |
Inputs | |
x dataset, array with length 1 | |
y class, array with length 1 | |
Output | |
shannon entropy | |
""" | |
# Init subarray for labels | |
sub_y = [] | |
# Get labels for when where we want to split == 1 | |
for (i,v) in enumerate(x) | |
if v == 1 | |
push!(sub_y,y[i]) | |
end | |
end | |
# Get total amount of sub responses | |
total = length(sub_y) | |
# Init dictionary to store counts of each sub response | |
D = Dict() | |
# Iterate through each sub label, determine counts, thencalculate shannon entropy | |
for label in sub_y | |
if get(D,label,0) == 0 | |
D[label] = 1 | |
else | |
D[label] += 1 | |
end | |
end | |
shannon_entropy = 0 | |
for key in keys(D) | |
shannon_entropy += (D[key]/total) * log2(D[key]/total) | |
end | |
return -shannon_entropy | |
end | |
function create_tree(labels, X, Y) | |
# If all items are same class, return labels | |
if all(y->y==Y[1], Y) | |
return Dict(labels[end] => Y[1]) | |
end | |
# Find best feature to split on | |
best_index = 0 | |
best_entropy = 100 | |
best_label = "" | |
# For each feature we have currently | |
for split_index in 1:length(labels)-1 | |
s_e = shannon_entropy(X[:,split_index], Y) | |
if s_e < best_entropy | |
best_entropy = s_e | |
best_index = split_index | |
best_label = labels[best_index] | |
end | |
end | |
# Remove splitting label from labels | |
deleteat!(labels, best_index) | |
# Split dataset into left and right (no and yes) | |
left_X = [] | |
right_X = [] | |
left_Y = [] | |
right_Y = [] | |
for i=1:size(X,1) | |
if X[i,best_index] == 1 | |
right_X = [right_X; deleteat!(X[i,:], best_index)] | |
right_Y = push!(right_Y, Y[i]) | |
else | |
left_X = [left_X; deleteat!(X[i,:], best_index)] | |
left_Y = push!(left_Y, Y[i]) | |
end | |
end | |
return Dict(best_label=>Dict("No" => create_tree(labels, left_X, left_Y), "Yes" => create_tree(labels, right_X, right_Y))) | |
end | |
function print_tree(tree, tabs) | |
if tree isa Dict | |
for key in keys(tree) | |
println("$tabs$key") | |
print_tree(tree[key], string("----",tabs)) | |
end | |
else | |
println("\t$tabs$tree") | |
end | |
end | |
function main() | |
#= | |
dataset = ["survive without coming to surface" "has flippers" "fish"; | |
1 1 "yes"; | |
1 1 "yes"; | |
1 0 "no"; | |
0 1 "no"; | |
0 1 "no"] | |
=# | |
labels = ["survive without coming to surface", "has flippers", "fish"] | |
X = [1 1; | |
1 1; | |
1 0; | |
0 1; | |
0 1 ] | |
Y = ["yes"; | |
"yes"; | |
"no"; | |
"no"; | |
"no" ] | |
tree = create_tree(labels,X,Y) | |
print_tree(tree,"") | |
end | |
main() | |