aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNiklas Halle <niklas@niklashalle.net>2021-07-26 09:40:42 +0200
committerNiklas Halle <niklas@niklashalle.net>2021-07-26 09:40:42 +0200
commite5fb4ee1ce3d85dd2666a7c0389b7019ce9a96ff (patch)
tree29a6d0d6a7037ecc5a4dbb049bf5f479c2f93f26
parent83373cb73979c9f44503f23b6ff45f76d5fa372a (diff)
downloadbachelor_thesis-e5fb4ee1ce3d85dd2666a7c0389b7019ce9a96ff.tar.gz
bachelor_thesis-e5fb4ee1ce3d85dd2666a7c0389b7019ce9a96ff.zip
.astype(np.float32)
-rw-r--r--code/python/utils.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/code/python/utils.py b/code/python/utils.py
index 556bf7f..464216d 100644
--- a/code/python/utils.py
+++ b/code/python/utils.py
@@ -50,7 +50,7 @@ def loss(y_true, y_pred, margin=1):
def load_tester(path):
with open(path) as f:
data = json.load(f)
- return np.asarray(data)
+ return np.asarray(data).astype(np.float32)
def load_data(type_name, data_dir='./generated_graphs'):
@@ -60,11 +60,11 @@ def load_data(type_name, data_dir='./generated_graphs'):
data = load_tester(file)
input_data.append(data)
- return np.asarray(input_data)
+ return np.asarray(input_data).astype(np.float32)
def load_embeddings():
- return np.asarray(load_data('emb'))
+ return np.asarray(load_data('emb')).astype(np.float32)
def load_betweenness(max_number_of_nodes):
@@ -74,7 +74,7 @@ def load_betweenness(max_number_of_nodes):
for graph_idx, graph_data in enumerate(data):
for member_idx, member_value in graph_data:
result[graph_idx][int(member_idx)] = member_value
- return result
+ return result.astype(np.float32)
def load_closeness(max_number_of_nodes):
@@ -84,7 +84,7 @@ def load_closeness(max_number_of_nodes):
for graph_idx, graph_data in enumerate(data):
for member_idx, member_value in graph_data:
result[graph_idx][int(member_idx)] = member_value
- return result
+ return result.astype(np.float32)
def load_electrical_closeness(max_number_of_nodes):
@@ -94,7 +94,7 @@ def load_electrical_closeness(max_number_of_nodes):
for graph_idx, graph_data in enumerate(data):
for member_idx, member_value in graph_data:
result[graph_idx][int(member_idx)] = member_value
- return result
+ return result.astype(np.float32)
def load_communities(max_number_of_nodes):
@@ -105,7 +105,7 @@ def load_communities(max_number_of_nodes):
for community_idx, community in enumerate(graph_data):
for member in community:
result[graph_idx][member] = community_idx
- return result
+ return result.astype(np.float32)
def get_max_number_of_nodes(embeddings):