Skip to content

Commit 2caf533

Browse files
committed
Use isinstance instead of type(layer) ==
These were suggested in RedaOps#31 by @mmphego to support inheritance
1 parent 7fd257c commit 2caf533

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

ann_visualizer/visualize.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,44 +47,44 @@ def ann_viz(model, view=False, filename="network.gv", title="My Neural Network")
4747
if layer == model.layers[0]:
4848
input_layer = layer.input_shape[1]
4949
hidden_layers_nr += 1
50-
if type(layer) == Dense:
50+
if isinstance(layer, Dense):
5151
hidden_layers.append(layer.output_shape[1])
5252
layer_types.append("Dense")
5353
else:
5454
hidden_layers.append(1)
55-
if type(layer) == Conv2D:
55+
if isinstance(layer, Conv2D):
5656
layer_types.append("Conv2D")
57-
elif type(layer) == MaxPooling2D:
57+
elif isinstance(layer, MaxPooling2D):
5858
layer_types.append("MaxPooling2D")
59-
elif type(layer) == Dropout:
59+
elif isinstance(layer, Dropout):
6060
layer_types.append("Dropout")
61-
elif type(layer) == Flatten:
61+
elif isinstance(layer, Flatten):
6262
layer_types.append("Flatten")
63-
elif type(layer) == Activation:
63+
elif isinstance(layer, Activation):
6464
layer_types.append("Activation")
6565
else:
6666
if layer == model.layers[-1]:
6767
output_layer = layer.output_shape[1]
6868
else:
6969
hidden_layers_nr += 1
70-
if type(layer) == Dense:
70+
if isinstance(layer, Dense):
7171
hidden_layers.append(layer.output_shape[1])
7272
layer_types.append("Dense")
7373
else:
7474
hidden_layers.append(1)
75-
if type(layer) == Conv2D:
75+
if isinstance(layer, Conv2D):
7676
layer_types.append("Conv2D")
77-
elif type(layer) == MaxPooling2D:
77+
elif isinstance(layer, MaxPooling2D):
7878
layer_types.append("MaxPooling2D")
79-
elif type(layer) == Dropout:
79+
elif isinstance(layer, Dropout):
8080
layer_types.append("Dropout")
81-
elif type(layer) == Flatten:
81+
elif isinstance(layer, Flatten):
8282
layer_types.append("Flatten")
83-
elif type(layer) == Activation:
83+
elif isinstance(layer, Activation):
8484
layer_types.append("Activation")
8585
last_layer_nodes = input_layer
8686
nodes_up = input_layer
87-
if type(model.layers[0]) != Dense:
87+
if not isinstance(model.layers[0], Dense):
8888
last_layer_nodes = 1
8989
nodes_up = 1
9090
input_layer = 1
@@ -94,7 +94,7 @@ def ann_viz(model, view=False, filename="network.gv", title="My Neural Network")
9494
g.graph_attr.update(splines="false", nodesep='1', ranksep='2')
9595
# Input Layer
9696
with g.subgraph(name='cluster_input') as c:
97-
if type(model.layers[0]) == Dense:
97+
if isinstance(model.layers[0], Dense):
9898
the_label = title+'\n\n\n\nInput Layer'
9999
if model.layers[0].input_shape[1] > 10:
100100
the_label += " (+"+str(model.layers[0].input_shape[1] - 10)+")"
@@ -107,7 +107,7 @@ def ann_viz(model, view=False, filename="network.gv", title="My Neural Network")
107107
c.attr(rank='same')
108108
c.node_attr.update(color="#2ecc71", style="filled", fontcolor="#2ecc71", shape="circle")
109109

110-
elif type(model.layers[0]) == Conv2D:
110+
elif isinstance(model.layers[0], Conv2D):
111111
# Conv2D Input visualizing
112112
the_label = title+'\n\n\n\nInput Layer'
113113
c.attr(color="white", label=the_label)
@@ -193,7 +193,7 @@ def ann_viz(model, view=False, filename="network.gv", title="My Neural Network")
193193
nodes_up += 1
194194

195195
with g.subgraph(name='cluster_output') as c:
196-
if type(model.layers[-1]) == Dense:
196+
if isinstance(model.layers[-1], Dense):
197197
c.attr(color='white')
198198
c.attr(rank='same')
199199
c.attr(labeljust="1")

0 commit comments

Comments
 (0)