@@ -47,44 +47,44 @@ def ann_viz(model, view=False, filename="network.gv", title="My Neural Network")
4747 if layer == model .layers [0 ]:
4848 input_layer = layer .input_shape [1 ]
4949 hidden_layers_nr += 1
50- if type (layer ) == Dense :
50+ if isinstance (layer , Dense ) :
5151 hidden_layers .append (layer .output_shape [1 ])
5252 layer_types .append ("Dense" )
5353 else :
5454 hidden_layers .append (1 )
55- if type (layer ) == Conv2D :
55+ if isinstance (layer , Conv2D ) :
5656 layer_types .append ("Conv2D" )
57- elif type (layer ) == MaxPooling2D :
57+ elif isinstance (layer , MaxPooling2D ) :
5858 layer_types .append ("MaxPooling2D" )
59- elif type (layer ) == Dropout :
59+ elif isinstance (layer , Dropout ) :
6060 layer_types .append ("Dropout" )
61- elif type (layer ) == Flatten :
61+ elif isinstance (layer , Flatten ) :
6262 layer_types .append ("Flatten" )
63- elif type (layer ) == Activation :
63+ elif isinstance (layer , Activation ) :
6464 layer_types .append ("Activation" )
6565 else :
6666 if layer == model .layers [- 1 ]:
6767 output_layer = layer .output_shape [1 ]
6868 else :
6969 hidden_layers_nr += 1
70- if type (layer ) == Dense :
70+ if isinstance (layer , Dense ) :
7171 hidden_layers .append (layer .output_shape [1 ])
7272 layer_types .append ("Dense" )
7373 else :
7474 hidden_layers .append (1 )
75- if type (layer ) == Conv2D :
75+ if isinstance (layer , Conv2D ) :
7676 layer_types .append ("Conv2D" )
77- elif type (layer ) == MaxPooling2D :
77+ elif isinstance (layer , MaxPooling2D ) :
7878 layer_types .append ("MaxPooling2D" )
79- elif type (layer ) == Dropout :
79+ elif isinstance (layer , Dropout ) :
8080 layer_types .append ("Dropout" )
81- elif type (layer ) == Flatten :
81+ elif isinstance (layer , Flatten ) :
8282 layer_types .append ("Flatten" )
83- elif type (layer ) == Activation :
83+ elif isinstance (layer , Activation ) :
8484 layer_types .append ("Activation" )
8585 last_layer_nodes = input_layer
8686 nodes_up = input_layer
87- if type (model .layers [0 ]) != Dense :
87+ if not isinstance (model .layers [0 ], Dense ) :
8888 last_layer_nodes = 1
8989 nodes_up = 1
9090 input_layer = 1
@@ -94,7 +94,7 @@ def ann_viz(model, view=False, filename="network.gv", title="My Neural Network")
9494 g .graph_attr .update (splines = "false" , nodesep = '1' , ranksep = '2' )
9595 # Input Layer
9696 with g .subgraph (name = 'cluster_input' ) as c :
97- if type (model .layers [0 ]) == Dense :
97+ if isinstance (model .layers [0 ], Dense ) :
9898 the_label = title + '\n \n \n \n Input Layer'
9999 if model .layers [0 ].input_shape [1 ] > 10 :
100100 the_label += " (+" + str (model .layers [0 ].input_shape [1 ] - 10 )+ ")"
@@ -107,7 +107,7 @@ def ann_viz(model, view=False, filename="network.gv", title="My Neural Network")
107107 c .attr (rank = 'same' )
108108 c .node_attr .update (color = "#2ecc71" , style = "filled" , fontcolor = "#2ecc71" , shape = "circle" )
109109
110- elif type (model .layers [0 ]) == Conv2D :
110+ elif isinstance (model .layers [0 ], Conv2D ) :
111111 # Conv2D Input visualizing
112112 the_label = title + '\n \n \n \n Input Layer'
113113 c .attr (color = "white" , label = the_label )
@@ -193,7 +193,7 @@ def ann_viz(model, view=False, filename="network.gv", title="My Neural Network")
193193 nodes_up += 1
194194
195195 with g .subgraph (name = 'cluster_output' ) as c :
196- if type (model .layers [- 1 ]) == Dense :
196+ if isinstance (model .layers [- 1 ], Dense ) :
197197 c .attr (color = 'white' )
198198 c .attr (rank = 'same' )
199199 c .attr (labeljust = "1" )
0 commit comments