1. Overview
本文讨论的ps-lite版本
本文引用的代码会删掉一些非主干代码
首先看ps-lite的几个重要基础类型
类型 | 解释 |
---|---|
Node | 一个物理进程就是一个Node,分为三类(Server,Worker,Scheduler),每个Node可以使用hostname+port标识 |
Customer | 这个类的取名很随意,用于追踪Response,但是其本身不接管网络,因此和其他类的耦合很深 |
Van | 一个Node维护一个Van object,负责与其他节点的网络通信 |
Postoffice | 维护一个Node需要的各种杂七杂八的信息,全局信息管理类 |
2. Node
Node分为三类,如图
- Server: 维护模型的部分参数
- Worker: 从Server处pull参数,计算梯度并push到Server
- Scheduler: 中心调度器,控制其他Node
3. Customer
用于追踪每个Request对应的Response情况,同时又能处理Message 但是其本身并没有接管网络,因此实际的Response和Message需要外部调用者告诉它
首先看它的主要内部变量
RecvHandle recv_handle_; // 处理message的函数,recv_handle_(message)
ThreadsafeQueue<Message> recv_queue_; // 线程安全的队列
std::unique_ptr<std::thread> recv_thread_; // 不断从recv_queue读取message并调用recv_handle_
// 对于tracker_
// index表示Request编号
// Pair的first表示应收到的Response数量
// Pair的second表示目前为止实际收到的Response数量
std::mutex tracker_mu_;
std::condition_variable tracker_cond_;
std::vector<std::pair<int, int>> tracker_;
当我们需要给一个Resquest计数的时候,使用
// recver表示接收者的node_id,ps-lite中一个整数可能对应于多个node_id,因此使用Postoffice解码获得所有的真实node_id
int Customer::NewRequest(int recver) {
std::lock_guard<std::mutex> lk(tracker_mu_);
int num = Postoffice::Get()->GetNodeIDs(recver).size();
tracker_.push_back(std::make_pair(num, 0));
return tracker_.size() - 1; // 后续customer使用这个值代表这个request
}
当我们需要等待某个发出去的Request对应的Response全部收到时,使用
void Customer::WaitRequest(int timestamp) {
std::unique_lock<std::mutex> lk(tracker_mu_);
tracker_cond_.wait(lk, [this, timestamp]{
return tracker_[timestamp].first == tracker_[timestamp].second;
});
}
这个类有个缺陷,对于过期的以后不会再用到的Request信息,没有删除操作。
而这个类的单个对象的生存周期又近乎等于进程的生存周期。
因此,个人推测,基于ps-lite程序跑的时间久了都会OOM。
当外部调用者收到Response时,调用AddResponse告诉Customer对象
void Customer::AddResponse(int timestamp, int num) {
std::lock_guard<std::mutex> lk(tracker_mu_);
tracker_[timestamp].second += num;
}
4. Van
ps-lite的通信类,负责与其他Node通信,每个Node仅有一个Van对象
首先来看一下Van对象的初始化过程
首先从环境变量中得知这个Node的职责(Worker/Server/Scheduler),然后Bind一个端口,并建立到Scheduler的连接
// get scheduler info
scheduler_.hostname = std::string(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_URI")));
scheduler_.port = atoi(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_PORT")));
scheduler_.role = Node::SCHEDULER;
scheduler_.id = kScheduler;
is_scheduler_ = Postoffice::Get()->is_scheduler();
// get my node info
if (is_scheduler_) {
my_node_ = scheduler_;
} else {
auto role = is_scheduler_ ? Node::SCHEDULER :
(Postoffice::Get()->is_worker() ? Node::WORKER : Node::SERVER);
const char* nhost = Environment::Get()->find("DMLC_NODE_HOST");
std::string ip;
if (nhost) ip = std::string(nhost);
if (ip.empty()) {
const char* itf = Environment::Get()->find("DMLC_INTERFACE");
std::string interface;
if (itf) interface = std::string(itf);
if (interface.size()) {
GetIP(interface, &ip);
} else {
GetAvailableInterfaceAndIP(&interface, &ip);
}
}
int port = GetAvailablePort();
const char* pstr = Environment::Get()->find("PORT");
if (pstr) port = atoi(pstr);
my_node_.hostname = ip;
my_node_.role = role;
my_node_.port = port;
// cannot determine my id now, the scheduler will assign it later
// set it explicitly to make re-register within a same process possible
my_node_.id = Node::kEmpty;
}
// bind.
my_node_.port = Bind(my_node_, is_scheduler_ ? 0 : 40);
// connect to the scheduler
Connect(scheduler_);
建立到Scheduler的连接后,启动本地Node的接收线程
并将本地Node的信息告知Scheduler
然后等待Scheduler通知Ready
Ready后建立到Scheduler的Heartbeat
// start receiver
receiver_thread_ = std::unique_ptr<std::thread>(
new std::thread(&Van::Receiving, this));
if (!is_scheduler_) {
// let the scheduler know myself
Message msg;
msg.meta.recver = kScheduler;
msg.meta.control.cmd = Control::ADD_NODE;
msg.meta.control.node.push_back(my_node_);
msg.meta.timestamp = timestamp_++;
Send(msg);
}
// wait until ready
while (!ready_) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
if (!is_scheduler_) {
// start heartbeat thread
heartbeat_thread_ = std::unique_ptr<std::thread>(
new std::thread(&Van::Heartbeat, this));
}
Van对象自此完成初始化
而本地Node的全局rank等信息是由receiver_thread_获取
receiver_thread_的主要结构如下
while (true) {
Message msg;
int recv_bytes = RecvMsg(&msg);
recv_bytes_ += recv_bytes;
auto& ctrl = msg.meta.control;
if (ctrl.cmd == Control::ADD_NODE) {
// do something
} else if (ctrl.cmd == Control::BARRIER) {
// do something
} else if (ctrl.cmd == Control::HEARTBEAT) {
// 发回Heartbeat的ACK
} else {
// do something
}
}
BARRIER的部分在Postoffice中再看,Van中实现了结束Barrier状态的部分代码
这里主要看下ADD_NODE的实现
首先,如果这个message的发送方的id是未设定的值,那么处理此message的一定是Scheduler,进入如下分支,Scheduler记录这个新的node,如果这个node是重启产生的,则将旧node的信息更新。
if (msg.meta.sender == Meta::kEmpty) {
CHECK(is_scheduler_);
CHECK_EQ(ctrl.node.size(), 1);
if (nodes.control.node.size() < num_nodes) {
nodes.control.node.push_back(ctrl.node[0]);
} else {
// some node dies and restarts
CHECK(ready_);
for (size_t i = 0; i < nodes.control.node.size() - 1; ++i) {
const auto& node = nodes.control.node[i];
if (dead_set.find(node.id) != dead_set.end() && node.role == ctrl.node[0].role) {
auto& recovery_node = ctrl.node[0];
// assign previous node id
recovery_node.id = node.id;
recovery_node.is_recovery = true;
nodes.control.node[i] = recovery_node;
recovery_nodes.control.node.push_back(recovery_node);
break;
}
}
}
}
对普通的node,更新其rank
// update my id
for (size_t i = 0; i < ctrl.node.size(); ++i) {
const auto& node = ctrl.node[i];
if (my_node_.hostname == node.hostname &&
my_node_.port == node.port) {
my_node_ = node;
std::string rank = std::to_string(Postoffice::IDtoRank(node.id));
setenv("DMLC_RANK", rank.c_str(), true);
}
}
最后,对于Scheduler节点来说,其需要设定最新的所有node的rank并发送给所有Worker和Server
if (is_scheduler_) {
time_t t = time(NULL);
if (nodes.control.node.size() == num_nodes) {
// sort the nodes according their ip and port,
std::sort(nodes.control.node.begin(), nodes.control.node.end(),
[](const Node& a, const Node& b) {
return (a.hostname.compare(b.hostname) | (a.port < b.port)) > 0;
});
// assign node rank
for (auto& node : nodes.control.node) {
CHECK_EQ(node.id, Node::kEmpty);
int id = node.role == Node::SERVER ?
Postoffice::ServerRankToID(num_servers_) :
Postoffice::WorkerRankToID(num_workers_);
node.id = id;
Connect(node);
if (node.role == Node::SERVER) ++num_servers_;
if (node.role == Node::WORKER) ++num_workers_;
Postoffice::Get()->UpdateHeartbeat(node.id, t);
}
nodes.control.node.push_back(my_node_);
nodes.control.cmd = Control::ADD_NODE;
Message back; back.meta = nodes;
for (int r : Postoffice::Get()->GetNodeIDs(
kWorkerGroup + kServerGroup)) {
back.meta.recver = r;
back.meta.timestamp = timestamp_++;
Send(back);
}
ready_ = true;
} else if (recovery_nodes.control.node.size() > 0) {
// send back the recovery node
CHECK_EQ(recovery_nodes.control.node.size(), 1);
Connect(recovery_nodes.control.node[0]);
Postoffice::Get()->UpdateHeartbeat(recovery_nodes.control.node[0].id, t);
Message back;
for (int r : Postoffice::Get()->GetNodeIDs(
kWorkerGroup + kServerGroup)) {
if (r != recovery_nodes.control.node[0].id
&& dead_set.find(r) != dead_set.end()) {
// do not try to send anything to dead node
continue;
}
// only send recovery_node to nodes already exist
// but send all nodes to the recovery_node
back.meta = (r == recovery_nodes.control.node[0].id) ? nodes : recovery_nodes;
back.meta.recver = r;
back.meta.timestamp = timestamp_++;
Send(back);
}
}
} else {
for (const auto& node : ctrl.node) {
Connect(node);
if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_;
if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_;
}
ready_ = true;
}
5. Postoffice
Postoffice是一个很杂的类,每个node都维护一个Postoffice对象
这个Postoffice对象主要维护了一个Van对象,并提供了一套ID转换工具,以及其他一些比较杂的功能
这里我们提一下Postoffice的两个功能
一个是key与server的对应关系,另一个是Barrier
由于ps-lite的key只支持int类型
#if USE_KEY32
/*! \brief Use unsigned 32-bit int as the key type */
using Key = uint32_t;
#else
/*! \brief Use unsigned 64-bit int as the key type */
using Key = uint64_t;
#endif
/*! \brief The maximal allowed key value */
static const Key kMaxKey = std::numeric_limits<Key>::max();
将int范围均分即可
const std::vector<Range>& Postoffice::GetServerKeyRanges() {
if (server_key_ranges_.empty()) {
for (int i = 0; i < num_servers_; ++i) {
server_key_ranges_.push_back(Range(
kMaxKey / num_servers_ * i,
kMaxKey / num_servers_ * (i+1)));
}
}
return server_key_ranges_;
}
说完key-server的对应关系,我们再看Barrier
开始Barrier的Node会告知Scheduler并进入等待状态
void Postoffice::Barrier(int node_group) {
auto role = van_->my_node().role;
std::unique_lock<std::mutex> ulk(barrier_mu_);
barrier_done_ = false;
Message req;
req.meta.recver = kScheduler;
req.meta.request = true;
req.meta.control.cmd = Control::BARRIER;
req.meta.control.barrier_group = node_group;
req.meta.timestamp = van_->GetTimestamp();
CHECK_GT(van_->Send(req), 0);
barrier_cond_.wait(ulk, [this] {
return barrier_done_;
});
}
而Scheduler会对Barrier请求进行计数,当收到最后一个请求时,发送结束Barrier的命令
if (msg.meta.request) {
if (barrier_count_.empty()) {
barrier_count_.resize(8, 0);
}
int group = ctrl.barrier_group;
++barrier_count_[group];
if (barrier_count_[group] ==
static_cast<int>(Postoffice::Get()->GetNodeIDs(group).size())) {
barrier_count_[group] = 0;
Message res;
res.meta.request = false;
res.meta.control.cmd = Control::BARRIER;
for (int r : Postoffice::Get()->GetNodeIDs(group)) {
res.meta.recver = r;
res.meta.timestamp = timestamp_++;
CHECK_GT(Send(res), 0);
}
} else {
Postoffice::Get()->Manage(msg);
}
}
6. KVApp
在ps-lite中,key和value是分开存储的,每个key可能对应多个value,因此需要记录每个key的长度
template <typename Val>
struct KVPairs {
// /** \brief empty constructor */
// KVPairs() {}
/** \brief the list of keys */
SArray<Key> keys;
/** \brief the according values */
SArray<Val> vals;
/** \brief the according value lengths (could be empty) */
SArray<int> lens;
};
Server中维护一个哈希表,记录key和value,并对push和pull请求进行响应
template <typename Val>
struct KVServerDefaultHandle {
void operator()(
const KVMeta& req_meta, const KVPairs<Val>& req_data, KVServer<Val>* server) {
size_t n = req_data.keys.size();
KVPairs<Val> res;
if (req_meta.push) {
CHECK_EQ(n, req_data.vals.size());
} else {
res.keys = req_data.keys; res.vals.resize(n);
}
for (size_t i = 0; i < n; ++i) {
Key key = req_data.keys[i];
if (req_meta.push) {
store[key] += req_data.vals[i];
} else {
res.vals[i] = store[key];
}
}
server->Response(req_meta, res);
}
std::unordered_map<Key, Val> store;
};
Worker中的push和pull操作都是异步返回一个ID,使用ID进行wait阻塞等待进行同步操作,或者异步调用时传入一个Callback进行后续操作
7. 声明
本文为个人笔记,均是个人理解,如有错误,还请指正。