ps-lite代码笔记

Oct 15, 2016


1. Overview

本文讨论的ps-lite版本

原始项目地址

本文引用的代码会删掉一些非主干代码

首先看ps-lite的几个重要基础类型

类型 解释
Node 一个物理进程就是一个Node,分为三类(Server,Worker,Scheduler),每个Node可以使用hostname+port标识
Customer 这个类的取名很随意,用于追踪Response,但是其本身不接管网络,因此和其他类的耦合很深
Van 一个Node维护一个Van object,负责与其他节点的网络通信
Postoffice 维护一个Node需要的各种杂七杂八的信息,全局信息管理类

2. Node

Node分为三类,如图

  • Server: 维护模型的部分参数
  • Worker: 从Server处pull参数,计算梯度并push到Server
  • Scheduler: 中心调度器,控制其他Node

3. Customer

用于追踪每个Request对应的Response情况,同时又能处理Message 但是其本身并没有接管网络,因此实际的Response和Message需要外部调用者告诉它

首先看它的主要内部变量

RecvHandle recv_handle_; // 处理message的函数,recv_handle_(message)
ThreadsafeQueue<Message> recv_queue_; // 线程安全的队列
std::unique_ptr<std::thread> recv_thread_; // 不断从recv_queue读取message并调用recv_handle_

// 对于tracker_
//     index表示Request编号
//     Pair的first表示应收到的Response数量
//     Pair的second表示目前为止实际收到的Response数量
std::mutex tracker_mu_;
std::condition_variable tracker_cond_;
std::vector<std::pair<int, int>> tracker_;

当我们需要给一个Resquest计数的时候,使用

// recver表示接收者的node_id,ps-lite中一个整数可能对应于多个node_id,因此使用Postoffice解码获得所有的真实node_id
int Customer::NewRequest(int recver) {
  std::lock_guard<std::mutex> lk(tracker_mu_);
  int num = Postoffice::Get()->GetNodeIDs(recver).size();
  tracker_.push_back(std::make_pair(num, 0));
  return tracker_.size() - 1; // 后续customer使用这个值代表这个request
}

当我们需要等待某个发出去的Request对应的Response全部收到时,使用

void Customer::WaitRequest(int timestamp) {
  std::unique_lock<std::mutex> lk(tracker_mu_);
  tracker_cond_.wait(lk, [this, timestamp]{
      return tracker_[timestamp].first == tracker_[timestamp].second;
    });
}

这个类有个缺陷,对于过期的以后不会再用到的Request信息,没有删除操作。

而这个类的单个对象的生存周期又近乎等于进程的生存周期。

因此,个人推测,基于ps-lite程序跑的时间久了都会OOM。

当外部调用者收到Response时,调用AddResponse告诉Customer对象

void Customer::AddResponse(int timestamp, int num) {
  std::lock_guard<std::mutex> lk(tracker_mu_);
  tracker_[timestamp].second += num;
}

4. Van

ps-lite的通信类,负责与其他Node通信,每个Node仅有一个Van对象

首先来看一下Van对象的初始化过程

首先从环境变量中得知这个Node的职责(Worker/Server/Scheduler),然后Bind一个端口,并建立到Scheduler的连接

