Skip to content
Snippets Groups Projects
Unverified Commit c5c214df authored by Jan Buethe's avatar Jan Buethe
Browse files

added rudimentary support for dumping nn.Conv2d layers

parent 25c65a0c
No related branches found
No related tags found
No related merge requests found
Pipeline #4097 passed
...@@ -28,4 +28,4 @@ from .c_writer import CWriter ...@@ -28,4 +28,4 @@ from .c_writer import CWriter
*/ */
""" """
from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_vector from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer, print_vector
\ No newline at end of file \ No newline at end of file
...@@ -291,6 +291,7 @@ def print_conv1d_layer(writer : CWriter, ...@@ -291,6 +291,7 @@ def print_conv1d_layer(writer : CWriter,
lin_weight = np.reshape(weight, (-1, weight.shape[-1])) lin_weight = np.reshape(weight, (-1, weight.shape[-1]))
print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize) print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize)
writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n") writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n")
writer.header.write(f"\n#define {name.upper()}_IN_SIZE {weight.shape[1]}\n") writer.header.write(f"\n#define {name.upper()}_IN_SIZE {weight.shape[1]}\n")
writer.header.write(f"\n#define {name.upper()}_STATE_SIZE ({weight.shape[1]} * ({weight.shape[0] - 1}))\n") writer.header.write(f"\n#define {name.upper()}_STATE_SIZE ({weight.shape[1]} * ({weight.shape[0] - 1}))\n")
...@@ -298,6 +299,29 @@ def print_conv1d_layer(writer : CWriter, ...@@ -298,6 +299,29 @@ def print_conv1d_layer(writer : CWriter,
return weight.shape[0] * weight.shape[1] return weight.shape[0] * weight.shape[1]
def print_conv2d_layer(writer : CWriter,
name : str,
weight : np.ndarray,
bias : np.ndarray,
scale : float=1/128,
quantize : bool=False):
if quantize:
print("[print_conv2d_layer] warning: quantize argument ignored")
bias_name = name + "_bias"
float_weight_name = name + "_weight_float"
print_vector(writer, weight, float_weight_name)
print_vector(writer, bias, bias_name)
# init function
out_channels, in_channels, ksize1, ksize2 = weight.shape
init_call = f'conv2d_init(&model->{name}, arrays, "{bias_name}", "{float_weight_name}", {in_channels}, {out_channels}, {ksize1}, {ksize2})'
writer.layer_dict[name] = ('Conv2dLayer', init_call)
def print_gru_layer(writer : CWriter, def print_gru_layer(writer : CWriter,
name : str, name : str,
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
""" """
from .torch import dump_torch_conv1d_weights, load_torch_conv1d_weights from .torch import dump_torch_conv1d_weights, load_torch_conv1d_weights
from .torch import dump_torch_conv2d_weights, load_torch_conv2d_weights
from .torch import dump_torch_dense_weights, load_torch_dense_weights from .torch import dump_torch_dense_weights, load_torch_dense_weights
from .torch import dump_torch_gru_weights, load_torch_gru_weights from .torch import dump_torch_gru_weights, load_torch_gru_weights
from .torch import dump_torch_embedding_weights, load_torch_embedding_weights from .torch import dump_torch_embedding_weights, load_torch_embedding_weights
......
...@@ -32,7 +32,7 @@ import os ...@@ -32,7 +32,7 @@ import os
import torch import torch
import numpy as np import numpy as np
from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer
def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
...@@ -138,6 +138,33 @@ def load_torch_conv1d_weights(where, conv): ...@@ -138,6 +138,33 @@ def load_torch_conv1d_weights(where, conv):
conv.bias.set_(torch.from_numpy(b)) conv.bias.set_(torch.from_numpy(b))
def dump_torch_conv2d_weights(where, conv, name='conv', scale=1/128, quantize=False):
w = conv.weight.detach().cpu().permute(0, 1, 3, 2).numpy().copy()
if conv.bias is None:
b = np.zeros(conv.out_channels, dtype=w.dtype)
else:
b = conv.bias.detach().cpu().numpy().copy()
if isinstance(where, CWriter):
return print_conv2d_layer(where, name, w, b, scale=scale, quantize=quantize)
else:
os.makedirs(where, exist_ok=True)
np.save(os.path.join(where, 'weight_oiwh.npy'), w)
np.save(os.path.join(where, 'bias.npy'), b)
def load_torch_conv2d_weights(where, conv):
with torch.no_grad():
w = np.load(os.path.join(where, 'weight_oiwh.npy'))
conv.weight.set_(torch.from_numpy(w).permute(0, 1, 3, 2))
if type(conv.bias) != type(None):
b = np.load(os.path.join(where, 'bias.npy'))
if conv.bias is not None:
conv.bias.set_(torch.from_numpy(b))
def dump_torch_embedding_weights(where, emb): def dump_torch_embedding_weights(where, emb):
os.makedirs(where, exist_ok=True) os.makedirs(where, exist_ok=True)
...@@ -162,6 +189,8 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs): ...@@ -162,6 +189,8 @@ def dump_torch_weights(where, module, name=None, verbose=False, **kwargs):
return dump_torch_gru_weights(where, module, name, **kwargs) return dump_torch_gru_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Conv1d): elif isinstance(module, torch.nn.Conv1d):
return dump_torch_conv1d_weights(where, module, name, **kwargs) return dump_torch_conv1d_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Conv2d):
return dump_torch_conv2d_weights(where, module, name, **kwargs)
elif isinstance(module, torch.nn.Embedding): elif isinstance(module, torch.nn.Embedding):
return dump_torch_embedding_weights(where, module) return dump_torch_embedding_weights(where, module)
else: else:
...@@ -175,6 +204,8 @@ def load_torch_weights(where, module): ...@@ -175,6 +204,8 @@ def load_torch_weights(where, module):
load_torch_gru_weights(where, module) load_torch_gru_weights(where, module)
elif isinstance(module, torch.nn.Conv1d): elif isinstance(module, torch.nn.Conv1d):
load_torch_conv1d_weights(where, module) load_torch_conv1d_weights(where, module)
elif isinstance(module, torch.nn.Conv2d):
load_torch_conv2d_weights(where, module)
elif isinstance(module, torch.nn.Embedding): elif isinstance(module, torch.nn.Embedding):
load_torch_embedding_weights(where, module) load_torch_embedding_weights(where, module)
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment