CCOD-DQN (Contention Window Optimization in IEEE 802.11ax Networks with Deep Reinforcement Learning)

 

Please read "Contention Window Optimization in IEEE 802.11ax Networks with Deep Reinforcement Learning" article published at WCNC 2021. Preprint available at Arxivhttps://arxiv.org/pdf/2003.01492 first. The original code for the paper can be found at https://github.com/wwydmanski/RLinWiFi. But in this lab, I will try to show how to measure the throughputs for the traditional 802.11ax (CSMACA) and CCOD-DQN. But for CCOD-DQN, I use PARL framework, not tensorflow or keras, to do the reinforcement learning.

 

Please follow the instructions at https://github.com/tkn-tub/ns3-gym to install ns3-gym. Also, follow the instructions at https://github.com/PaddlePaddle/PARL to install PARL.

 

[steps]

Prepare cw.cc and test_dqn.py under scratch/myrlwifi

 

cw.cc (code is from https://github.com/wwydmanski/RLinWiFi/blob/master/linear-mesh/cw.cc)

#include "ns3/core-module.h"

#include "ns3/network-module.h"

#include "ns3/applications-module.h"

#include "ns3/wifi-module.h"

#include "ns3/mobility-module.h"

#include "ns3/csma-module.h"

#include "ns3/internet-module.h"

#include "ns3/flow-monitor-module.h"

#include "ns3/opengym-module.h"

#include "ns3/propagation-module.h"

#include "ns3/ipv4-flow-classifier.h"

#include "ns3/yans-wifi-channel.h"

#include <cmath>

#include <ctime>

#include <sstream>

#include <fstream>

#include <string>

#include <math.h>

#include <ctime>  

#include <iomanip>

#include <deque>

#include <algorithm>

#include <csignal>

 

#define PI 3.14159265

 

using namespace ns3;

using namespace std;

 

NS_LOG_COMPONENT_DEFINE ("wifi1");

void recordHistory();

 

double SimTime = 100.0;

uint64_t lastTotalRx = 0;

uint32_t mactxno,macrxno,phyrxok,phyrxerror,phytx;

 

Ptr<FlowMonitor> monitor;

FlowMonitorHelper flowmon;

 

double envStepTime = 0.1;

double simulationTime = 10;

double current_time = 0.0;

bool verbose = false;

int end_delay = 0;

bool dry_run = false;

 

uint32_t CW = 0;

uint32_t history_length = 20;

string type = "discrete";

deque<float> history;

 

Ptr<PacketSink> sinkApp;

 

Ptr<OpenGymSpace> MyGetObservationSpace(void)

{

    current_time += envStepTime;

 

    float low = 0.0;

    float high = 10.0;

    std::vector<uint32_t> shape = {

        history_length,

    };

    std::string dtype = TypeNameGet<float>();

    Ptr<OpenGymBoxSpace> space = CreateObject<OpenGymBoxSpace>(low, high, shape, dtype);

    if (verbose)

        NS_LOG_UNCOND("MyGetObservationSpace: " << space);

    return space;

}

 

Ptr<OpenGymSpace> MyGetActionSpace(void)

{

    float low = 0.0;

    float high = 10.0;

    std::vector<uint32_t> shape = {

        1,

    };

    std::string dtype = TypeNameGet<float>();

    Ptr<OpenGymBoxSpace> space = CreateObject<OpenGymBoxSpace>(low, high, shape, dtype);

    if (verbose)

        NS_LOG_UNCOND("MyGetActionSpace: " << space);

    return space;

}

 

uint64_t g_rxPktNum = 0;

uint64_t g_txPktNum = 0;

uint64_t my_rxPktNum=0;

 

std::string MyGetExtraInfo(void)

{

    static float ticks = 0.0;

    static float lastValue = 0.0;

    //g_rxPktNum = sinkApp->GetTotalRxPkt();

    g_rxPktNum = my_rxPktNum;

    //std::cout << "in MyGetExtraInfo(), g_rxPktNum=" << g_rxPktNum << std::endl;

    float obs = g_rxPktNum - lastValue;

    lastValue = g_rxPktNum;

    ticks += envStepTime;

 

    float sentMbytes = obs * (1500 - 20 - 8 - 8) * 8.0 / 1024 / 1024;

 

    std::string myInfo = std::to_string(sentMbytes);

    myInfo = myInfo + "|" + to_string(CW);  

 

    if (verbose)

        NS_LOG_UNCOND("MyGetExtraInfo: " << myInfo);

 

    return myInfo;

}

 

bool MyExecuteActions(Ptr<OpenGymDataContainer> action)

{

    if (verbose)

        NS_LOG_UNCOND("MyExecuteActions: " << action);

 

    Ptr<OpenGymBoxContainer<float>> box = DynamicCast<OpenGymBoxContainer<float>>(action);

    std::vector<float> actionVector = box->GetData();

 

    if (type == "discrete")

    {

        CW = pow(2, int(4 + actionVector.at(0)));

    }

    else if (type == "continuous")

    {

        CW = pow(2, actionVector.at(0) + 4);

    }

    else if (type == "direct_continuous")

    {

        CW = actionVector.at(0);

    }

    else

    {

        std::cout << "Unsupported agent type!" << endl;

        exit(0);

    }

 

    if (verbose) {

        NS_LOG_UNCOND("actionVector.at(0): " << actionVector.at(0));

    }

 

    uint32_t min_cw = 16;

    uint32_t max_cw = 1024;

 

    CW = min(max_cw, max(CW, min_cw));

 

    if (verbose) {

        NS_LOG_UNCOND("CW: " << CW);

    }

 

    if(!dry_run){

        //Config::Set("/$ns3::NodeListPriv/NodeList/*/$ns3::Node/DeviceList/*/$ns3::WifiNetDevice/Mac/$ns3::RegularWifiMac/BE_Txop/$ns3::QosTxop/MinCw", UintegerValue(CW));

        //Config::Set("/$ns3::NodeListPriv/NodeList/*/$ns3::Node/DeviceList/*/$ns3::WifiNetDevice/Mac/$ns3::RegularWifiMac/BE_Txop/$ns3::QosTxop/MaxCw", UintegerValue(CW));

        Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MinCw", UintegerValue(CW));

        Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MaxCw", UintegerValue(CW));

        //std::cout << "MinCw and MaxCw are set to " << CW << std::endl;

    }

    return true;

}

 

float MyGetReward(void)

{

    static float ticks = 0.0;

    static uint32_t last_packets = 0;

    static float last_reward = 0.0;

    ticks += envStepTime;

    g_rxPktNum = my_rxPktNum;

    //g_rxPktNum = sinkApp->GetTotalRxPkt();

    float res = g_rxPktNum - last_packets;

    //Need to understand why

    float reward = res * (1500 - 20 - 8 - 8) * 8.0 / 1024 / 1024 / (5 * 150 * envStepTime) * 10;

 

    last_packets = g_rxPktNum;

 

    if (ticks <= 2 * envStepTime)

        return 0.0;

 

    if (verbose)

        NS_LOG_UNCOND("MyGetReward: " << reward);

 

    if(reward>1.0f || reward<0.0f)

        reward = last_reward;

    last_reward = reward;

    return last_reward;

}

 

Ptr<OpenGymDataContainer> MyGetObservation()

{

    recordHistory();

 

    std::vector<uint32_t> shape = {

        history_length,

    };

    Ptr<OpenGymBoxContainer<float>> box = CreateObject<OpenGymBoxContainer<float>>(shape);

 

    for (uint32_t i = 0; i < history.size(); i++)

    {

        if (history[i] >= -100 && history[i] <= 100)

            box->AddValue(history[i]);

        else

            box->AddValue(0);

    }

    for (uint32_t i = history.size(); i < history_length; i++)

    {

        box->AddValue(0);

    }

    if (verbose)

        NS_LOG_UNCOND("MyGetObservation: " << box);

    return box;

}

 

bool MyGetGameOver(void)

{

    // bool isGameOver = (ns3::Simulator::Now().GetSeconds() > simulationTime + end_delay + 1.0);

    /*

    if (verbose) {

        bool isGameOver = false;

        static float stepCounter = 0.0;

        stepCounter += 1;

        if (stepCounter == 200) {

           isGameOver = true;

        }

        NS_LOG_UNCOND("MyGetGameOver: " << isGameOver);

        return isGameOver;

    }

    */

    return false;

}

 

void ScheduleNextStateRead(double envStepTime, Ptr<OpenGymInterface> openGymInterface)

{

    // if(ns3::Simulator::Now().GetSeconds()<simulationTime + end_delay + 1.0)

    // {

    Simulator::Schedule(Seconds(envStepTime), &ScheduleNextStateRead, envStepTime, openGymInterface);

    // }

    openGymInterface->NotifyCurrentState();

}

 

void recordHistory()

{

    static uint32_t last_rx = 0;

    static uint32_t last_tx = 0;

    static uint32_t calls = 0;

    calls++;

    g_rxPktNum = my_rxPktNum;

    //g_rxPktNum = sinkApp->GetTotalRxPkt();

    float received = g_rxPktNum - last_rx;

    float sent = g_txPktNum - last_tx;

    float errs = sent - received;

    float ratio;

 

    ratio = errs / sent;

    history.push_front(ratio);

 

    if (history.size() > history_length)

    {

        history.pop_back();

    }

    last_rx = g_rxPktNum;

    last_tx = g_txPktNum;

}

 

void packetReceived(Ptr<const Packet> packet)

{

    //std::cout << "packetReceived() is called, pktsize=" << packet->GetSize() << "bytes" << std::endl;

    my_rxPktNum++;

}

 

 

void packetSent(Ptr<const Packet> packet)

{  

    //std::cout << "packetSent() is called, pktsize=" << packet->GetSize() << "bytes" << std::endl;

    g_txPktNum++;

}

 

void signalHandler(int signum)

{

    cout << "Interrupt signal " << signum << " received.\n";

    exit(signum);

}

 

int

main(int argc, char *argv[])

{

  uint32_t nSta = 1;

  uint32_t cwmin = 15;

  uint32_t cwmax = 1023;

  uint32_t openGymPort = 5555;

 

  double txStartTime = 0.1;

 

  int mcs = 6;

  int channelWidth = 20;

  int guardInterval = 800;

 

  signal(SIGTERM, signalHandler);

 

  CommandLine cmd;

  cmd.AddValue("openGymPort", "Specify port number. Default: 5555", openGymPort);

  cmd.AddValue("CW", "Value of Contention Window", CW);

  cmd.AddValue("historyLength", "Length of history window", history_length);

  cmd.AddValue("verbose", "Tell echo applications to log if true", verbose);

  cmd.AddValue("dryRun", "Execute scenario with BEB and no agent interaction", dry_run);

  cmd.AddValue("simTime", "Simulation time in seconds. Default: 10s", simulationTime);

  cmd.AddValue("envStepTime", "Step time in seconds. Default: 0.1s", envStepTime);

  cmd.AddValue ("nSta", "Number of wifi STA devices", nSta);

  cmd.AddValue ("cwmin", "Minimum contention window size", cwmin);

  cmd.AddValue ("cwmax", "Maximum contention window size", cwmax);

  cmd.AddValue ("agentType", "Agent Type", type);

  cmd.Parse (argc, argv);

 

  Config::SetDefault ("ns3::WifiRemoteStationManager::FragmentationThreshold", StringValue ("2200"));

  Config::SetDefault ("ns3::WifiRemoteStationManager::RtsCtsThreshold", StringValue ("2200"));

 

  NS_LOG_UNCOND("Ns3Env parameters:");

  NS_LOG_UNCOND("--nSta: " << nSta);

  NS_LOG_UNCOND("--simulationTime: " << simulationTime);

  NS_LOG_UNCOND("--openGymPort: " << openGymPort);

  NS_LOG_UNCOND("--envStepTime: " << envStepTime);

  NS_LOG_UNCOND("--agentType: " << type);

  NS_LOG_UNCOND("--dryRun: " << dry_run);

  NS_LOG_UNCOND("--verbose: " << verbose);

 

  WifiMacHelper wifiMac;

  WifiHelper wifiHelper;

  wifiHelper.SetStandard (WIFI_PHY_STANDARD_80211ax_5GHZ);

  std::ostringstream oss;

  oss << "HeMcs" << mcs;

  wifiHelper.SetRemoteStationManager("ns3::ConstantRateWifiManager", "DataMode", StringValue(oss.str()), "ControlMode", StringValue(oss.str()));

 

  Ptr<MatrixPropagationLossModel> lossModel = CreateObject<MatrixPropagationLossModel>();

  lossModel->SetDefaultLoss(50);

  YansWifiChannelHelper channel = YansWifiChannelHelper::Default ();

  Ptr<YansWifiChannel> chan = channel.Create();

  chan->SetPropagationLossModel(lossModel);

  chan->SetPropagationDelayModel(CreateObject<ConstantSpeedPropagationDelayModel>());

  YansWifiPhyHelper wifiPhy;

  wifiPhy = YansWifiPhyHelper::Default();

  wifiPhy.SetChannel(chan);

  wifiPhy.Set("GuardInterval", TimeValue(NanoSeconds(guardInterval)));

 

  NodeContainer wifiStaNodes;

  wifiStaNodes.Create (nSta);

  NodeContainer wifiApNode;

  wifiApNode.Create (uint32_t (1));

 

  Ssid ssid = Ssid ("wifi1");

  wifiMac.SetType ("ns3::ApWifiMac",

                   "Ssid", SsidValue (ssid));

 

  NetDeviceContainer apDevice;

  apDevice = wifiHelper.Install (wifiPhy, wifiMac, wifiApNode);

 

  wifiMac.SetType ("ns3::StaWifiMac",

                   "Ssid", SsidValue (ssid));

 

  NetDeviceContainer staDevices;

  staDevices = wifiHelper.Install (wifiPhy, wifiMac, wifiStaNodes);

 

  Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/Phy/ChannelWidth", UintegerValue(channelWidth));

 

  std::cout << "----------------------------" << std::endl;

  if (!dry_run)

  {

       Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MinCw", UintegerValue(CW));

       Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MaxCw", UintegerValue(CW));

  }

  else

  {

       NS_LOG_UNCOND("Default CW");

       Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MinCw", UintegerValue(cwmin));

       Config::Set("/NodeList/*/DeviceList/*/Mac/Txop/MaxCw", UintegerValue(cwmax));

  }

 

  MobilityHelper mobility;

  Ptr<ListPositionAllocator> positionAlloc = CreateObject<ListPositionAllocator> ();

  positionAlloc->Add (Vector (0.0, 0.0, 0.0));

  float rho = 0.5;

  for (uint32_t i = 0; i < nSta; i++)

  {

    double theta = i * 2 * PI / nSta;

    positionAlloc->Add (Vector (rho * cos(theta), rho * sin(theta), 0.0));

    std::cout << "node " << i << " position:" << "(" << rho * cos(theta) << "," << rho * sin(theta) << ",0.0)" << std::endl;

  }

 

  mobility.SetPositionAllocator (positionAlloc);

  mobility.SetMobilityModel ("ns3::ConstantPositionMobilityModel");

  mobility.Install (wifiApNode);

  mobility.Install (wifiStaNodes);

 

  InternetStackHelper stack;

  stack.Install(wifiApNode);

  stack.Install(wifiStaNodes);

 

  Ipv4AddressHelper address;

  address.SetBase("10.1.1.0", "255.255.255.0");

  Ipv4InterfaceContainer ApInterface = address.Assign(apDevice);

  Ipv4InterfaceContainer StaInterface =       address.Assign(staDevices);

 

  for (uint32_t i = 0; i < nSta; i++)

  {

          OnOffHelper onoff("ns3::UdpSocketFactory", Address(InetSocketAddress(ApInterface.GetAddress(0), 9)));

          onoff.SetConstantRate(DataRate ("100000kb/s"), 1500 - 20 - 8 - 8);

          ApplicationContainer temp = onoff.Install(wifiStaNodes.Get (i));

          temp.Start(Seconds(txStartTime));

          temp.Stop(Seconds(simulationTime));

  }

 

  PacketSinkHelper sink ("ns3::UdpSocketFactory", Address(InetSocketAddress(ApInterface.GetAddress(0), 9)));

  ApplicationContainer Serverapp = sink.Install(wifiApNode.Get (0));

  Serverapp.Start(Seconds (0.0));

 

  //Config::ConnectWithoutContext("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/Phy/PhyTxBegin", MakeCallback(&packetSent));

  Config::ConnectWithoutContext("/NodeList/*/DeviceList/*/Mac/MacTx", MakeCallback(&packetSent));

  Config::ConnectWithoutContext ("/NodeList/*/DeviceList/*/Mac/MacRx", MakeCallback (&packetReceived));

 

  Ipv4GlobalRoutingHelper::PopulateRoutingTables ();

 

  sinkApp = DynamicCast<PacketSink> (Serverapp.Get (0));

  FlowMonitorHelper flowmon;

  Ptr<FlowMonitor> monitor = flowmon.InstallAll ();

 

  Ptr<OpenGymInterface> openGymInterface = CreateObject<OpenGymInterface>(openGymPort);

  openGymInterface->SetGetActionSpaceCb(MakeCallback(&MyGetActionSpace));

  openGymInterface->SetGetObservationSpaceCb(MakeCallback(&MyGetObservationSpace));

  openGymInterface->SetGetGameOverCb(MakeCallback(&MyGetGameOver));

  openGymInterface->SetGetObservationCb(MakeCallback(&MyGetObservation));

  openGymInterface->SetGetRewardCb(MakeCallback(&MyGetReward));

  openGymInterface->SetGetExtraInfoCb(MakeCallback(&MyGetExtraInfo));

  openGymInterface->SetExecuteActionsCb(MakeCallback(&MyExecuteActions));

  Simulator::Schedule(Seconds(1.0), &ScheduleNextStateRead, envStepTime, openGymInterface);

 

  Simulator::Stop(Seconds(simulationTime + 1.0 + envStepTime*(history_length+1)));

  Simulator::Run();

 

  Ptr<Ipv4FlowClassifier> classifier = DynamicCast<Ipv4FlowClassifier>(flowmon.GetClassifier());

  std::map<FlowId, FlowMonitor::FlowStats> stats = monitor->GetFlowStats();

  double lastRxTime = 0;

  double firstRxTime = simulationTime + 10;;

  double flowThr;

  double timediff;

  uint32_t totalRx =0;

  uint32_t totalTx =0;

  uint32_t totalRxBytes =0;

 

  for(std::map<FlowId, FlowMonitor::FlowStats>::const_iterator set = stats.begin(); set != stats.end(); set++)

  {

    if(lastRxTime < set->second.timeLastRxPacket.GetSeconds())

    {

            lastRxTime = set->second.timeLastRxPacket.GetSeconds();

    }

    if(firstRxTime > set->second.timeFirstRxPacket.GetSeconds())

    {

            firstRxTime = set->second.timeFirstRxPacket.GetSeconds();

    }

 

    totalRx +=  set->second.rxPackets;

    totalTx +=  set->second.txPackets;

    totalRxBytes += set->second.rxBytes;

 

    Ipv4FlowClassifier::FiveTuple t = classifier->FindFlow(set->first);

    timediff = set->second.timeLastRxPacket.GetSeconds() - set->second.timeFirstRxPacket.GetSeconds();

    flowThr = set->second.rxBytes * 8.0 / timediff / 1000 / 1000;

    std::cout << "Flow " << set->first << " (" << t.sourceAddress << " -> " << t.destinationAddress << ")\tThroughput: " << flowThr << " Mbps\tTime: " << set->second.timeLastRxPacket.GetSeconds() - set->second.timeFirstRxPacket.GetSeconds() << " s\tRx packets " << set->second.rxPackets << std::endl;

    //std::cout << "packetsDropped:" << set->second.packetsDropped.size() << std::endl;

  }

 

  std::cout << "totalTx:" << totalTx << " totalRx:" << totalRx << std::endl;

  std::cout << "sinkApp->GetTotalRxPkt()=" << sinkApp->GetTotalRxPkt() << std::endl;

  std::cout << "g_txPktNum=" << g_txPktNum << std::endl;

  std::cout << "my_rxPktNum=" << my_rxPktNum << std::endl;

  //std::cout << "totalRxBytes=" << totalRxBytes << std::endl;

  //std::cout << "sinkApp->GetTotalRx()=" << sinkApp->GetTotalRx() << std::endl;

 

  double totalBytes = sinkApp->GetTotalRx();

  float throughput = totalBytes * 8.0/1000/1000/(lastRxTime - firstRxTime);

  std::cout << "throughput:\t" << throughput << " Mbps" << std::endl;

  //std::cout << "cwmin: " << cwmin << ", cwmax: " << cwmax << ", nSta: " << nSta << std::endl;

  //std::cout << "firstRxTime: " << firstRxTime << "sec,\t lastRxTime: " << lastRxTime << "sec" << std::endl;

 

  openGymInterface->NotifySimulationEnd();

  Simulator::Destroy ();

}

 

test_dqn.py

#!/usr/bin/env python3

# -*- coding: utf-8 -*-

 

import argparse

from ns3gym import ns3env

 

import numpy as np

import random

import os

import parl

from parl import layers 

import copy

import paddle.fluid as fluid

import collections

 

MEMORY_SIZE = 20000 

MEMORY_WARMUP_SIZE = 100 

BATCH_SIZE = 32

LEARNING_RATE = 0.001

GAMMA = 0.9

 

class Model(parl.Model):

    def __init__(self, act_dim):

        hid1_size = 128

        hid2_size = 128

        self.fc1 = layers.fc(size=hid1_size, act='relu')

        self.fc2 = layers.fc(size=hid2_size, act='relu')

        self.fc3 = layers.fc(size=act_dim, act=None)

 

    def value(self, obs):

        h1 = self.fc1(obs)

        h2 = self.fc2(h1)

        Q = self.fc3(h2)

        return Q

 

class DQN(parl.Algorithm):

    def __init__(self, model, act_dim=None, gamma=None, lr=None):

        self.model = model

        self.target_model = copy.deepcopy(model)

 

        assert isinstance(act_dim, int)

        assert isinstance(gamma, float)

        assert isinstance(lr, float)

        self.act_dim = act_dim

        self.gamma = gamma

        self.lr = lr

 

    def predict(self, obs):

        return self.model.value(obs)

 

    def learn(self, obs, action, reward, next_obs, terminal):

        next_pred_value = self.target_model.value(next_obs)

        best_v = layers.reduce_max(next_pred_value, dim=1)

        best_v.stop_gradient = True 

        terminal = layers.cast(terminal, dtype='float32')

        target = reward + (1.0 - terminal) * self.gamma * best_v

 

        pred_value = self.model.value(obs)

        action_onehot = layers.one_hot(action, self.act_dim)

        action_onehot = layers.cast(action_onehot, dtype='float32')

        pred_action_value = layers.reduce_sum(

            layers.elementwise_mul(action_onehot, pred_value), dim=1)

 

        cost = layers.square_error_cost(pred_action_value, target)

        cost = layers.reduce_mean(cost)

        optimizer = fluid.optimizer.Adam(learning_rate=self.lr)

        optimizer.minimize(cost)

        return cost

 

    def sync_target(self):

        self.model.sync_weights_to(self.target_model)

 

class Agent(parl.Agent):

    def __init__(self,

                 algorithm,

                 obs_dim,

                 act_dim,

                 e_greed=0.1,

                 e_greed_decrement=0):

        assert isinstance(obs_dim, int)

        assert isinstance(act_dim, int)

        self.obs_dim = obs_dim

        self.act_dim = act_dim

        super(Agent, self).__init__(algorithm)

 

        self.global_step = 0

        self.update_target_steps = 200 

 

        self.e_greed = e_greed

        self.e_greed_decrement = e_greed_decrement

 

    def build_program(self):

        self.pred_program = fluid.Program()

        self.learn_program = fluid.Program()

 

        with fluid.program_guard(self.pred_program):

            obs = layers.data(

                name='obs', shape=[self.obs_dim], dtype='float32')

            self.value = self.alg.predict(obs)

 

        with fluid.program_guard(self.learn_program):

            obs = layers.data(

                name='obs', shape=[self.obs_dim], dtype='float32')

            action = layers.data(name='act', shape=[1], dtype='int32')

            reward = layers.data(name='reward', shape=[], dtype='float32')

            next_obs = layers.data(

                name='next_obs', shape=[self.obs_dim], dtype='float32')

            terminal = layers.data(name='terminal', shape=[], dtype='bool')

            self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)

 

    def sample(self, obs):

        sample = np.random.rand() 

        if sample < self.e_greed:

            act = np.random.randint(self.act_dim) 

        else:

            act = self.predict(obs)

        self.e_greed = max(

            0.01, self.e_greed - self.e_greed_decrement)

        return act

 

    def predict(self, obs):

        obs = np.expand_dims(obs, axis=0)

        pred_Q = self.fluid_executor.run(

            self.pred_program,

            feed={'obs': obs.astype('float32')},

            fetch_list=[self.value])[0]

        pred_Q = np.squeeze(pred_Q, axis=0)

        act = np.argmax(pred_Q) 

        return act

 

    def learn(self, obs, act, reward, next_obs, terminal):

        if self.global_step % self.update_target_steps == 0:

            self.alg.sync_target()

        self.global_step += 1

 

        act = np.expand_dims(act, -1)

        feed = {

            'obs': obs.astype('float32'),

            'act': act.astype('int32'),

            'reward': reward,

            'next_obs': next_obs.astype('float32'),

            'terminal': terminal

        }

        cost = self.fluid_executor.run(

            self.learn_program, feed=feed, fetch_list=[self.cost])[0]

        return cost

 

