Merge tag 'upstream/0.17.2' into debian
Upstream version 0.17.2
Erik Johnston
7 years ago
0 | Changes in synapse v0.17.2 (2016-09-08) | |
1 | ======================================= | |
2 | ||
3 | This release contains security bug fixes. Please upgrade. | |
4 | ||
5 | ||
6 | No changes since v0.17.2 | |
7 | ||
8 | ||
9 | Changes in synapse v0.17.2-rc1 (2016-09-05) | |
10 | =========================================== | |
11 | ||
12 | Features: | |
13 | ||
14 | * Start adding store-and-forward direct-to-device messaging (PR #1046, #1050, | |
15 | #1062, #1066) | |
16 | ||
17 | ||
18 | Changes: | |
19 | ||
20 | * Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063, | |
21 | #1068) | |
22 | * Don't notify for online to online presence transitions. (PR #1054) | |
23 | * Occasionally persist unpersisted presence updates (PR #1055) | |
24 | * Allow application services to have an optional 'url' (PR #1056) | |
25 | * Clean up old sent transactions from DB (PR #1059) | |
26 | ||
27 | ||
28 | Bug fixes: | |
29 | ||
30 | * Fix None check in backfill (PR #1043) | |
31 | * Fix membership changes to be idempotent (PR #1067) | |
32 | * Fix bug in get_pdu where it would sometimes return events with incorrect | |
33 | signature | |
34 | ||
35 | ||
36 | ||
0 | 37 | Changes in synapse v0.17.1 (2016-08-24) |
1 | 38 | ======================================= |
2 | 39 |
133 | 133 | sudo pip install --upgrade ndg-httpsclient |
134 | 134 | sudo pip install --upgrade virtualenv |
135 | 135 | |
136 | Installing prerequisites on openSUSE:: | |
137 | ||
138 | sudo zypper in -t pattern devel_basis | |
139 | sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \ | |
140 | python-devel libffi-devel libopenssl-devel libjpeg62-devel | |
141 | ||
136 | 142 | To install the synapse homeserver run:: |
137 | 143 | |
138 | 144 | virtualenv -p python2.7 ~/.synapse |
198 | 204 | source ./bin/activate |
199 | 205 | synctl start |
200 | 206 | |
207 | Security Note | |
208 | ============= | |
209 | ||
210 | Matrix serves raw user generated data in some APIs - specifically the content | |
211 | repository endpoints: http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid | |
212 | Whilst we have tried to mitigate against possible XSS attacks (e.g. | |
213 | https://github.com/matrix-org/synapse/pull/1021) we recommend running | |
214 | matrix homeservers on a dedicated domain name, to limit any malicious user generated | |
215 | content served to web browsers a matrix API from being able to attack webapps hosted | |
216 | on the same domain. This is particularly true of sharing a matrix webclient and | |
217 | server on the same domain. | |
218 | ||
219 | See https://github.com/vector-im/vector-web/issues/1977 and | |
220 | https://developer.github.com/changes/2014-04-25-user-content-security for more details. | |
221 | ||
201 | 222 | Using PostgreSQL |
202 | 223 | ================ |
203 | 224 | |
213 | 234 | * allowing basic active/backup high-availability with a "hot spare" synapse |
214 | 235 | pointing at the same DB master, as well as enabling DB replication in |
215 | 236 | synapse itself. |
216 | ||
217 | The only disadvantage is that the code is relatively new as of April 2015 and | |
218 | may have a few regressions relative to SQLite. | |
219 | 237 | |
220 | 238 | For information on how to install and use PostgreSQL, please see |
221 | 239 | `docs/postgres.rst <docs/postgres.rst>`_. |
15 | 15 | """ This is a reference implementation of a Matrix home server. |
16 | 16 | """ |
17 | 17 | |
18 | __version__ = "0.17.1" | |
18 | __version__ = "0.17.2" |
51 | 51 | self.state = hs.get_state_handler() |
52 | 52 | self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 |
53 | 53 | # Docs for these currently lives at |
54 | # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst | |
54 | # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst | |
55 | 55 | # In addition, we have type == delete_pusher which grants access only to |
56 | 56 | # delete pushers. |
57 | 57 | self._KNOWN_CAVEAT_PREFIXES = set([ |
61 | 61 | "time < ", |
62 | 62 | "user_id = ", |
63 | 63 | ]) |
64 | ||
65 | @defer.inlineCallbacks | |
66 | def check_from_context(self, event, context, do_sig_check=True): | |
67 | auth_events_ids = yield self.compute_auth_events( | |
68 | event, context.prev_state_ids, for_verification=True, | |
69 | ) | |
70 | auth_events = yield self.store.get_events(auth_events_ids) | |
71 | auth_events = { | |
72 | (e.type, e.state_key): e for e in auth_events.values() | |
73 | } | |
74 | self.check(event, auth_events=auth_events, do_sig_check=False) | |
64 | 75 | |
65 | 76 | def check(self, event, auth_events, do_sig_check=True): |
66 | 77 | """ Checks if this event is correctly authed. |
266 | 277 | |
267 | 278 | @defer.inlineCallbacks |
268 | 279 | def check_host_in_room(self, room_id, host): |
269 | curr_state = yield self.state.get_current_state(room_id) | |
270 | ||
271 | for event in curr_state.values(): | |
272 | if event.type == EventTypes.Member: | |
273 | try: | |
274 | if get_domain_from_id(event.state_key) != host: | |
275 | continue | |
276 | except: | |
277 | logger.warn("state_key not user_id: %s", event.state_key) | |
278 | continue | |
279 | ||
280 | if event.content["membership"] == Membership.JOIN: | |
281 | defer.returnValue(True) | |
282 | ||
283 | defer.returnValue(False) | |
280 | with Measure(self.clock, "check_host_in_room"): | |
281 | latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) | |
282 | ||
283 | entry = yield self.state.resolve_state_groups( | |
284 | room_id, latest_event_ids | |
285 | ) | |
286 | ||
287 | ret = yield self.store.is_host_joined( | |
288 | room_id, host, entry.state_group, entry.state | |
289 | ) | |
290 | defer.returnValue(ret) | |
284 | 291 | |
285 | 292 | def check_event_sender_in_room(self, event, auth_events): |
286 | 293 | key = (EventTypes.Member, event.user_id, ) |
846 | 853 | |
847 | 854 | @defer.inlineCallbacks |
848 | 855 | def add_auth_events(self, builder, context): |
849 | auth_ids = self.compute_auth_events(builder, context.current_state) | |
856 | auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) | |
850 | 857 | |
851 | 858 | auth_events_entries = yield self.store.add_event_hashes( |
852 | 859 | auth_ids |
854 | 861 | |
855 | 862 | builder.auth_events = auth_events_entries |
856 | 863 | |
857 | def compute_auth_events(self, event, current_state): | |
864 | @defer.inlineCallbacks | |
865 | def compute_auth_events(self, event, current_state_ids, for_verification=False): | |
858 | 866 | if event.type == EventTypes.Create: |
859 | return [] | |
867 | defer.returnValue([]) | |
860 | 868 | |
861 | 869 | auth_ids = [] |
862 | 870 | |
863 | 871 | key = (EventTypes.PowerLevels, "", ) |
864 | power_level_event = current_state.get(key) | |
865 | ||
866 | if power_level_event: | |
867 | auth_ids.append(power_level_event.event_id) | |
872 | power_level_event_id = current_state_ids.get(key) | |
873 | ||
874 | if power_level_event_id: | |
875 | auth_ids.append(power_level_event_id) | |
868 | 876 | |
869 | 877 | key = (EventTypes.JoinRules, "", ) |
870 | join_rule_event = current_state.get(key) | |
878 | join_rule_event_id = current_state_ids.get(key) | |
871 | 879 | |
872 | 880 | key = (EventTypes.Member, event.user_id, ) |
873 | member_event = current_state.get(key) | |
881 | member_event_id = current_state_ids.get(key) | |
874 | 882 | |
875 | 883 | key = (EventTypes.Create, "", ) |
876 | create_event = current_state.get(key) | |
877 | if create_event: | |
878 | auth_ids.append(create_event.event_id) | |
879 | ||
880 | if join_rule_event: | |
884 | create_event_id = current_state_ids.get(key) | |
885 | if create_event_id: | |
886 | auth_ids.append(create_event_id) | |
887 | ||
888 | if join_rule_event_id: | |
889 | join_rule_event = yield self.store.get_event(join_rule_event_id) | |
881 | 890 | join_rule = join_rule_event.content.get("join_rule") |
882 | 891 | is_public = join_rule == JoinRules.PUBLIC if join_rule else False |
883 | 892 | else: |
886 | 895 | if event.type == EventTypes.Member: |
887 | 896 | e_type = event.content["membership"] |
888 | 897 | if e_type in [Membership.JOIN, Membership.INVITE]: |
889 | if join_rule_event: | |
890 | auth_ids.append(join_rule_event.event_id) | |
898 | if join_rule_event_id: | |
899 | auth_ids.append(join_rule_event_id) | |
891 | 900 | |
892 | 901 | if e_type == Membership.JOIN: |
893 | if member_event and not is_public: | |
894 | auth_ids.append(member_event.event_id) | |
902 | if member_event_id and not is_public: | |
903 | auth_ids.append(member_event_id) | |
895 | 904 | else: |
896 | if member_event: | |
897 | auth_ids.append(member_event.event_id) | |
905 | if member_event_id: | |
906 | auth_ids.append(member_event_id) | |
907 | ||
908 | if for_verification: | |
909 | key = (EventTypes.Member, event.state_key, ) | |
910 | existing_event_id = current_state_ids.get(key) | |
911 | if existing_event_id: | |
912 | auth_ids.append(existing_event_id) | |
898 | 913 | |
899 | 914 | if e_type == Membership.INVITE: |
900 | 915 | if "third_party_invite" in event.content: |
902 | 917 | EventTypes.ThirdPartyInvite, |
903 | 918 | event.content["third_party_invite"]["signed"]["token"] |
904 | 919 | ) |
905 | third_party_invite = current_state.get(key) | |
906 | if third_party_invite: | |
907 | auth_ids.append(third_party_invite.event_id) | |
908 | elif member_event: | |
920 | third_party_invite_id = current_state_ids.get(key) | |
921 | if third_party_invite_id: | |
922 | auth_ids.append(third_party_invite_id) | |
923 | elif member_event_id: | |
924 | member_event = yield self.store.get_event(member_event_id) | |
909 | 925 | if member_event.content["membership"] == Membership.JOIN: |
910 | 926 | auth_ids.append(member_event.event_id) |
911 | 927 | |
912 | return auth_ids | |
928 | defer.returnValue(auth_ids) | |
913 | 929 | |
914 | 930 | def _get_send_level(self, etype, state_key, auth_events): |
915 | 931 | key = (EventTypes.PowerLevels, "", ) |
84 | 84 | PRIVATE_CHAT = "private_chat" |
85 | 85 | PUBLIC_CHAT = "public_chat" |
86 | 86 | TRUSTED_PRIVATE_CHAT = "trusted_private_chat" |
87 | ||
88 | ||
89 | class ThirdPartyEntityKind(object): | |
90 | USER = "user" | |
91 | LOCATION = "location" |
24 | 24 | SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" |
25 | 25 | MEDIA_PREFIX = "/_matrix/media/r0" |
26 | 26 | LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" |
27 | APP_SERVICE_PREFIX = "/_matrix/appservice/v1" |
35 | 35 | from synapse.replication.slave.storage.filtering import SlavedFilteringStore |
36 | 36 | from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore |
37 | 37 | from synapse.replication.slave.storage.presence import SlavedPresenceStore |
38 | from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore | |
38 | 39 | from synapse.server import HomeServer |
39 | 40 | from synapse.storage.client_ips import ClientIpStore |
40 | 41 | from synapse.storage.engines import create_engine |
71 | 72 | SlavedRegistrationStore, |
72 | 73 | SlavedFilteringStore, |
73 | 74 | SlavedPresenceStore, |
75 | SlavedDeviceInboxStore, | |
74 | 76 | BaseSlavedStore, |
75 | 77 | ClientIpStore, # After BaseSlavedStore because the constructor is different |
76 | 78 | ): |
396 | 398 | notify_from_stream( |
397 | 399 | result, "typing", "typing_key", room="room_id" |
398 | 400 | ) |
401 | notify_from_stream( | |
402 | result, "to_device", "to_device_key", user="user_id" | |
403 | ) | |
399 | 404 | |
400 | 405 | while True: |
401 | 406 | try: |
87 | 87 | self.sender = sender |
88 | 88 | self.namespaces = self._check_namespaces(namespaces) |
89 | 89 | self.id = id |
90 | ||
91 | # .protocols is a publicly visible field | |
90 | 92 | if protocols: |
91 | 93 | self.protocols = set(protocols) |
92 | 94 | else: |
13 | 13 | # limitations under the License. |
14 | 14 | from twisted.internet import defer |
15 | 15 | |
16 | from synapse.api.constants import ThirdPartyEntityKind | |
16 | 17 | from synapse.api.errors import CodeMessageException |
17 | 18 | from synapse.http.client import SimpleHttpClient |
18 | 19 | from synapse.events.utils import serialize_event |
19 | from synapse.types import ThirdPartyEntityKind | |
20 | from synapse.util.caches.response_cache import ResponseCache | |
20 | 21 | |
21 | 22 | import logging |
22 | 23 | import urllib |
23 | 24 | |
24 | 25 | logger = logging.getLogger(__name__) |
26 | ||
27 | ||
28 | HOUR_IN_MS = 60 * 60 * 1000 | |
29 | ||
30 | ||
31 | APP_SERVICE_PREFIX = "/_matrix/app/unstable" | |
25 | 32 | |
26 | 33 | |
27 | 34 | def _is_valid_3pe_result(r, field): |
55 | 62 | super(ApplicationServiceApi, self).__init__(hs) |
56 | 63 | self.clock = hs.get_clock() |
57 | 64 | |
65 | self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS) | |
66 | ||
58 | 67 | @defer.inlineCallbacks |
59 | 68 | def query_user(self, service, user_id): |
69 | if service.url is None: | |
70 | defer.returnValue(False) | |
60 | 71 | uri = service.url + ("/users/%s" % urllib.quote(user_id)) |
61 | 72 | response = None |
62 | 73 | try: |
76 | 87 | |
77 | 88 | @defer.inlineCallbacks |
78 | 89 | def query_alias(self, service, alias): |
90 | if service.url is None: | |
91 | defer.returnValue(False) | |
79 | 92 | uri = service.url + ("/rooms/%s" % urllib.quote(alias)) |
80 | 93 | response = None |
81 | 94 | try: |
96 | 109 | @defer.inlineCallbacks |
97 | 110 | def query_3pe(self, service, kind, protocol, fields): |
98 | 111 | if kind == ThirdPartyEntityKind.USER: |
99 | uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) | |
100 | 112 | required_field = "userid" |
101 | 113 | elif kind == ThirdPartyEntityKind.LOCATION: |
102 | uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) | |
103 | 114 | required_field = "alias" |
104 | 115 | else: |
105 | 116 | raise ValueError( |
106 | 117 | "Unrecognised 'kind' argument %r to query_3pe()", kind |
107 | 118 | ) |
108 | ||
119 | if service.url is None: | |
120 | defer.returnValue([]) | |
121 | ||
122 | uri = "%s%s/thirdparty/%s/%s" % ( | |
123 | service.url, | |
124 | APP_SERVICE_PREFIX, | |
125 | kind, | |
126 | urllib.quote(protocol) | |
127 | ) | |
109 | 128 | try: |
110 | 129 | response = yield self.get_json(uri, fields) |
111 | 130 | if not isinstance(response, list): |
130 | 149 | logger.warning("query_3pe to %s threw exception %s", uri, ex) |
131 | 150 | defer.returnValue([]) |
132 | 151 | |
152 | def get_3pe_protocol(self, service, protocol): | |
153 | if service.url is None: | |
154 | defer.returnValue({}) | |
155 | ||
156 | @defer.inlineCallbacks | |
157 | def _get(): | |
158 | uri = "%s%s/thirdparty/protocol/%s" % ( | |
159 | service.url, | |
160 | APP_SERVICE_PREFIX, | |
161 | urllib.quote(protocol) | |
162 | ) | |
163 | try: | |
164 | defer.returnValue((yield self.get_json(uri, {}))) | |
165 | except Exception as ex: | |
166 | logger.warning("query_3pe_protocol to %s threw exception %s", | |
167 | uri, ex) | |
168 | defer.returnValue({}) | |
169 | ||
170 | key = (service.id, protocol) | |
171 | return self.protocol_meta_cache.get(key) or ( | |
172 | self.protocol_meta_cache.set(key, _get()) | |
173 | ) | |
174 | ||
133 | 175 | @defer.inlineCallbacks |
134 | 176 | def push_bulk(self, service, events, txn_id=None): |
177 | if service.url is None: | |
178 | defer.returnValue(True) | |
179 | ||
135 | 180 | events = self._serialize(events) |
136 | 181 | |
137 | 182 | if txn_id is None: |
85 | 85 | |
86 | 86 | def _load_appservice(hostname, as_info, config_filename): |
87 | 87 | required_string_fields = [ |
88 | "id", "url", "as_token", "hs_token", "sender_localpart" | |
88 | "id", "as_token", "hs_token", "sender_localpart" | |
89 | 89 | ] |
90 | 90 | for field in required_string_fields: |
91 | 91 | if not isinstance(as_info.get(field), basestring): |
92 | 92 | raise KeyError("Required string field: '%s' (%s)" % ( |
93 | 93 | field, config_filename, |
94 | 94 | )) |
95 | ||
96 | # 'url' must either be a string or explicitly null, not missing | |
97 | # to avoid accidentally turning off push for ASes. | |
98 | if (not isinstance(as_info.get("url"), basestring) and | |
99 | as_info.get("url", "") is not None): | |
100 | raise KeyError( | |
101 | "Required string field or explicit null: 'url' (%s)" % (config_filename,) | |
102 | ) | |
95 | 103 | |
96 | 104 | localpart = as_info["sender_localpart"] |
97 | 105 | if urllib.quote(localpart) != localpart: |
131 | 139 | for p in protocols: |
132 | 140 | if not isinstance(p, str): |
133 | 141 | raise KeyError("Bad value for 'protocols' item") |
142 | ||
143 | if as_info["url"] is None: | |
144 | logger.info( | |
145 | "(%s) Explicitly empty 'url' provided. This application service" | |
146 | " will not receive events or queries.", | |
147 | config_filename, | |
148 | ) | |
134 | 149 | return ApplicationService( |
135 | 150 | token=as_info["as_token"], |
136 | 151 | url=as_info["url"], |
98 | 98 | |
99 | 99 | return d |
100 | 100 | |
101 | def get(self, key, default): | |
101 | def get(self, key, default=None): | |
102 | 102 | return self._event_dict.get(key, default) |
103 | 103 | |
104 | 104 | def get_internal_metadata_dict(self): |
14 | 14 | |
15 | 15 | |
16 | 16 | class EventContext(object): |
17 | ||
18 | def __init__(self, current_state=None): | |
19 | self.current_state = current_state | |
17 | def __init__(self): | |
18 | self.current_state_ids = None | |
19 | self.prev_state_ids = None | |
20 | 20 | self.state_group = None |
21 | 21 | self.rejected = False |
22 | 22 | self.push_actions = [] |
28 | 28 | from synapse.util.logutils import log_function |
29 | 29 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred |
30 | 30 | from synapse.events import FrozenEvent |
31 | from synapse.types import get_domain_from_id | |
31 | 32 | import synapse.metrics |
32 | 33 | |
33 | 34 | from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination |
62 | 63 | self._clock.looping_call( |
63 | 64 | self._clear_tried_cache, 60 * 1000, |
64 | 65 | ) |
66 | self.state = hs.get_state_handler() | |
65 | 67 | |
66 | 68 | def _clear_tried_cache(self): |
67 | 69 | """Clear pdu_destination_tried cache""" |
266 | 268 | |
267 | 269 | pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) |
268 | 270 | |
269 | pdu = None | |
271 | signed_pdu = None | |
270 | 272 | for destination in destinations: |
271 | 273 | now = self._clock.time_msec() |
272 | 274 | last_attempt = pdu_attempts.get(destination, 0) |
296 | 298 | pdu = pdu_list[0] |
297 | 299 | |
298 | 300 | # Check signatures are correct. |
299 | pdu = yield self._check_sigs_and_hashes([pdu])[0] | |
301 | signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] | |
300 | 302 | |
301 | 303 | break |
302 | 304 | |
319 | 321 | ) |
320 | 322 | continue |
321 | 323 | |
322 | if self._get_pdu_cache is not None and pdu: | |
323 | self._get_pdu_cache[event_id] = pdu | |
324 | ||
325 | defer.returnValue(pdu) | |
324 | if self._get_pdu_cache is not None and signed_pdu: | |
325 | self._get_pdu_cache[event_id] = signed_pdu | |
326 | ||
327 | defer.returnValue(signed_pdu) | |
326 | 328 | |
327 | 329 | @defer.inlineCallbacks |
328 | 330 | @log_function |
810 | 812 | if len(signed_events) >= limit: |
811 | 813 | defer.returnValue(signed_events) |
812 | 814 | |
813 | servers = yield self.store.get_joined_hosts_for_room(room_id) | |
815 | users = yield self.state.get_current_user_in_room(room_id) | |
816 | servers = set(get_domain_from_id(u) for u in users) | |
814 | 817 | |
815 | 818 | servers = set(servers) |
816 | 819 | servers.discard(self.server_name) |
222 | 222 | if not in_room: |
223 | 223 | raise AuthError(403, "Host not in room.") |
224 | 224 | |
225 | pdus = yield self.handler.get_state_for_pdu( | |
225 | state_ids = yield self.handler.get_state_ids_for_pdu( | |
226 | 226 | room_id, event_id, |
227 | 227 | ) |
228 | auth_chain = yield self.store.get_auth_chain( | |
229 | [pdu.event_id for pdu in pdus] | |
230 | ) | |
228 | auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) | |
231 | 229 | |
232 | 230 | defer.returnValue((200, { |
233 | "pdu_ids": [pdu.event_id for pdu in pdus], | |
234 | "auth_chain_ids": [pdu.event_id for pdu in auth_chain], | |
231 | "pdu_ids": state_ids, | |
232 | "auth_chain_ids": auth_chain_ids, | |
235 | 233 | })) |
236 | 234 | |
237 | 235 | @defer.inlineCallbacks |
64 | 64 | retry_after_ms=int(1000 * (time_allowed - time_now)), |
65 | 65 | ) |
66 | 66 | |
67 | def is_host_in_room(self, current_state): | |
68 | room_members = [ | |
69 | (state_key, event.membership) | |
70 | for ((event_type, state_key), event) in current_state.items() | |
71 | if event_type == EventTypes.Member | |
72 | ] | |
73 | if len(room_members) == 0: | |
74 | # Have we just created the room, and is this about to be the very | |
75 | # first member event? | |
76 | create_event = current_state.get(("m.room.create", "")) | |
77 | if create_event: | |
78 | return True | |
79 | for (state_key, membership) in room_members: | |
80 | if ( | |
81 | self.hs.is_mine_id(state_key) | |
82 | and membership == Membership.JOIN | |
83 | ): | |
84 | return True | |
85 | return False | |
86 | ||
87 | 67 | @defer.inlineCallbacks |
88 | def maybe_kick_guest_users(self, event, current_state): | |
68 | def maybe_kick_guest_users(self, event, context=None): | |
89 | 69 | # Technically this function invalidates current_state by changing it. |
90 | 70 | # Hopefully this isn't that important to the caller. |
91 | 71 | if event.type == EventTypes.GuestAccess: |
92 | 72 | guest_access = event.content.get("guest_access", "forbidden") |
93 | 73 | if guest_access != "can_join": |
74 | if context: | |
75 | current_state = yield self.store.get_events( | |
76 | context.current_state_ids.values() | |
77 | ) | |
78 | current_state = current_state.values() | |
79 | else: | |
80 | current_state = yield self.store.get_current_state(event.room_id) | |
81 | logger.info("maybe_kick_guest_users %r", current_state) | |
94 | 82 | yield self.kick_guest_users(current_state) |
95 | 83 | |
96 | 84 | @defer.inlineCallbacks |
175 | 175 | defer.returnValue(ret) |
176 | 176 | |
177 | 177 | @defer.inlineCallbacks |
178 | def get_3pe_protocols(self): | |
179 | services = yield self.store.get_app_services() | |
180 | protocols = {} | |
181 | for s in services: | |
182 | for p in s.protocols: | |
183 | protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) | |
184 | ||
185 | defer.returnValue(protocols) | |
186 | ||
187 | @defer.inlineCallbacks | |
178 | 188 | def _get_services_for_event(self, event): |
179 | 189 | """Retrieve a list of application services interested in this event. |
180 | 190 |
18 | 18 | |
19 | 19 | from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError |
20 | 20 | from synapse.api.constants import EventTypes |
21 | from synapse.types import RoomAlias, UserID | |
21 | from synapse.types import RoomAlias, UserID, get_domain_from_id | |
22 | 22 | |
23 | 23 | import logging |
24 | 24 | import string |
54 | 54 | # TODO(erikj): Add transactions. |
55 | 55 | # TODO(erikj): Check if there is a current association. |
56 | 56 | if not servers: |
57 | servers = yield self.store.get_joined_hosts_for_room(room_id) | |
57 | users = yield self.state.get_current_user_in_room(room_id) | |
58 | servers = set(get_domain_from_id(u) for u in users) | |
58 | 59 | |
59 | 60 | if not servers: |
60 | 61 | raise SynapseError(400, "Failed to get server list") |
192 | 193 | Codes.NOT_FOUND |
193 | 194 | ) |
194 | 195 | |
195 | extra_servers = yield self.store.get_joined_hosts_for_room(room_id) | |
196 | users = yield self.state.get_current_user_in_room(room_id) | |
197 | extra_servers = set(get_domain_from_id(u) for u in users) | |
196 | 198 | servers = set(extra_servers) | set(servers) |
197 | 199 | |
198 | 200 | # If this server is in the list of servers, return it first. |
46 | 46 | self.clock = hs.get_clock() |
47 | 47 | |
48 | 48 | self.notifier = hs.get_notifier() |
49 | self.state = hs.get_state_handler() | |
49 | 50 | |
50 | 51 | @defer.inlineCallbacks |
51 | 52 | @log_function |
89 | 90 | # Send down presence. |
90 | 91 | if event.state_key == auth_user_id: |
91 | 92 | # Send down presence for everyone in the room. |
92 | users = yield self.store.get_users_in_room(event.room_id) | |
93 | users = yield self.state.get_current_user_in_room(event.room_id) | |
93 | 94 | states = yield presence_handler.get_states( |
94 | 95 | users, |
95 | 96 | as_event=True, |
28 | 28 | from synapse.util.logcontext import ( |
29 | 29 | PreserveLoggingContext, preserve_fn, preserve_context_over_deferred |
30 | 30 | ) |
31 | from synapse.util.metrics import measure_func | |
31 | 32 | from synapse.util.logutils import log_function |
32 | 33 | from synapse.util.async import run_on_reactor |
33 | 34 | from synapse.util.frozenutils import unfreeze |
99 | 100 | def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): |
100 | 101 | """ Called by the ReplicationLayer when we have a new pdu. We need to |
101 | 102 | do auth checks and put it through the StateHandler. |
103 | ||
104 | auth_chain and state are None if we already have the necessary state | |
105 | and prev_events in the db | |
102 | 106 | """ |
103 | 107 | event = pdu |
104 | 108 | |
116 | 120 | |
117 | 121 | # FIXME (erikj): Awful hack to make the case where we are not currently |
118 | 122 | # in the room work |
119 | is_in_room = yield self.auth.check_host_in_room( | |
120 | event.room_id, | |
121 | self.server_name | |
122 | ) | |
123 | if not is_in_room and not event.internal_metadata.is_outlier(): | |
124 | logger.debug("Got event for room we're not in.") | |
123 | # If state and auth_chain are None, then we don't need to do this check | |
124 | # as we already know we have enough state in the DB to handle this | |
125 | # event. | |
126 | if state and auth_chain and not event.internal_metadata.is_outlier(): | |
127 | is_in_room = yield self.auth.check_host_in_room( | |
128 | event.room_id, | |
129 | self.server_name | |
130 | ) | |
131 | else: | |
132 | is_in_room = True | |
133 | if not is_in_room: | |
134 | logger.info( | |
135 | "Got event for room we're not in: %r %r", | |
136 | event.room_id, event.event_id | |
137 | ) | |
125 | 138 | |
126 | 139 | try: |
127 | 140 | event_stream_id, max_stream_id = yield self._persist_auth_tree( |
216 | 229 | |
217 | 230 | if event.type == EventTypes.Member: |
218 | 231 | if event.membership == Membership.JOIN: |
219 | prev_state = context.current_state.get((event.type, event.state_key)) | |
220 | if not prev_state or prev_state.membership != Membership.JOIN: | |
221 | # Only fire user_joined_room if the user has acutally | |
222 | # joined the room. Don't bother if the user is just | |
223 | # changing their profile info. | |
232 | # Only fire user_joined_room if the user has acutally | |
233 | # joined the room. Don't bother if the user is just | |
234 | # changing their profile info. | |
235 | newly_joined = True | |
236 | prev_state_id = context.prev_state_ids.get( | |
237 | (event.type, event.state_key) | |
238 | ) | |
239 | if prev_state_id: | |
240 | prev_state = yield self.store.get_event( | |
241 | prev_state_id, allow_none=True, | |
242 | ) | |
243 | if prev_state and prev_state.membership == Membership.JOIN: | |
244 | newly_joined = False | |
245 | ||
246 | if newly_joined: | |
224 | 247 | user = UserID.from_string(event.state_key) |
225 | 248 | yield user_joined_room(self.distributor, user, event.room_id) |
226 | 249 | |
250 | @measure_func("_filter_events_for_server") | |
227 | 251 | @defer.inlineCallbacks |
228 | 252 | def _filter_events_for_server(self, server_name, room_id, events): |
229 | event_to_state = yield self.store.get_state_for_events( | |
253 | event_to_state_ids = yield self.store.get_state_ids_for_events( | |
230 | 254 | frozenset(e.event_id for e in events), |
231 | 255 | types=( |
232 | 256 | (EventTypes.RoomHistoryVisibility, ""), |
233 | 257 | (EventTypes.Member, None), |
234 | 258 | ) |
235 | 259 | ) |
260 | ||
261 | # We only want to pull out member events that correspond to the | |
262 | # server's domain. | |
263 | ||
264 | def check_match(id): | |
265 | try: | |
266 | return server_name == get_domain_from_id(id) | |
267 | except: | |
268 | return False | |
269 | ||
270 | event_map = yield self.store.get_events([ | |
271 | e_id for key_to_eid in event_to_state_ids.values() | |
272 | for key, e_id in key_to_eid | |
273 | if key[0] != EventTypes.Member or check_match(key[1]) | |
274 | ]) | |
275 | ||
276 | event_to_state = { | |
277 | e_id: { | |
278 | key: event_map[inner_e_id] | |
279 | for key, inner_e_id in key_to_eid.items() | |
280 | if inner_e_id in event_map | |
281 | } | |
282 | for e_id, key_to_eid in event_to_state_ids.items() | |
283 | } | |
236 | 284 | |
237 | 285 | def redact_disallowed(event, state): |
238 | 286 | if not state: |
376 | 424 | )).addErrback(unwrapFirstError) |
377 | 425 | auth_events.update({a.event_id: a for a in results if a}) |
378 | 426 | required_auth.update( |
379 | a_id for event in results for a_id, _ in event.auth_events if event | |
427 | a_id | |
428 | for event in results if event | |
429 | for a_id, _ in event.auth_events | |
380 | 430 | ) |
381 | 431 | missing_auth = required_auth - set(auth_events) |
382 | 432 | |
559 | 609 | ])) |
560 | 610 | states = dict(zip(event_ids, [s[1] for s in states])) |
561 | 611 | |
612 | state_map = yield self.store.get_events( | |
613 | [e_id for ids in states.values() for e_id in ids], | |
614 | get_prev_content=False | |
615 | ) | |
616 | states = { | |
617 | key: { | |
618 | k: state_map[e_id] | |
619 | for k, e_id in state_dict.items() | |
620 | if e_id in state_map | |
621 | } for key, state_dict in states.items() | |
622 | } | |
623 | ||
562 | 624 | for e_id, _ in sorted_extremeties_tuple: |
563 | 625 | likely_domains = get_domains_from_state(states[e_id]) |
564 | 626 | |
721 | 783 | |
722 | 784 | # The remote hasn't signed it yet, obviously. We'll do the full checks |
723 | 785 | # when we get the event back in `on_send_join_request` |
724 | self.auth.check(event, auth_events=context.current_state, do_sig_check=False) | |
786 | yield self.auth.check_from_context(event, context, do_sig_check=False) | |
725 | 787 | |
726 | 788 | defer.returnValue(event) |
727 | 789 | |
769 | 831 | |
770 | 832 | new_pdu = event |
771 | 833 | |
772 | destinations = set() | |
773 | ||
774 | for k, s in context.current_state.items(): | |
775 | try: | |
776 | if k[0] == EventTypes.Member: | |
777 | if s.content["membership"] == Membership.JOIN: | |
778 | destinations.add(get_domain_from_id(s.state_key)) | |
779 | except: | |
780 | logger.warn( | |
781 | "Failed to get destination from event %s", s.event_id | |
782 | ) | |
783 | ||
834 | message_handler = self.hs.get_handlers().message_handler | |
835 | destinations = yield message_handler.get_joined_hosts_for_room_from_state( | |
836 | context | |
837 | ) | |
838 | destinations = set(destinations) | |
784 | 839 | destinations.discard(origin) |
785 | 840 | |
786 | 841 | logger.debug( |
791 | 846 | |
792 | 847 | self.replication_layer.send_pdu(new_pdu, destinations) |
793 | 848 | |
794 | state_ids = [e.event_id for e in context.current_state.values()] | |
849 | state_ids = context.prev_state_ids.values() | |
795 | 850 | auth_chain = yield self.store.get_auth_chain(set( |
796 | 851 | [event.event_id] + state_ids |
797 | 852 | )) |
798 | 853 | |
854 | state = yield self.store.get_events(context.prev_state_ids.values()) | |
855 | ||
799 | 856 | defer.returnValue({ |
800 | "state": context.current_state.values(), | |
857 | "state": state.values(), | |
801 | 858 | "auth_chain": auth_chain, |
802 | 859 | }) |
803 | 860 | |
953 | 1010 | try: |
954 | 1011 | # The remote hasn't signed it yet, obviously. We'll do the full checks |
955 | 1012 | # when we get the event back in `on_send_leave_request` |
956 | self.auth.check(event, auth_events=context.current_state, do_sig_check=False) | |
1013 | yield self.auth.check_from_context(event, context, do_sig_check=False) | |
957 | 1014 | except AuthError as e: |
958 | 1015 | logger.warn("Failed to create new leave %r because %s", event, e) |
959 | 1016 | raise e |
997 | 1054 | |
998 | 1055 | new_pdu = event |
999 | 1056 | |
1000 | destinations = set() | |
1001 | ||
1002 | for k, s in context.current_state.items(): | |
1003 | try: | |
1004 | if k[0] == EventTypes.Member: | |
1005 | if s.content["membership"] == Membership.LEAVE: | |
1006 | destinations.add(get_domain_from_id(s.state_key)) | |
1007 | except: | |
1008 | logger.warn( | |
1009 | "Failed to get destination from event %s", s.event_id | |
1010 | ) | |
1011 | ||
1057 | message_handler = self.hs.get_handlers().message_handler | |
1058 | destinations = yield message_handler.get_joined_hosts_for_room_from_state( | |
1059 | context | |
1060 | ) | |
1061 | destinations = set(destinations) | |
1012 | 1062 | destinations.discard(origin) |
1013 | 1063 | |
1014 | 1064 | logger.debug( |
1023 | 1073 | |
1024 | 1074 | @defer.inlineCallbacks |
1025 | 1075 | def get_state_for_pdu(self, room_id, event_id): |
1076 | """Returns the state at the event. i.e. not including said event. | |
1077 | """ | |
1026 | 1078 | yield run_on_reactor() |
1027 | 1079 | |
1028 | 1080 | state_groups = yield self.store.get_state_groups( |
1060 | 1112 | ) |
1061 | 1113 | |
1062 | 1114 | defer.returnValue(res) |
1115 | else: | |
1116 | defer.returnValue([]) | |
1117 | ||
1118 | @defer.inlineCallbacks | |
1119 | def get_state_ids_for_pdu(self, room_id, event_id): | |
1120 | """Returns the state at the event. i.e. not including said event. | |
1121 | """ | |
1122 | yield run_on_reactor() | |
1123 | ||
1124 | state_groups = yield self.store.get_state_groups_ids( | |
1125 | room_id, [event_id] | |
1126 | ) | |
1127 | ||
1128 | if state_groups: | |
1129 | _, state = state_groups.items().pop() | |
1130 | results = state | |
1131 | ||
1132 | event = yield self.store.get_event(event_id) | |
1133 | if event and event.is_state(): | |
1134 | # Get previous state | |
1135 | if "replaces_state" in event.unsigned: | |
1136 | prev_id = event.unsigned["replaces_state"] | |
1137 | if prev_id != event.event_id: | |
1138 | results[(event.type, event.state_key)] = prev_id | |
1139 | else: | |
1140 | del results[(event.type, event.state_key)] | |
1141 | ||
1142 | defer.returnValue(results.values()) | |
1063 | 1143 | else: |
1064 | 1144 | defer.returnValue([]) |
1065 | 1145 | |
1293 | 1373 | ) |
1294 | 1374 | |
1295 | 1375 | if not auth_events: |
1296 | auth_events = context.current_state | |
1376 | auth_events_ids = yield self.auth.compute_auth_events( | |
1377 | event, context.prev_state_ids, for_verification=True, | |
1378 | ) | |
1379 | auth_events = yield self.store.get_events(auth_events_ids) | |
1380 | auth_events = { | |
1381 | (e.type, e.state_key): e for e in auth_events.values() | |
1382 | } | |
1297 | 1383 | |
1298 | 1384 | # This is a hack to fix some old rooms where the initial join event |
1299 | 1385 | # didn't reference the create event in its auth events. |
1319 | 1405 | context.rejected = RejectedReason.AUTH_ERROR |
1320 | 1406 | |
1321 | 1407 | if event.type == EventTypes.GuestAccess: |
1322 | full_context = yield self.store.get_current_state(room_id=event.room_id) | |
1323 | yield self.maybe_kick_guest_users(event, full_context) | |
1408 | yield self.maybe_kick_guest_users(event) | |
1324 | 1409 | |
1325 | 1410 | defer.returnValue(context) |
1326 | 1411 | |
1387 | 1472 | # Check if we have all the auth events. |
1388 | 1473 | current_state = set(e.event_id for e in auth_events.values()) |
1389 | 1474 | event_auth_events = set(e_id for e_id, _ in event.auth_events) |
1475 | ||
1476 | if event.is_state(): | |
1477 | event_key = (event.type, event.state_key) | |
1478 | else: | |
1479 | event_key = None | |
1390 | 1480 | |
1391 | 1481 | if event_auth_events - current_state: |
1392 | 1482 | have_events = yield self.store.have_events( |
1491 | 1581 | current_state = set(e.event_id for e in auth_events.values()) |
1492 | 1582 | different_auth = event_auth_events - current_state |
1493 | 1583 | |
1494 | context.current_state.update(auth_events) | |
1495 | context.state_group = None | |
1584 | context.current_state_ids.update({ | |
1585 | k: a.event_id for k, a in auth_events.items() | |
1586 | if k != event_key | |
1587 | }) | |
1588 | context.prev_state_ids.update({ | |
1589 | k: a.event_id for k, a in auth_events.items() | |
1590 | }) | |
1591 | context.state_group = self.store.get_next_state_group() | |
1496 | 1592 | |
1497 | 1593 | if different_auth and not event.internal_metadata.is_outlier(): |
1498 | 1594 | logger.info("Different auth after resolution: %s", different_auth) |
1513 | 1609 | |
1514 | 1610 | if do_resolution: |
1515 | 1611 | # 1. Get what we think is the auth chain. |
1516 | auth_ids = self.auth.compute_auth_events( | |
1517 | event, context.current_state | |
1612 | auth_ids = yield self.auth.compute_auth_events( | |
1613 | event, context.prev_state_ids | |
1518 | 1614 | ) |
1519 | 1615 | local_auth_chain = yield self.store.get_auth_chain(auth_ids) |
1520 | 1616 | |
1570 | 1666 | # 4. Look at rejects and their proofs. |
1571 | 1667 | # TODO. |
1572 | 1668 | |
1573 | context.current_state.update(auth_events) | |
1574 | context.state_group = None | |
1669 | context.current_state_ids.update({ | |
1670 | k: a.event_id for k, a in auth_events.items() | |
1671 | if k != event_key | |
1672 | }) | |
1673 | context.prev_state_ids.update({ | |
1674 | k: a.event_id for k, a in auth_events.items() | |
1675 | }) | |
1676 | context.state_group = self.store.get_next_state_group() | |
1575 | 1677 | |
1576 | 1678 | try: |
1577 | 1679 | self.auth.check(event, auth_events=auth_events) |
1757 | 1859 | ) |
1758 | 1860 | |
1759 | 1861 | try: |
1760 | self.auth.check(event, context.current_state) | |
1862 | yield self.auth.check_from_context(event, context) | |
1761 | 1863 | except AuthError as e: |
1762 | 1864 | logger.warn("Denying new third party invite %r because %s", event, e) |
1763 | 1865 | raise e |
1764 | 1866 | |
1765 | yield self._check_signature(event, auth_events=context.current_state) | |
1867 | yield self._check_signature(event, context) | |
1766 | 1868 | member_handler = self.hs.get_handlers().room_member_handler |
1767 | 1869 | yield member_handler.send_membership_event(None, event, context) |
1768 | 1870 | else: |
1788 | 1890 | ) |
1789 | 1891 | |
1790 | 1892 | try: |
1791 | self.auth.check(event, auth_events=context.current_state) | |
1893 | self.auth.check_from_context(event, context) | |
1792 | 1894 | except AuthError as e: |
1793 | 1895 | logger.warn("Denying third party invite %r because %s", event, e) |
1794 | 1896 | raise e |
1795 | yield self._check_signature(event, auth_events=context.current_state) | |
1897 | yield self._check_signature(event, context) | |
1796 | 1898 | |
1797 | 1899 | returned_invite = yield self.send_invite(origin, event) |
1798 | 1900 | # TODO: Make sure the signatures actually are correct. |
1806 | 1908 | EventTypes.ThirdPartyInvite, |
1807 | 1909 | event.content["third_party_invite"]["signed"]["token"] |
1808 | 1910 | ) |
1809 | original_invite = context.current_state.get(key) | |
1911 | original_invite = None | |
1912 | original_invite_id = context.prev_state_ids.get(key) | |
1913 | if original_invite_id: | |
1914 | original_invite = yield self.store.get_event( | |
1915 | original_invite_id, allow_none=True | |
1916 | ) | |
1810 | 1917 | if not original_invite: |
1811 | 1918 | logger.info( |
1812 | 1919 | "Could not find invite event for third_party_invite - " |
1823 | 1930 | defer.returnValue((event, context)) |
1824 | 1931 | |
1825 | 1932 | @defer.inlineCallbacks |
1826 | def _check_signature(self, event, auth_events): | |
1933 | def _check_signature(self, event, context): | |
1827 | 1934 | """ |
1828 | 1935 | Checks that the signature in the event is consistent with its invite. |
1829 | 1936 | |
1830 | 1937 | Args: |
1831 | 1938 | event (Event): The m.room.member event to check |
1832 | auth_events (dict<(event type, state_key), event>): | |
1939 | context (EventContext): | |
1833 | 1940 | |
1834 | 1941 | Raises: |
1835 | 1942 | AuthError: if signature didn't match any keys, or key has been |
1840 | 1947 | signed = event.content["third_party_invite"]["signed"] |
1841 | 1948 | token = signed["token"] |
1842 | 1949 | |
1843 | invite_event = auth_events.get( | |
1950 | invite_event_id = context.prev_state_ids.get( | |
1844 | 1951 | (EventTypes.ThirdPartyInvite, token,) |
1845 | 1952 | ) |
1953 | ||
1954 | invite_event = None | |
1955 | if invite_event_id: | |
1956 | invite_event = yield self.store.get_event(invite_event_id, allow_none=True) | |
1846 | 1957 | |
1847 | 1958 | if not invite_event: |
1848 | 1959 | raise AuthError(403, "Could not find invite") |
29 | 29 | from synapse.util.caches.snapshot_cache import SnapshotCache |
30 | 30 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred |
31 | 31 | from synapse.util.metrics import measure_func |
32 | from synapse.util.caches.descriptors import cachedInlineCallbacks | |
32 | 33 | from synapse.visibility import filter_events_for_client |
33 | 34 | |
34 | 35 | from ._base import BaseHandler |
247 | 248 | assert self.hs.is_mine(user), "User must be our own: %s" % (user,) |
248 | 249 | |
249 | 250 | if event.is_state(): |
250 | prev_state = self.deduplicate_state_event(event, context) | |
251 | prev_state = yield self.deduplicate_state_event(event, context) | |
251 | 252 | if prev_state is not None: |
252 | 253 | defer.returnValue(prev_state) |
253 | 254 | |
262 | 263 | presence = self.hs.get_presence_handler() |
263 | 264 | yield presence.bump_presence_active_time(user) |
264 | 265 | |
266 | @defer.inlineCallbacks | |
265 | 267 | def deduplicate_state_event(self, event, context): |
266 | 268 | """ |
267 | 269 | Checks whether event is in the latest resolved state in context. |
269 | 271 | If so, returns the version of the event in context. |
270 | 272 | Otherwise, returns None. |
271 | 273 | """ |
272 | prev_event = context.current_state.get((event.type, event.state_key)) | |
274 | prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) | |
275 | prev_event = yield self.store.get_event(prev_event_id, allow_none=True) | |
276 | if not prev_event: | |
277 | return | |
278 | ||
273 | 279 | if prev_event and event.user_id == prev_event.user_id: |
274 | 280 | prev_content = encode_canonical_json(prev_event.content) |
275 | 281 | next_content = encode_canonical_json(event.content) |
276 | 282 | if prev_content == next_content: |
277 | return prev_event | |
278 | return None | |
283 | defer.returnValue(prev_event) | |
284 | return | |
279 | 285 | |
280 | 286 | @defer.inlineCallbacks |
281 | 287 | def create_and_send_nonmember_event( |
801 | 807 | event = builder.build() |
802 | 808 | |
803 | 809 | logger.debug( |
804 | "Created event %s with current state: %s", | |
805 | event.event_id, context.current_state, | |
810 | "Created event %s with state: %s", | |
811 | event.event_id, context.prev_state_ids, | |
806 | 812 | ) |
807 | 813 | |
808 | 814 | defer.returnValue( |
825 | 831 | self.ratelimit(requester) |
826 | 832 | |
827 | 833 | try: |
828 | self.auth.check(event, auth_events=context.current_state) | |
834 | yield self.auth.check_from_context(event, context) | |
829 | 835 | except AuthError as err: |
830 | 836 | logger.warn("Denying new event %r because %s", event, err) |
831 | 837 | raise err |
832 | 838 | |
833 | yield self.maybe_kick_guest_users(event, context.current_state.values()) | |
839 | yield self.maybe_kick_guest_users(event, context) | |
834 | 840 | |
835 | 841 | if event.type == EventTypes.CanonicalAlias: |
836 | 842 | # Check the alias is acually valid (at this time at least) |
858 | 864 | e.sender == event.sender |
859 | 865 | ) |
860 | 866 | |
867 | state_to_include_ids = [ | |
868 | e_id | |
869 | for k, e_id in context.current_state_ids.items() | |
870 | if k[0] in self.hs.config.room_invite_state_types | |
871 | or k[0] == EventTypes.Member and k[1] == event.sender | |
872 | ] | |
873 | ||
874 | state_to_include = yield self.store.get_events(state_to_include_ids) | |
875 | ||
861 | 876 | event.unsigned["invite_room_state"] = [ |
862 | 877 | { |
863 | 878 | "type": e.type, |
865 | 880 | "content": e.content, |
866 | 881 | "sender": e.sender, |
867 | 882 | } |
868 | for k, e in context.current_state.items() | |
869 | if e.type in self.hs.config.room_invite_state_types | |
870 | or is_inviter_member_event(e) | |
883 | for e in state_to_include.values() | |
871 | 884 | ] |
872 | 885 | |
873 | 886 | invitee = UserID.from_string(event.state_key) |
889 | 902 | ) |
890 | 903 | |
891 | 904 | if event.type == EventTypes.Redaction: |
892 | if self.auth.check_redaction(event, auth_events=context.current_state): | |
905 | auth_events_ids = yield self.auth.compute_auth_events( | |
906 | event, context.prev_state_ids, for_verification=True, | |
907 | ) | |
908 | auth_events = yield self.store.get_events(auth_events_ids) | |
909 | auth_events = { | |
910 | (e.type, e.state_key): e for e in auth_events.values() | |
911 | } | |
912 | if self.auth.check_redaction(event, auth_events=auth_events): | |
893 | 913 | original_event = yield self.store.get_event( |
894 | 914 | event.redacts, |
895 | 915 | check_redacted=False, |
903 | 923 | "You don't have permission to redact events" |
904 | 924 | ) |
905 | 925 | |
906 | if event.type == EventTypes.Create and context.current_state: | |
926 | if event.type == EventTypes.Create and context.prev_state_ids: | |
907 | 927 | raise AuthError( |
908 | 928 | 403, |
909 | 929 | "Changing the room create event is forbidden", |
924 | 944 | event_stream_id, max_stream_id |
925 | 945 | ) |
926 | 946 | |
927 | destinations = set() | |
928 | for k, s in context.current_state.items(): | |
929 | try: | |
930 | if k[0] == EventTypes.Member: | |
931 | if s.content["membership"] == Membership.JOIN: | |
932 | destinations.add(get_domain_from_id(s.state_key)) | |
933 | except SynapseError: | |
934 | logger.warn( | |
935 | "Failed to get destination from event %s", s.event_id | |
936 | ) | |
947 | destinations = yield self.get_joined_hosts_for_room_from_state(context) | |
937 | 948 | |
938 | 949 | @defer.inlineCallbacks |
939 | 950 | def _notify(): |
951 | 962 | preserve_fn(federation_handler.handle_new_event)( |
952 | 963 | event, destinations=destinations, |
953 | 964 | ) |
965 | ||
966 | def get_joined_hosts_for_room_from_state(self, context): | |
967 | state_group = context.state_group | |
968 | if not state_group: | |
969 | # If state_group is None it means it has yet to be assigned a | |
970 | # state group, i.e. we need to make sure that calls with a state_group | |
971 | # of None don't hit previous cached calls with a None state_group. | |
972 | # To do this we set the state_group to a new object as object() != object() | |
973 | state_group = object() | |
974 | ||
975 | return self._get_joined_hosts_for_room_from_state( | |
976 | state_group, context.current_state_ids | |
977 | ) | |
978 | ||
979 | @cachedInlineCallbacks(num_args=1, cache_context=True) | |
980 | def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids, | |
981 | cache_context): | |
982 | ||
983 | # Don't bother getting state for people on the same HS | |
984 | current_state = yield self.store.get_events([ | |
985 | e_id for key, e_id in current_state_ids.items() | |
986 | if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1]) | |
987 | ]) | |
988 | ||
989 | destinations = set() | |
990 | for e in current_state.itervalues(): | |
991 | try: | |
992 | if e.type == EventTypes.Member: | |
993 | if e.content["membership"] == Membership.JOIN: | |
994 | destinations.add(get_domain_from_id(e.state_key)) | |
995 | except SynapseError: | |
996 | logger.warn( | |
997 | "Failed to get destination from event %s", e.event_id | |
998 | ) | |
999 | ||
1000 | defer.returnValue(destinations) |
86 | 86 | self.wheel_timer = WheelTimer() |
87 | 87 | self.notifier = hs.get_notifier() |
88 | 88 | self.federation = hs.get_replication_layer() |
89 | ||
90 | self.state = hs.get_state_handler() | |
89 | 91 | |
90 | 92 | self.federation.register_edu_handler( |
91 | 93 | "m.presence", self.incoming_presence |
188 | 190 | 5000, |
189 | 191 | ) |
190 | 192 | |
193 | self.clock.call_later( | |
194 | 60, | |
195 | self.clock.looping_call, | |
196 | self._persist_unpersisted_changes, | |
197 | 60 * 1000, | |
198 | ) | |
199 | ||
191 | 200 | metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) |
192 | 201 | |
193 | 202 | @defer.inlineCallbacks |
212 | 221 | for user_id in self.unpersisted_users_changes |
213 | 222 | ]) |
214 | 223 | logger.info("Finished _on_shutdown") |
224 | ||
225 | @defer.inlineCallbacks | |
226 | def _persist_unpersisted_changes(self): | |
227 | """We periodically persist the unpersisted changes, as otherwise they | |
228 | may stack up and slow down shutdown times. | |
229 | """ | |
230 | logger.info( | |
231 | "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", | |
232 | len(self.unpersisted_users_changes) | |
233 | ) | |
234 | ||
235 | unpersisted = self.unpersisted_users_changes | |
236 | self.unpersisted_users_changes = set() | |
237 | ||
238 | if unpersisted: | |
239 | yield self.store.update_presence([ | |
240 | self.user_to_current_state[user_id] | |
241 | for user_id in unpersisted | |
242 | ]) | |
243 | ||
244 | logger.info("Finished _persist_unpersisted_changes") | |
215 | 245 | |
216 | 246 | @defer.inlineCallbacks |
217 | 247 | def _update_states(self, new_states): |
531 | 561 | if not local_states: |
532 | 562 | continue |
533 | 563 | |
534 | hosts = yield self.store.get_joined_hosts_for_room(room_id) | |
564 | users = yield self.state.get_current_user_in_room(room_id) | |
565 | hosts = set(get_domain_from_id(u) for u in users) | |
566 | ||
535 | 567 | for host in hosts: |
536 | 568 | hosts_to_states.setdefault(host, []).extend(local_states) |
537 | 569 | |
724 | 756 | # don't need to send to local clients here, as that is done as part |
725 | 757 | # of the event stream/sync. |
726 | 758 | # TODO: Only send to servers not already in the room. |
759 | user_ids = yield self.state.get_current_user_in_room(room_id) | |
727 | 760 | if self.is_mine(user): |
728 | 761 | state = yield self.current_state_for_user(user.to_string()) |
729 | 762 | |
730 | hosts = yield self.store.get_joined_hosts_for_room(room_id) | |
763 | hosts = set(get_domain_from_id(u) for u in user_ids) | |
731 | 764 | self._push_to_remotes({host: (state,) for host in hosts}) |
732 | 765 | else: |
733 | user_ids = yield self.store.get_users_in_room(room_id) | |
734 | 766 | user_ids = filter(self.is_mine_id, user_ids) |
735 | 767 | |
736 | 768 | states = yield self.current_state_for_users(user_ids) |
917 | 949 | if new_state.currently_active != old_state.currently_active: |
918 | 950 | return True |
919 | 951 | |
920 | if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: | |
952 | if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: | |
953 | # Only notify about last active bumps if we're not currently acive | |
954 | if not (old_state.currently_active and new_state.currently_active): | |
955 | return True | |
956 | ||
957 | elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: | |
921 | 958 | # Always notify for a transition where last active gets bumped. |
922 | 959 | return True |
923 | 960 | |
954 | 991 | self.get_presence_handler = hs.get_presence_handler |
955 | 992 | self.clock = hs.get_clock() |
956 | 993 | self.store = hs.get_datastore() |
994 | self.state = hs.get_state_handler() | |
957 | 995 | |
958 | 996 | @defer.inlineCallbacks |
959 | 997 | @log_function |
1016 | 1054 | |
1017 | 1055 | user_ids_to_check = set() |
1018 | 1056 | for room_id in room_ids: |
1019 | users = yield self.store.get_users_in_room(room_id) | |
1057 | users = yield self.state.get_current_user_in_room(room_id) | |
1020 | 1058 | user_ids_to_check.update(users) |
1021 | 1059 | |
1022 | 1060 | user_ids_to_check.update(friends) |
17 | 17 | from twisted.internet import defer |
18 | 18 | |
19 | 19 | from synapse.util.logcontext import PreserveLoggingContext |
20 | from synapse.types import get_domain_from_id | |
20 | 21 | |
21 | 22 | import logging |
22 | 23 | |
36 | 37 | "m.receipt", self._received_remote_receipt |
37 | 38 | ) |
38 | 39 | self.clock = self.hs.get_clock() |
40 | self.state = hs.get_state_handler() | |
39 | 41 | |
40 | 42 | @defer.inlineCallbacks |
41 | 43 | def received_client_receipt(self, room_id, receipt_type, user_id, |
132 | 134 | event_ids = receipt["event_ids"] |
133 | 135 | data = receipt["data"] |
134 | 136 | |
135 | remotedomains = yield self.store.get_joined_hosts_for_room(room_id) | |
137 | users = yield self.state.get_current_user_in_room(room_id) | |
138 | remotedomains = set(get_domain_from_id(u) for u in users) | |
136 | 139 | remotedomains = remotedomains.copy() |
137 | 140 | remotedomains.discard(self.server_name) |
138 | 141 |
84 | 84 | prev_event_ids=prev_event_ids, |
85 | 85 | ) |
86 | 86 | |
87 | # Check if this event matches the previous membership event for the user. | |
88 | duplicate = yield msg_handler.deduplicate_state_event(event, context) | |
89 | if duplicate is not None: | |
90 | # Discard the new event since this membership change is a no-op. | |
91 | return | |
92 | ||
87 | 93 | yield msg_handler.handle_new_client_event( |
88 | 94 | requester, |
89 | 95 | event, |
92 | 98 | ratelimit=ratelimit, |
93 | 99 | ) |
94 | 100 | |
95 | prev_member_event = context.current_state.get( | |
101 | prev_member_event_id = context.prev_state_ids.get( | |
96 | 102 | (EventTypes.Member, target.to_string()), |
97 | 103 | None |
98 | 104 | ) |
99 | 105 | |
100 | 106 | if event.membership == Membership.JOIN: |
101 | if not prev_member_event or prev_member_event.membership != Membership.JOIN: | |
102 | # Only fire user_joined_room if the user has acutally joined the | |
103 | # room. Don't bother if the user is just changing their profile | |
104 | # info. | |
107 | # Only fire user_joined_room if the user has acutally joined the | |
108 | # room. Don't bother if the user is just changing their profile | |
109 | # info. | |
110 | newly_joined = True | |
111 | if prev_member_event_id: | |
112 | prev_member_event = yield self.store.get_event(prev_member_event_id) | |
113 | newly_joined = prev_member_event.membership != Membership.JOIN | |
114 | if newly_joined: | |
105 | 115 | yield user_joined_room(self.distributor, target, room_id) |
106 | 116 | elif event.membership == Membership.LEAVE: |
107 | if prev_member_event and prev_member_event.membership == Membership.JOIN: | |
108 | user_left_room(self.distributor, target, room_id) | |
117 | if prev_member_event_id: | |
118 | prev_member_event = yield self.store.get_event(prev_member_event_id) | |
119 | if prev_member_event.membership == Membership.JOIN: | |
120 | user_left_room(self.distributor, target, room_id) | |
109 | 121 | |
110 | 122 | @defer.inlineCallbacks |
111 | 123 | def remote_join(self, remote_room_hosts, room_id, user, content): |
194 | 206 | remote_room_hosts = [] |
195 | 207 | |
196 | 208 | latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) |
197 | current_state = yield self.state_handler.get_current_state( | |
209 | current_state_ids = yield self.state_handler.get_current_state_ids( | |
198 | 210 | room_id, latest_event_ids=latest_event_ids, |
199 | 211 | ) |
200 | 212 | |
201 | old_state = current_state.get((EventTypes.Member, target.to_string())) | |
202 | old_membership = old_state.content.get("membership") if old_state else None | |
203 | if action == "unban" and old_membership != "ban": | |
204 | raise SynapseError( | |
205 | 403, | |
206 | "Cannot unban user who was not banned (membership=%s)" % old_membership, | |
207 | errcode=Codes.BAD_STATE | |
208 | ) | |
209 | if old_membership == "ban" and action != "unban": | |
210 | raise SynapseError( | |
211 | 403, | |
212 | "Cannot %s user who was banned" % (action,), | |
213 | errcode=Codes.BAD_STATE | |
214 | ) | |
215 | ||
216 | is_host_in_room = self.is_host_in_room(current_state) | |
213 | old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) | |
214 | if old_state_id: | |
215 | old_state = yield self.store.get_event(old_state_id, allow_none=True) | |
216 | old_membership = old_state.content.get("membership") if old_state else None | |
217 | if action == "unban" and old_membership != "ban": | |
218 | raise SynapseError( | |
219 | 403, | |
220 | "Cannot unban user who was not banned" | |
221 | " (membership=%s)" % old_membership, | |
222 | errcode=Codes.BAD_STATE | |
223 | ) | |
224 | if old_membership == "ban" and action != "unban": | |
225 | raise SynapseError( | |
226 | 403, | |
227 | "Cannot %s user who was banned" % (action,), | |
228 | errcode=Codes.BAD_STATE | |
229 | ) | |
230 | ||
231 | is_host_in_room = yield self._is_host_in_room(current_state_ids) | |
217 | 232 | |
218 | 233 | if effective_membership_state == Membership.JOIN: |
219 | if requester.is_guest and not self._can_guest_join(current_state): | |
234 | if requester.is_guest and not self._can_guest_join(current_state_ids): | |
220 | 235 | # This should be an auth check, but guests are a local concept, |
221 | 236 | # so don't really fit into the general auth process. |
222 | 237 | raise AuthError(403, "Guest access not allowed") |
325 | 340 | requester = synapse.types.create_requester(target_user) |
326 | 341 | |
327 | 342 | message_handler = self.hs.get_handlers().message_handler |
328 | prev_event = message_handler.deduplicate_state_event(event, context) | |
343 | prev_event = yield message_handler.deduplicate_state_event(event, context) | |
329 | 344 | if prev_event is not None: |
330 | 345 | return |
331 | 346 | |
332 | 347 | if event.membership == Membership.JOIN: |
333 | if requester.is_guest and not self._can_guest_join(context.current_state): | |
334 | # This should be an auth check, but guests are a local concept, | |
335 | # so don't really fit into the general auth process. | |
336 | raise AuthError(403, "Guest access not allowed") | |
348 | if requester.is_guest: | |
349 | guest_can_join = yield self._can_guest_join(context.prev_state_ids) | |
350 | if not guest_can_join: | |
351 | # This should be an auth check, but guests are a local concept, | |
352 | # so don't really fit into the general auth process. | |
353 | raise AuthError(403, "Guest access not allowed") | |
337 | 354 | |
338 | 355 | yield message_handler.handle_new_client_event( |
339 | 356 | requester, |
343 | 360 | ratelimit=ratelimit, |
344 | 361 | ) |
345 | 362 | |
346 | prev_member_event = context.current_state.get( | |
347 | (EventTypes.Member, target_user.to_string()), | |
363 | prev_member_event_id = context.prev_state_ids.get( | |
364 | (EventTypes.Member, event.state_key), | |
348 | 365 | None |
349 | 366 | ) |
350 | 367 | |
351 | 368 | if event.membership == Membership.JOIN: |
352 | if not prev_member_event or prev_member_event.membership != Membership.JOIN: | |
353 | # Only fire user_joined_room if the user has acutally joined the | |
354 | # room. Don't bother if the user is just changing their profile | |
355 | # info. | |
369 | # Only fire user_joined_room if the user has acutally joined the | |
370 | # room. Don't bother if the user is just changing their profile | |
371 | # info. | |
372 | newly_joined = True | |
373 | if prev_member_event_id: | |
374 | prev_member_event = yield self.store.get_event(prev_member_event_id) | |
375 | newly_joined = prev_member_event.membership != Membership.JOIN | |
376 | if newly_joined: | |
356 | 377 | yield user_joined_room(self.distributor, target_user, room_id) |
357 | 378 | elif event.membership == Membership.LEAVE: |
358 | if prev_member_event and prev_member_event.membership == Membership.JOIN: | |
359 | user_left_room(self.distributor, target_user, room_id) | |
360 | ||
361 | def _can_guest_join(self, current_state): | |
379 | if prev_member_event_id: | |
380 | prev_member_event = yield self.store.get_event(prev_member_event_id) | |
381 | if prev_member_event.membership == Membership.JOIN: | |
382 | user_left_room(self.distributor, target_user, room_id) | |
383 | ||
384 | @defer.inlineCallbacks | |
385 | def _can_guest_join(self, current_state_ids): | |
362 | 386 | """ |
363 | 387 | Returns whether a guest can join a room based on its current state. |
364 | 388 | """ |
365 | guest_access = current_state.get((EventTypes.GuestAccess, ""), None) | |
366 | return ( | |
389 | guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) | |
390 | if not guest_access_id: | |
391 | defer.returnValue(False) | |
392 | ||
393 | guest_access = yield self.store.get_event(guest_access_id) | |
394 | ||
395 | defer.returnValue( | |
367 | 396 | guest_access |
368 | 397 | and guest_access.content |
369 | 398 | and "guest_access" in guest_access.content |
682 | 711 | |
683 | 712 | if membership: |
684 | 713 | yield self.store.forget(user_id, room_id) |
714 | ||
715 | @defer.inlineCallbacks | |
716 | def _is_host_in_room(self, current_state_ids): | |
717 | # Have we just created the room, and is this about to be the very | |
718 | # first member event? | |
719 | create_event_id = current_state_ids.get(("m.room.create", "")) | |
720 | if len(current_state_ids) == 1 and create_event_id: | |
721 | defer.returnValue(self.hs.is_mine_id(create_event_id)) | |
722 | ||
723 | for (etype, state_key), event_id in current_state_ids.items(): | |
724 | if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): | |
725 | continue | |
726 | ||
727 | event = yield self.store.get_event(event_id, allow_none=True) | |
728 | if not event: | |
729 | continue | |
730 | ||
731 | if event.membership == Membership.JOIN: | |
732 | defer.returnValue(True) | |
733 | ||
734 | defer.returnValue(False) |
34 | 34 | "filter_collection", |
35 | 35 | "is_guest", |
36 | 36 | "request_key", |
37 | "device_id", | |
37 | 38 | ]) |
38 | 39 | |
39 | 40 | |
112 | 113 | "joined", # JoinedSyncResult for each joined room. |
113 | 114 | "invited", # InvitedSyncResult for each invited room. |
114 | 115 | "archived", # ArchivedSyncResult for each archived room. |
116 | "to_device", # List of direct messages for the device. | |
115 | 117 | ])): |
116 | 118 | __slots__ = [] |
117 | 119 | |
125 | 127 | self.joined or |
126 | 128 | self.invited or |
127 | 129 | self.archived or |
128 | self.account_data | |
130 | self.account_data or | |
131 | self.to_device | |
129 | 132 | ) |
130 | 133 | |
131 | 134 | |
138 | 141 | self.event_sources = hs.get_event_sources() |
139 | 142 | self.clock = hs.get_clock() |
140 | 143 | self.response_cache = ResponseCache(hs) |
144 | self.state = hs.get_state_handler() | |
141 | 145 | |
142 | 146 | def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, |
143 | 147 | full_state=False): |
354 | 358 | Returns: |
355 | 359 | A Deferred map from ((type, state_key)->Event) |
356 | 360 | """ |
357 | state = yield self.store.get_state_for_event(event.event_id) | |
361 | state_ids = yield self.store.get_state_ids_for_event(event.event_id) | |
358 | 362 | if event.is_state(): |
359 | state = state.copy() | |
360 | state[(event.type, event.state_key)] = event | |
361 | defer.returnValue(state) | |
363 | state_ids = state_ids.copy() | |
364 | state_ids[(event.type, event.state_key)] = event.event_id | |
365 | defer.returnValue(state_ids) | |
362 | 366 | |
363 | 367 | @defer.inlineCallbacks |
364 | 368 | def get_state_at(self, room_id, stream_position): |
411 | 415 | with Measure(self.clock, "compute_state_delta"): |
412 | 416 | if full_state: |
413 | 417 | if batch: |
414 | current_state = yield self.store.get_state_for_event( | |
418 | current_state_ids = yield self.store.get_state_ids_for_event( | |
415 | 419 | batch.events[-1].event_id |
416 | 420 | ) |
417 | 421 | |
418 | state = yield self.store.get_state_for_event( | |
422 | state_ids = yield self.store.get_state_ids_for_event( | |
419 | 423 | batch.events[0].event_id |
420 | 424 | ) |
421 | 425 | else: |
422 | current_state = yield self.get_state_at( | |
426 | current_state_ids = yield self.get_state_at( | |
423 | 427 | room_id, stream_position=now_token |
424 | 428 | ) |
425 | 429 | |
426 | state = current_state | |
430 | state_ids = current_state_ids | |
427 | 431 | |
428 | 432 | timeline_state = { |
429 | (event.type, event.state_key): event | |
433 | (event.type, event.state_key): event.event_id | |
430 | 434 | for event in batch.events if event.is_state() |
431 | 435 | } |
432 | 436 | |
433 | state = _calculate_state( | |
437 | state_ids = _calculate_state( | |
434 | 438 | timeline_contains=timeline_state, |
435 | timeline_start=state, | |
439 | timeline_start=state_ids, | |
436 | 440 | previous={}, |
437 | current=current_state, | |
441 | current=current_state_ids, | |
438 | 442 | ) |
439 | 443 | elif batch.limited: |
440 | 444 | state_at_previous_sync = yield self.get_state_at( |
441 | 445 | room_id, stream_position=since_token |
442 | 446 | ) |
443 | 447 | |
444 | current_state = yield self.store.get_state_for_event( | |
448 | current_state_ids = yield self.store.get_state_ids_for_event( | |
445 | 449 | batch.events[-1].event_id |
446 | 450 | ) |
447 | 451 | |
448 | state_at_timeline_start = yield self.store.get_state_for_event( | |
452 | state_at_timeline_start = yield self.store.get_state_ids_for_event( | |
449 | 453 | batch.events[0].event_id |
450 | 454 | ) |
451 | 455 | |
452 | 456 | timeline_state = { |
453 | (event.type, event.state_key): event | |
457 | (event.type, event.state_key): event.event_id | |
454 | 458 | for event in batch.events if event.is_state() |
455 | 459 | } |
456 | 460 | |
457 | state = _calculate_state( | |
461 | state_ids = _calculate_state( | |
458 | 462 | timeline_contains=timeline_state, |
459 | 463 | timeline_start=state_at_timeline_start, |
460 | 464 | previous=state_at_previous_sync, |
461 | current=current_state, | |
465 | current=current_state_ids, | |
462 | 466 | ) |
463 | 467 | else: |
464 | state = {} | |
468 | state_ids = {} | |
469 | ||
470 | state = {} | |
471 | if state_ids: | |
472 | state = yield self.store.get_events(state_ids.values()) | |
465 | 473 | |
466 | 474 | defer.returnValue({ |
467 | 475 | (e.type, e.state_key): e |
526 | 534 | sync_result_builder, newly_joined_rooms, newly_joined_users |
527 | 535 | ) |
528 | 536 | |
537 | yield self._generate_sync_entry_for_to_device(sync_result_builder) | |
538 | ||
529 | 539 | defer.returnValue(SyncResult( |
530 | 540 | presence=sync_result_builder.presence, |
531 | 541 | account_data=sync_result_builder.account_data, |
532 | 542 | joined=sync_result_builder.joined, |
533 | 543 | invited=sync_result_builder.invited, |
534 | 544 | archived=sync_result_builder.archived, |
545 | to_device=sync_result_builder.to_device, | |
535 | 546 | next_batch=sync_result_builder.now_token, |
536 | 547 | )) |
548 | ||
549 | @defer.inlineCallbacks | |
550 | def _generate_sync_entry_for_to_device(self, sync_result_builder): | |
551 | """Generates the portion of the sync response. Populates | |
552 | `sync_result_builder` with the result. | |
553 | ||
554 | Args: | |
555 | sync_result_builder(SyncResultBuilder) | |
556 | ||
557 | Returns: | |
558 | Deferred(dict): A dictionary containing the per room account data. | |
559 | """ | |
560 | user_id = sync_result_builder.sync_config.user.to_string() | |
561 | device_id = sync_result_builder.sync_config.device_id | |
562 | now_token = sync_result_builder.now_token | |
563 | since_stream_id = 0 | |
564 | if sync_result_builder.since_token is not None: | |
565 | since_stream_id = int(sync_result_builder.since_token.to_device_key) | |
566 | ||
567 | if since_stream_id != int(now_token.to_device_key): | |
568 | # We only delete messages when a new message comes in, but that's | |
569 | # fine so long as we delete them at some point. | |
570 | ||
571 | logger.debug("Deleting messages up to %d", since_stream_id) | |
572 | yield self.store.delete_messages_for_device( | |
573 | user_id, device_id, since_stream_id | |
574 | ) | |
575 | ||
576 | logger.debug("Getting messages up to %d", now_token.to_device_key) | |
577 | messages, stream_id = yield self.store.get_new_messages_for_device( | |
578 | user_id, device_id, since_stream_id, now_token.to_device_key | |
579 | ) | |
580 | logger.debug("Got messages up to %d: %r", stream_id, messages) | |
581 | sync_result_builder.now_token = now_token.copy_and_replace( | |
582 | "to_device_key", stream_id | |
583 | ) | |
584 | sync_result_builder.to_device = messages | |
585 | else: | |
586 | sync_result_builder.to_device = [] | |
537 | 587 | |
538 | 588 | @defer.inlineCallbacks |
539 | 589 | def _generate_sync_entry_for_account_data(self, sync_result_builder): |
625 | 675 | |
626 | 676 | extra_users_ids = set(newly_joined_users) |
627 | 677 | for room_id in newly_joined_rooms: |
628 | users = yield self.store.get_users_in_room(room_id) | |
678 | users = yield self.state.get_current_user_in_room(room_id) | |
629 | 679 | extra_users_ids.update(users) |
630 | 680 | extra_users_ids.discard(user.to_string()) |
631 | 681 | |
765 | 815 | # the last sync (even if we have since left). This is to make sure |
766 | 816 | # we do send down the room, and with full state, where necessary |
767 | 817 | if room_id in joined_room_ids or has_join: |
768 | old_state = yield self.get_state_at(room_id, since_token) | |
769 | old_mem_ev = old_state.get((EventTypes.Member, user_id), None) | |
818 | old_state_ids = yield self.get_state_at(room_id, since_token) | |
819 | old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) | |
820 | old_mem_ev = None | |
821 | if old_mem_ev_id: | |
822 | old_mem_ev = yield self.store.get_event( | |
823 | old_mem_ev_id, allow_none=True | |
824 | ) | |
770 | 825 | if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: |
771 | 826 | newly_joined_rooms.append(room_id) |
772 | 827 | |
1058 | 1113 | Returns: |
1059 | 1114 | dict |
1060 | 1115 | """ |
1061 | event_id_to_state = { | |
1062 | e.event_id: e | |
1063 | for e in itertools.chain( | |
1064 | timeline_contains.values(), | |
1065 | previous.values(), | |
1066 | timeline_start.values(), | |
1067 | current.values(), | |
1116 | event_id_to_key = { | |
1117 | e: key | |
1118 | for key, e in itertools.chain( | |
1119 | timeline_contains.items(), | |
1120 | previous.items(), | |
1121 | timeline_start.items(), | |
1122 | current.items(), | |
1068 | 1123 | ) |
1069 | 1124 | } |
1070 | 1125 | |
1071 | c_ids = set(e.event_id for e in current.values()) | |
1072 | tc_ids = set(e.event_id for e in timeline_contains.values()) | |
1073 | p_ids = set(e.event_id for e in previous.values()) | |
1074 | ts_ids = set(e.event_id for e in timeline_start.values()) | |
1126 | c_ids = set(e for e in current.values()) | |
1127 | tc_ids = set(e for e in timeline_contains.values()) | |
1128 | p_ids = set(e for e in previous.values()) | |
1129 | ts_ids = set(e for e in timeline_start.values()) | |
1075 | 1130 | |
1076 | 1131 | state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids |
1077 | 1132 | |
1078 | evs = (event_id_to_state[e] for e in state_ids) | |
1079 | 1133 | return { |
1080 | (e.type, e.state_key): e | |
1081 | for e in evs | |
1134 | event_id_to_key[e]: e for e in state_ids | |
1082 | 1135 | } |
1083 | 1136 | |
1084 | 1137 | |
1102 | 1155 | self.joined = [] |
1103 | 1156 | self.invited = [] |
1104 | 1157 | self.archived = [] |
1158 | self.device = [] | |
1105 | 1159 | |
1106 | 1160 | |
1107 | 1161 | class RoomSyncResultBuilder(object): |
19 | 19 | PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, |
20 | 20 | ) |
21 | 21 | from synapse.util.metrics import Measure |
22 | from synapse.types import UserID | |
22 | from synapse.types import UserID, get_domain_from_id | |
23 | 23 | |
24 | 24 | import logging |
25 | 25 | |
41 | 41 | self.auth = hs.get_auth() |
42 | 42 | self.is_mine_id = hs.is_mine_id |
43 | 43 | self.notifier = hs.get_notifier() |
44 | self.state = hs.get_state_handler() | |
44 | 45 | |
45 | 46 | self.clock = hs.get_clock() |
46 | 47 | |
165 | 166 | |
166 | 167 | @defer.inlineCallbacks |
167 | 168 | def _push_update(self, room_id, user_id, typing): |
168 | domains = yield self.store.get_joined_hosts_for_room(room_id) | |
169 | users = yield self.state.get_current_user_in_room(room_id) | |
170 | domains = set(get_domain_from_id(u) for u in users) | |
169 | 171 | |
170 | 172 | deferreds = [] |
171 | 173 | for domain in domains: |
198 | 200 | # Check that the string is a valid user id |
199 | 201 | UserID.from_string(user_id) |
200 | 202 | |
201 | domains = yield self.store.get_joined_hosts_for_room(room_id) | |
203 | users = yield self.state.get_current_user_in_room(room_id) | |
204 | domains = set(get_domain_from_id(u) for u in users) | |
202 | 205 | |
203 | 206 | if self.server_name in domains: |
204 | 207 | self._push_update_local( |
422 | 422 | def _is_world_readable(self, room_id): |
423 | 423 | state = yield self.state_handler.get_current_state( |
424 | 424 | room_id, |
425 | EventTypes.RoomHistoryVisibility | |
425 | EventTypes.RoomHistoryVisibility, | |
426 | "", | |
426 | 427 | ) |
427 | 428 | if state and "history_visibility" in state.content: |
428 | 429 | defer.returnValue(state.content["history_visibility"] == "world_readable") |
39 | 39 | def handle_push_actions_for_event(self, event, context): |
40 | 40 | with Measure(self.clock, "evaluator_for_event"): |
41 | 41 | bulk_evaluator = yield evaluator_for_event( |
42 | event, self.hs, self.store, context.state_group, context.current_state | |
42 | event, self.hs, self.store, context | |
43 | 43 | ) |
44 | 44 | |
45 | 45 | with Measure(self.clock, "action_for_event_by_user"): |
46 | 46 | actions_by_user = yield bulk_evaluator.action_for_event_by_user( |
47 | event, context.current_state | |
47 | event, context | |
48 | 48 | ) |
49 | 49 | |
50 | 50 | context.push_actions = [ |
18 | 18 | |
19 | 19 | from .push_rule_evaluator import PushRuleEvaluatorForEvent |
20 | 20 | |
21 | from synapse.api.constants import EventTypes, Membership | |
22 | from synapse.visibility import filter_events_for_clients | |
21 | from synapse.api.constants import EventTypes | |
22 | from synapse.visibility import filter_events_for_clients_context | |
23 | 23 | |
24 | 24 | |
25 | 25 | logger = logging.getLogger(__name__) |
35 | 35 | |
36 | 36 | |
37 | 37 | @defer.inlineCallbacks |
38 | def evaluator_for_event(event, hs, store, state_group, current_state): | |
38 | def evaluator_for_event(event, hs, store, context): | |
39 | 39 | rules_by_user = yield store.bulk_get_push_rules_for_room( |
40 | event.room_id, state_group, current_state | |
40 | event, context | |
41 | 41 | ) |
42 | 42 | |
43 | 43 | # if this event is an invite event, we may need to run rules for the user |
71 | 71 | self.store = store |
72 | 72 | |
73 | 73 | @defer.inlineCallbacks |
74 | def action_for_event_by_user(self, event, current_state): | |
74 | def action_for_event_by_user(self, event, context): | |
75 | 75 | actions_by_user = {} |
76 | 76 | |
77 | 77 | # None of these users can be peeking since this list of users comes |
81 | 81 | (u, False) for u in self.rules_by_user.keys() |
82 | 82 | ] |
83 | 83 | |
84 | filtered_by_user = yield filter_events_for_clients( | |
85 | self.store, user_tuples, [event], {event.event_id: current_state} | |
84 | filtered_by_user = yield filter_events_for_clients_context( | |
85 | self.store, user_tuples, [event], {event.event_id: context} | |
86 | 86 | ) |
87 | 87 | |
88 | room_members = set( | |
89 | e.state_key for e in current_state.values() | |
90 | if e.type == EventTypes.Member and e.membership == Membership.JOIN | |
88 | room_members = yield self.store.get_joined_users_from_context( | |
89 | event, context | |
91 | 90 | ) |
92 | 91 | |
93 | 92 | evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) |
94 | 93 | |
95 | 94 | condition_cache = {} |
96 | 95 | |
97 | display_names = {} | |
98 | for ev in current_state.values(): | |
99 | nm = ev.content.get("displayname", None) | |
100 | if nm and ev.type == EventTypes.Member: | |
101 | display_names[ev.state_key] = nm | |
102 | ||
103 | 96 | for uid, rules in self.rules_by_user.items(): |
104 | display_name = display_names.get(uid, None) | |
97 | display_name = None | |
98 | member_ev_id = context.current_state_ids.get((EventTypes.Member, uid)) | |
99 | if member_ev_id: | |
100 | member_ev = yield self.store.get_event(member_ev_id, allow_none=True) | |
101 | if member_ev: | |
102 | display_name = member_ev.content.get("displayname", None) | |
105 | 103 | |
106 | 104 | filtered = filtered_by_user[uid] |
107 | 105 | if len(filtered) == 0: |
244 | 244 | @defer.inlineCallbacks |
245 | 245 | def _build_notification_dict(self, event, tweaks, badge): |
246 | 246 | ctx = yield push_tools.get_context_for_event( |
247 | self.state_handler, event, self.user_id | |
247 | self.store, self.state_handler, event, self.user_id | |
248 | 248 | ) |
249 | 249 | |
250 | 250 | d = { |
21 | 21 | from email.mime.multipart import MIMEMultipart |
22 | 22 | |
23 | 23 | from synapse.util.async import concurrently_execute |
24 | from synapse.util.presentable_names import ( | |
24 | from synapse.push.presentable_names import ( | |
25 | 25 | calculate_room_name, name_from_member_event, descriptor_from_member_events |
26 | 26 | ) |
27 | 27 | from synapse.types import UserID |
138 | 138 | |
139 | 139 | @defer.inlineCallbacks |
140 | 140 | def _fetch_room_state(room_id): |
141 | room_state = yield self.state_handler.get_current_state(room_id) | |
141 | room_state = yield self.state_handler.get_current_state_ids(room_id) | |
142 | 142 | state_by_room[room_id] = room_state |
143 | 143 | |
144 | 144 | # Run at most 3 of these at once: sync does 10 at a time but email |
158 | 158 | ) |
159 | 159 | rooms.append(roomvars) |
160 | 160 | |
161 | reason['room_name'] = calculate_room_name( | |
162 | state_by_room[reason['room_id']], user_id, fallback_to_members=True | |
163 | ) | |
164 | ||
165 | summary_text = self.make_summary_text( | |
161 | reason['room_name'] = yield calculate_room_name( | |
162 | self.store, state_by_room[reason['room_id']], user_id, | |
163 | fallback_to_members=True | |
164 | ) | |
165 | ||
166 | summary_text = yield self.make_summary_text( | |
166 | 167 | notifs_by_room, state_by_room, notif_events, user_id, reason |
167 | 168 | ) |
168 | 169 | |
202 | 203 | ) |
203 | 204 | |
204 | 205 | @defer.inlineCallbacks |
205 | def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state): | |
206 | my_member_event = room_state[("m.room.member", user_id)] | |
206 | def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): | |
207 | my_member_event_id = room_state_ids[("m.room.member", user_id)] | |
208 | my_member_event = yield self.store.get_event(my_member_event_id) | |
207 | 209 | is_invite = my_member_event.content["membership"] == "invite" |
208 | 210 | |
211 | room_name = yield calculate_room_name(self.store, room_state_ids, user_id) | |
212 | ||
209 | 213 | room_vars = { |
210 | "title": calculate_room_name(room_state, user_id), | |
214 | "title": room_name, | |
211 | 215 | "hash": string_ordinal_total(room_id), # See sender avatar hash |
212 | 216 | "notifs": [], |
213 | 217 | "invite": is_invite, |
217 | 221 | if not is_invite: |
218 | 222 | for n in notifs: |
219 | 223 | notifvars = yield self.get_notif_vars( |
220 | n, user_id, notif_events[n['event_id']], room_state | |
224 | n, user_id, notif_events[n['event_id']], room_state_ids | |
221 | 225 | ) |
222 | 226 | |
223 | 227 | # merge overlapping notifs together. |
242 | 246 | defer.returnValue(room_vars) |
243 | 247 | |
244 | 248 | @defer.inlineCallbacks |
245 | def get_notif_vars(self, notif, user_id, notif_event, room_state): | |
249 | def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): | |
246 | 250 | results = yield self.store.get_events_around( |
247 | 251 | notif['room_id'], notif['event_id'], |
248 | 252 | before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER |
260 | 264 | the_events.append(notif_event) |
261 | 265 | |
262 | 266 | for event in the_events: |
263 | messagevars = self.get_message_vars(notif, event, room_state) | |
267 | messagevars = yield self.get_message_vars(notif, event, room_state_ids) | |
264 | 268 | if messagevars is not None: |
265 | 269 | ret['messages'].append(messagevars) |
266 | 270 | |
267 | 271 | defer.returnValue(ret) |
268 | 272 | |
269 | def get_message_vars(self, notif, event, room_state): | |
273 | @defer.inlineCallbacks | |
274 | def get_message_vars(self, notif, event, room_state_ids): | |
270 | 275 | if event.type != EventTypes.Message: |
271 | return None | |
272 | ||
273 | sender_state_event = room_state[("m.room.member", event.sender)] | |
276 | return | |
277 | ||
278 | sender_state_event_id = room_state_ids[("m.room.member", event.sender)] | |
279 | sender_state_event = yield self.store.get_event(sender_state_event_id) | |
274 | 280 | sender_name = name_from_member_event(sender_state_event) |
275 | 281 | sender_avatar_url = sender_state_event.content.get("avatar_url") |
276 | 282 | |
298 | 304 | if "body" in event.content: |
299 | 305 | ret["body_text_plain"] = event.content["body"] |
300 | 306 | |
301 | return ret | |
307 | defer.returnValue(ret) | |
302 | 308 | |
303 | 309 | def add_text_message_vars(self, messagevars, event): |
304 | 310 | msgformat = event.content.get("format") |
320 | 326 | |
321 | 327 | return messagevars |
322 | 328 | |
329 | @defer.inlineCallbacks | |
323 | 330 | def make_summary_text(self, notifs_by_room, state_by_room, |
324 | 331 | notif_events, user_id, reason): |
325 | 332 | if len(notifs_by_room) == 1: |
329 | 336 | # If the room has some kind of name, use it, but we don't |
330 | 337 | # want the generated-from-names one here otherwise we'll |
331 | 338 | # end up with, "new message from Bob in the Bob room" |
332 | room_name = calculate_room_name( | |
333 | state_by_room[room_id], user_id, fallback_to_members=False | |
339 | room_name = yield calculate_room_name( | |
340 | self.store, state_by_room[room_id], user_id, fallback_to_members=False | |
334 | 341 | ) |
335 | 342 | |
336 | 343 | my_member_event = state_by_room[room_id][("m.room.member", user_id)] |
341 | 348 | inviter_name = name_from_member_event(inviter_member_event) |
342 | 349 | |
343 | 350 | if room_name is None: |
344 | return INVITE_FROM_PERSON % { | |
351 | defer.returnValue(INVITE_FROM_PERSON % { | |
345 | 352 | "person": inviter_name, |
346 | 353 | "app": self.app_name |
347 | } | |
354 | }) | |
348 | 355 | else: |
349 | return INVITE_FROM_PERSON_TO_ROOM % { | |
356 | defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % { | |
350 | 357 | "person": inviter_name, |
351 | 358 | "room": room_name, |
352 | 359 | "app": self.app_name, |
353 | } | |
360 | }) | |
354 | 361 | |
355 | 362 | sender_name = None |
356 | 363 | if len(notifs_by_room[room_id]) == 1: |
361 | 368 | sender_name = name_from_member_event(state_event) |
362 | 369 | |
363 | 370 | if sender_name is not None and room_name is not None: |
364 | return MESSAGE_FROM_PERSON_IN_ROOM % { | |
371 | defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % { | |
365 | 372 | "person": sender_name, |
366 | 373 | "room": room_name, |
367 | 374 | "app": self.app_name, |
368 | } | |
375 | }) | |
369 | 376 | elif sender_name is not None: |
370 | return MESSAGE_FROM_PERSON % { | |
377 | defer.returnValue(MESSAGE_FROM_PERSON % { | |
371 | 378 | "person": sender_name, |
372 | 379 | "app": self.app_name, |
373 | } | |
380 | }) | |
374 | 381 | else: |
375 | 382 | # There's more than one notification for this room, so just |
376 | 383 | # say there are several |
377 | 384 | if room_name is not None: |
378 | return MESSAGES_IN_ROOM % { | |
385 | defer.returnValue(MESSAGES_IN_ROOM % { | |
379 | 386 | "room": room_name, |
380 | 387 | "app": self.app_name, |
381 | } | |
388 | }) | |
382 | 389 | else: |
383 | 390 | # If the room doesn't have a name, say who the messages |
384 | 391 | # are from explicitly to avoid, "messages in the Bob room" |
387 | 394 | for n in notifs_by_room[room_id] |
388 | 395 | ])) |
389 | 396 | |
390 | return MESSAGES_FROM_PERSON % { | |
397 | defer.returnValue(MESSAGES_FROM_PERSON % { | |
391 | 398 | "person": descriptor_from_member_events([ |
392 | 399 | state_by_room[room_id][("m.room.member", s)] |
393 | 400 | for s in sender_ids |
394 | 401 | ]), |
395 | 402 | "app": self.app_name, |
396 | } | |
403 | }) | |
397 | 404 | else: |
398 | 405 | # Stuff's happened in multiple different rooms |
399 | 406 | |
400 | 407 | # ...but we still refer to the 'reason' room which triggered the mail |
401 | 408 | if reason['room_name'] is not None: |
402 | return MESSAGES_IN_ROOM_AND_OTHERS % { | |
409 | defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % { | |
403 | 410 | "room": reason['room_name'], |
404 | 411 | "app": self.app_name, |
405 | } | |
412 | }) | |
406 | 413 | else: |
407 | 414 | # If the reason room doesn't have a name, say who the messages |
408 | 415 | # are from explicitly to avoid, "messages in the Bob room" |
411 | 418 | for n in notifs_by_room[reason['room_id']] |
412 | 419 | ])) |
413 | 420 | |
414 | return MESSAGES_FROM_PERSON_AND_OTHERS % { | |
421 | defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { | |
415 | 422 | "person": descriptor_from_member_events([ |
416 | 423 | state_by_room[reason['room_id']][("m.room.member", s)] |
417 | 424 | for s in sender_ids |
418 | 425 | ]), |
419 | 426 | "app": self.app_name, |
420 | } | |
427 | }) | |
421 | 428 | |
422 | 429 | def make_room_link(self, room_id): |
423 | 430 | # need /beta for Universal Links to work on iOS |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from twisted.internet import defer | |
16 | ||
17 | import re | |
18 | import logging | |
19 | ||
20 | logger = logging.getLogger(__name__) | |
21 | ||
22 | # intentionally looser than what aliases we allow to be registered since | |
23 | # other HSes may allow aliases that we would not | |
24 | ALIAS_RE = re.compile(r"^#.*:.+$") | |
25 | ||
26 | ALL_ALONE = "Empty Room" | |
27 | ||
28 | ||
29 | @defer.inlineCallbacks | |
30 | def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True, | |
31 | fallback_to_single_member=True): | |
32 | """ | |
33 | Works out a user-facing name for the given room as per Matrix | |
34 | spec recommendations. | |
35 | Does not yet support internationalisation. | |
36 | Args: | |
37 | room_state: Dictionary of the room's state | |
38 | user_id: The ID of the user to whom the room name is being presented | |
39 | fallback_to_members: If False, return None instead of generating a name | |
40 | based on the room's members if the room has no | |
41 | title or aliases. | |
42 | ||
43 | Returns: | |
44 | (string or None) A human readable name for the room. | |
45 | """ | |
46 | # does it have a name? | |
47 | if ("m.room.name", "") in room_state_ids: | |
48 | m_room_name = yield store.get_event( | |
49 | room_state_ids[("m.room.name", "")], allow_none=True | |
50 | ) | |
51 | if m_room_name and m_room_name.content and m_room_name.content["name"]: | |
52 | defer.returnValue(m_room_name.content["name"]) | |
53 | ||
54 | # does it have a canonical alias? | |
55 | if ("m.room.canonical_alias", "") in room_state_ids: | |
56 | canon_alias = yield store.get_event( | |
57 | room_state_ids[("m.room.canonical_alias", "")], allow_none=True | |
58 | ) | |
59 | if ( | |
60 | canon_alias and canon_alias.content and canon_alias.content["alias"] and | |
61 | _looks_like_an_alias(canon_alias.content["alias"]) | |
62 | ): | |
63 | defer.returnValue(canon_alias.content["alias"]) | |
64 | ||
65 | # at this point we're going to need to search the state by all state keys | |
66 | # for an event type, so rearrange the data structure | |
67 | room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) | |
68 | ||
69 | # right then, any aliases at all? | |
70 | if "m.room.aliases" in room_state_bytype_ids: | |
71 | m_room_aliases = room_state_bytype_ids["m.room.aliases"] | |
72 | for alias_id in m_room_aliases.values(): | |
73 | alias_event = yield store.get_event( | |
74 | alias_id, allow_none=True | |
75 | ) | |
76 | if alias_event and alias_event.content.get("aliases"): | |
77 | the_aliases = alias_event.content["aliases"] | |
78 | if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): | |
79 | defer.returnValue(the_aliases[0]) | |
80 | ||
81 | if not fallback_to_members: | |
82 | defer.returnValue(None) | |
83 | ||
84 | my_member_event = None | |
85 | if ("m.room.member", user_id) in room_state_ids: | |
86 | my_member_event = yield store.get_event( | |
87 | room_state_ids[("m.room.member", user_id)], allow_none=True | |
88 | ) | |
89 | ||
90 | if ( | |
91 | my_member_event is not None and | |
92 | my_member_event.content['membership'] == "invite" | |
93 | ): | |
94 | if ("m.room.member", my_member_event.sender) in room_state_ids: | |
95 | inviter_member_event = yield store.get_event( | |
96 | room_state_ids[("m.room.member", my_member_event.sender)], | |
97 | allow_none=True, | |
98 | ) | |
99 | if inviter_member_event: | |
100 | if fallback_to_single_member: | |
101 | defer.returnValue( | |
102 | "Invite from %s" % ( | |
103 | name_from_member_event(inviter_member_event), | |
104 | ) | |
105 | ) | |
106 | else: | |
107 | return | |
108 | else: | |
109 | defer.returnValue("Room Invite") | |
110 | ||
111 | # we're going to have to generate a name based on who's in the room, | |
112 | # so find out who is in the room that isn't the user. | |
113 | if "m.room.member" in room_state_bytype_ids: | |
114 | member_events = yield store.get_events( | |
115 | room_state_bytype_ids["m.room.member"].values() | |
116 | ) | |
117 | all_members = [ | |
118 | ev for ev in member_events.values() | |
119 | if ev.content['membership'] == "join" or ev.content['membership'] == "invite" | |
120 | ] | |
121 | # Sort the member events oldest-first so the we name people in the | |
122 | # order the joined (it should at least be deterministic rather than | |
123 | # dictionary iteration order) | |
124 | all_members.sort(key=lambda e: e.origin_server_ts) | |
125 | other_members = [m for m in all_members if m.state_key != user_id] | |
126 | else: | |
127 | other_members = [] | |
128 | all_members = [] | |
129 | ||
130 | if len(other_members) == 0: | |
131 | if len(all_members) == 1: | |
132 | # self-chat, peeked room with 1 participant, | |
133 | # or inbound invite, or outbound 3PID invite. | |
134 | if all_members[0].sender == user_id: | |
135 | if "m.room.third_party_invite" in room_state_bytype_ids: | |
136 | third_party_invites = ( | |
137 | room_state_bytype_ids["m.room.third_party_invite"].values() | |
138 | ) | |
139 | ||
140 | if len(third_party_invites) > 0: | |
141 | # technically third party invite events are not member | |
142 | # events, but they are close enough | |
143 | ||
144 | # FIXME: no they're not - they look nothing like a member; | |
145 | # they have a great big encrypted thing as their name to | |
146 | # prevent leaking the 3PID name... | |
147 | # return "Inviting %s" % ( | |
148 | # descriptor_from_member_events(third_party_invites) | |
149 | # ) | |
150 | defer.returnValue("Inviting email address") | |
151 | else: | |
152 | defer.returnValue(ALL_ALONE) | |
153 | else: | |
154 | defer.returnValue(name_from_member_event(all_members[0])) | |
155 | else: | |
156 | defer.returnValue(ALL_ALONE) | |
157 | elif len(other_members) == 1 and not fallback_to_single_member: | |
158 | return | |
159 | else: | |
160 | defer.returnValue(descriptor_from_member_events(other_members)) | |
161 | ||
162 | ||
163 | def descriptor_from_member_events(member_events): | |
164 | if len(member_events) == 0: | |
165 | return "nobody" | |
166 | elif len(member_events) == 1: | |
167 | return name_from_member_event(member_events[0]) | |
168 | elif len(member_events) == 2: | |
169 | return "%s and %s" % ( | |
170 | name_from_member_event(member_events[0]), | |
171 | name_from_member_event(member_events[1]), | |
172 | ) | |
173 | else: | |
174 | return "%s and %d others" % ( | |
175 | name_from_member_event(member_events[0]), | |
176 | len(member_events) - 1, | |
177 | ) | |
178 | ||
179 | ||
180 | def name_from_member_event(member_event): | |
181 | if ( | |
182 | member_event.content and "displayname" in member_event.content and | |
183 | member_event.content["displayname"] | |
184 | ): | |
185 | return member_event.content["displayname"] | |
186 | return member_event.state_key | |
187 | ||
188 | ||
189 | def _state_as_two_level_dict(state): | |
190 | ret = {} | |
191 | for k, v in state.items(): | |
192 | ret.setdefault(k[0], {})[k[1]] = v | |
193 | return ret | |
194 | ||
195 | ||
196 | def _looks_like_an_alias(string): | |
197 | return ALIAS_RE.match(string) is not None |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | from twisted.internet import defer |
16 | from synapse.util.presentable_names import ( | |
16 | from synapse.push.presentable_names import ( | |
17 | 17 | calculate_room_name, name_from_member_event |
18 | 18 | ) |
19 | 19 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred |
48 | 48 | |
49 | 49 | |
50 | 50 | @defer.inlineCallbacks |
51 | def get_context_for_event(state_handler, ev, user_id): | |
51 | def get_context_for_event(store, state_handler, ev, user_id): | |
52 | 52 | ctx = {} |
53 | 53 | |
54 | room_state = yield state_handler.get_current_state(ev.room_id) | |
54 | room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) | |
55 | 55 | |
56 | 56 | # we no longer bother setting room_alias, and make room_name the |
57 | 57 | # human-readable name instead, be that m.room.name, an alias or |
58 | 58 | # a list of people in the room |
59 | name = calculate_room_name( | |
60 | room_state, user_id, fallback_to_single_member=False | |
59 | name = yield calculate_room_name( | |
60 | store, room_state_ids, user_id, fallback_to_single_member=False | |
61 | 61 | ) |
62 | 62 | if name: |
63 | 63 | ctx['name'] = name |
64 | 64 | |
65 | sender_state_event = room_state[("m.room.member", ev.sender)] | |
65 | sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] | |
66 | sender_state_event = yield store.get_event(sender_state_event_id) | |
66 | 67 | ctx['sender_display_name'] = name_from_member_event(sender_state_event) |
67 | 68 | |
68 | 69 | defer.returnValue(ctx) |
39 | 39 | ("backfill",), |
40 | 40 | ("push_rules",), |
41 | 41 | ("pushers",), |
42 | ("state",), | |
43 | 42 | ("caches",), |
43 | ("to_device",), | |
44 | 44 | ) |
45 | 45 | |
46 | 46 | |
129 | 129 | backfill_token = yield self.store.get_current_backfill_token() |
130 | 130 | push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() |
131 | 131 | pushers_token = self.store.get_pushers_stream_token() |
132 | state_token = self.store.get_state_stream_token() | |
133 | 132 | caches_token = self.store.get_cache_stream_token() |
134 | 133 | |
135 | 134 | defer.returnValue(_ReplicationToken( |
141 | 140 | backfill_token, |
142 | 141 | push_rules_token, |
143 | 142 | pushers_token, |
144 | state_token, | |
143 | 0, # State stream is no longer a thing | |
145 | 144 | caches_token, |
145 | int(stream_token.to_device_key), | |
146 | 146 | )) |
147 | 147 | |
148 | 148 | @request_handler() |
190 | 190 | yield self.receipts(writer, current_token, limit, request_streams) |
191 | 191 | yield self.push_rules(writer, current_token, limit, request_streams) |
192 | 192 | yield self.pushers(writer, current_token, limit, request_streams) |
193 | yield self.state(writer, current_token, limit, request_streams) | |
194 | 193 | yield self.caches(writer, current_token, limit, request_streams) |
194 | yield self.to_device(writer, current_token, limit, request_streams) | |
195 | 195 | self.streams(writer, current_token, request_streams) |
196 | 196 | |
197 | 197 | logger.info("Replicated %d rows", writer.total) |
365 | 365 | )) |
366 | 366 | |
367 | 367 | @defer.inlineCallbacks |
368 | def state(self, writer, current_token, limit, request_streams): | |
369 | current_position = current_token.state | |
370 | ||
371 | state = request_streams.get("state") | |
372 | ||
373 | if state is not None: | |
374 | state_groups, state_group_state = ( | |
375 | yield self.store.get_all_new_state_groups( | |
376 | state, current_position, limit | |
377 | ) | |
378 | ) | |
379 | writer.write_header_and_rows("state_groups", state_groups, ( | |
380 | "position", "room_id", "event_id" | |
381 | )) | |
382 | writer.write_header_and_rows("state_group_state", state_group_state, ( | |
383 | "position", "type", "state_key", "event_id" | |
384 | )) | |
385 | ||
386 | @defer.inlineCallbacks | |
387 | 368 | def caches(self, writer, current_token, limit, request_streams): |
388 | 369 | current_position = current_token.caches |
389 | 370 | |
395 | 376 | ) |
396 | 377 | writer.write_header_and_rows("caches", updated_caches, ( |
397 | 378 | "position", "cache_func", "keys", "invalidation_ts" |
379 | )) | |
380 | ||
381 | @defer.inlineCallbacks | |
382 | def to_device(self, writer, current_token, limit, request_streams): | |
383 | current_position = current_token.to_device | |
384 | ||
385 | to_device = request_streams.get("to_device") | |
386 | ||
387 | if to_device is not None: | |
388 | to_device_rows = yield self.store.get_all_new_device_messages( | |
389 | to_device, current_position, limit | |
390 | ) | |
391 | writer.write_header_and_rows("to_device", to_device_rows, ( | |
392 | "position", "user_id", "device_id", "message_json" | |
398 | 393 | )) |
399 | 394 | |
400 | 395 | |
425 | 420 | |
426 | 421 | class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( |
427 | 422 | "events", "presence", "typing", "receipts", "account_data", "backfill", |
428 | "push_rules", "pushers", "state", "caches", | |
423 | "push_rules", "pushers", "state", "caches", "to_device", | |
429 | 424 | ))): |
430 | 425 | __slots__ = [] |
431 | 426 |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from ._base import BaseSlavedStore | |
16 | from ._slaved_id_tracker import SlavedIdTracker | |
17 | from synapse.storage import DataStore | |
18 | ||
19 | ||
20 | class SlavedDeviceInboxStore(BaseSlavedStore): | |
21 | def __init__(self, db_conn, hs): | |
22 | super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) | |
23 | self._device_inbox_id_gen = SlavedIdTracker( | |
24 | db_conn, "device_inbox", "stream_id", | |
25 | ) | |
26 | ||
27 | get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ | |
28 | get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ | |
29 | delete_messages_for_device = DataStore.delete_messages_for_device.__func__ | |
30 | ||
31 | def stream_positions(self): | |
32 | result = super(SlavedDeviceInboxStore, self).stream_positions() | |
33 | result["to_device"] = self._device_inbox_id_gen.get_current_token() | |
34 | return result | |
35 | ||
36 | def process_replication(self, result): | |
37 | stream = result.get("to_device") | |
38 | if stream: | |
39 | self._device_inbox_id_gen.advance(int(stream["position"])) | |
40 | ||
41 | return super(SlavedDeviceInboxStore, self).process_replication(result) |
119 | 119 | get_state_for_event = DataStore.get_state_for_event.__func__ |
120 | 120 | get_state_for_events = DataStore.get_state_for_events.__func__ |
121 | 121 | get_state_groups = DataStore.get_state_groups.__func__ |
122 | get_state_groups_ids = DataStore.get_state_groups_ids.__func__ | |
123 | get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ | |
124 | get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ | |
125 | get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ | |
126 | get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ | |
127 | _get_joined_users_from_context = ( | |
128 | RoomMemberStore.__dict__["_get_joined_users_from_context"] | |
129 | ) | |
130 | ||
122 | 131 | get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ |
123 | 132 | get_room_events_stream_for_rooms = ( |
124 | 133 | DataStore.get_room_events_stream_for_rooms.__func__ |
125 | 134 | ) |
135 | is_host_joined = DataStore.is_host_joined.__func__ | |
136 | _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"] | |
126 | 137 | get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ |
127 | 138 | |
128 | 139 | _set_before_and_after = staticmethod(DataStore._set_before_and_after) |
210 | 221 | self._get_current_state_for_key.invalidate_all() |
211 | 222 | self.get_rooms_for_user.invalidate_all() |
212 | 223 | self.get_users_in_room.invalidate((event.room_id,)) |
213 | # self.get_joined_hosts_for_room.invalidate((event.room_id,)) | |
214 | 224 | |
215 | 225 | self._invalidate_get_event_cache(event.event_id) |
216 | 226 | |
234 | 244 | |
235 | 245 | if event.type == EventTypes.Member: |
236 | 246 | self.get_rooms_for_user.invalidate((event.state_key,)) |
237 | # self.get_joined_hosts_for_room.invalidate((event.room_id,)) | |
238 | 247 | self.get_users_in_room.invalidate((event.room_id,)) |
239 | 248 | self._membership_stream_cache.entity_has_changed( |
240 | 249 | event.state_key, event.internal_metadata.stream_ordering |
48 | 48 | notifications, |
49 | 49 | devices, |
50 | 50 | thirdparty, |
51 | sendtodevice, | |
51 | 52 | ) |
52 | 53 | |
53 | 54 | from synapse.http.server import JsonResource |
95 | 96 | notifications.register_servlets(hs, client_resource) |
96 | 97 | devices.register_servlets(hs, client_resource) |
97 | 98 | thirdparty.register_servlets(hs, client_resource) |
99 | sendtodevice.register_servlets(hs, client_resource) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import logging | |
16 | ||
17 | from twisted.internet import defer | |
18 | from synapse.http.servlet import parse_json_object_from_request | |
19 | ||
20 | from synapse.http import servlet | |
21 | from synapse.rest.client.v1.transactions import HttpTransactionStore | |
22 | from ._base import client_v2_patterns | |
23 | ||
24 | logger = logging.getLogger(__name__) | |
25 | ||
26 | ||
27 | class SendToDeviceRestServlet(servlet.RestServlet): | |
28 | PATTERNS = client_v2_patterns( | |
29 | "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", | |
30 | releases=[], v2_alpha=False | |
31 | ) | |
32 | ||
33 | def __init__(self, hs): | |
34 | """ | |
35 | Args: | |
36 | hs (synapse.server.HomeServer): server | |
37 | """ | |
38 | super(SendToDeviceRestServlet, self).__init__() | |
39 | self.hs = hs | |
40 | self.auth = hs.get_auth() | |
41 | self.store = hs.get_datastore() | |
42 | self.notifier = hs.get_notifier() | |
43 | self.is_mine_id = hs.is_mine_id | |
44 | self.txns = HttpTransactionStore() | |
45 | ||
46 | @defer.inlineCallbacks | |
47 | def on_PUT(self, request, message_type, txn_id): | |
48 | try: | |
49 | defer.returnValue( | |
50 | self.txns.get_client_transaction(request, txn_id) | |
51 | ) | |
52 | except KeyError: | |
53 | pass | |
54 | ||
55 | requester = yield self.auth.get_user_by_req(request) | |
56 | ||
57 | content = parse_json_object_from_request(request) | |
58 | ||
59 | # TODO: Prod the notifier to wake up sync streams. | |
60 | # TODO: Implement replication for the messages. | |
61 | # TODO: Send the messages to remote servers if needed. | |
62 | ||
63 | local_messages = {} | |
64 | for user_id, by_device in content["messages"].items(): | |
65 | if self.is_mine_id(user_id): | |
66 | messages_by_device = { | |
67 | device_id: { | |
68 | "content": message_content, | |
69 | "type": message_type, | |
70 | "sender": requester.user.to_string(), | |
71 | } | |
72 | for device_id, message_content in by_device.items() | |
73 | } | |
74 | if messages_by_device: | |
75 | local_messages[user_id] = messages_by_device | |
76 | ||
77 | stream_id = yield self.store.add_messages_to_device_inbox(local_messages) | |
78 | ||
79 | self.notifier.on_new_event( | |
80 | "to_device_key", stream_id, users=local_messages.keys() | |
81 | ) | |
82 | ||
83 | response = (200, {}) | |
84 | self.txns.store_client_transaction(request, txn_id, response) | |
85 | defer.returnValue(response) | |
86 | ||
87 | ||
88 | def register_servlets(hs, http_server): | |
89 | SendToDeviceRestServlet(hs).register(http_server) |
96 | 96 | request, allow_guest=True |
97 | 97 | ) |
98 | 98 | user = requester.user |
99 | device_id = requester.device_id | |
99 | 100 | |
100 | 101 | timeout = parse_integer(request, "timeout", default=0) |
101 | 102 | since = parse_string(request, "since") |
108 | 109 | |
109 | 110 | logger.info( |
110 | 111 | "/sync: user=%r, timeout=%r, since=%r," |
111 | " set_presence=%r, filter_id=%r" % ( | |
112 | user, timeout, since, set_presence, filter_id | |
113 | ) | |
114 | ) | |
115 | ||
116 | request_key = (user, timeout, since, filter_id, full_state) | |
112 | " set_presence=%r, filter_id=%r, device_id=%r" % ( | |
113 | user, timeout, since, set_presence, filter_id, device_id | |
114 | ) | |
115 | ) | |
116 | ||
117 | request_key = (user, timeout, since, filter_id, full_state, device_id) | |
117 | 118 | |
118 | 119 | if filter_id: |
119 | 120 | if filter_id.startswith('{'): |
135 | 136 | filter_collection=filter, |
136 | 137 | is_guest=requester.is_guest, |
137 | 138 | request_key=request_key, |
139 | device_id=device_id, | |
138 | 140 | ) |
139 | 141 | |
140 | 142 | if since is not None: |
172 | 174 | |
173 | 175 | response_content = { |
174 | 176 | "account_data": {"events": sync_result.account_data}, |
177 | "to_device": {"events": sync_result.to_device}, | |
175 | 178 | "presence": self.encode_presence( |
176 | 179 | sync_result.presence, time_now |
177 | 180 | ), |
17 | 17 | |
18 | 18 | from twisted.internet import defer |
19 | 19 | |
20 | from synapse.api.constants import ThirdPartyEntityKind | |
20 | 21 | from synapse.http.servlet import RestServlet |
21 | from synapse.types import ThirdPartyEntityKind | |
22 | 22 | from ._base import client_v2_patterns |
23 | 23 | |
24 | 24 | logger = logging.getLogger(__name__) |
25 | 25 | |
26 | 26 | |
27 | class ThirdPartyProtocolsServlet(RestServlet): | |
28 | PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) | |
29 | ||
30 | def __init__(self, hs): | |
31 | super(ThirdPartyProtocolsServlet, self).__init__() | |
32 | ||
33 | self.auth = hs.get_auth() | |
34 | self.appservice_handler = hs.get_application_service_handler() | |
35 | ||
36 | @defer.inlineCallbacks | |
37 | def on_GET(self, request): | |
38 | yield self.auth.get_user_by_req(request) | |
39 | ||
40 | protocols = yield self.appservice_handler.get_3pe_protocols() | |
41 | defer.returnValue((200, protocols)) | |
42 | ||
43 | ||
27 | 44 | class ThirdPartyUserServlet(RestServlet): |
28 | PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$", | |
45 | PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", | |
29 | 46 | releases=()) |
30 | 47 | |
31 | 48 | def __init__(self, hs): |
49 | 66 | |
50 | 67 | |
51 | 68 | class ThirdPartyLocationServlet(RestServlet): |
52 | PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$", | |
69 | PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", | |
53 | 70 | releases=()) |
54 | 71 | |
55 | 72 | def __init__(self, hs): |
73 | 90 | |
74 | 91 | |
75 | 92 | def register_servlets(hs, http_server): |
93 | ThirdPartyProtocolsServlet(hs).register(http_server) | |
76 | 94 | ThirdPartyUserServlet(hs).register(http_server) |
77 | 95 | ThirdPartyLocationServlet(hs).register(http_server) |
22 | 22 | from synapse.api.errors import AuthError |
23 | 23 | from synapse.api.auth import AuthEventTypes |
24 | 24 | from synapse.events.snapshot import EventContext |
25 | from synapse.util.async import Linearizer | |
25 | 26 | |
26 | 27 | from collections import namedtuple |
27 | 28 | |
42 | 43 | EVICTION_TIMEOUT_SECONDS = 60 * 60 |
43 | 44 | |
44 | 45 | |
46 | _NEXT_STATE_ID = 1 | |
47 | ||
48 | ||
49 | def _gen_state_id(): | |
50 | global _NEXT_STATE_ID | |
51 | s = "X%d" % (_NEXT_STATE_ID,) | |
52 | _NEXT_STATE_ID += 1 | |
53 | return s | |
54 | ||
55 | ||
45 | 56 | class _StateCacheEntry(object): |
46 | def __init__(self, state, state_group, ts): | |
57 | __slots__ = ["state", "state_group", "state_id"] | |
58 | ||
59 | def __init__(self, state, state_group): | |
47 | 60 | self.state = state |
48 | 61 | self.state_group = state_group |
62 | ||
63 | # The `state_id` is a unique ID we generate that can be used as ID for | |
64 | # this collection of state. Usually this would be the same as the | |
65 | # state group, but on worker instances we can't generate a new state | |
66 | # group each time we resolve state, so we generate a separate one that | |
67 | # isn't persisted and is used solely for caches. | |
68 | # `state_id` is either a state_group (and so an int) or a string. This | |
69 | # ensures we don't accidentally persist a state_id as a stateg_group | |
70 | if state_group: | |
71 | self.state_id = state_group | |
72 | else: | |
73 | self.state_id = _gen_state_id() | |
49 | 74 | |
50 | 75 | |
51 | 76 | class StateHandler(object): |
59 | 84 | |
60 | 85 | # dict of set of event_ids -> _StateCacheEntry. |
61 | 86 | self._state_cache = None |
87 | self.resolve_linearizer = Linearizer() | |
62 | 88 | |
63 | 89 | def start_caching(self): |
64 | 90 | logger.debug("start_caching") |
92 | 118 | if not latest_event_ids: |
93 | 119 | latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) |
94 | 120 | |
95 | res = yield self.resolve_state_groups(room_id, latest_event_ids) | |
96 | state = res[1] | |
121 | ret = yield self.resolve_state_groups(room_id, latest_event_ids) | |
122 | state = ret.state | |
123 | ||
124 | if event_type: | |
125 | event_id = state.get((event_type, state_key)) | |
126 | event = None | |
127 | if event_id: | |
128 | event = yield self.store.get_event(event_id, allow_none=True) | |
129 | defer.returnValue(event) | |
130 | return | |
131 | ||
132 | state_map = yield self.store.get_events(state.values(), get_prev_content=False) | |
133 | state = { | |
134 | key: state_map[e_id] for key, e_id in state.items() if e_id in state_map | |
135 | } | |
136 | ||
137 | defer.returnValue(state) | |
138 | ||
139 | @defer.inlineCallbacks | |
140 | def get_current_state_ids(self, room_id, event_type=None, state_key="", | |
141 | latest_event_ids=None): | |
142 | if not latest_event_ids: | |
143 | latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) | |
144 | ||
145 | ret = yield self.resolve_state_groups(room_id, latest_event_ids) | |
146 | state = ret.state | |
97 | 147 | |
98 | 148 | if event_type: |
99 | 149 | defer.returnValue(state.get((event_type, state_key))) |
100 | 150 | return |
101 | 151 | |
102 | 152 | defer.returnValue(state) |
153 | ||
154 | @defer.inlineCallbacks | |
155 | def get_current_user_in_room(self, room_id): | |
156 | latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) | |
157 | entry = yield self.resolve_state_groups(room_id, latest_event_ids) | |
158 | joined_users = yield self.store.get_joined_users_from_state( | |
159 | room_id, entry.state_id, entry.state | |
160 | ) | |
161 | defer.returnValue(joined_users) | |
103 | 162 | |
104 | 163 | @defer.inlineCallbacks |
105 | 164 | def compute_event_context(self, event, old_state=None): |
122 | 181 | # state. Certainly store.get_current_state won't return any, and |
123 | 182 | # persisting the event won't store the state group. |
124 | 183 | if old_state: |
125 | context.current_state = { | |
126 | (s.type, s.state_key): s for s in old_state | |
184 | context.prev_state_ids = { | |
185 | (s.type, s.state_key): s.event_id for s in old_state | |
127 | 186 | } |
187 | if event.is_state(): | |
188 | context.current_state_events = dict(context.prev_state_ids) | |
189 | key = (event.type, event.state_key) | |
190 | context.current_state_events[key] = event.event_id | |
191 | else: | |
192 | context.current_state_events = context.prev_state_ids | |
128 | 193 | else: |
129 | context.current_state = {} | |
194 | context.current_state_ids = {} | |
195 | context.prev_state_ids = {} | |
130 | 196 | context.prev_state_events = [] |
131 | context.state_group = None | |
197 | context.state_group = self.store.get_next_state_group() | |
132 | 198 | defer.returnValue(context) |
133 | 199 | |
134 | 200 | if old_state: |
135 | context.current_state = { | |
136 | (s.type, s.state_key): s for s in old_state | |
201 | context.prev_state_ids = { | |
202 | (s.type, s.state_key): s.event_id for s in old_state | |
137 | 203 | } |
138 | context.state_group = None | |
204 | context.state_group = self.store.get_next_state_group() | |
139 | 205 | |
140 | 206 | if event.is_state(): |
141 | 207 | key = (event.type, event.state_key) |
142 | if key in context.current_state: | |
143 | replaces = context.current_state[key] | |
144 | if replaces.event_id != event.event_id: # Paranoia check | |
145 | event.unsigned["replaces_state"] = replaces.event_id | |
208 | if key in context.prev_state_ids: | |
209 | replaces = context.prev_state_ids[key] | |
210 | if replaces != event.event_id: # Paranoia check | |
211 | event.unsigned["replaces_state"] = replaces | |
212 | context.current_state_ids = dict(context.prev_state_ids) | |
213 | context.current_state_ids[key] = event.event_id | |
214 | else: | |
215 | context.current_state_ids = context.prev_state_ids | |
146 | 216 | |
147 | 217 | context.prev_state_events = [] |
148 | 218 | defer.returnValue(context) |
149 | 219 | |
150 | 220 | if event.is_state(): |
151 | ret = yield self.resolve_state_groups( | |
221 | entry = yield self.resolve_state_groups( | |
152 | 222 | event.room_id, [e for e, _ in event.prev_events], |
153 | 223 | event_type=event.type, |
154 | 224 | state_key=event.state_key, |
155 | 225 | ) |
156 | 226 | else: |
157 | ret = yield self.resolve_state_groups( | |
227 | entry = yield self.resolve_state_groups( | |
158 | 228 | event.room_id, [e for e, _ in event.prev_events], |
159 | 229 | ) |
160 | 230 | |
161 | group, curr_state, prev_state = ret | |
162 | ||
163 | context.current_state = curr_state | |
164 | context.state_group = group if not event.is_state() else None | |
231 | curr_state = entry.state | |
232 | ||
233 | context.prev_state_ids = curr_state | |
234 | if event.is_state(): | |
235 | context.state_group = self.store.get_next_state_group() | |
236 | else: | |
237 | if entry.state_group is None: | |
238 | entry.state_group = self.store.get_next_state_group() | |
239 | entry.state_id = entry.state_group | |
240 | context.state_group = entry.state_group | |
165 | 241 | |
166 | 242 | if event.is_state(): |
167 | 243 | key = (event.type, event.state_key) |
168 | if key in context.current_state: | |
169 | replaces = context.current_state[key] | |
170 | event.unsigned["replaces_state"] = replaces.event_id | |
171 | ||
172 | context.prev_state_events = prev_state | |
244 | if key in context.prev_state_ids: | |
245 | replaces = context.prev_state_ids[key] | |
246 | event.unsigned["replaces_state"] = replaces | |
247 | context.current_state_ids = dict(context.prev_state_ids) | |
248 | context.current_state_ids[key] = event.event_id | |
249 | else: | |
250 | context.current_state_ids = context.prev_state_ids | |
251 | ||
252 | context.prev_state_events = [] | |
173 | 253 | defer.returnValue(context) |
174 | 254 | |
175 | 255 | @defer.inlineCallbacks |
186 | 266 | """ |
187 | 267 | logger.debug("resolve_state_groups event_ids %s", event_ids) |
188 | 268 | |
189 | state_groups = yield self.store.get_state_groups( | |
269 | state_groups_ids = yield self.store.get_state_groups_ids( | |
190 | 270 | room_id, event_ids |
191 | 271 | ) |
192 | 272 | |
193 | 273 | logger.debug( |
194 | 274 | "resolve_state_groups state_groups %s", |
195 | state_groups.keys() | |
275 | state_groups_ids.keys() | |
196 | 276 | ) |
197 | 277 | |
198 | group_names = frozenset(state_groups.keys()) | |
278 | group_names = frozenset(state_groups_ids.keys()) | |
199 | 279 | if len(group_names) == 1: |
200 | name, state_list = state_groups.items().pop() | |
201 | state = { | |
202 | (e.type, e.state_key): e | |
203 | for e in state_list | |
280 | name, state_list = state_groups_ids.items().pop() | |
281 | ||
282 | defer.returnValue(_StateCacheEntry( | |
283 | state=state_list, | |
284 | state_group=name, | |
285 | )) | |
286 | ||
287 | with (yield self.resolve_linearizer.queue(group_names)): | |
288 | if self._state_cache is not None: | |
289 | cache = self._state_cache.get(group_names, None) | |
290 | if cache: | |
291 | defer.returnValue(cache) | |
292 | ||
293 | logger.info( | |
294 | "Resolving state for %s with %d groups", room_id, len(state_groups_ids) | |
295 | ) | |
296 | ||
297 | state = {} | |
298 | for st in state_groups_ids.values(): | |
299 | for key, e_id in st.items(): | |
300 | state.setdefault(key, set()).add(e_id) | |
301 | ||
302 | conflicted_state = { | |
303 | k: list(v) | |
304 | for k, v in state.items() | |
305 | if len(v) > 1 | |
204 | 306 | } |
205 | prev_state = state.get((event_type, state_key), None) | |
206 | if prev_state: | |
207 | prev_state = prev_state.event_id | |
208 | prev_states = [prev_state] | |
307 | ||
308 | if conflicted_state: | |
309 | logger.info("Resolving conflicted state for %r", room_id) | |
310 | state_map = yield self.store.get_events( | |
311 | [e_id for st in state_groups_ids.values() for e_id in st.values()], | |
312 | get_prev_content=False | |
313 | ) | |
314 | state_sets = [ | |
315 | [state_map[e_id] for key, e_id in st.items() if e_id in state_map] | |
316 | for st in state_groups_ids.values() | |
317 | ] | |
318 | new_state, _ = self._resolve_events( | |
319 | state_sets, event_type, state_key | |
320 | ) | |
321 | new_state = { | |
322 | key: e.event_id for key, e in new_state.items() | |
323 | } | |
209 | 324 | else: |
210 | prev_states = [] | |
211 | ||
212 | defer.returnValue((name, state, prev_states)) | |
213 | ||
214 | if self._state_cache is not None: | |
215 | cache = self._state_cache.get(group_names, None) | |
216 | if cache: | |
217 | cache.ts = self.clock.time_msec() | |
218 | ||
219 | event_dict = yield self.store.get_events(cache.state.values()) | |
220 | state = {(e.type, e.state_key): e for e in event_dict.values()} | |
221 | ||
222 | prev_state = state.get((event_type, state_key), None) | |
223 | if prev_state: | |
224 | prev_state = prev_state.event_id | |
225 | prev_states = [prev_state] | |
226 | else: | |
227 | prev_states = [] | |
228 | defer.returnValue( | |
229 | (cache.state_group, state, prev_states) | |
230 | ) | |
231 | ||
232 | logger.info("Resolving state for %s with %d groups", room_id, len(state_groups)) | |
233 | ||
234 | new_state, prev_states = self._resolve_events( | |
235 | state_groups.values(), event_type, state_key | |
236 | ) | |
237 | ||
238 | state_group = None | |
239 | new_state_event_ids = frozenset(e.event_id for e in new_state.values()) | |
240 | for sg, events in state_groups.items(): | |
241 | if new_state_event_ids == frozenset(e.event_id for e in events): | |
242 | state_group = sg | |
243 | break | |
244 | ||
245 | if self._state_cache is not None: | |
325 | new_state = { | |
326 | key: e_ids.pop() for key, e_ids in state.items() | |
327 | } | |
328 | ||
329 | state_group = None | |
330 | new_state_event_ids = frozenset(new_state.values()) | |
331 | for sg, events in state_groups_ids.items(): | |
332 | if new_state_event_ids == frozenset(e_id for e_id in events): | |
333 | state_group = sg | |
334 | break | |
335 | if state_group is None: | |
336 | # Worker instances don't have access to this method, but we want | |
337 | # to set the state_group on the main instance to increase cache | |
338 | # hits. | |
339 | if hasattr(self.store, "get_next_state_group"): | |
340 | state_group = self.store.get_next_state_group() | |
341 | ||
246 | 342 | cache = _StateCacheEntry( |
247 | state={key: event.event_id for key, event in new_state.items()}, | |
343 | state=new_state, | |
248 | 344 | state_group=state_group, |
249 | ts=self.clock.time_msec() | |
250 | 345 | ) |
251 | 346 | |
252 | self._state_cache[group_names] = cache | |
253 | ||
254 | defer.returnValue((state_group, new_state, prev_states)) | |
347 | if self._state_cache is not None: | |
348 | self._state_cache[group_names] = cache | |
349 | ||
350 | defer.returnValue(cache) | |
255 | 351 | |
256 | 352 | def resolve_events(self, state_sets, event): |
257 | 353 | logger.info( |
35 | 35 | from .media_repository import MediaRepositoryStore |
36 | 36 | from .rejections import RejectionsStore |
37 | 37 | from .event_push_actions import EventPushActionsStore |
38 | from .deviceinbox import DeviceInboxStore | |
38 | 39 | |
39 | 40 | from .state import StateStore |
40 | 41 | from .signatures import SignatureStore |
83 | 84 | OpenIdStore, |
84 | 85 | ClientIpStore, |
85 | 86 | DeviceStore, |
87 | DeviceInboxStore, | |
86 | 88 | ): |
87 | 89 | |
88 | 90 | def __init__(self, db_conn, hs): |
107 | 109 | self._presence_id_gen = StreamIdGenerator( |
108 | 110 | db_conn, "presence_stream", "stream_id" |
109 | 111 | ) |
112 | self._device_inbox_id_gen = StreamIdGenerator( | |
113 | db_conn, "device_inbox", "stream_id" | |
114 | ) | |
110 | 115 | |
111 | 116 | self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") |
112 | self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") | |
117 | self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") | |
113 | 118 | self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") |
114 | 119 | self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") |
115 | 120 | self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import logging | |
16 | import ujson | |
17 | ||
18 | from twisted.internet import defer | |
19 | ||
20 | from ._base import SQLBaseStore | |
21 | ||
22 | ||
23 | logger = logging.getLogger(__name__) | |
24 | ||
25 | ||
26 | class DeviceInboxStore(SQLBaseStore): | |
27 | ||
28 | @defer.inlineCallbacks | |
29 | def add_messages_to_device_inbox(self, messages_by_user_then_device): | |
30 | """ | |
31 | Args: | |
32 | messages_by_user_and_device(dict): | |
33 | Dictionary of user_id to device_id to message. | |
34 | Returns: | |
35 | A deferred stream_id that resolves when the messages have been | |
36 | inserted. | |
37 | """ | |
38 | ||
39 | def select_devices_txn(txn, user_id, devices): | |
40 | if not devices: | |
41 | return [] | |
42 | sql = ( | |
43 | "SELECT user_id, device_id FROM devices" | |
44 | " WHERE user_id = ? AND device_id IN (" | |
45 | + ",".join("?" * len(devices)) | |
46 | + ")" | |
47 | ) | |
48 | # TODO: Maybe this needs to be done in batches if there are | |
49 | # too many local devices for a given user. | |
50 | args = [user_id] + devices | |
51 | txn.execute(sql, args) | |
52 | return [tuple(row) for row in txn.fetchall()] | |
53 | ||
54 | def add_messages_to_device_inbox_txn(txn, stream_id): | |
55 | local_users_and_devices = set() | |
56 | for user_id, messages_by_device in messages_by_user_then_device.items(): | |
57 | local_users_and_devices.update( | |
58 | select_devices_txn(txn, user_id, messages_by_device.keys()) | |
59 | ) | |
60 | ||
61 | sql = ( | |
62 | "INSERT INTO device_inbox" | |
63 | " (user_id, device_id, stream_id, message_json)" | |
64 | " VALUES (?,?,?,?)" | |
65 | ) | |
66 | rows = [] | |
67 | for user_id, messages_by_device in messages_by_user_then_device.items(): | |
68 | for device_id, message in messages_by_device.items(): | |
69 | message_json = ujson.dumps(message) | |
70 | # Only insert into the local inbox if the device exists on | |
71 | # this server | |
72 | if (user_id, device_id) in local_users_and_devices: | |
73 | rows.append((user_id, device_id, stream_id, message_json)) | |
74 | ||
75 | txn.executemany(sql, rows) | |
76 | ||
77 | with self._device_inbox_id_gen.get_next() as stream_id: | |
78 | yield self.runInteraction( | |
79 | "add_messages_to_device_inbox", | |
80 | add_messages_to_device_inbox_txn, | |
81 | stream_id | |
82 | ) | |
83 | ||
84 | defer.returnValue(self._device_inbox_id_gen.get_current_token()) | |
85 | ||
86 | def get_new_messages_for_device( | |
87 | self, user_id, device_id, last_stream_id, current_stream_id, limit=100 | |
88 | ): | |
89 | """ | |
90 | Args: | |
91 | user_id(str): The recipient user_id. | |
92 | device_id(str): The recipient device_id. | |
93 | current_stream_id(int): The current position of the to device | |
94 | message stream. | |
95 | Returns: | |
96 | Deferred ([dict], int): List of messages for the device and where | |
97 | in the stream the messages got to. | |
98 | """ | |
99 | def get_new_messages_for_device_txn(txn): | |
100 | sql = ( | |
101 | "SELECT stream_id, message_json FROM device_inbox" | |
102 | " WHERE user_id = ? AND device_id = ?" | |
103 | " AND ? < stream_id AND stream_id <= ?" | |
104 | " ORDER BY stream_id ASC" | |
105 | " LIMIT ?" | |
106 | ) | |
107 | txn.execute(sql, ( | |
108 | user_id, device_id, last_stream_id, current_stream_id, limit | |
109 | )) | |
110 | messages = [] | |
111 | for row in txn.fetchall(): | |
112 | stream_pos = row[0] | |
113 | messages.append(ujson.loads(row[1])) | |
114 | if len(messages) < limit: | |
115 | stream_pos = current_stream_id | |
116 | return (messages, stream_pos) | |
117 | ||
118 | return self.runInteraction( | |
119 | "get_new_messages_for_device", get_new_messages_for_device_txn, | |
120 | ) | |
121 | ||
122 | def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): | |
123 | """ | |
124 | Args: | |
125 | user_id(str): The recipient user_id. | |
126 | device_id(str): The recipient device_id. | |
127 | up_to_stream_id(int): Where to delete messages up to. | |
128 | Returns: | |
129 | A deferred that resolves when the messages have been deleted. | |
130 | """ | |
131 | def delete_messages_for_device_txn(txn): | |
132 | sql = ( | |
133 | "DELETE FROM device_inbox" | |
134 | " WHERE user_id = ? AND device_id = ?" | |
135 | " AND stream_id <= ?" | |
136 | ) | |
137 | txn.execute(sql, (user_id, device_id, up_to_stream_id)) | |
138 | ||
139 | return self.runInteraction( | |
140 | "delete_messages_for_device", delete_messages_for_device_txn | |
141 | ) | |
142 | ||
143 | def get_all_new_device_messages(self, last_pos, current_pos, limit): | |
144 | """ | |
145 | Args: | |
146 | last_pos(int): | |
147 | current_pos(int): | |
148 | limit(int): | |
149 | Returns: | |
150 | A deferred list of rows from the device inbox | |
151 | """ | |
152 | if last_pos == current_pos: | |
153 | return defer.succeed([]) | |
154 | ||
155 | def get_all_new_device_messages_txn(txn): | |
156 | sql = ( | |
157 | "SELECT stream_id FROM device_inbox" | |
158 | " WHERE ? < stream_id AND stream_id <= ?" | |
159 | " GROUP BY stream_id" | |
160 | " ORDER BY stream_id ASC" | |
161 | " LIMIT ?" | |
162 | ) | |
163 | txn.execute(sql, (last_pos, current_pos, limit)) | |
164 | stream_ids = txn.fetchall() | |
165 | if not stream_ids: | |
166 | return [] | |
167 | max_stream_id_in_limit = stream_ids[-1] | |
168 | ||
169 | sql = ( | |
170 | "SELECT stream_id, user_id, device_id, message_json" | |
171 | " FROM device_inbox" | |
172 | " WHERE ? < stream_id AND stream_id <= ?" | |
173 | " ORDER BY stream_id ASC" | |
174 | ) | |
175 | txn.execute(sql, (last_pos, max_stream_id_in_limit)) | |
176 | return txn.fetchall() | |
177 | ||
178 | return self.runInteraction( | |
179 | "get_all_new_device_messages", get_all_new_device_messages_txn | |
180 | ) | |
181 | ||
182 | def get_to_device_stream_token(self): | |
183 | return self._device_inbox_id_gen.get_current_token() |
270 | 270 | len(events_and_contexts) |
271 | 271 | ) |
272 | 272 | |
273 | state_group_id_manager = self._state_groups_id_gen.get_next_mult( | |
274 | len(events_and_contexts) | |
275 | ) | |
276 | 273 | with stream_ordering_manager as stream_orderings: |
277 | with state_group_id_manager as state_group_ids: | |
278 | for (event, context), stream, state_group_id in zip( | |
279 | events_and_contexts, stream_orderings, state_group_ids | |
280 | ): | |
281 | event.internal_metadata.stream_ordering = stream | |
282 | # Assign a state group_id in case a new id is needed for | |
283 | # this context. In theory we only need to assign this | |
284 | # for contexts that have current_state and aren't outliers | |
285 | # but that make the code more complicated. Assigning an ID | |
286 | # per event only causes the state_group_ids to grow as fast | |
287 | # as the stream_ordering so in practise shouldn't be a problem. | |
288 | context.new_state_group_id = state_group_id | |
289 | ||
290 | chunks = [ | |
291 | events_and_contexts[x:x + 100] | |
292 | for x in xrange(0, len(events_and_contexts), 100) | |
293 | ] | |
294 | ||
295 | for chunk in chunks: | |
296 | # We can't easily parallelize these since different chunks | |
297 | # might contain the same event. :( | |
298 | yield self.runInteraction( | |
299 | "persist_events", | |
300 | self._persist_events_txn, | |
301 | events_and_contexts=chunk, | |
302 | backfilled=backfilled, | |
303 | delete_existing=delete_existing, | |
304 | ) | |
305 | persist_event_counter.inc_by(len(chunk)) | |
274 | for (event, context), stream, in zip( | |
275 | events_and_contexts, stream_orderings | |
276 | ): | |
277 | event.internal_metadata.stream_ordering = stream | |
278 | ||
279 | chunks = [ | |
280 | events_and_contexts[x:x + 100] | |
281 | for x in xrange(0, len(events_and_contexts), 100) | |
282 | ] | |
283 | ||
284 | for chunk in chunks: | |
285 | # We can't easily parallelize these since different chunks | |
286 | # might contain the same event. :( | |
287 | yield self.runInteraction( | |
288 | "persist_events", | |
289 | self._persist_events_txn, | |
290 | events_and_contexts=chunk, | |
291 | backfilled=backfilled, | |
292 | delete_existing=delete_existing, | |
293 | ) | |
294 | persist_event_counter.inc_by(len(chunk)) | |
306 | 295 | |
307 | 296 | @_retry_on_integrity_error |
308 | 297 | @defer.inlineCallbacks |
311 | 300 | delete_existing=False): |
312 | 301 | try: |
313 | 302 | with self._stream_id_gen.get_next() as stream_ordering: |
314 | with self._state_groups_id_gen.get_next() as state_group_id: | |
315 | event.internal_metadata.stream_ordering = stream_ordering | |
316 | context.new_state_group_id = state_group_id | |
317 | yield self.runInteraction( | |
318 | "persist_event", | |
319 | self._persist_event_txn, | |
320 | event=event, | |
321 | context=context, | |
322 | current_state=current_state, | |
323 | backfilled=backfilled, | |
324 | delete_existing=delete_existing, | |
325 | ) | |
326 | persist_event_counter.inc() | |
303 | event.internal_metadata.stream_ordering = stream_ordering | |
304 | yield self.runInteraction( | |
305 | "persist_event", | |
306 | self._persist_event_txn, | |
307 | event=event, | |
308 | context=context, | |
309 | current_state=current_state, | |
310 | backfilled=backfilled, | |
311 | delete_existing=delete_existing, | |
312 | ) | |
313 | persist_event_counter.inc() | |
327 | 314 | except _RollbackButIsFineException: |
328 | 315 | pass |
329 | 316 | |
392 | 379 | txn.call_after(self._get_current_state_for_key.invalidate_all) |
393 | 380 | txn.call_after(self.get_rooms_for_user.invalidate_all) |
394 | 381 | txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) |
395 | txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) | |
396 | 382 | |
397 | 383 | # Add an entry to the current_state_resets table to record the point |
398 | 384 | # where we clobbered the current state |
528 | 514 | # Add an entry to the ex_outlier_stream table to replicate the |
529 | 515 | # change in outlier status to our workers. |
530 | 516 | stream_order = event.internal_metadata.stream_ordering |
531 | state_group_id = context.state_group or context.new_state_group_id | |
517 | state_group_id = context.state_group | |
532 | 518 | self._simple_insert_txn( |
533 | 519 | txn, |
534 | 520 | table="ex_outlier_stream", |
15 | 15 | from ._base import SQLBaseStore |
16 | 16 | from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList |
17 | 17 | from synapse.push.baserules import list_with_base_rules |
18 | from synapse.api.constants import EventTypes, Membership | |
19 | 18 | from twisted.internet import defer |
20 | 19 | |
21 | 20 | import logging |
123 | 122 | |
124 | 123 | defer.returnValue(results) |
125 | 124 | |
126 | def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): | |
125 | def bulk_get_push_rules_for_room(self, event, context): | |
126 | state_group = context.state_group | |
127 | 127 | if not state_group: |
128 | 128 | # If state_group is None it means it has yet to be assigned a |
129 | 129 | # state group, i.e. we need to make sure that calls with a state_group |
131 | 131 | # To do this we set the state_group to a new object as object() != object() |
132 | 132 | state_group = object() |
133 | 133 | |
134 | return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) | |
134 | return self._bulk_get_push_rules_for_room( | |
135 | event.room_id, state_group, context.current_state_ids, event=event | |
136 | ) | |
135 | 137 | |
136 | 138 | @cachedInlineCallbacks(num_args=2, cache_context=True) |
137 | def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, | |
138 | cache_context): | |
139 | def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, | |
140 | cache_context, event=None): | |
139 | 141 | # We don't use `state_group`, its there so that we can cache based |
140 | 142 | # on it. However, its important that its never None, since two current_state's |
141 | 143 | # with a state_group of None are likely to be different. |
146 | 148 | # their unread countss are correct in the event stream, but to avoid |
147 | 149 | # generating them for bot / AS users etc, we only do so for people who've |
148 | 150 | # sent a read receipt into the room. |
149 | local_users_in_room = set( | |
150 | e.state_key for e in current_state.values() | |
151 | if e.type == EventTypes.Member and e.membership == Membership.JOIN | |
152 | and self.hs.is_mine_id(e.state_key) | |
153 | ) | |
151 | ||
152 | users_in_room = yield self._get_joined_users_from_context( | |
153 | room_id, state_group, current_state_ids, | |
154 | on_invalidate=cache_context.invalidate, | |
155 | event=event, | |
156 | ) | |
157 | ||
158 | local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u)) | |
154 | 159 | |
155 | 160 | # users in the room who have pushers need to get push rules run because |
156 | 161 | # that's how their pushers work |
144 | 144 | |
145 | 145 | defer.returnValue([ev for res in results.values() for ev in res]) |
146 | 146 | |
147 | @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) | |
147 | @cachedInlineCallbacks(num_args=3, tree=True) | |
148 | 148 | def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): |
149 | 149 | """Get receipts for a single room for sending to clients. |
150 | 150 |
19 | 19 | from ._base import SQLBaseStore |
20 | 20 | from synapse.util.caches.descriptors import cached, cachedInlineCallbacks |
21 | 21 | |
22 | from synapse.api.constants import Membership | |
22 | from synapse.api.constants import Membership, EventTypes | |
23 | 23 | from synapse.types import get_domain_from_id |
24 | 24 | |
25 | 25 | import logging |
55 | 55 | |
56 | 56 | for event in events: |
57 | 57 | txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) |
58 | txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) | |
59 | 58 | txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) |
60 | 59 | txn.call_after( |
61 | 60 | self._membership_stream_cache.entity_has_changed, |
237 | 236 | |
238 | 237 | return results |
239 | 238 | |
240 | @cachedInlineCallbacks(max_entries=5000) | |
241 | def get_joined_hosts_for_room(self, room_id): | |
242 | user_ids = yield self.get_users_in_room(room_id) | |
243 | defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids)) | |
244 | ||
245 | 239 | def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): |
246 | 240 | where_clause = "c.room_id = ?" |
247 | 241 | where_values = [room_id] |
324 | 318 | |
325 | 319 | @cachedInlineCallbacks(num_args=3) |
326 | 320 | def was_forgotten_at(self, user_id, room_id, event_id): |
327 | """Returns whether user_id has elected to discard history for room_id at event_id. | |
321 | """Returns whether user_id has elected to discard history for room_id at | |
322 | event_id. | |
328 | 323 | |
329 | 324 | event_id must be a membership event.""" |
330 | 325 | def f(txn): |
357 | 352 | }, |
358 | 353 | desc="who_forgot" |
359 | 354 | ) |
355 | ||
356 | def get_joined_users_from_context(self, event, context): | |
357 | state_group = context.state_group | |
358 | if not state_group: | |
359 | # If state_group is None it means it has yet to be assigned a | |
360 | # state group, i.e. we need to make sure that calls with a state_group | |
361 | # of None don't hit previous cached calls with a None state_group. | |
362 | # To do this we set the state_group to a new object as object() != object() | |
363 | state_group = object() | |
364 | ||
365 | return self._get_joined_users_from_context( | |
366 | event.room_id, state_group, context.current_state_ids, event=event, | |
367 | ) | |
368 | ||
369 | def get_joined_users_from_state(self, room_id, state_group, state_ids): | |
370 | if not state_group: | |
371 | # If state_group is None it means it has yet to be assigned a | |
372 | # state group, i.e. we need to make sure that calls with a state_group | |
373 | # of None don't hit previous cached calls with a None state_group. | |
374 | # To do this we set the state_group to a new object as object() != object() | |
375 | state_group = object() | |
376 | ||
377 | return self._get_joined_users_from_context( | |
378 | room_id, state_group, state_ids, | |
379 | ) | |
380 | ||
381 | @cachedInlineCallbacks(num_args=2, cache_context=True) | |
382 | def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, | |
383 | cache_context, event=None): | |
384 | # We don't use `state_group`, its there so that we can cache based | |
385 | # on it. However, its important that its never None, since two current_state's | |
386 | # with a state_group of None are likely to be different. | |
387 | # See bulk_get_push_rules_for_room for how we work around this. | |
388 | assert state_group is not None | |
389 | ||
390 | member_event_ids = [ | |
391 | e_id | |
392 | for key, e_id in current_state_ids.iteritems() | |
393 | if key[0] == EventTypes.Member | |
394 | ] | |
395 | ||
396 | rows = yield self._simple_select_many_batch( | |
397 | table="room_memberships", | |
398 | column="event_id", | |
399 | iterable=member_event_ids, | |
400 | retcols=['user_id'], | |
401 | keyvalues={ | |
402 | "membership": Membership.JOIN, | |
403 | }, | |
404 | batch_size=1000, | |
405 | desc="_get_joined_users_from_context", | |
406 | ) | |
407 | ||
408 | users_in_room = set(row["user_id"] for row in rows) | |
409 | if event is not None and event.type == EventTypes.Member: | |
410 | if event.membership == Membership.JOIN: | |
411 | if event.event_id in member_event_ids: | |
412 | users_in_room.add(event.state_key) | |
413 | ||
414 | defer.returnValue(users_in_room) | |
415 | ||
416 | def is_host_joined(self, room_id, host, state_group, state_ids): | |
417 | if not state_group: | |
418 | # If state_group is None it means it has yet to be assigned a | |
419 | # state group, i.e. we need to make sure that calls with a state_group | |
420 | # of None don't hit previous cached calls with a None state_group. | |
421 | # To do this we set the state_group to a new object as object() != object() | |
422 | state_group = object() | |
423 | ||
424 | return self._is_host_joined( | |
425 | room_id, host, state_group, state_ids | |
426 | ) | |
427 | ||
428 | @cachedInlineCallbacks(num_args=3) | |
429 | def _is_host_joined(self, room_id, host, state_group, current_state_ids): | |
430 | # We don't use `state_group`, its there so that we can cache based | |
431 | # on it. However, its important that its never None, since two current_state's | |
432 | # with a state_group of None are likely to be different. | |
433 | # See bulk_get_push_rules_for_room for how we work around this. | |
434 | assert state_group is not None | |
435 | ||
436 | for (etype, state_key), event_id in current_state_ids.items(): | |
437 | if etype == EventTypes.Member: | |
438 | try: | |
439 | if get_domain_from_id(state_key) != host: | |
440 | continue | |
441 | except: | |
442 | logger.warn("state_key not user_id: %s", state_key) | |
443 | continue | |
444 | ||
445 | event = yield self.get_event(event_id, allow_none=True) | |
446 | if event and event.content["membership"] == Membership.JOIN: | |
447 | defer.returnValue(True) | |
448 | ||
449 | defer.returnValue(False) |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | CREATE TABLE device_inbox ( | |
16 | user_id TEXT NOT NULL, | |
17 | device_id TEXT NOT NULL, | |
18 | stream_id BIGINT NOT NULL, | |
19 | message_json TEXT NOT NULL -- {"type":, "sender":, "content",} | |
20 | ); | |
21 | ||
22 | CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); | |
23 | CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); |
0 | # Copyright 2016 OpenMarket Ltd | |
1 | # | |
2 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | # you may not use this file except in compliance with the License. | |
4 | # You may obtain a copy of the License at | |
5 | # | |
6 | # http://www.apache.org/licenses/LICENSE-2.0 | |
7 | # | |
8 | # Unless required by applicable law or agreed to in writing, software | |
9 | # distributed under the License is distributed on an "AS IS" BASIS, | |
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | # See the License for the specific language governing permissions and | |
12 | # limitations under the License. | |
13 | ||
14 | from synapse.storage.engines import PostgresEngine | |
15 | ||
16 | import logging | |
17 | ||
18 | logger = logging.getLogger(__name__) | |
19 | ||
20 | ||
21 | def run_create(cur, database_engine, *args, **kwargs): | |
22 | if isinstance(database_engine, PostgresEngine): | |
23 | cur.execute("TRUNCATE sent_transactions") | |
24 | else: | |
25 | cur.execute("DELETE FROM sent_transactions") | |
26 | ||
27 | cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)") | |
28 | ||
29 | ||
30 | def run_upgrade(cur, database_engine, *args, **kwargs): | |
31 | pass |
43 | 43 | """ |
44 | 44 | |
45 | 45 | @defer.inlineCallbacks |
46 | def get_state_groups_ids(self, room_id, event_ids): | |
47 | if not event_ids: | |
48 | defer.returnValue({}) | |
49 | ||
50 | event_to_groups = yield self._get_state_group_for_events( | |
51 | event_ids, | |
52 | ) | |
53 | ||
54 | groups = set(event_to_groups.values()) | |
55 | group_to_state = yield self._get_state_for_groups(groups) | |
56 | ||
57 | defer.returnValue(group_to_state) | |
58 | ||
59 | @defer.inlineCallbacks | |
46 | 60 | def get_state_groups(self, room_id, event_ids): |
47 | 61 | """ Get the state groups for the given list of event_ids |
48 | 62 | |
51 | 65 | if not event_ids: |
52 | 66 | defer.returnValue({}) |
53 | 67 | |
54 | event_to_groups = yield self._get_state_group_for_events( | |
55 | event_ids, | |
56 | ) | |
57 | ||
58 | groups = set(event_to_groups.values()) | |
59 | group_to_state = yield self._get_state_for_groups(groups) | |
68 | group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) | |
69 | ||
70 | state_event_map = yield self.get_events( | |
71 | [ | |
72 | ev_id for group_ids in group_to_ids.values() | |
73 | for ev_id in group_ids.values() | |
74 | ], | |
75 | get_prev_content=False | |
76 | ) | |
60 | 77 | |
61 | 78 | defer.returnValue({ |
62 | group: state_map.values() | |
63 | for group, state_map in group_to_state.items() | |
79 | group: [ | |
80 | state_event_map[v] for v in event_id_map.values() if v in state_event_map | |
81 | ] | |
82 | for group, event_id_map in group_to_ids.items() | |
64 | 83 | }) |
84 | ||
85 | def _have_persisted_state_group_txn(self, txn, state_group): | |
86 | txn.execute( | |
87 | "SELECT count(*) FROM state_groups WHERE id = ?", | |
88 | (state_group,) | |
89 | ) | |
90 | row = txn.fetchone() | |
91 | return row and row[0] | |
65 | 92 | |
66 | 93 | def _store_mult_state_groups_txn(self, txn, events_and_contexts): |
67 | 94 | state_groups = {} |
69 | 96 | if event.internal_metadata.is_outlier(): |
70 | 97 | continue |
71 | 98 | |
72 | if context.current_state is None: | |
99 | if context.current_state_ids is None: | |
73 | 100 | continue |
74 | 101 | |
75 | if context.state_group is not None: | |
76 | state_groups[event.event_id] = context.state_group | |
102 | state_groups[event.event_id] = context.state_group | |
103 | ||
104 | if self._have_persisted_state_group_txn(txn, context.state_group): | |
105 | logger.info("Already persisted state_group: %r", context.state_group) | |
77 | 106 | continue |
78 | 107 | |
79 | state_events = dict(context.current_state) | |
80 | ||
81 | if event.is_state(): | |
82 | state_events[(event.type, event.state_key)] = event | |
83 | ||
84 | state_group = context.new_state_group_id | |
108 | state_event_ids = dict(context.current_state_ids) | |
85 | 109 | |
86 | 110 | self._simple_insert_txn( |
87 | 111 | txn, |
88 | 112 | table="state_groups", |
89 | 113 | values={ |
90 | "id": state_group, | |
114 | "id": context.state_group, | |
91 | 115 | "room_id": event.room_id, |
92 | 116 | "event_id": event.event_id, |
93 | 117 | }, |
98 | 122 | table="state_groups_state", |
99 | 123 | values=[ |
100 | 124 | { |
101 | "state_group": state_group, | |
102 | "room_id": state.room_id, | |
103 | "type": state.type, | |
104 | "state_key": state.state_key, | |
105 | "event_id": state.event_id, | |
125 | "state_group": context.state_group, | |
126 | "room_id": event.room_id, | |
127 | "type": key[0], | |
128 | "state_key": key[1], | |
129 | "event_id": state_id, | |
106 | 130 | } |
107 | for state in state_events.values() | |
131 | for key, state_id in state_event_ids.items() | |
108 | 132 | ], |
109 | 133 | ) |
110 | state_groups[event.event_id] = state_group | |
111 | 134 | |
112 | 135 | self._simple_insert_many_txn( |
113 | 136 | txn, |
247 | 270 | groups = set(event_to_groups.values()) |
248 | 271 | group_to_state = yield self._get_state_for_groups(groups, types) |
249 | 272 | |
273 | state_event_map = yield self.get_events( | |
274 | [ev_id for sd in group_to_state.values() for ev_id in sd.values()], | |
275 | get_prev_content=False | |
276 | ) | |
277 | ||
278 | event_to_state = { | |
279 | event_id: { | |
280 | k: state_event_map[v] | |
281 | for k, v in group_to_state[group].items() | |
282 | if v in state_event_map | |
283 | } | |
284 | for event_id, group in event_to_groups.items() | |
285 | } | |
286 | ||
287 | defer.returnValue({event: event_to_state[event] for event in event_ids}) | |
288 | ||
289 | @defer.inlineCallbacks | |
290 | def get_state_ids_for_events(self, event_ids, types): | |
291 | event_to_groups = yield self._get_state_group_for_events( | |
292 | event_ids, | |
293 | ) | |
294 | ||
295 | groups = set(event_to_groups.values()) | |
296 | group_to_state = yield self._get_state_for_groups(groups, types) | |
297 | ||
250 | 298 | event_to_state = { |
251 | 299 | event_id: group_to_state[group] |
252 | 300 | for event_id, group in event_to_groups.items() |
269 | 317 | A deferred dict from (type, state_key) -> state_event |
270 | 318 | """ |
271 | 319 | state_map = yield self.get_state_for_events([event_id], types) |
320 | defer.returnValue(state_map[event_id]) | |
321 | ||
322 | @defer.inlineCallbacks | |
323 | def get_state_ids_for_event(self, event_id, types=None): | |
324 | """ | |
325 | Get the state dict corresponding to a particular event | |
326 | ||
327 | Args: | |
328 | event_id(str): event whose state should be returned | |
329 | types(list[(str, str)]|None): List of (type, state_key) tuples | |
330 | which are used to filter the state fetched. May be None, which | |
331 | matches any key | |
332 | ||
333 | Returns: | |
334 | A deferred dict from (type, state_key) -> state_event | |
335 | """ | |
336 | state_map = yield self.get_state_ids_for_events([event_id], types) | |
272 | 337 | defer.returnValue(state_map[event_id]) |
273 | 338 | |
274 | 339 | @cached(num_args=2, max_entries=10000) |
427 | 492 | full=(types is None), |
428 | 493 | ) |
429 | 494 | |
430 | state_events = yield self._get_events( | |
431 | [ev_id for sd in results.values() for ev_id in sd.values()], | |
432 | get_prev_content=False | |
433 | ) | |
434 | ||
435 | state_events = {e.event_id: e for e in state_events} | |
436 | ||
437 | 495 | # Remove all the entries with None values. The None values were just |
438 | 496 | # used for bookkeeping in the cache. |
439 | 497 | for group, state_dict in results.items(): |
440 | 498 | results[group] = { |
441 | key: state_events[event_id] | |
499 | key: event_id | |
442 | 500 | for key, event_id in state_dict.items() |
443 | if event_id and event_id in state_events | |
501 | if event_id | |
444 | 502 | } |
445 | 503 | |
446 | 504 | defer.returnValue(results) |
472 | 530 | "get_all_new_state_groups", get_all_new_state_groups_txn |
473 | 531 | ) |
474 | 532 | |
475 | def get_state_stream_token(self): | |
476 | return self._state_groups_id_gen.get_current_token() | |
533 | def get_next_state_group(self): | |
534 | return self._state_groups_id_gen.get_next() |
244 | 244 | |
245 | 245 | return self.cursor_to_dict(txn) |
246 | 246 | |
247 | @cached() | |
247 | @cached(max_entries=10000) | |
248 | 248 | def get_destination_retry_timings(self, destination): |
249 | 249 | """Gets the current retry timings (if any) for a given destination. |
250 | 250 | |
386 | 386 | def _cleanup_transactions(self): |
387 | 387 | now = self._clock.time_msec() |
388 | 388 | month_ago = now - 30 * 24 * 60 * 60 * 1000 |
389 | six_hours_ago = now - 6 * 60 * 60 * 1000 | |
389 | 390 | |
390 | 391 | def _cleanup_transactions_txn(txn): |
391 | 392 | txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) |
393 | txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,)) | |
392 | 394 | |
393 | 395 | return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn) |
42 | 42 | @defer.inlineCallbacks |
43 | 43 | def get_current_token(self, direction='f'): |
44 | 44 | push_rules_key, _ = self.store.get_push_rules_stream_token() |
45 | to_device_key = self.store.get_to_device_stream_token() | |
45 | 46 | |
46 | 47 | token = StreamToken( |
47 | 48 | room_key=( |
60 | 61 | yield self.sources["account_data"].get_current_key() |
61 | 62 | ), |
62 | 63 | push_rules_key=push_rules_key, |
64 | to_device_key=to_device_key, | |
63 | 65 | ) |
64 | 66 | defer.returnValue(token) |
153 | 153 | "receipt_key", |
154 | 154 | "account_data_key", |
155 | 155 | "push_rules_key", |
156 | "to_device_key", | |
156 | 157 | )) |
157 | 158 | ): |
158 | 159 | _SEPARATOR = "_" |
189 | 190 | or (int(other.receipt_key) < int(self.receipt_key)) |
190 | 191 | or (int(other.account_data_key) < int(self.account_data_key)) |
191 | 192 | or (int(other.push_rules_key) < int(self.push_rules_key)) |
193 | or (int(other.to_device_key) < int(self.to_device_key)) | |
192 | 194 | ) |
193 | 195 | |
194 | 196 | def copy_and_advance(self, key, new_value): |
268 | 270 | return "t%d-%d" % (self.topological, self.stream) |
269 | 271 | else: |
270 | 272 | return "s%d" % (self.stream,) |
271 | ||
272 | ||
273 | # Some arbitrary constants used for internal API enumerations. Don't rely on | |
274 | # exact values; always pass or compare symbolically | |
275 | class ThirdPartyEntityKind(object): | |
276 | USER = 'user' | |
277 | LOCATION = 'location' |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import re | |
16 | import logging | |
17 | ||
18 | logger = logging.getLogger(__name__) | |
19 | ||
20 | # intentionally looser than what aliases we allow to be registered since | |
21 | # other HSes may allow aliases that we would not | |
22 | ALIAS_RE = re.compile(r"^#.*:.+$") | |
23 | ||
24 | ALL_ALONE = "Empty Room" | |
25 | ||
26 | ||
27 | def calculate_room_name(room_state, user_id, fallback_to_members=True, | |
28 | fallback_to_single_member=True): | |
29 | """ | |
30 | Works out a user-facing name for the given room as per Matrix | |
31 | spec recommendations. | |
32 | Does not yet support internationalisation. | |
33 | Args: | |
34 | room_state: Dictionary of the room's state | |
35 | user_id: The ID of the user to whom the room name is being presented | |
36 | fallback_to_members: If False, return None instead of generating a name | |
37 | based on the room's members if the room has no | |
38 | title or aliases. | |
39 | ||
40 | Returns: | |
41 | (string or None) A human readable name for the room. | |
42 | """ | |
43 | # does it have a name? | |
44 | if ("m.room.name", "") in room_state: | |
45 | m_room_name = room_state[("m.room.name", "")] | |
46 | if m_room_name.content and m_room_name.content["name"]: | |
47 | return m_room_name.content["name"] | |
48 | ||
49 | # does it have a canonical alias? | |
50 | if ("m.room.canonical_alias", "") in room_state: | |
51 | canon_alias = room_state[("m.room.canonical_alias", "")] | |
52 | if ( | |
53 | canon_alias.content and canon_alias.content["alias"] and | |
54 | _looks_like_an_alias(canon_alias.content["alias"]) | |
55 | ): | |
56 | return canon_alias.content["alias"] | |
57 | ||
58 | # at this point we're going to need to search the state by all state keys | |
59 | # for an event type, so rearrange the data structure | |
60 | room_state_bytype = _state_as_two_level_dict(room_state) | |
61 | ||
62 | # right then, any aliases at all? | |
63 | if "m.room.aliases" in room_state_bytype: | |
64 | m_room_aliases = room_state_bytype["m.room.aliases"] | |
65 | if len(m_room_aliases.values()) > 0: | |
66 | first_alias_event = m_room_aliases.values()[0] | |
67 | if first_alias_event.content and first_alias_event.content["aliases"]: | |
68 | the_aliases = first_alias_event.content["aliases"] | |
69 | if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): | |
70 | return the_aliases[0] | |
71 | ||
72 | if not fallback_to_members: | |
73 | return None | |
74 | ||
75 | my_member_event = None | |
76 | if ("m.room.member", user_id) in room_state: | |
77 | my_member_event = room_state[("m.room.member", user_id)] | |
78 | ||
79 | if ( | |
80 | my_member_event is not None and | |
81 | my_member_event.content['membership'] == "invite" | |
82 | ): | |
83 | if ("m.room.member", my_member_event.sender) in room_state: | |
84 | inviter_member_event = room_state[("m.room.member", my_member_event.sender)] | |
85 | if fallback_to_single_member: | |
86 | return "Invite from %s" % (name_from_member_event(inviter_member_event),) | |
87 | else: | |
88 | return None | |
89 | else: | |
90 | return "Room Invite" | |
91 | ||
92 | # we're going to have to generate a name based on who's in the room, | |
93 | # so find out who is in the room that isn't the user. | |
94 | if "m.room.member" in room_state_bytype: | |
95 | all_members = [ | |
96 | ev for ev in room_state_bytype["m.room.member"].values() | |
97 | if ev.content['membership'] == "join" or ev.content['membership'] == "invite" | |
98 | ] | |
99 | # Sort the member events oldest-first so the we name people in the | |
100 | # order the joined (it should at least be deterministic rather than | |
101 | # dictionary iteration order) | |
102 | all_members.sort(key=lambda e: e.origin_server_ts) | |
103 | other_members = [m for m in all_members if m.state_key != user_id] | |
104 | else: | |
105 | other_members = [] | |
106 | all_members = [] | |
107 | ||
108 | if len(other_members) == 0: | |
109 | if len(all_members) == 1: | |
110 | # self-chat, peeked room with 1 participant, | |
111 | # or inbound invite, or outbound 3PID invite. | |
112 | if all_members[0].sender == user_id: | |
113 | if "m.room.third_party_invite" in room_state_bytype: | |
114 | third_party_invites = ( | |
115 | room_state_bytype["m.room.third_party_invite"].values() | |
116 | ) | |
117 | ||
118 | if len(third_party_invites) > 0: | |
119 | # technically third party invite events are not member | |
120 | # events, but they are close enough | |
121 | ||
122 | # FIXME: no they're not - they look nothing like a member; | |
123 | # they have a great big encrypted thing as their name to | |
124 | # prevent leaking the 3PID name... | |
125 | # return "Inviting %s" % ( | |
126 | # descriptor_from_member_events(third_party_invites) | |
127 | # ) | |
128 | return "Inviting email address" | |
129 | else: | |
130 | return ALL_ALONE | |
131 | else: | |
132 | return name_from_member_event(all_members[0]) | |
133 | else: | |
134 | return ALL_ALONE | |
135 | elif len(other_members) == 1 and not fallback_to_single_member: | |
136 | return None | |
137 | else: | |
138 | return descriptor_from_member_events(other_members) | |
139 | ||
140 | ||
141 | def descriptor_from_member_events(member_events): | |
142 | if len(member_events) == 0: | |
143 | return "nobody" | |
144 | elif len(member_events) == 1: | |
145 | return name_from_member_event(member_events[0]) | |
146 | elif len(member_events) == 2: | |
147 | return "%s and %s" % ( | |
148 | name_from_member_event(member_events[0]), | |
149 | name_from_member_event(member_events[1]), | |
150 | ) | |
151 | else: | |
152 | return "%s and %d others" % ( | |
153 | name_from_member_event(member_events[0]), | |
154 | len(member_events) - 1, | |
155 | ) | |
156 | ||
157 | ||
158 | def name_from_member_event(member_event): | |
159 | if ( | |
160 | member_event.content and "displayname" in member_event.content and | |
161 | member_event.content["displayname"] | |
162 | ): | |
163 | return member_event.content["displayname"] | |
164 | return member_event.state_key | |
165 | ||
166 | ||
167 | def _state_as_two_level_dict(state): | |
168 | ret = {} | |
169 | for k, v in state.items(): | |
170 | ret.setdefault(k[0], {})[k[1]] = v | |
171 | return ret | |
172 | ||
173 | ||
174 | def _looks_like_an_alias(string): | |
175 | return ALIAS_RE.match(string) is not None |
180 | 180 | |
181 | 181 | |
182 | 182 | @defer.inlineCallbacks |
183 | def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context): | |
184 | user_ids = set(u[0] for u in user_tuples) | |
185 | event_id_to_state = {} | |
186 | for event_id, context in event_id_to_context.items(): | |
187 | state = yield store.get_events([ | |
188 | e_id | |
189 | for key, e_id in context.current_state_ids.iteritems() | |
190 | if key == (EventTypes.RoomHistoryVisibility, "") | |
191 | or (key[0] == EventTypes.Member and key[1] in user_ids) | |
192 | ]) | |
193 | event_id_to_state[event_id] = state | |
194 | ||
195 | res = yield filter_events_for_clients( | |
196 | store, user_tuples, events, event_id_to_state | |
197 | ) | |
198 | defer.returnValue(res) | |
199 | ||
200 | ||
201 | @defer.inlineCallbacks | |
183 | 202 | def filter_events_for_client(store, user_id, events, is_peeking=False): |
184 | 203 | """ |
185 | 204 | Check which events a user is allowed to see |
114 | 114 | ), |
115 | 115 | ], any_order=True) |
116 | 116 | |
117 | def test_online_to_online_last_active_noop(self): | |
118 | wheel_timer = Mock() | |
119 | user_id = "@foo:bar" | |
120 | now = 5000000 | |
121 | ||
122 | prev_state = UserPresenceState.default(user_id) | |
123 | prev_state = prev_state.copy_and_replace( | |
124 | state=PresenceState.ONLINE, | |
125 | last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10, | |
126 | currently_active=True, | |
127 | ) | |
128 | ||
129 | new_state = prev_state.copy_and_replace( | |
130 | state=PresenceState.ONLINE, | |
131 | last_active_ts=now, | |
132 | ) | |
133 | ||
134 | state, persist_and_notify, federation_ping = handle_update( | |
135 | prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now | |
136 | ) | |
137 | ||
138 | self.assertFalse(persist_and_notify) | |
139 | self.assertTrue(federation_ping) | |
140 | self.assertTrue(state.currently_active) | |
141 | self.assertEquals(new_state.state, state.state) | |
142 | self.assertEquals(new_state.status_msg, state.status_msg) | |
143 | self.assertEquals(state.last_federation_update_ts, now) | |
144 | ||
145 | self.assertEquals(wheel_timer.insert.call_count, 3) | |
146 | wheel_timer.insert.assert_has_calls([ | |
147 | call( | |
148 | now=now, | |
149 | obj=user_id, | |
150 | then=new_state.last_active_ts + IDLE_TIMER | |
151 | ), | |
152 | call( | |
153 | now=now, | |
154 | obj=user_id, | |
155 | then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT | |
156 | ), | |
157 | call( | |
158 | now=now, | |
159 | obj=user_id, | |
160 | then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY | |
161 | ), | |
162 | ], any_order=True) | |
163 | ||
117 | 164 | def test_online_to_online_last_active(self): |
118 | 165 | wheel_timer = Mock() |
119 | 166 | user_id = "@foo:bar" |
61 | 61 | self.on_new_event = mock_notifier.on_new_event |
62 | 62 | |
63 | 63 | self.auth = Mock(spec=[]) |
64 | self.state_handler = Mock() | |
64 | 65 | |
65 | 66 | hs = yield setup_test_homeserver( |
66 | 67 | "test", |
74 | 75 | "set_received_txn_response", |
75 | 76 | "get_destination_retry_timings", |
76 | 77 | ]), |
78 | state_handler=self.state_handler, | |
77 | 79 | handlers=None, |
78 | 80 | notifier=mock_notifier, |
79 | 81 | resource_for_client=Mock(), |
111 | 113 | def get_joined_hosts_for_room(room_id): |
112 | 114 | return set(member.domain for member in self.room_members) |
113 | 115 | self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room |
116 | ||
117 | def get_current_user_in_room(room_id): | |
118 | return set(str(u) for u in self.room_members) | |
119 | self.state_handler.get_current_user_in_room = get_current_user_in_room | |
114 | 120 | |
115 | 121 | self.auth.check_joined_room = check_joined_room |
116 | 122 |
304 | 304 | |
305 | 305 | self.event_id += 1 |
306 | 306 | |
307 | context = EventContext(current_state=state) | |
307 | if state is not None: | |
308 | state_ids = { | |
309 | key: e.event_id for key, e in state.items() | |
310 | } | |
311 | else: | |
312 | state_ids = None | |
313 | ||
314 | context = EventContext() | |
315 | context.current_state_ids = state_ids | |
316 | context.prev_state_ids = state_ids | |
308 | 317 | context.push_actions = push_actions |
309 | 318 | |
310 | 319 | ordering = None |
59 | 59 | self.assertEquals(body, {}) |
60 | 60 | |
61 | 61 | @defer.inlineCallbacks |
62 | def test_events_and_state(self): | |
63 | get = self.get(events="-1", state="-1", timeout="0") | |
62 | def test_events(self): | |
63 | get = self.get(events="-1", timeout="0") | |
64 | 64 | yield self.hs.get_handlers().room_creation_handler.create_room( |
65 | 65 | synapse.types.create_requester(self.user), {} |
66 | 66 | ) |
68 | 68 | self.assertEquals(code, 200) |
69 | 69 | self.assertEquals(body["events"]["field_names"], [ |
70 | 70 | "position", "internal", "json", "state_group" |
71 | ]) | |
72 | self.assertEquals(body["state_groups"]["field_names"], [ | |
73 | "position", "room_id", "event_id" | |
74 | ]) | |
75 | self.assertEquals(body["state_group_state"]["field_names"], [ | |
76 | "position", "type", "state_key", "event_id" | |
77 | 71 | ]) |
78 | 72 | |
79 | 73 | @defer.inlineCallbacks |
1031 | 1031 | |
1032 | 1032 | @defer.inlineCallbacks |
1033 | 1033 | def test_topo_token_is_accepted(self): |
1034 | token = "t1-0_0_0_0_0_0" | |
1034 | token = "t1-0_0_0_0_0_0_0" | |
1035 | 1035 | (code, response) = yield self.mock_resource.trigger_get( |
1036 | 1036 | "/rooms/%s/messages?access_token=x&from=%s" % |
1037 | 1037 | (self.room_id, token)) |
1043 | 1043 | |
1044 | 1044 | @defer.inlineCallbacks |
1045 | 1045 | def test_stream_token_is_accepted_for_fwd_pagianation(self): |
1046 | token = "s0_0_0_0_0_0" | |
1046 | token = "s0_0_0_0_0_0_0" | |
1047 | 1047 | (code, response) = yield self.mock_resource.trigger_get( |
1048 | 1048 | "/rooms/%s/messages?access_token=x&from=%s" % |
1049 | 1049 | (self.room_id, token)) |
77 | 77 | ) |
78 | 78 | )] |
79 | 79 | ) |
80 | ||
81 | @defer.inlineCallbacks | |
82 | def test_room_hosts(self): | |
83 | yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) | |
84 | ||
85 | self.assertEquals( | |
86 | {"test"}, | |
87 | (yield self.store.get_joined_hosts_for_room(self.room.to_string())) | |
88 | ) | |
89 | ||
90 | # Should still have just one host after second join from it | |
91 | yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN) | |
92 | ||
93 | self.assertEquals( | |
94 | {"test"}, | |
95 | (yield self.store.get_joined_hosts_for_room(self.room.to_string())) | |
96 | ) | |
97 | ||
98 | # Should now have two hosts after join from other host | |
99 | yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN) | |
100 | ||
101 | self.assertEquals( | |
102 | {"test", "elsewhere"}, | |
103 | (yield self.store.get_joined_hosts_for_room(self.room.to_string())) | |
104 | ) | |
105 | ||
106 | # Should still have both hosts | |
107 | yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE) | |
108 | ||
109 | self.assertEquals( | |
110 | {"test", "elsewhere"}, | |
111 | (yield self.store.get_joined_hosts_for_room(self.room.to_string())) | |
112 | ) | |
113 | ||
114 | # Should have only one host after other leaves | |
115 | yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE) | |
116 | ||
117 | self.assertEquals( | |
118 | {"test"}, | |
119 | (yield self.store.get_joined_hosts_for_room(self.room.to_string())) | |
120 | ) |
66 | 66 | self._event_to_state_group = {} |
67 | 67 | self._group_to_state = {} |
68 | 68 | |
69 | self._event_id_to_event = {} | |
70 | ||
69 | 71 | self._next_group = 1 |
70 | 72 | |
71 | def get_state_groups(self, room_id, event_ids): | |
73 | def get_state_groups_ids(self, room_id, event_ids): | |
72 | 74 | groups = {} |
73 | 75 | for event_id in event_ids: |
74 | 76 | group = self._event_to_state_group.get(event_id) |
78 | 80 | return defer.succeed(groups) |
79 | 81 | |
80 | 82 | def store_state_groups(self, event, context): |
81 | if context.current_state is None: | |
83 | if context.current_state_ids is None: | |
82 | 84 | return |
83 | 85 | |
84 | state_events = context.current_state | |
85 | ||
86 | if event.is_state(): | |
87 | state_events[(event.type, event.state_key)] = event | |
88 | ||
89 | state_group = context.state_group | |
90 | if not state_group: | |
91 | state_group = self._next_group | |
92 | self._next_group += 1 | |
93 | ||
94 | self._group_to_state[state_group] = state_events.values() | |
95 | ||
96 | self._event_to_state_group[event.event_id] = state_group | |
86 | state_events = dict(context.current_state_ids) | |
87 | ||
88 | self._group_to_state[context.state_group] = state_events | |
89 | self._event_to_state_group[event.event_id] = context.state_group | |
90 | ||
91 | def get_events(self, event_ids, **kwargs): | |
92 | return { | |
93 | e_id: self._event_id_to_event[e_id] for e_id in event_ids | |
94 | if e_id in self._event_id_to_event | |
95 | } | |
96 | ||
97 | def register_events(self, events): | |
98 | for e in events: | |
99 | self._event_id_to_event[e.event_id] = e | |
97 | 100 | |
98 | 101 | |
99 | 102 | class DictObj(dict): |
135 | 138 | def setUp(self): |
136 | 139 | self.store = Mock( |
137 | 140 | spec_set=[ |
138 | "get_state_groups", | |
141 | "get_state_groups_ids", | |
139 | 142 | "add_event_hashes", |
143 | "get_events", | |
144 | "get_next_state_group", | |
140 | 145 | ] |
141 | 146 | ) |
142 | 147 | hs = Mock(spec_set=[ |
147 | 152 | hs.get_clock.return_value = MockClock() |
148 | 153 | hs.get_auth.return_value = Auth(hs) |
149 | 154 | |
155 | self.store.get_next_state_group.side_effect = Mock | |
156 | ||
150 | 157 | self.state = StateHandler(hs) |
151 | 158 | self.event_id = 0 |
152 | 159 | |
186 | 193 | ) |
187 | 194 | |
188 | 195 | store = StateGroupStore() |
189 | self.store.get_state_groups.side_effect = store.get_state_groups | |
196 | self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids | |
190 | 197 | |
191 | 198 | context_store = {} |
192 | 199 | |
195 | 202 | store.store_state_groups(event, context) |
196 | 203 | context_store[event.event_id] = context |
197 | 204 | |
198 | self.assertEqual(2, len(context_store["D"].current_state)) | |
205 | self.assertEqual(2, len(context_store["D"].prev_state_ids)) | |
199 | 206 | |
200 | 207 | @defer.inlineCallbacks |
201 | 208 | def test_branch_basic_conflict(self): |
238 | 245 | ) |
239 | 246 | |
240 | 247 | store = StateGroupStore() |
241 | self.store.get_state_groups.side_effect = store.get_state_groups | |
248 | self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids | |
249 | self.store.get_events = store.get_events | |
250 | store.register_events(graph.walk()) | |
242 | 251 | |
243 | 252 | context_store = {} |
244 | 253 | |
249 | 258 | |
250 | 259 | self.assertSetEqual( |
251 | 260 | {"START", "A", "C"}, |
252 | {e.event_id for e in context_store["D"].current_state.values()} | |
261 | {e_id for e_id in context_store["D"].prev_state_ids.values()} | |
253 | 262 | ) |
254 | 263 | |
255 | 264 | @defer.inlineCallbacks |
302 | 311 | ) |
303 | 312 | |
304 | 313 | store = StateGroupStore() |
305 | self.store.get_state_groups.side_effect = store.get_state_groups | |
314 | self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids | |
315 | self.store.get_events = store.get_events | |
316 | store.register_events(graph.walk()) | |
306 | 317 | |
307 | 318 | context_store = {} |
308 | 319 | |
313 | 324 | |
314 | 325 | self.assertSetEqual( |
315 | 326 | {"START", "A", "B", "C"}, |
316 | {e.event_id for e in context_store["E"].current_state.values()} | |
327 | {e for e in context_store["E"].prev_state_ids.values()} | |
317 | 328 | ) |
318 | 329 | |
319 | 330 | @defer.inlineCallbacks |
383 | 394 | graph = Graph(nodes, edges) |
384 | 395 | |
385 | 396 | store = StateGroupStore() |
386 | self.store.get_state_groups.side_effect = store.get_state_groups | |
397 | self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids | |
398 | self.store.get_events = store.get_events | |
399 | store.register_events(graph.walk()) | |
387 | 400 | |
388 | 401 | context_store = {} |
389 | 402 | |
394 | 407 | |
395 | 408 | self.assertSetEqual( |
396 | 409 | {"A1", "A2", "A3", "A5", "B"}, |
397 | {e.event_id for e in context_store["D"].current_state.values()} | |
410 | {e for e in context_store["D"].prev_state_ids.values()} | |
398 | 411 | ) |
399 | 412 | |
400 | 413 | def _add_depths(self, nodes, edges): |
423 | 436 | event, old_state=old_state |
424 | 437 | ) |
425 | 438 | |
426 | for k, v in context.current_state.items(): | |
427 | type, state_key = k | |
428 | self.assertEqual(type, v.type) | |
429 | self.assertEqual(state_key, v.state_key) | |
430 | ||
431 | 439 | self.assertEqual( |
432 | set(old_state), set(context.current_state.values()) | |
433 | ) | |
434 | ||
435 | self.assertIsNone(context.state_group) | |
440 | set(e.event_id for e in old_state), set(context.current_state_ids.values()) | |
441 | ) | |
442 | ||
443 | self.assertIsNotNone(context.state_group) | |
436 | 444 | |
437 | 445 | @defer.inlineCallbacks |
438 | 446 | def test_annotate_with_old_state(self): |
448 | 456 | event, old_state=old_state |
449 | 457 | ) |
450 | 458 | |
451 | for k, v in context.current_state.items(): | |
452 | type, state_key = k | |
453 | self.assertEqual(type, v.type) | |
454 | self.assertEqual(state_key, v.state_key) | |
455 | ||
456 | 459 | self.assertEqual( |
457 | set(old_state), | |
458 | set(context.current_state.values()) | |
459 | ) | |
460 | ||
461 | self.assertIsNone(context.state_group) | |
460 | set(e.event_id for e in old_state), set(context.prev_state_ids.values()) | |
461 | ) | |
462 | 462 | |
463 | 463 | @defer.inlineCallbacks |
464 | 464 | def test_trivial_annotate_message(self): |
472 | 472 | |
473 | 473 | group_name = "group_name_1" |
474 | 474 | |
475 | self.store.get_state_groups.return_value = { | |
476 | group_name: old_state, | |
475 | self.store.get_state_groups_ids.return_value = { | |
476 | group_name: {(e.type, e.state_key): e.event_id for e in old_state}, | |
477 | 477 | } |
478 | 478 | |
479 | 479 | context = yield self.state.compute_event_context(event) |
480 | ||
481 | for k, v in context.current_state.items(): | |
482 | type, state_key = k | |
483 | self.assertEqual(type, v.type) | |
484 | self.assertEqual(state_key, v.state_key) | |
485 | 480 | |
486 | 481 | self.assertEqual( |
487 | 482 | set([e.event_id for e in old_state]), |
488 | set([e.event_id for e in context.current_state.values()]) | |
483 | set(context.current_state_ids.values()) | |
489 | 484 | ) |
490 | 485 | |
491 | 486 | self.assertEqual(group_name, context.state_group) |
502 | 497 | |
503 | 498 | group_name = "group_name_1" |
504 | 499 | |
505 | self.store.get_state_groups.return_value = { | |
506 | group_name: old_state, | |
500 | self.store.get_state_groups_ids.return_value = { | |
501 | group_name: {(e.type, e.state_key): e.event_id for e in old_state}, | |
507 | 502 | } |
508 | 503 | |
509 | 504 | context = yield self.state.compute_event_context(event) |
510 | ||
511 | for k, v in context.current_state.items(): | |
512 | type, state_key = k | |
513 | self.assertEqual(type, v.type) | |
514 | self.assertEqual(state_key, v.state_key) | |
515 | 505 | |
516 | 506 | self.assertEqual( |
517 | 507 | set([e.event_id for e in old_state]), |
518 | set([e.event_id for e in context.current_state.values()]) | |
519 | ) | |
520 | ||
521 | self.assertIsNone(context.state_group) | |
508 | set(context.prev_state_ids.values()) | |
509 | ) | |
510 | ||
511 | self.assertIsNotNone(context.state_group) | |
522 | 512 | |
523 | 513 | @defer.inlineCallbacks |
524 | 514 | def test_resolve_message_conflict(self): |
542 | 532 | create_event(type="test4", state_key=""), |
543 | 533 | ] |
544 | 534 | |
535 | store = StateGroupStore() | |
536 | store.register_events(old_state_1) | |
537 | store.register_events(old_state_2) | |
538 | self.store.get_events = store.get_events | |
539 | ||
545 | 540 | context = yield self._get_context(event, old_state_1, old_state_2) |
546 | 541 | |
547 | self.assertEqual(len(context.current_state), 6) | |
548 | ||
549 | self.assertIsNone(context.state_group) | |
542 | self.assertEqual(len(context.current_state_ids), 6) | |
543 | ||
544 | self.assertIsNotNone(context.state_group) | |
550 | 545 | |
551 | 546 | @defer.inlineCallbacks |
552 | 547 | def test_resolve_state_conflict(self): |
570 | 565 | create_event(type="test4", state_key=""), |
571 | 566 | ] |
572 | 567 | |
568 | store = StateGroupStore() | |
569 | store.register_events(old_state_1) | |
570 | store.register_events(old_state_2) | |
571 | self.store.get_events = store.get_events | |
572 | ||
573 | 573 | context = yield self._get_context(event, old_state_1, old_state_2) |
574 | 574 | |
575 | self.assertEqual(len(context.current_state), 6) | |
576 | ||
577 | self.assertIsNone(context.state_group) | |
575 | self.assertEqual(len(context.current_state_ids), 6) | |
576 | ||
577 | self.assertIsNotNone(context.state_group) | |
578 | 578 | |
579 | 579 | @defer.inlineCallbacks |
580 | 580 | def test_standard_depth_conflict(self): |
605 | 605 | create_event(type="test1", state_key="1", depth=2), |
606 | 606 | ] |
607 | 607 | |
608 | store = StateGroupStore() | |
609 | store.register_events(old_state_1) | |
610 | store.register_events(old_state_2) | |
611 | self.store.get_events = store.get_events | |
612 | ||
608 | 613 | context = yield self._get_context(event, old_state_1, old_state_2) |
609 | 614 | |
610 | self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) | |
615 | self.assertEqual( | |
616 | old_state_2[2].event_id, context.current_state_ids[("test1", "1")] | |
617 | ) | |
611 | 618 | |
612 | 619 | # Reverse the depth to make sure we are actually using the depths |
613 | 620 | # during state resolution. |
624 | 631 | create_event(type="test1", state_key="1", depth=1), |
625 | 632 | ] |
626 | 633 | |
634 | store.register_events(old_state_1) | |
635 | store.register_events(old_state_2) | |
636 | ||
627 | 637 | context = yield self._get_context(event, old_state_1, old_state_2) |
628 | 638 | |
629 | self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) | |
639 | self.assertEqual( | |
640 | old_state_1[2].event_id, context.current_state_ids[("test1", "1")] | |
641 | ) | |
630 | 642 | |
631 | 643 | def _get_context(self, event, old_state_1, old_state_2): |
632 | 644 | group_name_1 = "group_name_1" |
633 | 645 | group_name_2 = "group_name_2" |
634 | 646 | |
635 | self.store.get_state_groups.return_value = { | |
636 | group_name_1: old_state_1, | |
637 | group_name_2: old_state_2, | |
647 | self.store.get_state_groups_ids.return_value = { | |
648 | group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, | |
649 | group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, | |
638 | 650 | } |
639 | 651 | |
640 | 652 | return self.state.compute_event_context(event) |