// get scheduler info
scheduler_.hostname = std::string(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_URI")));
scheduler_.port     = atoi(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_PORT")));
scheduler_.role     = Node::SCHEDULER;
scheduler_.id       = kScheduler;
is_scheduler_       = Postoffice::Get()->is_scheduler();
// get my node info
if (is_scheduler_) {
  my_node_ = scheduler_;
} else {
  auto role = is_scheduler_ ? Node::SCHEDULER :
              (Postoffice::Get()->is_worker() ? Node::WORKER : Node::SERVER);
  const char* nhost = Environment::Get()->find("DMLC_NODE_HOST");
  std::string ip;
  if (nhost) ip = std::string(nhost);
  if (ip.empty()) {
    const char*  itf = Environment::Get()->find("DMLC_INTERFACE");
    std::string interface;
    if (itf) interface = std::string(itf);
    if (interface.size()) {
      GetIP(interface, &ip);
    } else {
      GetAvailableInterfaceAndIP(&interface, &ip);
    }
  }
  int port = GetAvailablePort();
  const char* pstr = Environment::Get()->find("PORT");
  if (pstr) port = atoi(pstr);
  my_node_.hostname = ip;
  my_node_.role     = role;
  my_node_.port     = port;
  // cannot determine my id now, the scheduler will assign it later
  // set it explicitly to make re-register within a same process possible
  my_node_.id = Node::kEmpty;
}
// bind.
my_node_.port = Bind(my_node_, is_scheduler_ ? 0 : 40);
// connect to the scheduler
Connect(scheduler_);

建立到Scheduler的连接后,启动本地Node的接收线程

并将本地Node的信息告知Scheduler

然后等待Scheduler通知Ready

Ready后建立到Scheduler的Heartbeat

// start receiver
receiver_thread_ = std::unique_ptr<std::thread>(
    new std::thread(&Van::Receiving, this));
    
if (!is_scheduler_) {
  // let the scheduler know myself
  Message msg;
  msg.meta.recver = kScheduler;
  msg.meta.control.cmd = Control::ADD_NODE;
  msg.meta.control.node.push_back(my_node_);
  msg.meta.timestamp = timestamp_++;
  Send(msg);
}
// wait until ready
while (!ready_) {
  std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
if (!is_scheduler_) {
  // start heartbeat thread
  heartbeat_thread_ = std::unique_ptr<std::thread>(
  new std::thread(&Van::Heartbeat, this));
}

Van对象自此完成初始化

而本地Node的全局rank等信息是由receiver_thread_获取

receiver_thread_的主要结构如下

while (true) {
  Message msg;
  int recv_bytes = RecvMsg(&msg);
  recv_bytes_ += recv_bytes;
  auto& ctrl = msg.meta.control;
  if (ctrl.cmd == Control::ADD_NODE) {
    // do something
  } else if (ctrl.cmd == Control::BARRIER) {
    // do something
  } else if (ctrl.cmd == Control::HEARTBEAT) {
    // 发回Heartbeat的ACK
  } else {
    // do something
  }
}

BARRIER的部分在Postoffice中再看,Van中实现了结束Barrier状态的部分代码

这里主要看下ADD_NODE的实现

首先,如果这个message的发送方的id是未设定的值,那么处理此message的一定是Scheduler,进入如下分支,Scheduler记录这个新的node,如果这个node是重启产生的,则将旧node的信息更新。

if (msg.meta.sender == Meta::kEmpty) {
  CHECK(is_scheduler_);
  CHECK_EQ(ctrl.node.size(), 1);
  if (nodes.control.node.size() < num_nodes) {
    nodes.control.node.push_back(ctrl.node[0]);
  } else {
    // some node dies and restarts
    CHECK(ready_);
    for (size_t i = 0; i < nodes.control.node.size() - 1; ++i) {
      const auto& node = nodes.control.node[i];
      if (dead_set.find(node.id) != dead_set.end() && node.role == ctrl.node[0].role) {
        auto& recovery_node = ctrl.node[0];
        // assign previous node id
        recovery_node.id = node.id;
        recovery_node.is_recovery = true;
        nodes.control.node[i] = recovery_node;
        recovery_nodes.control.node.push_back(recovery_node);
        break;
      }
    }
  }
}

对普通的node,更新其rank

// update my id
for (size_t i = 0; i < ctrl.node.size(); ++i) {
  const auto& node = ctrl.node[i];
  if (my_node_.hostname == node.hostname &&
      my_node_.port == node.port) {
    my_node_ = node;
    std::string rank = std::to_string(Postoffice::IDtoRank(node.id));
    setenv("DMLC_RANK", rank.c_str(), true);
  }
}

最后,对于Scheduler节点来说,其需要设定最新的所有node的rank并发送给所有Worker和Server

if (is_scheduler_) {
  time_t t = time(NULL);
  if (nodes.control.node.size() == num_nodes) {
    // sort the nodes according their ip and port,
    std::sort(nodes.control.node.begin(), nodes.control.node.end(),
              [](const Node& a, const Node& b) {
                return (a.hostname.compare(b.hostname) | (a.port < b.port)) > 0;
              });
    // assign node rank
    for (auto& node : nodes.control.node) {
      CHECK_EQ(node.id, Node::kEmpty);
      int id = node.role == Node::SERVER ?
               Postoffice::ServerRankToID(num_servers_) :
               Postoffice::WorkerRankToID(num_workers_);
      node.id = id;
      Connect(node);
      if (node.role == Node::SERVER) ++num_servers_;
      if (node.role == Node::WORKER) ++num_workers_;
      Postoffice::Get()->UpdateHeartbeat(node.id, t);
    }
    nodes.control.node.push_back(my_node_);
    nodes.control.cmd = Control::ADD_NODE;
    Message back; back.meta = nodes;
    for (int r : Postoffice::Get()->GetNodeIDs(
             kWorkerGroup + kServerGroup)) {
      back.meta.recver = r;
      back.meta.timestamp = timestamp_++;
      Send(back);
    }
    ready_ = true;
  } else if (recovery_nodes.control.node.size() > 0) {
    // send back the recovery node
    CHECK_EQ(recovery_nodes.control.node.size(), 1);
    Connect(recovery_nodes.control.node[0]);
    Postoffice::Get()->UpdateHeartbeat(recovery_nodes.control.node[0].id, t);
    Message back;
    for (int r : Postoffice::Get()->GetNodeIDs(
             kWorkerGroup + kServerGroup)) {
      if (r != recovery_nodes.control.node[0].id
            && dead_set.find(r) != dead_set.end()) {
        // do not try to send anything to dead node
        continue;
      }
      // only send recovery_node to nodes already exist
      // but send all nodes to the recovery_node
      back.meta = (r == recovery_nodes.control.node[0].id) ? nodes : recovery_nodes;
      back.meta.recver = r;
      back.meta.timestamp = timestamp_++;
      Send(back);
    }
  }
} else {
  for (const auto& node : ctrl.node) {
    Connect(node);
    if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_;
    if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_;
  }
  ready_ = true;
}

5. Postoffice

Postoffice是一个很杂的类,每个node都维护一个Postoffice对象

这个Postoffice对象主要维护了一个Van对象,并提供了一套ID转换工具,以及其他一些比较杂的功能

这里我们提一下Postoffice的两个功能

一个是key与server的对应关系,另一个是Barrier

由于ps-lite的key只支持int类型

#if USE_KEY32
/*! \brief Use unsigned 32-bit int as the key type */
using Key = uint32_t;
#else
/*! \brief Use unsigned 64-bit int as the key type */
using Key = uint64_t;
#endif
/*! \brief The maximal allowed key value */
static const Key kMaxKey = std::numeric_limits<Key>::max();

将int范围均分即可

const std::vector<Range>& Postoffice::GetServerKeyRanges() {
  if (server_key_ranges_.empty()) {
    for (int i = 0; i < num_servers_; ++i) {
      server_key_ranges_.push_back(Range(
          kMaxKey / num_servers_ * i,
          kMaxKey / num_servers_ * (i+1)));
    }
  }
  return server_key_ranges_;
}

说完key-server的对应关系,我们再看Barrier

开始Barrier的Node会告知Scheduler并进入等待状态

void Postoffice::Barrier(int node_group) {
  auto role = van_->my_node().role;

  std::unique_lock<std::mutex> ulk(barrier_mu_);
  barrier_done_ = false;
  Message req;
  req.meta.recver = kScheduler;
  req.meta.request = true;
  req.meta.control.cmd = Control::BARRIER;
  req.meta.control.barrier_group = node_group;
  req.meta.timestamp = van_->GetTimestamp();
  CHECK_GT(van_->Send(req), 0);

  barrier_cond_.wait(ulk, [this] {
      return barrier_done_;
    });
}

而Scheduler会对Barrier请求进行计数,当收到最后一个请求时,发送结束Barrier的命令

if (msg.meta.request) {
  if (barrier_count_.empty()) {
    barrier_count_.resize(8, 0);
  }
  int group = ctrl.barrier_group;
  ++barrier_count_[group];
  if (barrier_count_[group] ==
      static_cast<int>(Postoffice::Get()->GetNodeIDs(group).size())) {
    barrier_count_[group] = 0;
    Message res;
    res.meta.request = false;
    res.meta.control.cmd = Control::BARRIER;
    for (int r : Postoffice::Get()->GetNodeIDs(group)) {
      res.meta.recver = r;
      res.meta.timestamp = timestamp_++;
      CHECK_GT(Send(res), 0);
    }
  } else {
    Postoffice::Get()->Manage(msg);
  }
}

6. KVApp

在ps-lite中,key和value是分开存储的,每个key可能对应多个value,因此需要记录每个key的长度

template <typename Val>
struct KVPairs {
  // /** \brief empty constructor */
  // KVPairs() {}
  /** \brief the list of keys */
  SArray<Key> keys;
  /** \brief the according values */
  SArray<Val> vals;
  /** \brief the according value lengths (could be empty) */
  SArray<int> lens;
};

Server中维护一个哈希表,记录key和value,并对push和pull请求进行响应

template <typename Val>
struct KVServerDefaultHandle {
  void operator()(
      const KVMeta& req_meta, const KVPairs<Val>& req_data, KVServer<Val>* server) {
    size_t n = req_data.keys.size();
    KVPairs<Val> res;
    if (req_meta.push) {
      CHECK_EQ(n, req_data.vals.size());
    } else {
      res.keys = req_data.keys; res.vals.resize(n);
    }
    for (size_t i = 0; i < n; ++i) {
      Key key = req_data.keys[i];
      if (req_meta.push) {
        store[key] += req_data.vals[i];
      } else {
        res.vals[i] = store[key];
      }
    }
    server->Response(req_meta, res);
  }
  std::unordered_map<Key, Val> store;
};

Worker中的push和pull操作都是异步返回一个ID,使用ID进行wait阻塞等待进行同步操作,或者异步调用时传入一个Callback进行后续操作

7. 声明

本文为个人笔记,均是个人理解,如有错误,还请指正。