class ReplayMemory(object):

    def __init__(self, max_size):

        self.buffer = collections.deque(maxlen=max_size)

 

    def append(self, exp):

        self.buffer.append(exp)

 

    def sample(self, batch_size):

        mini_batch = random.sample(self.buffer, batch_size)

        obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []

 

        for experience in mini_batch:

            s, a, r, s_p, done = experience

            obs_batch.append(s)

            action_batch.append(a)

            reward_batch.append(r)

            next_obs_batch.append(s_p)

            done_batch.append(done)

 

        return np.array(obs_batch).astype('float32'), \

            np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\

            np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32')

 

    def __len__(self):

        return len(self.buffer)

 

 

port = 5555

seed = 1

env = ns3env.Ns3Env(port=port, simSeed=seed)

env.reset()

 

stepIdx = 0

obs_dim=20

act_dim=7

 

model = Model(act_dim=act_dim)

algorithm = DQN(model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE)

agent = Agent(

        algorithm,

        obs_dim=obs_dim,

        act_dim=act_dim,

        e_greed=0.1, 

        e_greed_decrement=1e-6)

 

rpm = ReplayMemory(MEMORY_SIZE)

 

try:

        obs = env.reset()

        print("Step: ", stepIdx)

        print("---obs:", obs)

    

        while True:

            while len(rpm) < MEMORY_WARMUP_SIZE:

              stepIdx += 1

              action = agent.sample(obs)  

              action2 = np.array([action])                    

              print("---action: ", action)

              next_obs, reward, done, info = env.step(action2)

              print("---obs, reward, done, info: ", next_obs, reward, done, info)

              rpm.append((obs, action, reward, next_obs, done))

              obs=next_obs

 

              print("Step: ", stepIdx)

 

            if stepIdx%5==0:

              print("="*20,"agent learn","="*20)

              (batch_obs, batch_action, batch_reward, batch_next_obs, batch_done) = rpm.sample(BATCH_SIZE)

              train_loss = agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs, batch_done)

 

            action = agent.sample(obs)  

            action2 = np.array([action])                    

            print("---action: ", action)

 

            next_obs, reward, done, info = env.step(action2)

            print("---obs, reward, done, info: ", next_obs, reward, done, info)

            rpm.append((obs, action, reward, next_obs, done))

            obs=next_obs

            if done:

              print("done")

              break

 

            stepIdx += 1

            print("Step: ", stepIdx)

   

except KeyboardInterrupt:

    print("Ctrl-C -> Exit")

finally:

    env.close()

    print("Done")

 

[Executions] (My Test OS environment: ubuntu18.04)

(for CSMACA, 30 nodes, simulation time:10 sec)

 

Open another terminal

 

(Wait)

In the first terminal, you can see that the throughput for CSMACA is 39.3 Mbps.

 

For CCOD-DQN

 

Open another terminal

 

(Wait)

You can find the throughput for CCOD-DQN is 51.5933Mbps (better than CSMACA)

 

If you are interested in improving 802.11 throughput via reinforcement learning, you can also refer to my work at https://csie.nqu.edu.tw/smallko/setl-rl.htm

 

Back to NS3 Learning Guide

Last Modified: 2022/2/27 done

 

[Author]

Dr. Chih-Heng Ke

Department of Computer Science and Information Engineering, National Quemoy University, Kinmen, Taiwan

Email: smallko@gmail.com