Skip to content
Snippets Groups Projects
Verified Commit 627aa7f5 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Packet loss generation model

parent 7d328f5b
No related merge requests found
Pipeline #4704 passed
import torch
from torch import nn
import torch.nn.functional as F
class LossGen(nn.Module):
def __init__(self, gru1_size=16, gru2_size=16):
super(LossGen, self).__init__()
self.gru1_size = gru1_size
self.gru2_size = gru2_size
self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True)
self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
self.dense_out = nn.Linear(self.gru2_size, 1)
def forward(self, loss, perc, states=None):
#print(states)
device = loss.device
batch_size = loss.size(0)
if states is None:
gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
else:
gru1_state = states[0]
gru2_state = states[1]
x = torch.cat([loss, perc], dim=-1)
gru1_out, gru1_state = self.gru1(x, gru1_state)
gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
return self.dense_out(gru2_out), [gru1_state, gru2_state]
#!/bin/sh
#directory containing the loss files
datadir=$1
for i in $datadir/*_is_lost.txt
do
perc=`cat $i | awk '{a+=$1}END{print a/NR}'`
echo $perc $i
done > percentage_list.txt
sort -n percentage_list.txt | awk '{print $2}' > percentage_sorted.txt
for i in `cat percentage_sorted.txt`
do
cat $i
done > loss_sorted.txt
import lossgen
import os
import argparse
import torch
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='CELPNet model')
parser.add_argument('percentage', type=float, help='percentage loss')
parser.add_argument('output', type=str, help='path to output file (ascii)')
parser.add_argument('--length', type=int, help="length of sequence to generate", default=500)
args = parser.parse_args()
checkpoint = torch.load(args.model, map_location='cpu')
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
states=None
last = torch.zeros((1,1,1))
perc = torch.tensor((args.percentage,))[None,None,:]
seq = torch.zeros((0,1,1))
one = torch.ones((1,1,1))
zero = torch.zeros((1,1,1))
if __name__ == '__main__':
for i in range(args.length):
prob, states = model(last, perc, states=states)
prob = torch.sigmoid(prob)
states[0] = states[0].detach()
states[1] = states[1].detach()
loss = one if np.random.rand() < prob else zero
last = loss
seq = torch.cat([seq, loss])
np.savetxt(args.output, seq[:,:,0].numpy().astype('int'), fmt='%d')
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
from scipy.signal import lfilter
import os
import lossgen
class LossDataset(torch.utils.data.Dataset):
def __init__(self,
loss_file,
sequence_length=997):
self.sequence_length = sequence_length
self.loss = np.loadtxt(loss_file, dtype='float32')
self.nb_sequences = self.loss.shape[0]//self.sequence_length
self.loss = self.loss[:self.nb_sequences*self.sequence_length]
self.perc = lfilter(np.array([.001], dtype='float32'), np.array([1., -.999], dtype='float32'), self.loss)
self.loss = np.reshape(self.loss, (self.nb_sequences, self.sequence_length, 1))
self.perc = np.reshape(self.perc, (self.nb_sequences, self.sequence_length, 1))
def __len__(self):
return self.nb_sequences
def __getitem__(self, index):
r0 = np.random.normal(scale=.02, size=(1,1)).astype('float32')
r1 = np.random.normal(scale=.02, size=(self.sequence_length,1)).astype('float32')
return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1]
adam_betas = [0.8, 0.99]
adam_eps = 1e-8
batch_size=512
lr_decay = 0.0001
lr = 0.001
epsilon = 1e-5
epochs = 20
checkpoint_dir='checkpoint'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = dict()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint['model_args'] = ()
checkpoint['model_kwargs'] = {'gru1_size': 16, 'gru2_size': 48}
model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
dataset = LossDataset('loss_sorted.txt')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
# learning rate scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
if __name__ == '__main__':
model.to(device)
for epoch in range(1, epochs + 1):
running_loss = 0
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (loss, perc) in enumerate(tepoch):
optimizer.zero_grad()
loss = loss.to(device)
perc = perc.to(device)
out, _ = model(loss, perc)
out = torch.sigmoid(out[:,:-1,:])
target = loss[:,1:,:]
loss = torch.mean(-target*torch.log(out+epsilon) - (1-target)*torch.log(1-out+epsilon))
loss.backward()
optimizer.step()
scheduler.step()
running_loss += loss.detach().cpu().item()
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
)
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'lossgen_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = running_loss / len(dataloader)
checkpoint['epoch'] = epoch
torch.save(checkpoint, checkpoint_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment