# Copyright 2020 Aptpod, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import warnings
from typing import Optional
from urllib import parse
from debtcollector import removals
from tornado import gen, httpclient, locks, queues, websocket
from intdash import _client, _models, _protocol, _utils
class StreamAlreadyOpenException(Exception):
pass
class ResultNGException(Exception):
pass
__all__ = ["WebSocketConn"]
[ドキュメント]class WebSocketConn(object):
"""WebSocket 接続を表すオブジェクトです。"""
def _init(
self,
client,
flush_interval,
auto_reconnect,
token_source: Optional[_client.TokenSource] = None,
):
self._client = client
httpurl = parse.urlparse(client.url)
scheme = "wss" if httpurl.scheme == "https" else "ws"
self._wsurl = parse.urlunparse(
[
scheme,
httpurl.netloc,
"/api/v1/ws/measurements",
httpurl.params,
parse.urlencode({}),
httpurl.fragment,
]
)
self._flush_interval = flush_interval
self._next_stream_id = cyclic_counter256(initial=1)
self._next_req_id = cyclic_counter256(initial=0)
self._auto_reconnect = auto_reconnect
self._disconnected = locks.Event()
self._reconnecting = locks.Event()
self._quit = locks.Event()
self._rx_queues = {}
self._tx_queue = queues.Queue()
self._req_queues = {}
self._downstreams = None
self._upstreams = None
self._writer = None
self._conn = None
if token_source is not None and not isinstance(
token_source, _client.TokenSource
):
raise TypeError("Must be TokenSource class")
self._token_source = token_source
self._connect()
@gen.coroutine
def _connect(self):
while True:
headers = {
"Content-Type": "application/json; charset=utf-8",
"User-Agent": _client.USER_AGENT,
}
try:
if self._client.edge_token is not None:
headers["X-Edge-Token"] = self._client.edge_token
else:
if self._token_source is not None:
tk = self._token_source.token()
else:
self._client._auth()
tk = self._client.jwt
headers["Authorization"] = "Bearer " + tk
req = httpclient.HTTPRequest(
url=self._wsurl, headers=headers, user_agent=_client.USER_AGENT
)
self._conn = yield websocket.websocket_connect(
url=req, ping_interval=10
)
self._reconnecting.clear()
break
except Exception as e:
warnings.warn(Warning(e))
yield gen.sleep(1.0)
self._tx_routine()
self._rx_routine()
self._flush_routine()
self._check_connection()
@gen.coroutine
def _check_connection(self):
if not self._auto_reconnect:
while not self._quit.is_set():
if self._reconnecting.is_set():
self.close()
else:
yield gen.sleep(1.0)
continue
else:
while not self._quit.is_set():
if self._reconnecting.is_set():
self._quit.set()
self._tx_queue.join()
for q in self._rx_queues.values():
q.join()
for q in self._req_queues.values():
q.join()
if self._upstreams is not None:
self._upstreams._stop()
if self._downstreams is not None:
self._downstreams._stop()
self._quit.clear()
self._connect()
while self._reconnecting.is_set():
yield gen.sleep(0.1)
if self._upstreams is not None:
self._upstreams._open()
if self._downstreams is not None:
self._downstreams._open()
break
else:
yield gen.sleep(1.0)
continue
@gen.coroutine
def _flush_routine(self):
bio = io.BytesIO()
bwr = io.BufferedWriter(bio)
self._writer = _protocol.Writer(bwr)
while not self._quit.is_set():
yield gen.sleep(self._flush_interval)
if len(bio.getvalue()) == 0:
yield gen.sleep(0.1)
continue
if self._conn is None:
yield gen.sleep(0.1)
continue
try:
msg = bio.getvalue()
yield self._conn.write_message(message=msg, binary=True)
except websocket.WebSocketClosedError as e:
self._reconnecting.set()
break
except AttributeError:
break
bio = io.BytesIO()
bwr = io.BufferedWriter(bio)
self._writer = _protocol.Writer(bwr)
@gen.coroutine
def _tx_routine(self):
while not self._quit.is_set():
if self._writer is None:
yield gen.sleep(0.1)
continue
unit = yield self._tx_queue.get()
self._writer.write_elem(unit)
self._tx_queue.task_done()
@gen.coroutine
def _rx_routine(self):
while not self._quit.is_set():
if self._conn is None:
yield gen.sleep(0.1)
continue
data = yield self._conn.read_message()
if data is None:
self._reconnecting.set()
break
bio = io.BytesIO(data)
brd = io.BufferedReader(bio)
rd = _protocol.Reader(brd)
while not self._quit.is_set():
try:
e = rd.read_elem()
if isinstance(e, _protocol.StreamElement):
yield self._rx_queues[e.stream_id].put(e)
elif isinstance(e, _protocol.RequestElement):
yield self._req_queues[e.req_id].put(e)
except EOFError:
break
except KeyError:
continue
[ドキュメント] def close(self):
"""WebSocket接続を閉じます。"""
self._quit.set()
if self._upstreams is not None:
self._upstreams._stop()
if self._downstreams is not None:
self._downstreams._stop()
self._tx_queue.join()
for q in self._rx_queues.values():
q.join()
for q in self._req_queues.values():
q.join()
self._conn.close()
[ドキュメント] def open_downstreams(self, specs, callbacks):
"""指定したダウンストリームスペックに従ってダウンストリームを開きます。
Args:
specs (list[DownstreamSpec]): ダウンストリームスペックのリスト
callbacks (list[func]): 受信したUnitを処理する際に呼ばれるコールバック関数
"""
if self._downstreams is not None:
raise StreamAlreadyOpenException()
self._downstreams = Downstreams(conn=self, specs=specs, callbacks=callbacks)
[ドキュメント] def open_upstreams(self, specs, iterators, marker_interval=3):
"""指定したアップストリームスペックに従ってアップストリームを開きます。
Args:
specs (list[UpstreamSpec]): アップストリームスペックのリスト
iterators (list[iter]): 送信するUnitを生成する際に呼ばれるイテレータ
marker_interval (int): 秒単位のマーカー間隔
"""
if self._upstreams is not None:
raise StreamAlreadyOpenException()
self._upstreams = Upstreams(
conn=self, specs=specs, iterators=iterators, marker_interval=marker_interval
)
class Downstreams(object):
def __init__(self, conn, specs, callbacks):
self.conn = conn
self.quit = locks.Event()
# NOTE: asign stream_id using self.conn._stream_id method
self.specs = {}
self.callbacks = {}
for spec, callback in zip(specs, callbacks):
stream_id = self.conn._next_stream_id()
self.specs[stream_id] = spec
self.callbacks[stream_id] = callback
self._open()
def _stop(self):
self.quit.set()
for stream_id, spec in self.specs.items():
self.conn._rx_queues[stream_id].join()
del self.conn._rx_queues[stream_id]
@gen.coroutine
def _open(self):
self.quit.clear()
for stream_id, spec in self.specs.items():
self.conn._rx_queues[stream_id] = queues.Queue()
reqs = _utils._create_req_downstream(specs=self.specs)
for req in reqs:
req.req_id = self.conn._next_req_id()
self.conn._req_queues[req.req_id] = queues.Queue()
yield self.conn._tx_queue.put(req)
for req in reqs:
q = self.conn._req_queues[req.req_id]
resp = yield q.get()
del self.conn._req_queues[req.req_id]
q.task_done()
if resp.result_code != _protocol.ResultCode.OK:
raise ResultNGException()
self._receiving()
@gen.coroutine
def _receiving(self):
@gen.coroutine
def _routine(q, spec, callback):
while not self.quit.is_set():
unit = yield q.get()
callback(
_models.Unit(
elapsed_time=unit.elapsed_time,
channel=unit.channel,
data=unit.data,
)
)
q.task_done()
for stream_id, spec in self.specs.items():
_routine(
self.conn._rx_queues[stream_id],
self.specs[stream_id],
self.callbacks[stream_id],
)
class Upstreams(object):
def __init__(self, conn, specs, iterators, marker_interval):
self.conn = conn
self.quit = locks.Event()
self.marker_interval = marker_interval
# NOTE: asign stream_id using self.conn._stream_id method
self.specs = {}
self.iterators = {}
for spec, iterator in zip(specs, iterators):
stream_id = self.conn._next_stream_id()
self.specs[stream_id] = spec
self.iterators[stream_id] = iterator
self.stream_infos = {}
self._open()
def _stop(self):
self.quit.set()
for stream_id, spec in self.specs.items():
self.conn._rx_queues[stream_id].join()
del self.conn._rx_queues[stream_id]
@gen.coroutine
def _open(self):
self.quit.clear()
for stream_id, spec in self.specs.items():
self.conn._rx_queues[stream_id] = queues.Queue()
self.stream_infos = {}
for stream_id, spec in self.specs.items():
self.stream_infos[stream_id] = StreamInfo()
reqs = _utils._create_req_upstream(specs=self.specs)
for req in reqs:
req.req_id = self.conn._next_req_id()
self.conn._req_queues[req.req_id] = queues.Queue()
yield self.conn._tx_queue.put(req)
for req in reqs:
q = self.conn._req_queues[req.req_id]
resp = yield q.get()
del self.conn._req_queues[req.req_id]
q.task_done()
if resp.result_code != _protocol.ResultCode.OK:
raise ResultNGException()
self._check_section()
self._add_section()
self._sending()
@gen.coroutine
def _check_section(self):
@gen.coroutine
def _routine(q):
while not self.quit.is_set():
ack = yield q.get()
q.task_done()
for stream_id, spec in self.specs.items():
_routine(self.conn._rx_queues[stream_id])
@gen.coroutine
def _add_section(self):
@gen.coroutine
def _routine(stream_id):
stream_info = self.stream_infos[stream_id]
while not self.quit.is_set():
yield gen.sleep(self.marker_interval)
if stream_info.count == 0:
continue
yield self.conn._tx_queue.put(
_protocol.EOSMarker(
stream_id=stream_id,
final=False,
serial_number=stream_info.serial,
)
)
stream_info.next_serial()
yield self.conn._tx_queue.put(
_protocol.SOSMarker(
stream_id=stream_id, serial_number=stream_info.serial
)
)
for stream_id, spec in self.specs.items():
self.conn._tx_queue.put(
_protocol.SOSMarker(
stream_id=stream_id,
serial_number=self.stream_infos[stream_id].serial,
)
)
for stream_id, spec in self.specs.items():
_routine(stream_id)
@gen.coroutine
def _sending(self):
@gen.coroutine
def _routine(stream_id, iter):
while not self.quit.is_set():
try:
unit = next(iter)
yield
except StopIteration:
break
if unit is None:
continue
self.stream_infos[stream_id].count_up()
u = _protocol.Unit(
stream_id=stream_id,
channel=unit.channel,
elapsed_time=unit.elapsed_time,
data=unit.data,
time_precision="ns",
)
yield self.conn._tx_queue.put(u)
for stream_id, spec in self.specs.items():
_routine(stream_id, self.iterators[stream_id])
class StreamInfo(object):
def __init__(self):
self.count = 0
self.serial = 0
def count_up(self):
self.count += 1
def next_serial(self):
self.count = 0
self.serial += 1
def cyclic_counter256(initial):
cnt = initial
def inner():
nonlocal cnt
next = cnt
cnt = (cnt + 1) % 256
return next
return inner