Codebase list matrix-synapse / c1c15ad
Imported Upstream version 0.17.1 Erik Johnston 7 years ago
98 changed file(s) with 3357 addition(s) and 1510 deletion(s). Raw diff Collapse all Expand all
0 Changes in synapse v0.17.1 (2016-08-24)
1 =======================================
2
3 Changes:
4
5 * Delete old received_transactions rows (PR #1038)
6 * Pass through user-supplied content in /join/$room_id (PR #1039)
7
8
9 Bug fixes:
10
11 * Fix bug with backfill (PR #1040)
12
13
14 Changes in synapse v0.17.1-rc1 (2016-08-22)
15 ===========================================
16
17 Features:
18
19 * Add notification API (PR #1028)
20
21
22 Changes:
23
24 * Don't print stack traces when failing to get remote keys (PR #996)
25 * Various federation /event/ perf improvements (PR #998)
26 * Only process one local membership event per room at a time (PR #1005)
27 * Move default display name push rule (PR #1011, #1023)
28 * Fix up preview URL API. Add tests. (PR #1015)
29 * Set ``Content-Security-Policy`` on media repo (PR #1021)
30 * Make notify_interested_services faster (PR #1022)
31 * Add usage stats to prometheus monitoring (PR #1037)
32
33
34 Bug fixes:
35
36 * Fix token login (PR #993)
37 * Fix CAS login (PR #994, #995)
38 * Fix /sync to not clobber status_msg (PR #997)
39 * Fix redacted state events to include prev_content (PR #1003)
40 * Fix some bugs in the auth/ldap handler (PR #1007)
41 * Fix backfill request to limit URI length, so that remotes don't reject the
42 requests due to path length limits (PR #1012)
43 * Fix AS push code to not send duplicate events (PR #1025)
44
45
46
047 Changes in synapse v0.17.0 (2016-08-08)
148 =======================================
249
9494 System requirements:
9595 - POSIX-compliant system (tested on Linux & OS X)
9696 - Python 2.7
97 - At least 512 MB RAM.
97 - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
9898
9999 Synapse is written in python but some of the libraries is uses are written in
100100 C. So before we can install synapse itself we need a working C compiler and the
0 Scaling synapse via workers
1 ---------------------------
2
3 Synapse has experimental support for splitting out functionality into
4 multiple separate python processes, helping greatly with scalability. These
5 processes are called 'workers', and are (eventually) intended to scale
6 horizontally independently.
7
8 All processes continue to share the same database instance, and as such, workers
9 only work with postgres based synapse deployments (sharing a single sqlite
10 across multiple processes is a recipe for disaster, plus you should be using
11 postgres anyway if you care about scalability).
12
13 The workers communicate with the master synapse process via a synapse-specific
14 HTTP protocol called 'replication' - analogous to MySQL or Postgres style
15 database replication; feeding a stream of relevant data to the workers so they
16 can be kept in sync with the main synapse process and database state.
17
18 To enable workers, you need to add a replication listener to the master synapse, e.g.::
19
20 listeners:
21 - port: 9092
22 bind_address: '127.0.0.1'
23 type: http
24 tls: false
25 x_forwarded: false
26 resources:
27 - names: [replication]
28 compress: false
29
30 Under **no circumstances** should this replication API listener be exposed to the
31 public internet; it currently implements no authentication whatsoever and is
32 unencrypted HTTP.
33
34 You then create a set of configs for the various worker processes. These should be
35 worker configuration files should be stored in a dedicated subdirectory, to allow
36 synctl to manipulate them.
37
38 The current available worker applications are:
39 * synapse.app.pusher - handles sending push notifications to sygnal and email
40 * synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
41 * synapse.app.appservice - handles output traffic to Application Services
42 * synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
43 * synapse.app.media_repository - handles the media repository.
44
45 Each worker configuration file inherits the configuration of the main homeserver
46 configuration file. You can then override configuration specific to that worker,
47 e.g. the HTTP listener that it provides (if any); logging configuration; etc.
48 You should minimise the number of overrides though to maintain a usable config.
49
50 You must specify the type of worker application (worker_app) and the replication
51 endpoint that it's talking to on the main synapse process (worker_replication_url).
52
53 For instance::
54
55 worker_app: synapse.app.synchrotron
56
57 # The replication listener on the synapse to talk to.
58 worker_replication_url: http://127.0.0.1:9092/_synapse/replication
59
60 worker_listeners:
61 - type: http
62 port: 8083
63 resources:
64 - names:
65 - client
66
67 worker_daemonize: True
68 worker_pid_file: /home/matrix/synapse/synchrotron.pid
69 worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
70
71 ...is a full configuration for a synchrotron worker instance, which will expose a
72 plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
73 by the main synapse.
74
75 Obviously you should configure your loadbalancer to route the /sync endpoint to
76 the synchrotron instance(s) in this instance.
77
78 Finally, to actually run your worker-based synapse, you must pass synctl the -a
79 commandline option to tell it to operate on all the worker configurations found
80 in the given directory, e.g.::
81
82 synctl -a $CONFIG/workers start
83
84 Currently one should always restart all workers when restarting or upgrading
85 synapse, unless you explicitly know it's safe not to. For instance, restarting
86 synapse without restarting all the synchrotrons may result in broken typing
87 notifications.
88
89 To manipulate a specific worker, you pass the -w option to synctl::
90
91 synctl -w $CONFIG/workers/synchrotron.yaml restart
92
93 All of the above is highly experimental and subject to change as Synapse evolves,
94 but documenting it here to help folks needing highly scalable Synapses similar
95 to the one running matrix.org!
96
1313 tox -e py27 --notest -v
1414
1515 TOX_BIN=$TOX_DIR/py27/bin
16 $TOX_BIN/pip install setuptools
1617 python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
1718 $TOX_BIN/pip install lxml
1819 $TOX_BIN/pip install psycopg2
2424 tox --notest -e py27
2525 TOX_BIN=$WORKSPACE/.tox/py27/bin
2626 python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
27 $TOX_BIN/pip install lxml
2728
2829 tox -e py27
1515 """ This is a reference implementation of a Matrix home server.
1616 """
1717
18 __version__ = "0.17.0"
18 __version__ = "0.17.1"
674674 try:
675675 macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
676676
677 user_prefix = "user_id = "
678 user = None
679 user_id = None
680 guest = False
681 for caveat in macaroon.caveats:
682 if caveat.caveat_id.startswith(user_prefix):
683 user_id = caveat.caveat_id[len(user_prefix):]
684 user = UserID.from_string(user_id)
685 elif caveat.caveat_id == "guest = true":
686 guest = True
677 user_id = self.get_user_id_from_macaroon(macaroon)
678 user = UserID.from_string(user_id)
687679
688680 self.validate_macaroon(
689681 macaroon, rights, self.hs.config.expire_access_token,
690682 user_id=user_id,
691683 )
692684
693 if user is None:
694 raise AuthError(
695 self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
696 errcode=Codes.UNKNOWN_TOKEN
697 )
685 guest = False
686 for caveat in macaroon.caveats:
687 if caveat.caveat_id == "guest = true":
688 guest = True
698689
699690 if guest:
700691 ret = {
742733 errcode=Codes.UNKNOWN_TOKEN
743734 )
744735
736 def get_user_id_from_macaroon(self, macaroon):
737 """Retrieve the user_id given by the caveats on the macaroon.
738
739 Does *not* validate the macaroon.
740
741 Args:
742 macaroon (pymacaroons.Macaroon): The macaroon to validate
743
744 Returns:
745 (str) user id
746
747 Raises:
748 AuthError if there is no user_id caveat in the macaroon
749 """
750 user_prefix = "user_id = "
751 for caveat in macaroon.caveats:
752 if caveat.caveat_id.startswith(user_prefix):
753 return caveat.caveat_id[len(user_prefix):]
754 raise AuthError(
755 self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
756 errcode=Codes.UNKNOWN_TOKEN
757 )
758
745759 def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
746760 """
747761 validate that a Macaroon is understood by and was signed by this server.
753767 verify_expiry(bool): Whether to verify whether the macaroon has expired.
754768 This should really always be True, but no clients currently implement
755769 token refresh, so we can't enforce expiry yet.
770 user_id (str): The user_id required
756771 """
757772 v = pymacaroons.Verifier()
758773 v.satisfy_exact("gen = 1")
0 #!/usr/bin/env python
1 # -*- coding: utf-8 -*-
2 # Copyright 2016 OpenMarket Ltd
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 import synapse
17
18 from synapse.server import HomeServer
19 from synapse.config._base import ConfigError
20 from synapse.config.logger import setup_logging
21 from synapse.config.homeserver import HomeServerConfig
22 from synapse.http.site import SynapseSite
23 from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
24 from synapse.replication.slave.storage.directory import DirectoryStore
25 from synapse.replication.slave.storage.events import SlavedEventStore
26 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
27 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
28 from synapse.storage.engines import create_engine
29 from synapse.util.async import sleep
30 from synapse.util.httpresourcetree import create_resource_tree
31 from synapse.util.logcontext import LoggingContext
32 from synapse.util.manhole import manhole
33 from synapse.util.rlimit import change_resource_limit
34 from synapse.util.versionstring import get_version_string
35
36 from twisted.internet import reactor, defer
37 from twisted.web.resource import Resource
38
39 from daemonize import Daemonize
40
41 import sys
42 import logging
43 import gc
44
45 logger = logging.getLogger("synapse.app.appservice")
46
47
48 class AppserviceSlaveStore(
49 DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
50 SlavedRegistrationStore,
51 ):
52 pass
53
54
55 class AppserviceServer(HomeServer):
56 def get_db_conn(self, run_new_connection=True):
57 # Any param beginning with cp_ is a parameter for adbapi, and should
58 # not be passed to the database engine.
59 db_params = {
60 k: v for k, v in self.db_config.get("args", {}).items()
61 if not k.startswith("cp_")
62 }
63 db_conn = self.database_engine.module.connect(**db_params)
64
65 if run_new_connection:
66 self.database_engine.on_new_connection(db_conn)
67 return db_conn
68
69 def setup(self):
70 logger.info("Setting up.")
71 self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
72 logger.info("Finished setting up.")
73
74 def _listen_http(self, listener_config):
75 port = listener_config["port"]
76 bind_address = listener_config.get("bind_address", "")
77 site_tag = listener_config.get("tag", port)
78 resources = {}
79 for res in listener_config["resources"]:
80 for name in res["names"]:
81 if name == "metrics":
82 resources[METRICS_PREFIX] = MetricsResource(self)
83
84 root_resource = create_resource_tree(resources, Resource())
85 reactor.listenTCP(
86 port,
87 SynapseSite(
88 "synapse.access.http.%s" % (site_tag,),
89 site_tag,
90 listener_config,
91 root_resource,
92 ),
93 interface=bind_address
94 )
95 logger.info("Synapse appservice now listening on port %d", port)
96
97 def start_listening(self, listeners):
98 for listener in listeners:
99 if listener["type"] == "http":
100 self._listen_http(listener)
101 elif listener["type"] == "manhole":
102 reactor.listenTCP(
103 listener["port"],
104 manhole(
105 username="matrix",
106 password="rabbithole",
107 globals={"hs": self},
108 ),
109 interface=listener.get("bind_address", '127.0.0.1')
110 )
111 else:
112 logger.warn("Unrecognized listener type: %s", listener["type"])
113
114 @defer.inlineCallbacks
115 def replicate(self):
116 http_client = self.get_simple_http_client()
117 store = self.get_datastore()
118 replication_url = self.config.worker_replication_url
119 appservice_handler = self.get_application_service_handler()
120
121 @defer.inlineCallbacks
122 def replicate(results):
123 stream = results.get("events")
124 if stream:
125 max_stream_id = stream["position"]
126 yield appservice_handler.notify_interested_services(max_stream_id)
127
128 while True:
129 try:
130 args = store.stream_positions()
131 args["timeout"] = 30000
132 result = yield http_client.get_json(replication_url, args=args)
133 yield store.process_replication(result)
134 replicate(result)
135 except:
136 logger.exception("Error replicating from %r", replication_url)
137 yield sleep(30)
138
139
140 def start(config_options):
141 try:
142 config = HomeServerConfig.load_config(
143 "Synapse appservice", config_options
144 )
145 except ConfigError as e:
146 sys.stderr.write("\n" + e.message + "\n")
147 sys.exit(1)
148
149 assert config.worker_app == "synapse.app.appservice"
150
151 setup_logging(config.worker_log_config, config.worker_log_file)
152
153 database_engine = create_engine(config.database_config)
154
155 if config.notify_appservices:
156 sys.stderr.write(
157 "\nThe appservices must be disabled in the main synapse process"
158 "\nbefore they can be run in a separate worker."
159 "\nPlease add ``notify_appservices: false`` to the main config"
160 "\n"
161 )
162 sys.exit(1)
163
164 # Force the pushers to start since they will be disabled in the main config
165 config.notify_appservices = True
166
167 ps = AppserviceServer(
168 config.server_name,
169 db_config=config.database_config,
170 config=config,
171 version_string="Synapse/" + get_version_string(synapse),
172 database_engine=database_engine,
173 )
174
175 ps.setup()
176 ps.start_listening(config.worker_listeners)
177
178 def run():
179 with LoggingContext("run"):
180 logger.info("Running")
181 change_resource_limit(config.soft_file_limit)
182 if config.gc_thresholds:
183 gc.set_threshold(*config.gc_thresholds)
184 reactor.run()
185
186 def start():
187 ps.replicate()
188 ps.get_datastore().start_profiling()
189
190 reactor.callWhenRunning(start)
191
192 if config.worker_daemonize:
193 daemon = Daemonize(
194 app="synapse-appservice",
195 pid=config.worker_pid_file,
196 action=run,
197 auto_close_fds=False,
198 verbose=True,
199 logger=logger,
200 )
201 daemon.start()
202 else:
203 run()
204
205
206 if __name__ == '__main__':
207 with LoggingContext("main"):
208 start(sys.argv[1:])
5050 from synapse.config.homeserver import HomeServerConfig
5151 from synapse.crypto import context_factory
5252 from synapse.util.logcontext import LoggingContext
53 from synapse.metrics import register_memory_metrics
53 from synapse.metrics import register_memory_metrics, get_metrics_for
5454 from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
5555 from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
5656 from synapse.federation.transport.server import TransportLayerServer
384384
385385 start_time = hs.get_clock().time()
386386
387 stats = {}
388
387389 @defer.inlineCallbacks
388390 def phone_stats_home():
389391 logger.info("Gathering stats for reporting")
392394 if uptime < 0:
393395 uptime = 0
394396
395 stats = {}
397 # If the stats directory is empty then this is the first time we've
398 # reported stats.
399 first_time = not stats
400
396401 stats["homeserver"] = hs.config.server_name
397402 stats["timestamp"] = now
398403 stats["uptime_seconds"] = uptime
405410 daily_messages = yield hs.get_datastore().count_daily_messages()
406411 if daily_messages is not None:
407412 stats["daily_messages"] = daily_messages
413 else:
414 stats.pop("daily_messages", None)
415
416 if first_time:
417 # Add callbacks to report the synapse stats as metrics whenever
418 # prometheus requests them, typically every 30s.
419 # As some of the stats are expensive to calculate we only update
420 # them when synapse phones home to matrix.org every 24 hours.
421 metrics = get_metrics_for("synapse.usage")
422 metrics.add_callback("timestamp", lambda: stats["timestamp"])
423 metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
424 metrics.add_callback("total_users", lambda: stats["total_users"])
425 metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
426 metrics.add_callback(
427 "daily_active_users", lambda: stats["daily_active_users"]
428 )
429 metrics.add_callback(
430 "daily_messages", lambda: stats.get("daily_messages", 0)
431 )
408432
409433 logger.info("Reporting stats to matrix.org: %s" % (stats,))
410434 try:
0 #!/usr/bin/env python
1 # -*- coding: utf-8 -*-
2 # Copyright 2016 OpenMarket Ltd
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 import synapse
17
18 from synapse.config._base import ConfigError
19 from synapse.config.homeserver import HomeServerConfig
20 from synapse.config.logger import setup_logging
21 from synapse.http.site import SynapseSite
22 from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
23 from synapse.replication.slave.storage._base import BaseSlavedStore
24 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
25 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
26 from synapse.rest.media.v0.content_repository import ContentRepoResource
27 from synapse.rest.media.v1.media_repository import MediaRepositoryResource
28 from synapse.server import HomeServer
29 from synapse.storage.client_ips import ClientIpStore
30 from synapse.storage.engines import create_engine
31 from synapse.storage.media_repository import MediaRepositoryStore
32 from synapse.util.async import sleep
33 from synapse.util.httpresourcetree import create_resource_tree
34 from synapse.util.logcontext import LoggingContext
35 from synapse.util.manhole import manhole
36 from synapse.util.rlimit import change_resource_limit
37 from synapse.util.versionstring import get_version_string
38 from synapse.api.urls import (
39 CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
40 )
41 from synapse.crypto import context_factory
42
43
44 from twisted.internet import reactor, defer
45 from twisted.web.resource import Resource
46
47 from daemonize import Daemonize
48
49 import sys
50 import logging
51 import gc
52
53 logger = logging.getLogger("synapse.app.media_repository")
54
55
56 class MediaRepositorySlavedStore(
57 SlavedApplicationServiceStore,
58 SlavedRegistrationStore,
59 BaseSlavedStore,
60 MediaRepositoryStore,
61 ClientIpStore,
62 ):
63 pass
64
65
66 class MediaRepositoryServer(HomeServer):
67 def get_db_conn(self, run_new_connection=True):
68 # Any param beginning with cp_ is a parameter for adbapi, and should
69 # not be passed to the database engine.
70 db_params = {
71 k: v for k, v in self.db_config.get("args", {}).items()
72 if not k.startswith("cp_")
73 }
74 db_conn = self.database_engine.module.connect(**db_params)
75
76 if run_new_connection:
77 self.database_engine.on_new_connection(db_conn)
78 return db_conn
79
80 def setup(self):
81 logger.info("Setting up.")
82 self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
83 logger.info("Finished setting up.")
84
85 def _listen_http(self, listener_config):
86 port = listener_config["port"]
87 bind_address = listener_config.get("bind_address", "")
88 site_tag = listener_config.get("tag", port)
89 resources = {}
90 for res in listener_config["resources"]:
91 for name in res["names"]:
92 if name == "metrics":
93 resources[METRICS_PREFIX] = MetricsResource(self)
94 elif name == "media":
95 media_repo = MediaRepositoryResource(self)
96 resources.update({
97 MEDIA_PREFIX: media_repo,
98 LEGACY_MEDIA_PREFIX: media_repo,
99 CONTENT_REPO_PREFIX: ContentRepoResource(
100 self, self.config.uploads_path
101 ),
102 })
103
104 root_resource = create_resource_tree(resources, Resource())
105 reactor.listenTCP(
106 port,
107 SynapseSite(
108 "synapse.access.http.%s" % (site_tag,),
109 site_tag,
110 listener_config,
111 root_resource,
112 ),
113 interface=bind_address
114 )
115 logger.info("Synapse media repository now listening on port %d", port)
116
117 def start_listening(self, listeners):
118 for listener in listeners:
119 if listener["type"] == "http":
120 self._listen_http(listener)
121 elif listener["type"] == "manhole":
122 reactor.listenTCP(
123 listener["port"],
124 manhole(
125 username="matrix",
126 password="rabbithole",
127 globals={"hs": self},
128 ),
129 interface=listener.get("bind_address", '127.0.0.1')
130 )
131 else:
132 logger.warn("Unrecognized listener type: %s", listener["type"])
133
134 @defer.inlineCallbacks
135 def replicate(self):
136 http_client = self.get_simple_http_client()
137 store = self.get_datastore()
138 replication_url = self.config.worker_replication_url
139
140 while True:
141 try:
142 args = store.stream_positions()
143 args["timeout"] = 30000
144 result = yield http_client.get_json(replication_url, args=args)
145 yield store.process_replication(result)
146 except:
147 logger.exception("Error replicating from %r", replication_url)
148 yield sleep(5)
149
150
151 def start(config_options):
152 try:
153 config = HomeServerConfig.load_config(
154 "Synapse media repository", config_options
155 )
156 except ConfigError as e:
157 sys.stderr.write("\n" + e.message + "\n")
158 sys.exit(1)
159
160 assert config.worker_app == "synapse.app.media_repository"
161
162 setup_logging(config.worker_log_config, config.worker_log_file)
163
164 database_engine = create_engine(config.database_config)
165
166 tls_server_context_factory = context_factory.ServerContextFactory(config)
167
168 ss = MediaRepositoryServer(
169 config.server_name,
170 db_config=config.database_config,
171 tls_server_context_factory=tls_server_context_factory,
172 config=config,
173 version_string="Synapse/" + get_version_string(synapse),
174 database_engine=database_engine,
175 )
176
177 ss.setup()
178 ss.get_handlers()
179 ss.start_listening(config.worker_listeners)
180
181 def run():
182 with LoggingContext("run"):
183 logger.info("Running")
184 change_resource_limit(config.soft_file_limit)
185 if config.gc_thresholds:
186 gc.set_threshold(*config.gc_thresholds)
187 reactor.run()
188
189 def start():
190 ss.get_datastore().start_profiling()
191 ss.replicate()
192
193 reactor.callWhenRunning(start)
194
195 if config.worker_daemonize:
196 daemon = Daemonize(
197 app="synapse-media-repository",
198 pid=config.worker_pid_file,
199 action=run,
200 auto_close_fds=False,
201 verbose=True,
202 logger=logger,
203 )
204 daemon.start()
205 else:
206 run()
207
208
209 if __name__ == '__main__':
210 with LoggingContext("main"):
211 start(sys.argv[1:])
7979 DataStore.get_profile_displayname.__func__
8080 )
8181
82 # XXX: This is a bit broken because we don't persist forgotten rooms
83 # in a way that they can be streamed. This means that we don't have a
84 # way to invalidate the forgotten rooms cache correctly.
85 # For now we expire the cache every 10 minutes.
86 BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
8782 who_forgot_in_room = (
8883 RoomMemberStore.__dict__["who_forgot_in_room"]
8984 )
167162 store = self.get_datastore()
168163 replication_url = self.config.worker_replication_url
169164 pusher_pool = self.get_pusherpool()
170 clock = self.get_clock()
171165
172166 def stop_pusher(user_id, app_id, pushkey):
173167 key = "%s:%s" % (app_id, pushkey)
219213 min_stream_id, max_stream_id, affected_room_ids
220214 )
221215
222 def expire_broken_caches():
223 store.who_forgot_in_room.invalidate_all()
224
225 next_expire_broken_caches_ms = 0
226216 while True:
227217 try:
228218 args = store.stream_positions()
229219 args["timeout"] = 30000
230220 result = yield http_client.get_json(replication_url, args=args)
231 now_ms = clock.time_msec()
232 if now_ms > next_expire_broken_caches_ms:
233 expire_broken_caches()
234 next_expire_broken_caches_ms = (
235 now_ms + store.BROKEN_CACHE_EXPIRY_MS
236 )
237221 yield store.process_replication(result)
238222 poke_pushers(result)
239223 except:
2525 from synapse.http.server import JsonResource
2626 from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
2727 from synapse.rest.client.v2_alpha import sync
28 from synapse.rest.client.v1 import events
2829 from synapse.replication.slave.storage._base import BaseSlavedStore
2930 from synapse.replication.slave.storage.events import SlavedEventStore
3031 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
7374 BaseSlavedStore,
7475 ClientIpStore, # After BaseSlavedStore because the constructor is different
7576 ):
76 # XXX: This is a bit broken because we don't persist forgotten rooms
77 # in a way that they can be streamed. This means that we don't have a
78 # way to invalidate the forgotten rooms cache correctly.
79 # For now we expire the cache every 10 minutes.
80 BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
8177 who_forgot_in_room = (
8278 RoomMemberStore.__dict__["who_forgot_in_room"]
8379 )
8884 get_presence_list_accepted = PresenceStore.__dict__[
8985 "get_presence_list_accepted"
9086 ]
87 get_presence_list_observers_accepted = PresenceStore.__dict__[
88 "get_presence_list_observers_accepted"
89 ]
90
9191
9292 UPDATE_SYNCING_USERS_MS = 10 * 1000
9393
9494
9595 class SynchrotronPresence(object):
9696 def __init__(self, hs):
97 self.is_mine_id = hs.is_mine_id
9798 self.http_client = hs.get_simple_http_client()
9899 self.store = hs.get_datastore()
99100 self.user_to_num_current_syncs = {}
100101 self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
101102 self.clock = hs.get_clock()
103 self.notifier = hs.get_notifier()
102104
103105 active_presence = self.store.take_presence_startup_info()
104106 self.user_to_current_state = {
118120
119121 reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
120122
121 def set_state(self, user, state):
123 def set_state(self, user, state, ignore_status_msg=False):
122124 # TODO Hows this supposed to work?
123125 pass
124126
125127 get_states = PresenceHandler.get_states.__func__
128 get_state = PresenceHandler.get_state.__func__
129 _get_interested_parties = PresenceHandler._get_interested_parties.__func__
126130 current_state_for_users = PresenceHandler.current_state_for_users.__func__
127131
128132 @defer.inlineCallbacks
193197 self._need_to_send_sync = False
194198 yield self._send_syncing_users_now()
195199
200 @defer.inlineCallbacks
201 def notify_from_replication(self, states, stream_id):
202 parties = yield self._get_interested_parties(
203 states, calculate_remote_hosts=False
204 )
205 room_ids_to_states, users_to_states, _ = parties
206
207 self.notifier.on_new_event(
208 "presence_key", stream_id, rooms=room_ids_to_states.keys(),
209 users=users_to_states.keys()
210 )
211
212 @defer.inlineCallbacks
196213 def process_replication(self, result):
197214 stream = result.get("presence", {"rows": []})
215 states = []
198216 for row in stream["rows"]:
199217 (
200218 position, user_id, state, last_active_ts,
201219 last_federation_update_ts, last_user_sync_ts, status_msg,
202220 currently_active
203221 ) = row
204 self.user_to_current_state[user_id] = UserPresenceState(
222 state = UserPresenceState(
205223 user_id, state, last_active_ts,
206224 last_federation_update_ts, last_user_sync_ts, status_msg,
207225 currently_active
208226 )
227 self.user_to_current_state[user_id] = state
228 states.append(state)
229
230 if states and "position" in stream:
231 stream_id = int(stream["position"])
232 yield self.notify_from_replication(states, stream_id)
209233
210234
211235 class SynchrotronTyping(object):
265289 elif name == "client":
266290 resource = JsonResource(self, canonical_json=False)
267291 sync.register_servlets(self, resource)
292 events.register_servlets(self, resource)
268293 resources.update({
269294 "/_matrix/client/r0": resource,
270295 "/_matrix/client/unstable": resource,
271296 "/_matrix/client/v2_alpha": resource,
297 "/_matrix/client/api/v1": resource,
272298 })
273299
274300 root_resource = create_resource_tree(resources, Resource())
306332 http_client = self.get_simple_http_client()
307333 store = self.get_datastore()
308334 replication_url = self.config.worker_replication_url
309 clock = self.get_clock()
310335 notifier = self.get_notifier()
311336 presence_handler = self.get_presence_handler()
312337 typing_handler = self.get_typing_handler()
313
314 def expire_broken_caches():
315 store.who_forgot_in_room.invalidate_all()
316 store.get_presence_list_accepted.invalidate_all()
317338
318339 def notify_from_stream(
319340 result, stream_name, stream_key, room=None, user=None
376397 result, "typing", "typing_key", room="room_id"
377398 )
378399
379 next_expire_broken_caches_ms = 0
380400 while True:
381401 try:
382402 args = store.stream_positions()
383403 args.update(typing_handler.stream_positions())
384404 args["timeout"] = 30000
385405 result = yield http_client.get_json(replication_url, args=args)
386 now_ms = clock.time_msec()
387 if now_ms > next_expire_broken_caches_ms:
388 expire_broken_caches()
389 next_expire_broken_caches_ms = (
390 now_ms + store.BROKEN_CACHE_EXPIRY_MS
391 )
392406 yield store.process_replication(result)
393407 typing_handler.process_replication(result)
394 presence_handler.process_replication(result)
408 yield presence_handler.process_replication(result)
395409 notify(result)
396410 except:
397411 logger.exception("Error replicating from %r", replication_url)
1313 # limitations under the License.
1414 from synapse.api.constants import EventTypes
1515
16 from twisted.internet import defer
17
1618 import logging
1719 import re
1820
7880 NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
7981
8082 def __init__(self, token, url=None, namespaces=None, hs_token=None,
81 sender=None, id=None):
83 sender=None, id=None, protocols=None):
8284 self.token = token
8385 self.url = url
8486 self.hs_token = hs_token
8587 self.sender = sender
8688 self.namespaces = self._check_namespaces(namespaces)
8789 self.id = id
90 if protocols:
91 self.protocols = set(protocols)
92 else:
93 self.protocols = set()
8894
8995 def _check_namespaces(self, namespaces):
9096 # Sanity check that it is of the form:
137143 return regex_obj["exclusive"]
138144 return False
139145
140 def _matches_user(self, event, member_list):
141 if (hasattr(event, "sender") and
142 self.is_interested_in_user(event.sender)):
143 return True
146 @defer.inlineCallbacks
147 def _matches_user(self, event, store):
148 if not event:
149 defer.returnValue(False)
150
151 if self.is_interested_in_user(event.sender):
152 defer.returnValue(True)
144153 # also check m.room.member state key
145 if (hasattr(event, "type") and event.type == EventTypes.Member
146 and hasattr(event, "state_key")
147 and self.is_interested_in_user(event.state_key)):
148 return True
154 if (event.type == EventTypes.Member and
155 self.is_interested_in_user(event.state_key)):
156 defer.returnValue(True)
157
158 if not store:
159 defer.returnValue(False)
160
161 member_list = yield store.get_users_in_room(event.room_id)
162
149163 # check joined member events
150164 for user_id in member_list:
151165 if self.is_interested_in_user(user_id):
152 return True
153 return False
166 defer.returnValue(True)
167 defer.returnValue(False)
154168
155169 def _matches_room_id(self, event):
156170 if hasattr(event, "room_id"):
157171 return self.is_interested_in_room(event.room_id)
158172 return False
159173
160 def _matches_aliases(self, event, alias_list):
174 @defer.inlineCallbacks
175 def _matches_aliases(self, event, store):
176 if not store or not event:
177 defer.returnValue(False)
178
179 alias_list = yield store.get_aliases_for_room(event.room_id)
161180 for alias in alias_list:
162181 if self.is_interested_in_alias(alias):
163 return True
164 return False
165
166 def is_interested(self, event, restrict_to=None, aliases_for_event=None,
167 member_list=None):
182 defer.returnValue(True)
183 defer.returnValue(False)
184
185 @defer.inlineCallbacks
186 def is_interested(self, event, store=None):
168187 """Check if this service is interested in this event.
169188
170189 Args:
171190 event(Event): The event to check.
172 restrict_to(str): The namespace to restrict regex tests to.
173 aliases_for_event(list): A list of all the known room aliases for
174 this event.
175 member_list(list): A list of all joined user_ids in this room.
191 store(DataStore)
176192 Returns:
177193 bool: True if this service would like to know about this event.
178194 """
179 if aliases_for_event is None:
180 aliases_for_event = []
181 if member_list is None:
182 member_list = []
183
184 if restrict_to and restrict_to not in ApplicationService.NS_LIST:
185 # this is a programming error, so fail early and raise a general
186 # exception
187 raise Exception("Unexpected restrict_to value: %s". restrict_to)
188
189 if not restrict_to:
190 return (self._matches_user(event, member_list)
191 or self._matches_aliases(event, aliases_for_event)
192 or self._matches_room_id(event))
193 elif restrict_to == ApplicationService.NS_ALIASES:
194 return self._matches_aliases(event, aliases_for_event)
195 elif restrict_to == ApplicationService.NS_ROOMS:
196 return self._matches_room_id(event)
197 elif restrict_to == ApplicationService.NS_USERS:
198 return self._matches_user(event, member_list)
195 # Do cheap checks first
196 if self._matches_room_id(event):
197 defer.returnValue(True)
198
199 if (yield self._matches_aliases(event, store)):
200 defer.returnValue(True)
201
202 if (yield self._matches_user(event, store)):
203 defer.returnValue(True)
204
205 defer.returnValue(False)
199206
200207 def is_interested_in_user(self, user_id):
201208 return (
215222 or user_id == self.sender
216223 )
217224
225 def is_interested_in_protocol(self, protocol):
226 return protocol in self.protocols
227
218228 def is_exclusive_alias(self, alias):
219229 return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
220230
1616 from synapse.api.errors import CodeMessageException
1717 from synapse.http.client import SimpleHttpClient
1818 from synapse.events.utils import serialize_event
19 from synapse.types import ThirdPartyEntityKind
1920
2021 import logging
2122 import urllib
2223
2324 logger = logging.getLogger(__name__)
25
26
27 def _is_valid_3pe_result(r, field):
28 if not isinstance(r, dict):
29 return False
30
31 for k in (field, "protocol"):
32 if k not in r:
33 return False
34 if not isinstance(r[k], str):
35 return False
36
37 if "fields" not in r:
38 return False
39 fields = r["fields"]
40 if not isinstance(fields, dict):
41 return False
42 for k in fields.keys():
43 if not isinstance(fields[k], str):
44 return False
45
46 return True
2447
2548
2649 class ApplicationServiceApi(SimpleHttpClient):
7194 defer.returnValue(False)
7295
7396 @defer.inlineCallbacks
97 def query_3pe(self, service, kind, protocol, fields):
98 if kind == ThirdPartyEntityKind.USER:
99 uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
100 required_field = "userid"
101 elif kind == ThirdPartyEntityKind.LOCATION:
102 uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
103 required_field = "alias"
104 else:
105 raise ValueError(
106 "Unrecognised 'kind' argument %r to query_3pe()", kind
107 )
108
109 try:
110 response = yield self.get_json(uri, fields)
111 if not isinstance(response, list):
112 logger.warning(
113 "query_3pe to %s returned an invalid response %r",
114 uri, response
115 )
116 defer.returnValue([])
117
118 ret = []
119 for r in response:
120 if _is_valid_3pe_result(r, field=required_field):
121 ret.append(r)
122 else:
123 logger.warning(
124 "query_3pe to %s returned an invalid result %r",
125 uri, r
126 )
127
128 defer.returnValue(ret)
129 except Exception as ex:
130 logger.warning("query_3pe to %s threw exception %s", uri, ex)
131 defer.returnValue([])
132
133 @defer.inlineCallbacks
74134 def push_bulk(self, service, events, txn_id=None):
75135 events = self._serialize(events)
76136
4747 This is all tied together by the AppServiceScheduler which DIs the required
4848 components.
4949 """
50 from twisted.internet import defer
5051
5152 from synapse.appservice import ApplicationServiceState
52 from twisted.internet import defer
53 from synapse.util.logcontext import preserve_fn
54 from synapse.util.metrics import Measure
55
5356 import logging
5457
5558 logger = logging.getLogger(__name__)
7275 self.txn_ctrl = _TransactionController(
7376 self.clock, self.store, self.as_api, create_recoverer
7477 )
75 self.queuer = _ServiceQueuer(self.txn_ctrl)
78 self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
7679
7780 @defer.inlineCallbacks
7881 def start(self):
9396 this schedules any other events in the queue to run.
9497 """
9598
96 def __init__(self, txn_ctrl):
99 def __init__(self, txn_ctrl, clock):
97100 self.queued_events = {} # dict of {service_id: [events]}
98 self.pending_requests = {} # dict of {service_id: Deferred}
101 self.requests_in_flight = set()
99102 self.txn_ctrl = txn_ctrl
103 self.clock = clock
100104
101105 def enqueue(self, service, event):
102106 # if this service isn't being sent something
103 if not self.pending_requests.get(service.id):
104 self._send_request(service, [event])
105 else:
106 # add to queue for this service
107 if service.id not in self.queued_events:
108 self.queued_events[service.id] = []
109 self.queued_events[service.id].append(event)
110
111 def _send_request(self, service, events):
112 # send request and add callbacks
113 d = self.txn_ctrl.send(service, events)
114 d.addBoth(self._on_request_finish)
115 d.addErrback(self._on_request_fail)
116 self.pending_requests[service.id] = d
117
118 def _on_request_finish(self, service):
119 self.pending_requests[service.id] = None
120 # if there are queued events, then send them.
121 if (service.id in self.queued_events
122 and len(self.queued_events[service.id]) > 0):
123 self._send_request(service, self.queued_events[service.id])
124 self.queued_events[service.id] = []
125
126 def _on_request_fail(self, err):
127 logger.error("AS request failed: %s", err)
107 self.queued_events.setdefault(service.id, []).append(event)
108 preserve_fn(self._send_request)(service)
109
110 @defer.inlineCallbacks
111 def _send_request(self, service):
112 if service.id in self.requests_in_flight:
113 return
114
115 self.requests_in_flight.add(service.id)
116 try:
117 while True:
118 events = self.queued_events.pop(service.id, [])
119 if not events:
120 return
121
122 with Measure(self.clock, "servicequeuer.send"):
123 try:
124 yield self.txn_ctrl.send(service, events)
125 except:
126 logger.exception("AS request failed")
127 finally:
128 self.requests_in_flight.discard(service.id)
128129
129130
130131 class _TransactionController(object):
148149 if service_is_up:
149150 sent = yield txn.send(self.as_api)
150151 if sent:
151 txn.complete(self.store)
152 yield txn.complete(self.store)
152153 else:
153 self._start_recoverer(service)
154 preserve_fn(self._start_recoverer)(service)
154155 except Exception as e:
155156 logger.exception(e)
156 self._start_recoverer(service)
157 # request has finished
158 defer.returnValue(service)
157 preserve_fn(self._start_recoverer)(service)
159158
160159 @defer.inlineCallbacks
161160 def on_recovered(self, recoverer):
2727
2828 def read_config(self, config):
2929 self.app_service_config_files = config.get("app_service_config_files", [])
30 self.notify_appservices = config.get("notify_appservices", True)
3031
3132 def default_config(cls, **kwargs):
3233 return """\
121122 raise ValueError(
122123 "Missing/bad type 'exclusive' key in %s", regex_obj
123124 )
125 # protocols check
126 protocols = as_info.get("protocols")
127 if protocols:
128 # Because strings are lists in python
129 if isinstance(protocols, str) or not isinstance(protocols, list):
130 raise KeyError("Optional 'protocols' must be a list if present.")
131 for p in protocols:
132 if not isinstance(p, str):
133 raise KeyError("Bad value for 'protocols' item")
124134 return ApplicationService(
125135 token=as_info["as_token"],
126136 url=as_info["url"],
128138 hs_token=as_info["hs_token"],
129139 sender=user_id,
130140 id=as_info["id"],
141 protocols=protocols,
131142 )
2121 preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
2222 preserve_fn
2323 )
24 from synapse.util.metrics import Measure
2425
2526 from twisted.internet import defer
2627
5859 A deferred (server_name, key_id, verify_key) tuple that resolves when
5960 a verify key has been fetched
6061 """
62
63
64 class KeyLookupError(ValueError):
65 pass
6166
6267
6368 class Keyring(object):
238243
239244 @defer.inlineCallbacks
240245 def do_iterations():
241 merged_results = {}
242
243 missing_keys = {}
244 for verify_request in verify_requests:
245 missing_keys.setdefault(verify_request.server_name, set()).update(
246 verify_request.key_ids
247 )
248
249 for fn in key_fetch_fns:
250 results = yield fn(missing_keys.items())
251 merged_results.update(results)
252
253 # We now need to figure out which verify requests we have keys
254 # for and which we don't
246 with Measure(self.clock, "get_server_verify_keys"):
247 merged_results = {}
248
255249 missing_keys = {}
256 requests_missing_keys = []
257250 for verify_request in verify_requests:
258 server_name = verify_request.server_name
259 result_keys = merged_results[server_name]
260
261 if verify_request.deferred.called:
262 # We've already called this deferred, which probably
263 # means that we've already found a key for it.
264 continue
265
266 for key_id in verify_request.key_ids:
267 if key_id in result_keys:
268 with PreserveLoggingContext():
269 verify_request.deferred.callback((
270 server_name,
271 key_id,
272 result_keys[key_id],
273 ))
274 break
275 else:
276 # The else block is only reached if the loop above
277 # doesn't break.
278 missing_keys.setdefault(server_name, set()).update(
279 verify_request.key_ids
280 )
281 requests_missing_keys.append(verify_request)
282
283 if not missing_keys:
284 break
285
286 for verify_request in requests_missing_keys.values():
287 verify_request.deferred.errback(SynapseError(
288 401,
289 "No key for %s with id %s" % (
290 verify_request.server_name, verify_request.key_ids,
291 ),
292 Codes.UNAUTHORIZED,
293 ))
251 missing_keys.setdefault(verify_request.server_name, set()).update(
252 verify_request.key_ids
253 )
254
255 for fn in key_fetch_fns:
256 results = yield fn(missing_keys.items())
257 merged_results.update(results)
258
259 # We now need to figure out which verify requests we have keys
260 # for and which we don't
261 missing_keys = {}
262 requests_missing_keys = []
263 for verify_request in verify_requests:
264 server_name = verify_request.server_name
265 result_keys = merged_results[server_name]
266
267 if verify_request.deferred.called:
268 # We've already called this deferred, which probably
269 # means that we've already found a key for it.
270 continue
271
272 for key_id in verify_request.key_ids:
273 if key_id in result_keys:
274 with PreserveLoggingContext():
275 verify_request.deferred.callback((
276 server_name,
277 key_id,
278 result_keys[key_id],
279 ))
280 break
281 else:
282 # The else block is only reached if the loop above
283 # doesn't break.
284 missing_keys.setdefault(server_name, set()).update(
285 verify_request.key_ids
286 )
287 requests_missing_keys.append(verify_request)
288
289 if not missing_keys:
290 break
291
292 for verify_request in requests_missing_keys.values():
293 verify_request.deferred.errback(SynapseError(
294 401,
295 "No key for %s with id %s" % (
296 verify_request.server_name, verify_request.key_ids,
297 ),
298 Codes.UNAUTHORIZED,
299 ))
294300
295301 def on_err(err):
296302 for verify_request in verify_requests:
301307
302308 @defer.inlineCallbacks
303309 def get_keys_from_store(self, server_name_and_key_ids):
304 res = yield defer.gatherResults(
310 res = yield preserve_context_over_deferred(defer.gatherResults(
305311 [
306 self.store.get_server_verify_keys(
312 preserve_fn(self.store.get_server_verify_keys)(
307313 server_name, key_ids
308314 ).addCallback(lambda ks, server: (server, ks), server_name)
309315 for server_name, key_ids in server_name_and_key_ids
310316 ],
311317 consumeErrors=True,
312 ).addErrback(unwrapFirstError)
318 )).addErrback(unwrapFirstError)
313319
314320 defer.returnValue(dict(res))
315321
330336 )
331337 defer.returnValue({})
332338
333 results = yield defer.gatherResults(
339 results = yield preserve_context_over_deferred(defer.gatherResults(
334340 [
335 get_key(p_name, p_keys)
341 preserve_fn(get_key)(p_name, p_keys)
336342 for p_name, p_keys in self.perspective_servers.items()
337343 ],
338344 consumeErrors=True,
339 ).addErrback(unwrapFirstError)
345 )).addErrback(unwrapFirstError)
340346
341347 union_of_keys = {}
342348 for result in results:
362368 )
363369 except Exception as e:
364370 logger.info(
365 "Unable to getting key %r for %r directly: %s %s",
371 "Unable to get key %r for %r directly: %s %s",
366372 key_ids, server_name,
367373 type(e).__name__, str(e.message),
368374 )
376382
377383 defer.returnValue(keys)
378384
379 results = yield defer.gatherResults(
385 results = yield preserve_context_over_deferred(defer.gatherResults(
380386 [
381 get_key(server_name, key_ids)
387 preserve_fn(get_key)(server_name, key_ids)
382388 for server_name, key_ids in server_name_and_key_ids
383389 ],
384390 consumeErrors=True,
385 ).addErrback(unwrapFirstError)
391 )).addErrback(unwrapFirstError)
386392
387393 merged = {}
388394 for result in results:
424430 for response in responses:
425431 if (u"signatures" not in response
426432 or perspective_name not in response[u"signatures"]):
427 raise ValueError(
433 raise KeyLookupError(
428434 "Key response not signed by perspective server"
429435 " %r" % (perspective_name,)
430436 )
447453 list(response[u"signatures"][perspective_name]),
448454 list(perspective_keys)
449455 )
450 raise ValueError(
456 raise KeyLookupError(
451457 "Response not signed with a known key for perspective"
452458 " server %r" % (perspective_name,)
453459 )
459465 for server_name, response_keys in processed_response.items():
460466 keys.setdefault(server_name, {}).update(response_keys)
461467
462 yield defer.gatherResults(
468 yield preserve_context_over_deferred(defer.gatherResults(
463469 [
464 self.store_keys(
470 preserve_fn(self.store_keys)(
465471 server_name=server_name,
466472 from_server=perspective_name,
467473 verify_keys=response_keys,
469475 for server_name, response_keys in keys.items()
470476 ],
471477 consumeErrors=True
472 ).addErrback(unwrapFirstError)
478 )).addErrback(unwrapFirstError)
473479
474480 defer.returnValue(keys)
475481
490496
491497 if (u"signatures" not in response
492498 or server_name not in response[u"signatures"]):
493 raise ValueError("Key response not signed by remote server")
499 raise KeyLookupError("Key response not signed by remote server")
494500
495501 if "tls_fingerprints" not in response:
496 raise ValueError("Key response missing TLS fingerprints")
502 raise KeyLookupError("Key response missing TLS fingerprints")
497503
498504 certificate_bytes = crypto.dump_certificate(
499505 crypto.FILETYPE_ASN1, tls_certificate
507513 response_sha256_fingerprints.add(fingerprint[u"sha256"])
508514
509515 if sha256_fingerprint_b64 not in response_sha256_fingerprints:
510 raise ValueError("TLS certificate not allowed by fingerprints")
516 raise KeyLookupError("TLS certificate not allowed by fingerprints")
511517
512518 response_keys = yield self.process_v2_response(
513519 from_server=server_name,
517523
518524 keys.update(response_keys)
519525
520 yield defer.gatherResults(
526 yield preserve_context_over_deferred(defer.gatherResults(
521527 [
522528 preserve_fn(self.store_keys)(
523529 server_name=key_server_name,
527533 for key_server_name, verify_keys in keys.items()
528534 ],
529535 consumeErrors=True
530 ).addErrback(unwrapFirstError)
536 )).addErrback(unwrapFirstError)
531537
532538 defer.returnValue(keys)
533539
559565 server_name = response_json["server_name"]
560566 if only_from_server:
561567 if server_name != from_server:
562 raise ValueError(
568 raise KeyLookupError(
563569 "Expected a response for server %r not %r" % (
564570 from_server, server_name
565571 )
566572 )
567573 for key_id in response_json["signatures"].get(server_name, {}):
568574 if key_id not in response_json["verify_keys"]:
569 raise ValueError(
575 raise KeyLookupError(
570576 "Key response must include verification keys for all"
571577 " signatures"
572578 )
593599 response_keys.update(verify_keys)
594600 response_keys.update(old_verify_keys)
595601
596 yield defer.gatherResults(
602 yield preserve_context_over_deferred(defer.gatherResults(
597603 [
598604 preserve_fn(self.store.store_server_keys_json)(
599605 server_name=server_name,
606612 for key_id in updated_key_ids
607613 ],
608614 consumeErrors=True,
609 ).addErrback(unwrapFirstError)
615 )).addErrback(unwrapFirstError)
610616
611617 results[server_name] = response_keys
612618
634640
635641 if ("signatures" not in response
636642 or server_name not in response["signatures"]):
637 raise ValueError("Key response not signed by remote server")
643 raise KeyLookupError("Key response not signed by remote server")
638644
639645 if "tls_certificate" not in response:
640 raise ValueError("Key response missing TLS certificate")
646 raise KeyLookupError("Key response missing TLS certificate")
641647
642648 tls_certificate_b64 = response["tls_certificate"]
643649
644650 if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
645 raise ValueError("TLS certificate doesn't match")
651 raise KeyLookupError("TLS certificate doesn't match")
646652
647653 # Cache the result in the datastore.
648654
658664
659665 for key_id in response["signatures"][server_name]:
660666 if key_id not in response["verify_keys"]:
661 raise ValueError(
667 raise KeyLookupError(
662668 "Key response must include verification keys for all"
663669 " signatures"
664670 )
695701 A deferred that completes when the keys are stored.
696702 """
697703 # TODO(markjh): Store whether the keys have expired.
698 yield defer.gatherResults(
704 yield preserve_context_over_deferred(defer.gatherResults(
699705 [
700706 preserve_fn(self.store.store_server_verify_key)(
701707 server_name, server_name, key.time_added, key
703709 for key_id, key in verify_keys.items()
704710 ],
705711 consumeErrors=True,
706 ).addErrback(unwrapFirstError)
712 )).addErrback(unwrapFirstError)
8787
8888 if "age_ts" in event.unsigned:
8989 allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
90 if "replaces_state" in event.unsigned:
91 allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
9092
9193 return type(event)(
9294 allowed_fields,
2222 from synapse.api.errors import SynapseError
2323
2424 from synapse.util import unwrapFirstError
25 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
2526
2627 import logging
2728
101102 warn, pdu
102103 )
103104
104 valid_pdus = yield defer.gatherResults(
105 valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
105106 deferreds,
106107 consumeErrors=True
107 ).addErrback(unwrapFirstError)
108 )).addErrback(unwrapFirstError)
108109
109110 if include_none:
110111 defer.returnValue(valid_pdus)
128129 for pdu in pdus
129130 ]
130131
131 deferreds = self.keyring.verify_json_objects_for_server([
132 deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
132133 (p.origin, p.get_pdu_json())
133134 for p in redacted_pdus
134135 ])
2626 from synapse.util.async import concurrently_execute
2727 from synapse.util.caches.expiringcache import ExpiringCache
2828 from synapse.util.logutils import log_function
29 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
2930 from synapse.events import FrozenEvent
3031 import synapse.metrics
3132
5051 sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
5152
5253
54 PDU_RETRY_TIME_MS = 1 * 60 * 1000
55
56
5357 class FederationClient(FederationBase):
5458 def __init__(self, hs):
5559 super(FederationClient, self).__init__(hs)
60
61 self.pdu_destination_tried = {}
62 self._clock.looping_call(
63 self._clear_tried_cache, 60 * 1000,
64 )
65
66 def _clear_tried_cache(self):
67 """Clear pdu_destination_tried cache"""
68 now = self._clock.time_msec()
69
70 old_dict = self.pdu_destination_tried
71 self.pdu_destination_tried = {}
72
73 for event_id, destination_dict in old_dict.items():
74 destination_dict = {
75 dest: time
76 for dest, time in destination_dict.items()
77 if time + PDU_RETRY_TIME_MS > now
78 }
79 if destination_dict:
80 self.pdu_destination_tried[event_id] = destination_dict
5681
5782 def start_get_pdu_cache(self):
5883 self._get_pdu_cache = ExpiringCache(
200225 ]
201226
202227 # FIXME: We should handle signature failures more gracefully.
203 pdus[:] = yield defer.gatherResults(
228 pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
204229 self._check_sigs_and_hashes(pdus),
205230 consumeErrors=True,
206 ).addErrback(unwrapFirstError)
231 )).addErrback(unwrapFirstError)
207232
208233 defer.returnValue(pdus)
209234
239264 if ev:
240265 defer.returnValue(ev)
241266
267 pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
268
242269 pdu = None
243270 for destination in destinations:
271 now = self._clock.time_msec()
272 last_attempt = pdu_attempts.get(destination, 0)
273 if last_attempt + PDU_RETRY_TIME_MS > now:
274 continue
275
244276 try:
245277 limiter = yield get_retry_limiter(
246278 destination,
268300
269301 break
270302
303 pdu_attempts[destination] = now
304
271305 except SynapseError as e:
272306 logger.info(
273307 "Failed to get PDU %s from %s because %s",
274308 event_id, destination, e,
275309 )
276 continue
277 except CodeMessageException as e:
278 if 400 <= e.code < 500:
279 raise
280
281 logger.info(
282 "Failed to get PDU %s from %s because %s",
283 event_id, destination, e,
284 )
285 continue
286310 except NotRetryingDestination as e:
287311 logger.info(e.message)
288312 continue
289313 except Exception as e:
314 pdu_attempts[destination] = now
315
290316 logger.info(
291317 "Failed to get PDU %s from %s because %s",
292318 event_id, destination, e,
405431 events and the second is a list of event ids that we failed to fetch.
406432 """
407433 if return_local:
408 seen_events = yield self.store.get_events(event_ids)
434 seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
409435 signed_events = seen_events.values()
410436 else:
411437 seen_events = yield self.store.have_events(event_ids)
431457 batch = set(missing_events[i:i + batch_size])
432458
433459 deferreds = [
434 self.get_pdu(
460 preserve_fn(self.get_pdu)(
435461 destinations=random_server_list(),
436462 event_id=e_id,
437463 )
438464 for e_id in batch
439465 ]
440466
441 res = yield defer.DeferredList(deferreds, consumeErrors=True)
467 res = yield preserve_context_over_deferred(
468 defer.DeferredList(deferreds, consumeErrors=True)
469 )
442470 for success, result in res:
443471 if success:
444472 signed_events.append(result)
827855 return srvs
828856
829857 deferreds = [
830 self.get_pdu(
858 preserve_fn(self.get_pdu)(
831859 destinations=random_server_list(),
832860 event_id=e_id,
833861 )
834862 for e_id, depth in ordered_missing[:limit - len(signed_events)]
835863 ]
836864
837 res = yield defer.DeferredList(deferreds, consumeErrors=True)
865 res = yield preserve_context_over_deferred(
866 defer.DeferredList(deferreds, consumeErrors=True)
867 )
838868 for (result, val), (e_id, _) in zip(res, ordered_missing):
839869 if result and val:
840870 signed_events.append(val)
2020
2121 from synapse.api.errors import HttpResponseException
2222 from synapse.util.async import run_on_reactor
23 from synapse.util.logutils import log_function
24 from synapse.util.logcontext import PreserveLoggingContext
23 from synapse.util.logcontext import preserve_context_over_fn
2524 from synapse.util.retryutils import (
2625 get_retry_limiter, NotRetryingDestination,
2726 )
27 from synapse.util.metrics import measure_func
2828 import synapse.metrics
2929
3030 import logging
5050
5151 self.transport_layer = transport_layer
5252
53 self._clock = hs.get_clock()
53 self.clock = hs.get_clock()
5454
5555 # Is a mapping from destinations -> deferreds. Used to keep track
5656 # of which destinations have transactions in flight and when they are
8181 self.pending_failures_by_dest = {}
8282
8383 # HACK to get unique tx id
84 self._next_txn_id = int(self._clock.time_msec())
84 self._next_txn_id = int(self.clock.time_msec())
8585
8686 def can_send_to(self, destination):
8787 """Can we send messages to the given server?
118118 if not destinations:
119119 return
120120
121 deferreds = []
122
123121 for destination in destinations:
124 deferred = defer.Deferred()
125122 self.pending_pdus_by_dest.setdefault(destination, []).append(
126 (pdu, deferred, order)
123 (pdu, order)
127124 )
128125
129 def chain(failure):
130 if not deferred.called:
131 deferred.errback(failure)
132
133 def log_failure(f):
134 logger.warn("Failed to send pdu to %s: %s", destination, f.value)
135
136 deferred.addErrback(log_failure)
137
138 with PreserveLoggingContext():
139 self._attempt_new_transaction(destination).addErrback(chain)
140
141 deferreds.append(deferred)
142
143 # NO inlineCallbacks
126 preserve_context_over_fn(
127 self._attempt_new_transaction, destination
128 )
129
144130 def enqueue_edu(self, edu):
145131 destination = edu.destination
146132
147133 if not self.can_send_to(destination):
148134 return
149135
150 deferred = defer.Deferred()
151 self.pending_edus_by_dest.setdefault(destination, []).append(
152 (edu, deferred)
153 )
154
155 def chain(failure):
156 if not deferred.called:
157 deferred.errback(failure)
158
159 def log_failure(f):
160 logger.warn("Failed to send edu to %s: %s", destination, f.value)
161
162 deferred.addErrback(log_failure)
163
164 with PreserveLoggingContext():
165 self._attempt_new_transaction(destination).addErrback(chain)
166
167 return deferred
168
169 @defer.inlineCallbacks
136 self.pending_edus_by_dest.setdefault(destination, []).append(edu)
137
138 preserve_context_over_fn(
139 self._attempt_new_transaction, destination
140 )
141
170142 def enqueue_failure(self, failure, destination):
171143 if destination == self.server_name or destination == "localhost":
172144 return
173145
174 deferred = defer.Deferred()
175
176146 if not self.can_send_to(destination):
177147 return
178148
179149 self.pending_failures_by_dest.setdefault(
180150 destination, []
181 ).append(
182 (failure, deferred)
183 )
184
185 def chain(f):
186 if not deferred.called:
187 deferred.errback(f)
188
189 def log_failure(f):
190 logger.warn("Failed to send failure to %s: %s", destination, f.value)
191
192 deferred.addErrback(log_failure)
193
194 with PreserveLoggingContext():
195 self._attempt_new_transaction(destination).addErrback(chain)
196
197 yield deferred
151 ).append(failure)
152
153 preserve_context_over_fn(
154 self._attempt_new_transaction, destination
155 )
198156
199157 @defer.inlineCallbacks
200 @log_function
201158 def _attempt_new_transaction(self, destination):
202159 yield run_on_reactor()
203
204 # list of (pending_pdu, deferred, order)
205 if destination in self.pending_transactions:
206 # XXX: pending_transactions can get stuck on by a never-ending
207 # request at which point pending_pdus_by_dest just keeps growing.
208 # we need application-layer timeouts of some flavour of these
209 # requests
210 logger.debug(
211 "TX [%s] Transaction already in progress",
212 destination
160 while True:
161 # list of (pending_pdu, deferred, order)
162 if destination in self.pending_transactions:
163 # XXX: pending_transactions can get stuck on by a never-ending
164 # request at which point pending_pdus_by_dest just keeps growing.
165 # we need application-layer timeouts of some flavour of these
166 # requests
167 logger.debug(
168 "TX [%s] Transaction already in progress",
169 destination
170 )
171 return
172
173 pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
174 pending_edus = self.pending_edus_by_dest.pop(destination, [])
175 pending_failures = self.pending_failures_by_dest.pop(destination, [])
176
177 if pending_pdus:
178 logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
179 destination, len(pending_pdus))
180
181 if not pending_pdus and not pending_edus and not pending_failures:
182 logger.debug("TX [%s] Nothing to send", destination)
183 return
184
185 yield self._send_new_transaction(
186 destination, pending_pdus, pending_edus, pending_failures
213187 )
214 return
215
216 pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
217 pending_edus = self.pending_edus_by_dest.pop(destination, [])
218 pending_failures = self.pending_failures_by_dest.pop(destination, [])
219
220 if pending_pdus:
221 logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
222 destination, len(pending_pdus))
223
224 if not pending_pdus and not pending_edus and not pending_failures:
225 logger.debug("TX [%s] Nothing to send", destination)
226 return
227
228 try:
229 self.pending_transactions[destination] = 1
230
231 logger.debug("TX [%s] _attempt_new_transaction", destination)
188
189 @measure_func("_send_new_transaction")
190 @defer.inlineCallbacks
191 def _send_new_transaction(self, destination, pending_pdus, pending_edus,
192 pending_failures):
232193
233194 # Sort based on the order field
234 pending_pdus.sort(key=lambda t: t[2])
235
195 pending_pdus.sort(key=lambda t: t[1])
236196 pdus = [x[0] for x in pending_pdus]
237 edus = [x[0] for x in pending_edus]
238 failures = [x[0].get_dict() for x in pending_failures]
239 deferreds = [
240 x[1]
241 for x in pending_pdus + pending_edus + pending_failures
242 ]
243
244 txn_id = str(self._next_txn_id)
245
246 limiter = yield get_retry_limiter(
247 destination,
248 self._clock,
249 self.store,
250 )
251
252 logger.debug(
253 "TX [%s] {%s} Attempting new transaction"
254 " (pdus: %d, edus: %d, failures: %d)",
255 destination, txn_id,
256 len(pending_pdus),
257 len(pending_edus),
258 len(pending_failures)
259 )
260
261 logger.debug("TX [%s] Persisting transaction...", destination)
262
263 transaction = Transaction.create_new(
264 origin_server_ts=int(self._clock.time_msec()),
265 transaction_id=txn_id,
266 origin=self.server_name,
267 destination=destination,
268 pdus=pdus,
269 edus=edus,
270 pdu_failures=failures,
271 )
272
273 self._next_txn_id += 1
274
275 yield self.transaction_actions.prepare_to_send(transaction)
276
277 logger.debug("TX [%s] Persisted transaction", destination)
278 logger.info(
279 "TX [%s] {%s} Sending transaction [%s],"
280 " (PDUs: %d, EDUs: %d, failures: %d)",
281 destination, txn_id,
282 transaction.transaction_id,
283 len(pending_pdus),
284 len(pending_edus),
285 len(pending_failures),
286 )
287
288 with limiter:
289 # Actually send the transaction
290
291 # FIXME (erikj): This is a bit of a hack to make the Pdu age
292 # keys work
293 def json_data_cb():
294 data = transaction.get_dict()
295 now = int(self._clock.time_msec())
296 if "pdus" in data:
297 for p in data["pdus"]:
298 if "age_ts" in p:
299 unsigned = p.setdefault("unsigned", {})
300 unsigned["age"] = now - int(p["age_ts"])
301 del p["age_ts"]
302 return data
303
304 try:
305 response = yield self.transport_layer.send_transaction(
306 transaction, json_data_cb
197 edus = pending_edus
198 failures = [x.get_dict() for x in pending_failures]
199
200 try:
201 self.pending_transactions[destination] = 1
202
203 logger.debug("TX [%s] _attempt_new_transaction", destination)
204
205 txn_id = str(self._next_txn_id)
206
207 limiter = yield get_retry_limiter(
208 destination,
209 self.clock,
210 self.store,
211 )
212
213 logger.debug(
214 "TX [%s] {%s} Attempting new transaction"
215 " (pdus: %d, edus: %d, failures: %d)",
216 destination, txn_id,
217 len(pending_pdus),
218 len(pending_edus),
219 len(pending_failures)
220 )
221
222 logger.debug("TX [%s] Persisting transaction...", destination)
223
224 transaction = Transaction.create_new(
225 origin_server_ts=int(self.clock.time_msec()),
226 transaction_id=txn_id,
227 origin=self.server_name,
228 destination=destination,
229 pdus=pdus,
230 edus=edus,
231 pdu_failures=failures,
232 )
233
234 self._next_txn_id += 1
235
236 yield self.transaction_actions.prepare_to_send(transaction)
237
238 logger.debug("TX [%s] Persisted transaction", destination)
239 logger.info(
240 "TX [%s] {%s} Sending transaction [%s],"
241 " (PDUs: %d, EDUs: %d, failures: %d)",
242 destination, txn_id,
243 transaction.transaction_id,
244 len(pending_pdus),
245 len(pending_edus),
246 len(pending_failures),
247 )
248
249 with limiter:
250 # Actually send the transaction
251
252 # FIXME (erikj): This is a bit of a hack to make the Pdu age
253 # keys work
254 def json_data_cb():
255 data = transaction.get_dict()
256 now = int(self.clock.time_msec())
257 if "pdus" in data:
258 for p in data["pdus"]:
259 if "age_ts" in p:
260 unsigned = p.setdefault("unsigned", {})
261 unsigned["age"] = now - int(p["age_ts"])
262 del p["age_ts"]
263 return data
264
265 try:
266 response = yield self.transport_layer.send_transaction(
267 transaction, json_data_cb
268 )
269 code = 200
270
271 if response:
272 for e_id, r in response.get("pdus", {}).items():
273 if "error" in r:
274 logger.warn(
275 "Transaction returned error for %s: %s",
276 e_id, r,
277 )
278 except HttpResponseException as e:
279 code = e.code
280 response = e.response
281
282 logger.info(
283 "TX [%s] {%s} got %d response",
284 destination, txn_id, code
307285 )
308 code = 200
309
310 if response:
311 for e_id, r in response.get("pdus", {}).items():
312 if "error" in r:
313 logger.warn(
314 "Transaction returned error for %s: %s",
315 e_id, r,
316 )
317 except HttpResponseException as e:
318 code = e.code
319 response = e.response
320
286
287 logger.debug("TX [%s] Sent transaction", destination)
288 logger.debug("TX [%s] Marking as delivered...", destination)
289
290 yield self.transaction_actions.delivered(
291 transaction, code, response
292 )
293
294 logger.debug("TX [%s] Marked as delivered", destination)
295
296 if code != 200:
297 for p in pdus:
298 logger.info(
299 "Failed to send event %s to %s", p.event_id, destination
300 )
301 except NotRetryingDestination:
321302 logger.info(
322 "TX [%s] {%s} got %d response",
323 destination, txn_id, code
324 )
325
326 logger.debug("TX [%s] Sent transaction", destination)
327 logger.debug("TX [%s] Marking as delivered...", destination)
328
329 yield self.transaction_actions.delivered(
330 transaction, code, response
331 )
332
333 logger.debug("TX [%s] Marked as delivered", destination)
334
335 logger.debug("TX [%s] Yielding to callbacks...", destination)
336
337 for deferred in deferreds:
338 if code == 200:
339 deferred.callback(None)
340 else:
341 deferred.errback(RuntimeError("Got status %d" % code))
342
343 # Ensures we don't continue until all callbacks on that
344 # deferred have fired
345 try:
346 yield deferred
347 except:
348 pass
349
350 logger.debug("TX [%s] Yielded to callbacks", destination)
351 except NotRetryingDestination:
352 logger.info(
353 "TX [%s] not ready for retry yet - "
354 "dropping transaction for now",
355 destination,
356 )
357 except RuntimeError as e:
358 # We capture this here as there as nothing actually listens
359 # for this finishing functions deferred.
360 logger.warn(
361 "TX [%s] Problem in _attempt_transaction: %s",
362 destination,
363 e,
364 )
365 except Exception as e:
366 # We capture this here as there as nothing actually listens
367 # for this finishing functions deferred.
368 logger.warn(
369 "TX [%s] Problem in _attempt_transaction: %s",
370 destination,
371 e,
372 )
373
374 for deferred in deferreds:
375 if not deferred.called:
376 deferred.errback(e)
377
378 finally:
379 # We want to be *very* sure we delete this after we stop processing
380 self.pending_transactions.pop(destination, None)
381
382 # Check to see if there is anything else to send.
383 self._attempt_new_transaction(destination)
303 "TX [%s] not ready for retry yet - "
304 "dropping transaction for now",
305 destination,
306 )
307 except RuntimeError as e:
308 # We capture this here as there as nothing actually listens
309 # for this finishing functions deferred.
310 logger.warn(
311 "TX [%s] Problem in _attempt_transaction: %s",
312 destination,
313 e,
314 )
315
316 for p in pdus:
317 logger.info("Failed to send event %s to %s", p.event_id, destination)
318 except Exception as e:
319 # We capture this here as there as nothing actually listens
320 # for this finishing functions deferred.
321 logger.warn(
322 "TX [%s] Problem in _attempt_transaction: %s",
323 destination,
324 e,
325 )
326
327 for p in pdus:
328 logger.info("Failed to send event %s to %s", p.event_id, destination)
329
330 finally:
331 # We want to be *very* sure we delete this after we stop processing
332 self.pending_transactions.pop(destination, None)
1818 )
1919 from .room_member import RoomMemberHandler
2020 from .message import MessageHandler
21 from .events import EventStreamHandler, EventHandler
2221 from .federation import FederationHandler
2322 from .profile import ProfileHandler
2423 from .directory import DirectoryHandler
5251 self.message_handler = MessageHandler(hs)
5352 self.room_creation_handler = RoomCreationHandler(hs)
5453 self.room_member_handler = RoomMemberHandler(hs)
55 self.event_stream_handler = EventStreamHandler(hs)
56 self.event_handler = EventHandler(hs)
5754 self.federation_handler = FederationHandler(hs)
5855 self.profile_handler = ProfileHandler(hs)
5956 self.directory_handler = DirectoryHandler(hs)
1515 from twisted.internet import defer
1616
1717 from synapse.api.constants import EventTypes
18 from synapse.appservice import ApplicationService
18 from synapse.util.metrics import Measure
19 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
1920
2021 import logging
2122
4142 self.appservice_api = hs.get_application_service_api()
4243 self.scheduler = hs.get_application_service_scheduler()
4344 self.started_scheduler = False
44
45 @defer.inlineCallbacks
46 def notify_interested_services(self, event):
45 self.clock = hs.get_clock()
46 self.notify_appservices = hs.config.notify_appservices
47
48 self.current_max = 0
49 self.is_processing = False
50
51 @defer.inlineCallbacks
52 def notify_interested_services(self, current_id):
4753 """Notifies (pushes) all application services interested in this event.
4854
4955 Pushing is done asynchronously, so this method won't block for any
5056 prolonged length of time.
5157
5258 Args:
53 event(Event): The event to push out to interested services.
54 """
55 # Gather interested services
56 services = yield self._get_services_for_event(event)
57 if len(services) == 0:
58 return # no services need notifying
59
60 # Do we know this user exists? If not, poke the user query API for
61 # all services which match that user regex. This needs to block as these
62 # user queries need to be made BEFORE pushing the event.
63 yield self._check_user_exists(event.sender)
64 if event.type == EventTypes.Member:
65 yield self._check_user_exists(event.state_key)
66
67 if not self.started_scheduler:
68 self.scheduler.start().addErrback(log_failure)
69 self.started_scheduler = True
70
71 # Fork off pushes to these services
72 for service in services:
73 self.scheduler.submit_event_for_as(service, event)
59 current_id(int): The current maximum ID.
60 """
61 services = yield self.store.get_app_services()
62 if not services or not self.notify_appservices:
63 return
64
65 self.current_max = max(self.current_max, current_id)
66 if self.is_processing:
67 return
68
69 with Measure(self.clock, "notify_interested_services"):
70 self.is_processing = True
71 try:
72 upper_bound = self.current_max
73 limit = 100
74 while True:
75 upper_bound, events = yield self.store.get_new_events_for_appservice(
76 upper_bound, limit
77 )
78
79 if not events:
80 break
81
82 for event in events:
83 # Gather interested services
84 services = yield self._get_services_for_event(event)
85 if len(services) == 0:
86 continue # no services need notifying
87
88 # Do we know this user exists? If not, poke the user
89 # query API for all services which match that user regex.
90 # This needs to block as these user queries need to be
91 # made BEFORE pushing the event.
92 yield self._check_user_exists(event.sender)
93 if event.type == EventTypes.Member:
94 yield self._check_user_exists(event.state_key)
95
96 if not self.started_scheduler:
97 self.scheduler.start().addErrback(log_failure)
98 self.started_scheduler = True
99
100 # Fork off pushes to these services
101 for service in services:
102 preserve_fn(self.scheduler.submit_event_for_as)(
103 service, event
104 )
105
106 yield self.store.set_appservice_last_pos(upper_bound)
107
108 if len(events) < limit:
109 break
110 finally:
111 self.is_processing = False
74112
75113 @defer.inlineCallbacks
76114 def query_user_exists(self, user_id):
103141 association can be found.
104142 """
105143 room_alias_str = room_alias.to_string()
106 alias_query_services = yield self._get_services_for_event(
107 event=None,
108 restrict_to=ApplicationService.NS_ALIASES,
109 alias_list=[room_alias_str]
110 )
144 services = yield self.store.get_app_services()
145 alias_query_services = [
146 s for s in services if (
147 s.is_interested_in_alias(room_alias_str)
148 )
149 ]
111150 for alias_service in alias_query_services:
112151 is_known_alias = yield self.appservice_api.query_alias(
113152 alias_service, room_alias_str
120159 defer.returnValue(result)
121160
122161 @defer.inlineCallbacks
123 def _get_services_for_event(self, event, restrict_to="", alias_list=None):
162 def query_3pe(self, kind, protocol, fields):
163 services = yield self._get_services_for_3pn(protocol)
164
165 results = yield preserve_context_over_deferred(defer.DeferredList([
166 preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
167 for service in services
168 ], consumeErrors=True))
169
170 ret = []
171 for (success, result) in results:
172 if success:
173 ret.extend(result)
174
175 defer.returnValue(ret)
176
177 @defer.inlineCallbacks
178 def _get_services_for_event(self, event):
124179 """Retrieve a list of application services interested in this event.
125180
126181 Args:
127182 event(Event): The event to check. Can be None if alias_list is not.
128 restrict_to(str): The namespace to restrict regex tests to.
129 alias_list: A list of aliases to get services for. If None, this
130 list is obtained from the database.
131183 Returns:
132184 list<ApplicationService>: A list of services interested in this
133185 event based on the service regex.
134186 """
135 member_list = None
136 if hasattr(event, "room_id"):
137 # We need to know the aliases associated with this event.room_id,
138 # if any.
139 if not alias_list:
140 alias_list = yield self.store.get_aliases_for_room(
141 event.room_id
142 )
143 # We need to know the members associated with this event.room_id,
144 # if any.
145 member_list = yield self.store.get_users_in_room(event.room_id)
146
147187 services = yield self.store.get_app_services()
148188 interested_list = [
149189 s for s in services if (
150 s.is_interested(event, restrict_to, alias_list, member_list)
190 yield s.is_interested(event, self.store)
151191 )
152192 ]
153193 defer.returnValue(interested_list)
159199 s for s in services if (
160200 s.is_interested_in_user(user_id)
161201 )
202 ]
203 defer.returnValue(interested_list)
204
205 @defer.inlineCallbacks
206 def _get_services_for_3pn(self, protocol):
207 services = yield self.store.get_app_services()
208 interested_list = [
209 s for s in services if s.is_interested_in_protocol(protocol)
162210 ]
163211 defer.returnValue(interested_list)
164212
6969 self.ldap_uri = hs.config.ldap_uri
7070 self.ldap_start_tls = hs.config.ldap_start_tls
7171 self.ldap_base = hs.config.ldap_base
72 self.ldap_filter = hs.config.ldap_filter
7372 self.ldap_attributes = hs.config.ldap_attributes
7473 if self.ldap_mode == LDAPMode.SEARCH:
7574 self.ldap_bind_dn = hs.config.ldap_bind_dn
7675 self.ldap_bind_password = hs.config.ldap_bind_password
76 self.ldap_filter = hs.config.ldap_filter
7777
7878 self.hs = hs # FIXME better possibility to access registrationHandler later?
7979 self.device_handler = hs.get_device_handler()
659659 else:
660660 logger.warn(
661661 "ldap registration failed: unexpected (%d!=1) amount of results",
662 len(result)
662 len(conn.response)
663663 )
664664 defer.returnValue(False)
665665
718718 return macaroon.serialize()
719719
720720 def validate_short_term_login_token_and_get_user_id(self, login_token):
721 auth_api = self.hs.get_auth()
721722 try:
722723 macaroon = pymacaroons.Macaroon.deserialize(login_token)
723 auth_api = self.hs.get_auth()
724 auth_api.validate_macaroon(macaroon, "login", True)
725 return self.get_user_from_macaroon(macaroon)
726 except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
727 raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
724 user_id = auth_api.get_user_id_from_macaroon(macaroon)
725 auth_api.validate_macaroon(macaroon, "login", True, user_id)
726 return user_id
727 except Exception:
728 raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
728729
729730 def _generate_base_macaroon(self, user_id):
730731 macaroon = pymacaroons.Macaroon(
735736 macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
736737 return macaroon
737738
738 def get_user_from_macaroon(self, macaroon):
739 user_prefix = "user_id = "
740 for caveat in macaroon.caveats:
741 if caveat.caveat_id.startswith(user_prefix):
742 return caveat.caveat_id[len(user_prefix):]
743 raise AuthError(
744 self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
745 errcode=Codes.UNKNOWN_TOKEN
746 )
747
748739 @defer.inlineCallbacks
749740 def set_password(self, user_id, newpassword, requester=None):
750741 password_hash = self.hash(newpassword)
751742
752 except_access_token_ids = [requester.access_token_id] if requester else []
743 except_access_token_id = requester.access_token_id if requester else None
753744
754745 try:
755746 yield self.store.user_set_password_hash(user_id, password_hash)
758749 raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
759750 raise e
760751 yield self.store.user_delete_access_tokens(
761 user_id, except_access_token_ids
752 user_id, except_access_token_id
762753 )
763754 yield self.hs.get_pusherpool().remove_pushers_by_user(
764 user_id, except_access_token_ids
755 user_id, except_access_token_id
765756 )
766757
767758 @defer.inlineCallbacks
2525 from synapse.api.constants import EventTypes, Membership, RejectedReason
2626 from synapse.events.validator import EventValidator
2727 from synapse.util import unwrapFirstError
28 from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
28 from synapse.util.logcontext import (
29 PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
30 )
2931 from synapse.util.logutils import log_function
3032 from synapse.util.async import run_on_reactor
3133 from synapse.util.frozenutils import unfreeze
248250 if ev.type != EventTypes.Member:
249251 continue
250252 try:
251 domain = UserID.from_string(ev.state_key).domain
253 domain = get_domain_from_id(ev.state_key)
252254 except:
253255 continue
254256
273275
274276 @log_function
275277 @defer.inlineCallbacks
276 def backfill(self, dest, room_id, limit, extremities=[]):
278 def backfill(self, dest, room_id, limit, extremities):
277279 """ Trigger a backfill request to `dest` for the given `room_id`
278280
279281 This will attempt to get more events from the remote. This may return
282284 """
283285 if dest == self.server_name:
284286 raise SynapseError(400, "Can't backfill from self.")
285
286 if not extremities:
287 extremities = yield self.store.get_oldest_events_in_room(room_id)
288287
289288 events = yield self.replication_layer.backfill(
290289 dest,
363362 missing_auth - failed_to_fetch
364363 )
365364
366 results = yield defer.gatherResults(
365 results = yield preserve_context_over_deferred(defer.gatherResults(
367366 [
368 self.replication_layer.get_pdu(
367 preserve_fn(self.replication_layer.get_pdu)(
369368 [dest],
370369 event_id,
371370 outlier=True,
374373 for event_id in missing_auth - failed_to_fetch
375374 ],
376375 consumeErrors=True
377 ).addErrback(unwrapFirstError)
378 auth_events.update({a.event_id: a for a in results})
376 )).addErrback(unwrapFirstError)
377 auth_events.update({a.event_id: a for a in results if a})
379378 required_auth.update(
380 a_id for event in results for a_id, _ in event.auth_events
379 a_id for event in results for a_id, _ in event.auth_events if event
381380 )
382381 missing_auth = required_auth - set(auth_events)
383382
453452 key=lambda e: -int(e[1])
454453 )
455454 max_depth = sorted_extremeties_tuple[0][1]
455
456 # We don't want to specify too many extremities as it causes the backfill
457 # request URI to be too long.
458 extremities = dict(sorted_extremeties_tuple[:5])
456459
457460 if current_depth > max_depth:
458461 logger.debug(
550553
551554 event_ids = list(extremities.keys())
552555
553 states = yield defer.gatherResults([
554 self.state_handler.resolve_state_groups(room_id, [e])
556 states = yield preserve_context_over_deferred(defer.gatherResults([
557 preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
555558 for e in event_ids
556 ])
559 ]))
557560 states = dict(zip(event_ids, [s[1] for s in states]))
558561
559562 for e_id, _ in sorted_extremeties_tuple:
10921095 )
10931096
10941097 if event:
1095 # FIXME: This is a temporary work around where we occasionally
1096 # return events slightly differently than when they were
1097 # originally signed
1098 event.signatures.update(
1099 compute_event_signature(
1100 event,
1101 self.hs.hostname,
1102 self.hs.config.signing_key[0]
1103 )
1104 )
1098 if self.hs.is_mine_id(event.event_id):
1099 # FIXME: This is a temporary work around where we occasionally
1100 # return events slightly differently than when they were
1101 # originally signed
1102 event.signatures.update(
1103 compute_event_signature(
1104 event,
1105 self.hs.hostname,
1106 self.hs.config.signing_key[0]
1107 )
1108 )
11051109
11061110 if do_auth:
11071111 in_room = yield self.auth.check_host_in_room(
11101114 )
11111115 if not in_room:
11121116 raise AuthError(403, "Host not in room.")
1117
1118 events = yield self._filter_events_for_server(
1119 origin, event.room_id, [event]
1120 )
1121
1122 event = events[0]
11131123
11141124 defer.returnValue(event)
11151125 else:
11571167 a bunch of outliers, but not a chunk of individual events that depend
11581168 on each other for state calculations.
11591169 """
1160 contexts = yield defer.gatherResults(
1170 contexts = yield preserve_context_over_deferred(defer.gatherResults(
11611171 [
1162 self._prep_event(
1172 preserve_fn(self._prep_event)(
11631173 origin,
11641174 ev_info["event"],
11651175 state=ev_info.get("state"),
11671177 )
11681178 for ev_info in event_infos
11691179 ]
1170 )
1180 ))
11711181
11721182 yield self.store.persist_events(
11731183 [
14511461 # Do auth conflict res.
14521462 logger.info("Different auth: %s", different_auth)
14531463
1454 different_events = yield defer.gatherResults(
1464 different_events = yield preserve_context_over_deferred(defer.gatherResults(
14551465 [
1456 self.store.get_event(
1466 preserve_fn(self.store.get_event)(
14571467 d,
14581468 allow_none=True,
14591469 allow_rejected=False,
14621472 if d in have_events and not have_events[d]
14631473 ],
14641474 consumeErrors=True
1465 ).addErrback(unwrapFirstError)
1475 )).addErrback(unwrapFirstError)
14661476
14671477 if different_events:
14681478 local_view = dict(auth_events)
2727 from synapse.util import unwrapFirstError
2828 from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
2929 from synapse.util.caches.snapshot_cache import SnapshotCache
30 from synapse.util.logcontext import preserve_fn
30 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
31 from synapse.util.metrics import measure_func
3132 from synapse.visibility import filter_events_for_client
3233
3334 from ._base import BaseHandler
501502 lambda states: states[event.event_id]
502503 )
503504
504 (messages, token), current_state = yield defer.gatherResults(
505 [
506 self.store.get_recent_events_for_room(
507 event.room_id,
508 limit=limit,
509 end_token=room_end_token,
510 ),
511 deferred_room_state,
512 ]
505 (messages, token), current_state = yield preserve_context_over_deferred(
506 defer.gatherResults(
507 [
508 preserve_fn(self.store.get_recent_events_for_room)(
509 event.room_id,
510 limit=limit,
511 end_token=room_end_token,
512 ),
513 deferred_room_state,
514 ]
515 )
513516 ).addErrback(unwrapFirstError)
514517
515518 messages = yield filter_events_for_client(
718721
719722 presence, receipts, (messages, token) = yield defer.gatherResults(
720723 [
721 get_presence(),
722 get_receipts(),
723 self.store.get_recent_events_for_room(
724 preserve_fn(get_presence)(),
725 preserve_fn(get_receipts)(),
726 preserve_fn(self.store.get_recent_events_for_room)(
724727 room_id,
725728 limit=limit,
726729 end_token=now_token.room_key,
754757
755758 defer.returnValue(ret)
756759
760 @measure_func("_create_new_client_event")
757761 @defer.inlineCallbacks
758762 def _create_new_client_event(self, builder, prev_event_ids=None):
759763 if prev_event_ids:
805809 (event, context,)
806810 )
807811
812 @measure_func("handle_new_client_event")
808813 @defer.inlineCallbacks
809814 def handle_new_client_event(
810815 self,
933938 @defer.inlineCallbacks
934939 def _notify():
935940 yield run_on_reactor()
936 self.notifier.on_new_room_event(
941 yield self.notifier.on_new_room_event(
937942 event, event_stream_id, max_stream_id,
938943 extra_users=extra_users
939944 )
943948 # If invite, remove room_state from unsigned before sending.
944949 event.unsigned.pop("invite_room_state", None)
945950
946 federation_handler.handle_new_event(
951 preserve_fn(federation_handler.handle_new_event)(
947952 event, destinations=destinations,
948953 )
502502 defer.returnValue(states)
503503
504504 @defer.inlineCallbacks
505 def _get_interested_parties(self, states):
505 def _get_interested_parties(self, states, calculate_remote_hosts=True):
506506 """Given a list of states return which entities (rooms, users, servers)
507507 are interested in the given states.
508508
525525 users_to_states.setdefault(state.user_id, []).append(state)
526526
527527 hosts_to_states = {}
528 for room_id, states in room_ids_to_states.items():
529 local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
530 if not local_states:
531 continue
532
533 hosts = yield self.store.get_joined_hosts_for_room(room_id)
534 for host in hosts:
535 hosts_to_states.setdefault(host, []).extend(local_states)
528 if calculate_remote_hosts:
529 for room_id, states in room_ids_to_states.items():
530 local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
531 if not local_states:
532 continue
533
534 hosts = yield self.store.get_joined_hosts_for_room(room_id)
535 for host in hosts:
536 hosts_to_states.setdefault(host, []).extend(local_states)
536537
537538 for user_id, states in users_to_states.items():
538539 local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
563564 )
564565
565566 self._push_to_remotes(hosts_to_states)
567
568 @defer.inlineCallbacks
569 def notify_for_states(self, state, stream_id):
570 parties = yield self._get_interested_parties([state])
571 room_ids_to_states, users_to_states, hosts_to_states = parties
572
573 self.notifier.on_new_event(
574 "presence_key", stream_id, rooms=room_ids_to_states.keys(),
575 users=[UserID.from_string(u) for u in users_to_states.keys()]
576 )
566577
567578 def _push_to_remotes(self, hosts_to_states):
568579 """Sends state updates to remote servers.
671682 ])
672683
673684 @defer.inlineCallbacks
674 def set_state(self, target_user, state):
685 def set_state(self, target_user, state, ignore_status_msg=False):
675686 """Set the presence state of the user.
676687 """
677688 status_msg = state.get("status_msg", None)
688699 prev_state = yield self.current_state_for_user(user_id)
689700
690701 new_fields = {
691 "state": presence,
692 "status_msg": status_msg if presence != PresenceState.OFFLINE else None
702 "state": presence
693703 }
704
705 if not ignore_status_msg:
706 msg = status_msg if presence != PresenceState.OFFLINE else None
707 new_fields["status_msg"] = msg
694708
695709 if presence == PresenceState.ONLINE:
696710 new_fields["last_active_ts"] = self.clock.time_msec()
5858 prev_event_ids,
5959 txn_id=None,
6060 ratelimit=True,
61 content=None,
6162 ):
63 if content is None:
64 content = {}
6265 msg_handler = self.hs.get_handlers().message_handler
6366
64 content = {"membership": membership}
67 content["membership"] = membership
6568 if requester.is_guest:
6669 content["kind"] = "guest"
6770
139142 remote_room_hosts=None,
140143 third_party_signed=None,
141144 ratelimit=True,
145 content=None,
142146 ):
143 key = (target, room_id,)
147 key = (room_id,)
144148
145149 with (yield self.member_linearizer.queue(key)):
146150 result = yield self._update_membership(
152156 remote_room_hosts=remote_room_hosts,
153157 third_party_signed=third_party_signed,
154158 ratelimit=ratelimit,
159 content=content,
155160 )
156161
157162 defer.returnValue(result)
167172 remote_room_hosts=None,
168173 third_party_signed=None,
169174 ratelimit=True,
175 content=None,
170176 ):
177 if content is None:
178 content = {}
179
171180 effective_membership_state = action
172181 if action in ["kick", "unban"]:
173182 effective_membership_state = "leave"
217226 if inviter and not self.hs.is_mine(inviter):
218227 remote_room_hosts.append(inviter.domain)
219228
220 content = {"membership": Membership.JOIN}
229 content["membership"] = Membership.JOIN
221230
222231 profile = self.hs.get_handlers().profile_handler
223232 content["displayname"] = yield profile.get_displayname(target)
271280 txn_id=txn_id,
272281 ratelimit=ratelimit,
273282 prev_event_ids=latest_event_ids,
283 content=content,
274284 )
275285
276286 @defer.inlineCallbacks
463463 else:
464464 state = {}
465465
466 defer.returnValue({
467 (e.type, e.state_key): e
468 for e in sync_config.filter_collection.filter_room_state(state.values())
469 })
466 defer.returnValue({
467 (e.type, e.state_key): e
468 for e in sync_config.filter_collection.filter_room_state(state.values())
469 })
470470
471471 @defer.inlineCallbacks
472472 def unread_notifs_for_room_id(self, room_id, sync_config):
484484 )
485485 defer.returnValue(notifs)
486486
487 # There is no new information in this period, so your notification
488 # count is whatever it was last time.
489 defer.returnValue(None)
487 # There is no new information in this period, so your notification
488 # count is whatever it was last time.
489 defer.returnValue(None)
490490
491491 @defer.inlineCallbacks
492492 def generate_sync_result(self, sync_config, since_token=None, full_state=False):
1515 from twisted.internet import defer
1616
1717 from synapse.api.errors import SynapseError, AuthError
18 from synapse.util.logcontext import PreserveLoggingContext
18 from synapse.util.logcontext import (
19 PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
20 )
1921 from synapse.util.metrics import Measure
2022 from synapse.types import UserID
2123
168170 deferreds = []
169171 for domain in domains:
170172 if domain == self.server_name:
171 self._push_update_local(
173 preserve_fn(self._push_update_local)(
172174 room_id=room_id,
173175 user_id=user_id,
174176 typing=typing
175177 )
176178 else:
177 deferreds.append(self.federation.send_edu(
179 deferreds.append(preserve_fn(self.federation.send_edu)(
178180 destination=domain,
179181 edu_type="m.typing",
180182 content={
184186 },
185187 ))
186188
187 yield defer.DeferredList(deferreds, consumeErrors=True)
189 yield preserve_context_over_deferred(
190 defer.DeferredList(deferreds, consumeErrors=True)
191 )
188192
189193 @defer.inlineCallbacks
190194 def _recv_edu(self, origin, content):
154154 time_out=timeout / 1000. if timeout else 60,
155155 )
156156
157 response = yield preserve_context_over_fn(
158 send_request,
159 )
157 response = yield preserve_context_over_fn(send_request)
160158
161159 log_result = "%d %s" % (response.code, response.phrase,)
162160 break
1818 )
1919 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
2020 from synapse.util.caches import intern_dict
21 from synapse.util.metrics import Measure
2122 import synapse.metrics
2223 import synapse.events
2324
7374 _next_request_id = 0
7475
7576
76 def request_handler(report_metrics=True):
77 def request_handler(include_metrics=False):
7778 """Decorator for ``wrap_request_handler``"""
78 return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
79
80
81 def wrap_request_handler(request_handler, report_metrics):
79 return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
80
81
82 def wrap_request_handler(request_handler, include_metrics=False):
8283 """Wraps a method that acts as a request handler with the necessary logging
8384 and exception handling.
8485
102103 _next_request_id += 1
103104
104105 with LoggingContext(request_id) as request_context:
105 if report_metrics:
106 with Measure(self.clock, "wrapped_request_handler"):
106107 request_metrics = RequestMetrics()
107 request_metrics.start(self.clock)
108
109 request_context.request = request_id
110 with request.processing():
111 try:
112 with PreserveLoggingContext(request_context):
113 yield request_handler(self, request)
114 except CodeMessageException as e:
115 code = e.code
116 if isinstance(e, SynapseError):
117 logger.info(
118 "%s SynapseError: %s - %s", request, code, e.msg
108 request_metrics.start(self.clock, name=self.__class__.__name__)
109
110 request_context.request = request_id
111 with request.processing():
112 try:
113 with PreserveLoggingContext(request_context):
114 if include_metrics:
115 yield request_handler(self, request, request_metrics)
116 else:
117 yield request_handler(self, request)
118 except CodeMessageException as e:
119 code = e.code
120 if isinstance(e, SynapseError):
121 logger.info(
122 "%s SynapseError: %s - %s", request, code, e.msg
123 )
124 else:
125 logger.exception(e)
126 outgoing_responses_counter.inc(request.method, str(code))
127 respond_with_json(
128 request, code, cs_exception(e), send_cors=True,
129 pretty_print=_request_user_agent_is_curl(request),
130 version_string=self.version_string,
119131 )
120 else:
121 logger.exception(e)
122 outgoing_responses_counter.inc(request.method, str(code))
123 respond_with_json(
124 request, code, cs_exception(e), send_cors=True,
125 pretty_print=_request_user_agent_is_curl(request),
126 version_string=self.version_string,
127 )
128 except:
129 logger.exception(
130 "Failed handle request %s.%s on %r: %r",
131 request_handler.__module__,
132 request_handler.__name__,
133 self,
134 request
135 )
136 respond_with_json(
137 request,
138 500,
139 {
140 "error": "Internal server error",
141 "errcode": Codes.UNKNOWN,
142 },
143 send_cors=True
144 )
145 finally:
146 try:
147 if report_metrics:
132 except:
133 logger.exception(
134 "Failed handle request %s.%s on %r: %r",
135 request_handler.__module__,
136 request_handler.__name__,
137 self,
138 request
139 )
140 respond_with_json(
141 request,
142 500,
143 {
144 "error": "Internal server error",
145 "errcode": Codes.UNKNOWN,
146 },
147 send_cors=True
148 )
149 finally:
150 try:
148151 request_metrics.stop(
149 self.clock, request, self.__class__.__name__
152 self.clock, request
150153 )
151 except:
152 pass
154 except Exception as e:
155 logger.warn("Failed to stop metrics: %r", e)
153156 return wrapped_request_handler
154157
155158
219222 # It does its own metric reporting because _async_render dispatches to
220223 # a callback and it's the class name of that callback we want to report
221224 # against rather than the JsonResource itself.
222 @request_handler(report_metrics=False)
225 @request_handler(include_metrics=True)
223226 @defer.inlineCallbacks
224 def _async_render(self, request):
227 def _async_render(self, request, request_metrics):
225228 """ This gets called from render() every time someone sends us a request.
226229 This checks if anyone has registered a callback for that method and
227230 path.
230233 self._send_response(request, 200, {})
231234 return
232235
233 request_metrics = RequestMetrics()
234 request_metrics.start(self.clock)
235
236236 # Loop through all the registered callbacks to check if the method
237237 # and path regex match
238238 for path_entry in self.path_regexs.get(request.method, []):
246246
247247 callback = path_entry.callback
248248
249 kwargs = intern_dict({
250 name: urllib.unquote(value).decode("UTF-8") if value else value
251 for name, value in m.groupdict().items()
252 })
253
254 callback_return = yield callback(request, **kwargs)
255 if callback_return is not None:
256 code, response = callback_return
257 self._send_response(request, code, response)
258
249259 servlet_instance = getattr(callback, "__self__", None)
250260 if servlet_instance is not None:
251261 servlet_classname = servlet_instance.__class__.__name__
252262 else:
253263 servlet_classname = "%r" % callback
254264
255 kwargs = intern_dict({
256 name: urllib.unquote(value).decode("UTF-8") if value else value
257 for name, value in m.groupdict().items()
258 })
259
260 callback_return = yield callback(request, **kwargs)
261 if callback_return is not None:
262 code, response = callback_return
263 self._send_response(request, code, response)
264
265 try:
266 request_metrics.stop(self.clock, request, servlet_classname)
267 except:
268 pass
265 request_metrics.name = servlet_classname
269266
270267 return
271268
297294
298295
299296 class RequestMetrics(object):
300 def start(self, clock):
297 def start(self, clock, name):
301298 self.start = clock.time_msec()
302299 self.start_context = LoggingContext.current_context()
303
304 def stop(self, clock, request, servlet_classname):
300 self.name = name
301
302 def stop(self, clock, request):
305303 context = LoggingContext.current_context()
306304
307305 tag = ""
315313 )
316314 return
317315
318 incoming_requests_counter.inc(request.method, servlet_classname, tag)
316 incoming_requests_counter.inc(request.method, self.name, tag)
319317
320318 response_timer.inc_by(
321319 clock.time_msec() - self.start, request.method,
322 servlet_classname, tag
320 self.name, tag
323321 )
324322
325323 ru_utime, ru_stime = context.get_resource_usage()
326324
327325 response_ru_utime.inc_by(
328 ru_utime, request.method, servlet_classname, tag
326 ru_utime, request.method, self.name, tag
329327 )
330328 response_ru_stime.inc_by(
331 ru_stime, request.method, servlet_classname, tag
329 ru_stime, request.method, self.name, tag
332330 )
333331 response_db_txn_count.inc_by(
334 context.db_txn_count, request.method, servlet_classname, tag
332 context.db_txn_count, request.method, self.name, tag
335333 )
336334 response_db_txn_duration.inc_by(
337 context.db_txn_duration, request.method, servlet_classname, tag
335 context.db_txn_duration, request.method, self.name, tag
338336 )
339337
340338
1818
1919 from synapse.util.logutils import log_function
2020 from synapse.util.async import ObservableDeferred
21 from synapse.util.logcontext import PreserveLoggingContext
21 from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
22 from synapse.util.metrics import Measure
2223 from synapse.types import StreamToken
2324 from synapse.visibility import filter_events_for_client
2425 import synapse.metrics
6667 so that it can remove itself from the indexes in the Notifier class.
6768 """
6869
69 def __init__(self, user_id, rooms, current_token, time_now_ms,
70 appservice=None):
70 def __init__(self, user_id, rooms, current_token, time_now_ms):
7171 self.user_id = user_id
72 self.appservice = appservice
7372 self.rooms = set(rooms)
7473 self.current_token = current_token
7574 self.last_notified_ms = time_now_ms
106105
107106 notifier.user_to_user_stream.pop(self.user_id)
108107
109 if self.appservice:
110 notifier.appservice_to_user_streams.get(
111 self.appservice, set()
112 ).discard(self)
113
114108 def count_listeners(self):
115109 return len(self.notify_deferred.observers())
116110
141135 def __init__(self, hs):
142136 self.user_to_user_stream = {}
143137 self.room_to_user_streams = {}
144 self.appservice_to_user_streams = {}
145138
146139 self.event_sources = hs.get_event_sources()
147140 self.store = hs.get_datastore()
167160 all_user_streams |= x
168161 for x in self.user_to_user_stream.values():
169162 all_user_streams.add(x)
170 for x in self.appservice_to_user_streams.values():
171 all_user_streams |= x
172163
173164 return sum(stream.count_listeners() for stream in all_user_streams)
174165 metrics.register_callback("listeners", count_listeners)
181172 "users",
182173 lambda: len(self.user_to_user_stream),
183174 )
184 metrics.register_callback(
185 "appservices",
186 lambda: count(bool, self.appservice_to_user_streams.values()),
187 )
188
175
176 @preserve_fn
189177 def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
190178 extra_users=[]):
191179 """ Used by handlers to inform the notifier something has happened
207195
208196 self.notify_replication()
209197
198 @preserve_fn
210199 def _notify_pending_new_room_events(self, max_room_stream_id):
211200 """Notify for the room events that were queued waiting for a previous
212201 event to be persisted.
224213 else:
225214 self._on_new_room_event(event, room_stream_id, extra_users)
226215
216 @preserve_fn
227217 def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
228218 """Notify any user streams that are interested in this room event"""
229219 # poke any interested application service.
230 self.appservice_handler.notify_interested_services(event)
231
232 app_streams = set()
233
234 for appservice in self.appservice_to_user_streams:
235 # TODO (kegan): Redundant appservice listener checks?
236 # App services will already be in the room_to_user_streams set, but
237 # that isn't enough. They need to be checked here in order to
238 # receive *invites* for users they are interested in. Does this
239 # make the room_to_user_streams check somewhat obselete?
240 if appservice.is_interested(event):
241 app_user_streams = self.appservice_to_user_streams.get(
242 appservice, set()
243 )
244 app_streams |= app_user_streams
220 self.appservice_handler.notify_interested_services(room_stream_id)
245221
246222 if event.type == EventTypes.Member and event.membership == Membership.JOIN:
247223 self._user_joined_room(event.state_key, event.room_id)
250226 "room_key", room_stream_id,
251227 users=extra_users,
252228 rooms=[event.room_id],
253 extra_streams=app_streams,
254 )
255
256 def on_new_event(self, stream_key, new_token, users=[], rooms=[],
257 extra_streams=set()):
229 )
230
231 @preserve_fn
232 def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
258233 """ Used to inform listeners that something has happend event wise.
259234
260235 Will wake up all listeners for the given users and rooms.
261236 """
262237 with PreserveLoggingContext():
263 user_streams = set()
264
265 for user in users:
266 user_stream = self.user_to_user_stream.get(str(user))
267 if user_stream is not None:
268 user_streams.add(user_stream)
269
270 for room in rooms:
271 user_streams |= self.room_to_user_streams.get(room, set())
272
273 time_now_ms = self.clock.time_msec()
274 for user_stream in user_streams:
275 try:
276 user_stream.notify(stream_key, new_token, time_now_ms)
277 except:
278 logger.exception("Failed to notify listener")
279
280 self.notify_replication()
281
238 with Measure(self.clock, "on_new_event"):
239 user_streams = set()
240
241 for user in users:
242 user_stream = self.user_to_user_stream.get(str(user))
243 if user_stream is not None:
244 user_streams.add(user_stream)
245
246 for room in rooms:
247 user_streams |= self.room_to_user_streams.get(room, set())
248
249 time_now_ms = self.clock.time_msec()
250 for user_stream in user_streams:
251 try:
252 user_stream.notify(stream_key, new_token, time_now_ms)
253 except:
254 logger.exception("Failed to notify listener")
255
256 self.notify_replication()
257
258 @preserve_fn
282259 def on_new_replication_data(self):
283260 """Used to inform replication listeners that something has happend
284261 without waking up any of the normal user event streams"""
293270 """
294271 user_stream = self.user_to_user_stream.get(user_id)
295272 if user_stream is None:
296 appservice = yield self.store.get_app_service_by_user_id(user_id)
297273 current_token = yield self.event_sources.get_current_token()
298274 if room_ids is None:
299275 rooms = yield self.store.get_rooms_for_user(user_id)
301277 user_stream = _NotifierUserStream(
302278 user_id=user_id,
303279 rooms=room_ids,
304 appservice=appservice,
305280 current_token=current_token,
306281 time_now_ms=self.clock.time_msec(),
307282 )
476451 s = self.room_to_user_streams.setdefault(room, set())
477452 s.add(user_stream)
478453
479 if user_stream.appservice:
480 self.appservice_to_user_stream.setdefault(
481 user_stream.appservice, set()
482 ).add(user_stream)
483
484454 def _user_joined_room(self, user_id, room_id):
485455 new_user_stream = self.user_to_user_stream.get(user_id)
486456 if new_user_stream is not None:
3737
3838 @defer.inlineCallbacks
3939 def handle_push_actions_for_event(self, event, context):
40 with Measure(self.clock, "handle_push_actions_for_event"):
40 with Measure(self.clock, "evaluator_for_event"):
4141 bulk_evaluator = yield evaluator_for_event(
42 event, self.hs, self.store, context.current_state
42 event, self.hs, self.store, context.state_group, context.current_state
4343 )
4444
45 with Measure(self.clock, "action_for_event_by_user"):
4546 actions_by_user = yield bulk_evaluator.action_for_event_by_user(
4647 event, context.current_state
4748 )
4849
49 context.push_actions = [
50 (uid, actions) for uid, actions in actions_by_user.items()
51 ]
50 context.push_actions = [
51 (uid, actions) for uid, actions in actions_by_user.items()
52 ]
216216 'dont_notify'
217217 ]
218218 },
219 # This was changed from underride to override so it's closer in priority
220 # to the content rules where the user name highlight rule lives. This
221 # way a room rule is lower priority than both but a custom override rule
222 # is higher priority than both.
223 {
224 'rule_id': 'global/override/.m.rule.contains_display_name',
225 'conditions': [
226 {
227 'kind': 'contains_display_name'
228 }
229 ],
230 'actions': [
231 'notify',
232 {
233 'set_tweak': 'sound',
234 'value': 'default'
235 }, {
236 'set_tweak': 'highlight'
237 }
238 ]
239 },
219240 ]
220241
221242
238259 }, {
239260 'set_tweak': 'highlight',
240261 'value': False
241 }
242 ]
243 },
244 {
245 'rule_id': 'global/underride/.m.rule.contains_display_name',
246 'conditions': [
247 {
248 'kind': 'contains_display_name'
249 }
250 ],
251 'actions': [
252 'notify',
253 {
254 'set_tweak': 'sound',
255 'value': 'default'
256 }, {
257 'set_tweak': 'highlight'
258262 }
259263 ]
260264 },
3535
3636
3737 @defer.inlineCallbacks
38 def evaluator_for_event(event, hs, store, current_state):
39 room_id = event.room_id
40 # We also will want to generate notifs for other people in the room so
41 # their unread countss are correct in the event stream, but to avoid
42 # generating them for bot / AS users etc, we only do so for people who've
43 # sent a read receipt into the room.
44
45 local_users_in_room = set(
46 e.state_key for e in current_state.values()
47 if e.type == EventTypes.Member and e.membership == Membership.JOIN
48 and hs.is_mine_id(e.state_key)
38 def evaluator_for_event(event, hs, store, state_group, current_state):
39 rules_by_user = yield store.bulk_get_push_rules_for_room(
40 event.room_id, state_group, current_state
4941 )
50
51 # users in the room who have pushers need to get push rules run because
52 # that's how their pushers work
53 if_users_with_pushers = yield store.get_if_users_have_pushers(
54 local_users_in_room
55 )
56 user_ids = set(
57 uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
58 )
59
60 users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
61
62 # any users with pushers must be ours: they have pushers
63 for uid in users_with_receipts:
64 if uid in local_users_in_room:
65 user_ids.add(uid)
6642
6743 # if this event is an invite event, we may need to run rules for the user
6844 # who's been invited, otherwise they won't get told they've been invited
7147 if invited_user and hs.is_mine_id(invited_user):
7248 has_pusher = yield store.user_has_pusher(invited_user)
7349 if has_pusher:
74 user_ids.add(invited_user)
75
76 rules_by_user = yield _get_rules(room_id, user_ids, store)
50 rules_by_user[invited_user] = yield store.get_push_rules_for_user(
51 invited_user
52 )
7753
7854 defer.returnValue(BulkPushRuleEvaluator(
79 room_id, rules_by_user, user_ids, store
55 event.room_id, rules_by_user, store
8056 ))
8157
8258
8965 the same logic to run the actual rules, but could be optimised further
9066 (see https://matrix.org/jira/browse/SYN-562)
9167 """
92 def __init__(self, room_id, rules_by_user, users_in_room, store):
68 def __init__(self, room_id, rules_by_user, store):
9369 self.room_id = room_id
9470 self.rules_by_user = rules_by_user
95 self.users_in_room = users_in_room
9671 self.store = store
9772
9873 @defer.inlineCallbacks
1616 from synapse.util.presentable_names import (
1717 calculate_room_name, name_from_member_event
1818 )
19 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
1920
2021
2122 @defer.inlineCallbacks
2223 def get_badge_count(store, user_id):
23 invites, joins = yield defer.gatherResults([
24 store.get_invited_rooms_for_user(user_id),
25 store.get_rooms_for_user(user_id),
26 ], consumeErrors=True)
24 invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
25 preserve_fn(store.get_invited_rooms_for_user)(user_id),
26 preserve_fn(store.get_rooms_for_user)(user_id),
27 ], consumeErrors=True))
2728
2829 my_receipts_by_room = yield store.get_receipts_for_user(
2930 user_id, "m.read",
1616 from twisted.internet import defer
1717
1818 import pusher
19 from synapse.util.logcontext import preserve_fn
19 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
2020 from synapse.util.async import run_on_reactor
2121
2222 import logging
101101 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
102102
103103 @defer.inlineCallbacks
104 def remove_pushers_by_user(self, user_id, except_token_ids=[]):
104 def remove_pushers_by_user(self, user_id, except_access_token_id=None):
105105 all = yield self.store.get_all_pushers()
106106 logger.info(
107 "Removing all pushers for user %s except access tokens ids %r",
108 user_id, except_token_ids
107 "Removing all pushers for user %s except access tokens id %r",
108 user_id, except_access_token_id
109109 )
110110 for p in all:
111 if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
111 if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
112112 logger.info(
113113 "Removing pusher for app id %s, pushkey %s, user %s",
114114 p['app_id'], p['pushkey'], p['user_name']
129129 if u in self.pushers:
130130 for p in self.pushers[u].values():
131131 deferreds.append(
132 p.on_new_notifications(min_stream_id, max_stream_id)
132 preserve_fn(p.on_new_notifications)(
133 min_stream_id, max_stream_id
134 )
133135 )
134136
135 yield defer.gatherResults(deferreds)
137 yield preserve_context_over_deferred(defer.gatherResults(deferreds))
136138 except:
137139 logger.exception("Exception in pusher on_new_notifications")
138140
154156 if u in self.pushers:
155157 for p in self.pushers[u].values():
156158 deferreds.append(
157 p.on_new_receipts(min_stream_id, max_stream_id)
159 preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
158160 )
159161
160 yield defer.gatherResults(deferreds)
162 yield preserve_context_over_deferred(defer.gatherResults(deferreds))
161163 except:
162164 logger.exception("Exception in pusher on_new_receipts")
163165
4040 ("push_rules",),
4141 ("pushers",),
4242 ("state",),
43 ("caches",),
4344 )
4445
4546
6970 * "backfill": Old events that have been backfilled from other servers.
7071 * "push_rules": Per user changes to push rules.
7172 * "pushers": Per user changes to their pushers.
73 * "caches": Cache invalidations.
7274
7375 The API takes two additional query parameters:
7476
128130 push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
129131 pushers_token = self.store.get_pushers_stream_token()
130132 state_token = self.store.get_state_stream_token()
133 caches_token = self.store.get_cache_stream_token()
131134
132135 defer.returnValue(_ReplicationToken(
133136 room_stream_token,
139142 push_rules_token,
140143 pushers_token,
141144 state_token,
145 caches_token,
142146 ))
143147
144148 @request_handler()
187191 yield self.push_rules(writer, current_token, limit, request_streams)
188192 yield self.pushers(writer, current_token, limit, request_streams)
189193 yield self.state(writer, current_token, limit, request_streams)
194 yield self.caches(writer, current_token, limit, request_streams)
190195 self.streams(writer, current_token, request_streams)
191196
192197 logger.info("Replicated %d rows", writer.total)
378383 "position", "type", "state_key", "event_id"
379384 ))
380385
386 @defer.inlineCallbacks
387 def caches(self, writer, current_token, limit, request_streams):
388 current_position = current_token.caches
389
390 caches = request_streams.get("caches")
391
392 if caches is not None:
393 updated_caches = yield self.store.get_all_updated_caches(
394 caches, current_position, limit
395 )
396 writer.write_header_and_rows("caches", updated_caches, (
397 "position", "cache_func", "keys", "invalidation_ts"
398 ))
399
381400
382401 class _Writer(object):
383402 """Writes the streams as a JSON object as the response to the request"""
406425
407426 class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
408427 "events", "presence", "typing", "receipts", "account_data", "backfill",
409 "push_rules", "pushers", "state"
428 "push_rules", "pushers", "state", "caches",
410429 ))):
411430 __slots__ = []
412431
1313 # limitations under the License.
1414
1515 from synapse.storage._base import SQLBaseStore
16 from synapse.storage.engines import PostgresEngine
1617 from twisted.internet import defer
18
19 from ._slaved_id_tracker import SlavedIdTracker
20
21 import logging
22
23 logger = logging.getLogger(__name__)
1724
1825
1926 class BaseSlavedStore(SQLBaseStore):
2027 def __init__(self, db_conn, hs):
2128 super(BaseSlavedStore, self).__init__(hs)
29 if isinstance(self.database_engine, PostgresEngine):
30 self._cache_id_gen = SlavedIdTracker(
31 db_conn, "cache_invalidation_stream", "stream_id",
32 )
33 else:
34 self._cache_id_gen = None
2235
2336 def stream_positions(self):
24 return {}
37 pos = {}
38 if self._cache_id_gen:
39 pos["caches"] = self._cache_id_gen.get_current_token()
40 return pos
2541
2642 def process_replication(self, result):
43 stream = result.get("caches")
44 if stream:
45 for row in stream["rows"]:
46 (
47 position, cache_func, keys, invalidation_ts,
48 ) = row
49
50 try:
51 getattr(self, cache_func).invalidate(tuple(keys))
52 except AttributeError:
53 logger.info("Got unexpected cache_func: %r", cache_func)
54 self._cache_id_gen.advance(int(stream["position"]))
2755 return defer.succeed(None)
2727
2828 get_app_service_by_token = DataStore.get_app_service_by_token.__func__
2929 get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
30 get_app_services = DataStore.get_app_services.__func__
31 get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
32 create_appservice_txn = DataStore.create_appservice_txn.__func__
33 get_appservices_by_state = DataStore.get_appservices_by_state.__func__
34 get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
35 _get_last_txn = DataStore._get_last_txn.__func__
36 complete_appservice_txn = DataStore.complete_appservice_txn.__func__
37 get_appservice_state = DataStore.get_appservice_state.__func__
38 set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
39 set_appservice_state = DataStore.set_appservice_state.__func__
1919 class DirectoryStore(BaseSlavedStore):
2020 get_aliases_for_room = DirectoryStore.__dict__[
2121 "get_aliases_for_room"
22 ].orig
22 ]
2424 # TODO: use the cached version and invalidate deleted tokens
2525 get_user_by_access_token = RegistrationStore.__dict__[
2626 "get_user_by_access_token"
27 ].orig
27 ]
2828
2929 _query_for_auth = DataStore._query_for_auth.__func__
30 get_user_by_id = RegistrationStore.__dict__[
31 "get_user_by_id"
32 ]
4545 account_data,
4646 report_event,
4747 openid,
48 notifications,
4849 devices,
50 thirdparty,
4951 )
5052
5153 from synapse.http.server import JsonResource
9092 account_data.register_servlets(hs, client_resource)
9193 report_event.register_servlets(hs, client_resource)
9294 openid.register_servlets(hs, client_resource)
95 notifications.register_servlets(hs, client_resource)
9396 devices.register_servlets(hs, client_resource)
97 thirdparty.register_servlets(hs, client_resource)
2626
2727 class WhoisRestServlet(ClientV1RestServlet):
2828 PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
29
30 def __init__(self, hs):
31 super(WhoisRestServlet, self).__init__(hs)
32 self.handlers = hs.get_handlers()
2933
3034 @defer.inlineCallbacks
3135 def on_GET(self, request, user_id):
8185 "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
8286 )
8387
88 def __init__(self, hs):
89 super(PurgeHistoryRestServlet, self).__init__(hs)
90 self.handlers = hs.get_handlers()
91
8492 @defer.inlineCallbacks
8593 def on_POST(self, request, room_id, event_id):
8694 requester = yield self.auth.get_user_by_req(request)
5656 hs (synapse.server.HomeServer):
5757 """
5858 self.hs = hs
59 self.handlers = hs.get_handlers()
6059 self.builder_factory = hs.get_event_builder_factory()
6160 self.auth = hs.get_v1auth()
6261 self.txns = HttpTransactionStore()
3434
3535 class ClientDirectoryServer(ClientV1RestServlet):
3636 PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
37
38 def __init__(self, hs):
39 super(ClientDirectoryServer, self).__init__(hs)
40 self.handlers = hs.get_handlers()
3741
3842 @defer.inlineCallbacks
3943 def on_GET(self, request, room_alias):
145149 def __init__(self, hs):
146150 super(ClientDirectoryListServer, self).__init__(hs)
147151 self.store = hs.get_datastore()
152 self.handlers = hs.get_handlers()
148153
149154 @defer.inlineCallbacks
150155 def on_GET(self, request, room_id):
3131
3232 DEFAULT_LONGPOLL_TIME_MS = 30000
3333
34 def __init__(self, hs):
35 super(EventStreamRestServlet, self).__init__(hs)
36 self.event_stream_handler = hs.get_event_stream_handler()
37
3438 @defer.inlineCallbacks
3539 def on_GET(self, request):
3640 requester = yield self.auth.get_user_by_req(
4549 if "room_id" in request.args:
4650 room_id = request.args["room_id"][0]
4751
48 handler = self.handlers.event_stream_handler
4952 pagin_config = PaginationConfig.from_request(request)
5053 timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
5154 if "timeout" in request.args:
5659
5760 as_client_event = "raw" not in request.args
5861
59 chunk = yield handler.get_stream(
62 chunk = yield self.event_stream_handler.get_stream(
6063 requester.user.to_string(),
6164 pagin_config,
6265 timeout=timeout,
7982 def __init__(self, hs):
8083 super(EventRestServlet, self).__init__(hs)
8184 self.clock = hs.get_clock()
85 self.event_handler = hs.get_event_handler()
8286
8387 @defer.inlineCallbacks
8488 def on_GET(self, request, event_id):
8589 requester = yield self.auth.get_user_by_req(request)
86 handler = self.handlers.event_handler
87 event = yield handler.get_event(requester.user, event_id)
90 event = yield self.event_handler.get_event(requester.user, event_id)
8891
8992 time_now = self.clock.time_msec()
9093 if event:
2222 class InitialSyncRestServlet(ClientV1RestServlet):
2323 PATTERNS = client_path_patterns("/initialSync$")
2424
25 def __init__(self, hs):
26 super(InitialSyncRestServlet, self).__init__(hs)
27 self.handlers = hs.get_handlers()
28
2529 @defer.inlineCallbacks
2630 def on_GET(self, request):
2731 requester = yield self.auth.get_user_by_req(request)
5353 self.jwt_secret = hs.config.jwt_secret
5454 self.jwt_algorithm = hs.config.jwt_algorithm
5555 self.cas_enabled = hs.config.cas_enabled
56 self.cas_server_url = hs.config.cas_server_url
57 self.cas_required_attributes = hs.config.cas_required_attributes
58 self.servername = hs.config.server_name
59 self.http_client = hs.get_simple_http_client()
6056 self.auth_handler = self.hs.get_auth_handler()
6157 self.device_handler = self.hs.get_device_handler()
58 self.handlers = hs.get_handlers()
6259
6360 def on_GET(self, request):
6461 flows = []
109106 LoginRestServlet.JWT_TYPE):
110107 result = yield self.do_jwt_login(login_submission)
111108 defer.returnValue(result)
112 # TODO Delete this after all CAS clients switch to token login instead
113 elif self.cas_enabled and (login_submission["type"] ==
114 LoginRestServlet.CAS_TYPE):
115 uri = "%s/proxyValidate" % (self.cas_server_url,)
116 args = {
117 "ticket": login_submission["ticket"],
118 "service": login_submission["service"]
119 }
120 body = yield self.http_client.get_raw(uri, args)
121 result = yield self.do_cas_login(body)
122 defer.returnValue(result)
123109 elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
124110 result = yield self.do_token_login(login_submission)
125111 defer.returnValue(result)
187173 "home_server": self.hs.hostname,
188174 "device_id": device_id,
189175 }
190
191 defer.returnValue((200, result))
192
193 # TODO Delete this after all CAS clients switch to token login instead
194 @defer.inlineCallbacks
195 def do_cas_login(self, cas_response_body):
196 user, attributes = self.parse_cas_response(cas_response_body)
197
198 for required_attribute, required_value in self.cas_required_attributes.items():
199 # If required attribute was not in CAS Response - Forbidden
200 if required_attribute not in attributes:
201 raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
202
203 # Also need to check value
204 if required_value is not None:
205 actual_value = attributes[required_attribute]
206 # If required attribute value does not match expected - Forbidden
207 if required_value != actual_value:
208 raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
209
210 user_id = UserID.create(user, self.hs.hostname).to_string()
211 auth_handler = self.auth_handler
212 registered_user_id = yield auth_handler.check_user_exists(user_id)
213 if registered_user_id:
214 access_token, refresh_token = (
215 yield auth_handler.get_login_tuple_for_user_id(
216 registered_user_id
217 )
218 )
219 result = {
220 "user_id": registered_user_id, # may have changed
221 "access_token": access_token,
222 "refresh_token": refresh_token,
223 "home_server": self.hs.hostname,
224 }
225
226 else:
227 user_id, access_token = (
228 yield self.handlers.registration_handler.register(localpart=user)
229 )
230 result = {
231 "user_id": user_id, # may have changed
232 "access_token": access_token,
233 "home_server": self.hs.hostname,
234 }
235176
236177 defer.returnValue((200, result))
237178
291232 }
292233
293234 defer.returnValue((200, result))
294
295 # TODO Delete this after all CAS clients switch to token login instead
296 def parse_cas_response(self, cas_response_body):
297 root = ET.fromstring(cas_response_body)
298 if not root.tag.endswith("serviceResponse"):
299 raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
300 if not root[0].tag.endswith("authenticationSuccess"):
301 raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
302 for child in root[0]:
303 if child.tag.endswith("user"):
304 user = child.text
305 if child.tag.endswith("attributes"):
306 attributes = {}
307 for attribute in child:
308 # ElementTree library expands the namespace in attribute tags
309 # to the full URL of the namespace.
310 # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
311 # We don't care about namespace here and it will always be encased in
312 # curly braces, so we remove them.
313 if "}" in attribute.tag:
314 attributes[attribute.tag.split("}")[1]] = attribute.text
315 else:
316 attributes[attribute.tag] = attribute.text
317 if user is None or attributes is None:
318 raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
319
320 return (user, attributes)
321235
322236 def _register_device(self, user_id, login_submission):
323237 """Register a device for a user.
346260 def __init__(self, hs):
347261 super(SAML2RestServlet, self).__init__(hs)
348262 self.sp_config = hs.config.saml2_config_path
263 self.handlers = hs.get_handlers()
349264
350265 @defer.inlineCallbacks
351266 def on_POST(self, request):
383298 defer.returnValue((200, {"status": "not_authenticated"}))
384299
385300
386 # TODO Delete this after all CAS clients switch to token login instead
387 class CasRestServlet(ClientV1RestServlet):
388 PATTERNS = client_path_patterns("/login/cas", releases=())
389
390 def __init__(self, hs):
391 super(CasRestServlet, self).__init__(hs)
392 self.cas_server_url = hs.config.cas_server_url
393
394 def on_GET(self, request):
395 return (200, {"serverUrl": self.cas_server_url})
396
397
398301 class CasRedirectServlet(ClientV1RestServlet):
399302 PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
400303
426329 self.cas_server_url = hs.config.cas_server_url
427330 self.cas_service_url = hs.config.cas_service_url
428331 self.cas_required_attributes = hs.config.cas_required_attributes
332 self.auth_handler = hs.get_auth_handler()
333 self.handlers = hs.get_handlers()
429334
430335 @defer.inlineCallbacks
431336 def on_GET(self, request):
478383 return urlparse.urlunparse(url_parts)
479384
480385 def parse_cas_response(self, cas_response_body):
481 root = ET.fromstring(cas_response_body)
482 if not root.tag.endswith("serviceResponse"):
483 raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
484 if not root[0].tag.endswith("authenticationSuccess"):
485 raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
486 for child in root[0]:
487 if child.tag.endswith("user"):
488 user = child.text
489 if child.tag.endswith("attributes"):
490 attributes = {}
491 for attribute in child:
492 # ElementTree library expands the namespace in attribute tags
493 # to the full URL of the namespace.
494 # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
495 # We don't care about namespace here and it will always be encased in
496 # curly braces, so we remove them.
497 if "}" in attribute.tag:
498 attributes[attribute.tag.split("}")[1]] = attribute.text
499 else:
500 attributes[attribute.tag] = attribute.text
501 if user is None or attributes is None:
502 raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
503
504 return (user, attributes)
386 user = None
387 attributes = None
388 try:
389 root = ET.fromstring(cas_response_body)
390 if not root.tag.endswith("serviceResponse"):
391 raise Exception("root of CAS response is not serviceResponse")
392 success = (root[0].tag.endswith("authenticationSuccess"))
393 for child in root[0]:
394 if child.tag.endswith("user"):
395 user = child.text
396 if child.tag.endswith("attributes"):
397 attributes = {}
398 for attribute in child:
399 # ElementTree library expands the namespace in
400 # attribute tags to the full URL of the namespace.
401 # We don't care about namespace here and it will always
402 # be encased in curly braces, so we remove them.
403 tag = attribute.tag
404 if "}" in tag:
405 tag = tag.split("}")[1]
406 attributes[tag] = attribute.text
407 if user is None:
408 raise Exception("CAS response does not contain user")
409 if attributes is None:
410 raise Exception("CAS response does not contain attributes")
411 except Exception:
412 logger.error("Error parsing CAS response", exc_info=1)
413 raise LoginError(401, "Invalid CAS response",
414 errcode=Codes.UNAUTHORIZED)
415 if not success:
416 raise LoginError(401, "Unsuccessful CAS response",
417 errcode=Codes.UNAUTHORIZED)
418 return user, attributes
505419
506420
507421 def register_servlets(hs, http_server):
511425 if hs.config.cas_enabled:
512426 CasRedirectServlet(hs).register(http_server)
513427 CasTicketServlet(hs).register(http_server)
514 CasRestServlet(hs).register(http_server)
515 # TODO PasswordResetRestServlet(hs).register(http_server)
2222
2323 class ProfileDisplaynameRestServlet(ClientV1RestServlet):
2424 PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
25
26 def __init__(self, hs):
27 super(ProfileDisplaynameRestServlet, self).__init__(hs)
28 self.handlers = hs.get_handlers()
2529
2630 @defer.inlineCallbacks
2731 def on_GET(self, request, user_id):
6165 class ProfileAvatarURLRestServlet(ClientV1RestServlet):
6266 PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
6367
68 def __init__(self, hs):
69 super(ProfileAvatarURLRestServlet, self).__init__(hs)
70 self.handlers = hs.get_handlers()
71
6472 @defer.inlineCallbacks
6573 def on_GET(self, request, user_id):
6674 user = UserID.from_string(user_id)
98106 class ProfileRestServlet(ClientV1RestServlet):
99107 PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
100108
109 def __init__(self, hs):
110 super(ProfileRestServlet, self).__init__(hs)
111 self.handlers = hs.get_handlers()
112
101113 @defer.inlineCallbacks
102114 def on_GET(self, request, user_id):
103115 user = UserID.from_string(user_id)
6464 self.sessions = {}
6565 self.enable_registration = hs.config.enable_registration
6666 self.auth_handler = hs.get_auth_handler()
67 self.handlers = hs.get_handlers()
6768
6869 def on_GET(self, request):
6970 if self.hs.config.enable_registration_captcha:
382383 super(CreateUserRestServlet, self).__init__(hs)
383384 self.store = hs.get_datastore()
384385 self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
386 self.handlers = hs.get_handlers()
385387
386388 @defer.inlineCallbacks
387389 def on_POST(self, request):
3434 class RoomCreateRestServlet(ClientV1RestServlet):
3535 # No PATTERN; we have custom dispatch rules here
3636
37 def __init__(self, hs):
38 super(RoomCreateRestServlet, self).__init__(hs)
39 self.handlers = hs.get_handlers()
40
3741 def register(self, http_server):
3842 PATTERNS = "/createRoom"
3943 register_txn_path(self, PATTERNS, http_server)
8185
8286 # TODO: Needs unit testing for generic events
8387 class RoomStateEventRestServlet(ClientV1RestServlet):
88 def __init__(self, hs):
89 super(RoomStateEventRestServlet, self).__init__(hs)
90 self.handlers = hs.get_handlers()
91
8492 def register(self, http_server):
8593 # /room/$roomid/state/$eventtype
8694 no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
165173 # TODO: Needs unit testing for generic events + feedback
166174 class RoomSendEventRestServlet(ClientV1RestServlet):
167175
176 def __init__(self, hs):
177 super(RoomSendEventRestServlet, self).__init__(hs)
178 self.handlers = hs.get_handlers()
179
168180 def register(self, http_server):
169181 # /rooms/$roomid/send/$event_type[/$txn_id]
170182 PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
209221
210222 # TODO: Needs unit testing for room ID + alias joins
211223 class JoinRoomAliasServlet(ClientV1RestServlet):
224 def __init__(self, hs):
225 super(JoinRoomAliasServlet, self).__init__(hs)
226 self.handlers = hs.get_handlers()
212227
213228 def register(self, http_server):
214229 # /join/$room_identifier[/$txn_id]
252267 action="join",
253268 txn_id=txn_id,
254269 remote_room_hosts=remote_room_hosts,
270 content=content,
255271 third_party_signed=content.get("third_party_signed", None),
256272 )
257273
294310 # TODO: Needs unit testing
295311 class RoomMemberListRestServlet(ClientV1RestServlet):
296312 PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
313
314 def __init__(self, hs):
315 super(RoomMemberListRestServlet, self).__init__(hs)
316 self.handlers = hs.get_handlers()
297317
298318 @defer.inlineCallbacks
299319 def on_GET(self, request, room_id):
320340 # TODO: Needs better unit testing
321341 class RoomMessageListRestServlet(ClientV1RestServlet):
322342 PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
343
344 def __init__(self, hs):
345 super(RoomMessageListRestServlet, self).__init__(hs)
346 self.handlers = hs.get_handlers()
323347
324348 @defer.inlineCallbacks
325349 def on_GET(self, request, room_id):
350374 class RoomStateRestServlet(ClientV1RestServlet):
351375 PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
352376
377 def __init__(self, hs):
378 super(RoomStateRestServlet, self).__init__(hs)
379 self.handlers = hs.get_handlers()
380
353381 @defer.inlineCallbacks
354382 def on_GET(self, request, room_id):
355383 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
367395 class RoomInitialSyncRestServlet(ClientV1RestServlet):
368396 PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
369397
398 def __init__(self, hs):
399 super(RoomInitialSyncRestServlet, self).__init__(hs)
400 self.handlers = hs.get_handlers()
401
370402 @defer.inlineCallbacks
371403 def on_GET(self, request, room_id):
372404 requester = yield self.auth.get_user_by_req(request, allow_guest=True)
387419 def __init__(self, hs):
388420 super(RoomEventContext, self).__init__(hs)
389421 self.clock = hs.get_clock()
422 self.handlers = hs.get_handlers()
390423
391424 @defer.inlineCallbacks
392425 def on_GET(self, request, room_id, event_id):
423456
424457
425458 class RoomForgetRestServlet(ClientV1RestServlet):
459 def __init__(self, hs):
460 super(RoomForgetRestServlet, self).__init__(hs)
461 self.handlers = hs.get_handlers()
462
426463 def register(self, http_server):
427464 PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
428465 register_txn_path(self, PATTERNS, http_server)
460497
461498 # TODO: Needs unit testing
462499 class RoomMembershipRestServlet(ClientV1RestServlet):
500
501 def __init__(self, hs):
502 super(RoomMembershipRestServlet, self).__init__(hs)
503 self.handlers = hs.get_handlers()
463504
464505 def register(self, http_server):
465506 # /rooms/$roomid/[invite|join|leave]
541582
542583
543584 class RoomRedactEventRestServlet(ClientV1RestServlet):
585 def __init__(self, hs):
586 super(RoomRedactEventRestServlet, self).__init__(hs)
587 self.handlers = hs.get_handlers()
588
544589 def register(self, http_server):
545590 PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
546591 register_txn_path(self, PATTERNS, http_server)
622667 PATTERNS = client_path_patterns(
623668 "/search$"
624669 )
670
671 def __init__(self, hs):
672 super(SearchRestServlet, self).__init__(hs)
673 self.handlers = hs.get_handlers()
625674
626675 @defer.inlineCallbacks
627676 def on_POST(self, request):
0 # -*- coding: utf-8 -*-
1 # Copyright 2016 OpenMarket Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from twisted.internet import defer
16
17 from synapse.http.servlet import (
18 RestServlet, parse_string, parse_integer
19 )
20 from synapse.events.utils import (
21 serialize_event, format_event_for_client_v2_without_room_id,
22 )
23
24 from ._base import client_v2_patterns
25
26 import logging
27
28 logger = logging.getLogger(__name__)
29
30
31 class NotificationsServlet(RestServlet):
32 PATTERNS = client_v2_patterns("/notifications$", releases=())
33
34 def __init__(self, hs):
35 super(NotificationsServlet, self).__init__()
36 self.store = hs.get_datastore()
37 self.auth = hs.get_auth()
38 self.clock = hs.get_clock()
39
40 @defer.inlineCallbacks
41 def on_GET(self, request):
42 requester = yield self.auth.get_user_by_req(request)
43 user_id = requester.user.to_string()
44
45 from_token = parse_string(request, "from", required=False)
46 limit = parse_integer(request, "limit", default=50)
47
48 limit = min(limit, 500)
49
50 push_actions = yield self.store.get_push_actions_for_user(
51 user_id, from_token, limit
52 )
53
54 receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
55 user_id, 'm.read'
56 )
57
58 notif_event_ids = [pa["event_id"] for pa in push_actions]
59 notif_events = yield self.store.get_events(notif_event_ids)
60
61 returned_push_actions = []
62
63 next_token = None
64
65 for pa in push_actions:
66 returned_pa = {
67 "room_id": pa["room_id"],
68 "profile_tag": pa["profile_tag"],
69 "actions": pa["actions"],
70 "ts": pa["received_ts"],
71 "event": serialize_event(
72 notif_events[pa["event_id"]],
73 self.clock.time_msec(),
74 event_format=format_event_for_client_v2_without_room_id,
75 ),
76 }
77
78 if pa["room_id"] not in receipts_by_room:
79 returned_pa["read"] = False
80 else:
81 receipt = receipts_by_room[pa["room_id"]]
82
83 returned_pa["read"] = (
84 receipt["topological_ordering"], receipt["stream_ordering"]
85 ) >= (
86 pa["topological_ordering"], pa["stream_ordering"]
87 )
88 returned_push_actions.append(returned_pa)
89 next_token = pa["stream_ordering"]
90
91 defer.returnValue((200, {
92 "notifications": returned_push_actions,
93 "next_token": next_token,
94 }))
95
96
97 def register_servlets(hs, http_server):
98 NotificationsServlet(hs).register(http_server)
402402 # register the user's device
403403 device_id = params.get("device_id")
404404 initial_display_name = params.get("initial_device_display_name")
405 device_id = self.device_handler.check_device_registered(
405 return self.device_handler.check_device_registered(
406406 user_id, device_id, initial_display_name
407407 )
408 return device_id
409408
410409 @defer.inlineCallbacks
411410 def _do_guest_registration(self):
145145 affect_presence = set_presence != PresenceState.OFFLINE
146146
147147 if affect_presence:
148 yield self.presence_handler.set_state(user, {"presence": set_presence})
148 yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
149149
150150 context = yield self.presence_handler.user_syncing(
151151 user.to_string(), affect_presence=affect_presence,
0 # -*- coding: utf-8 -*-
1 # Copyright 2015, 2016 OpenMarket Ltd
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15
16 import logging
17
18 from twisted.internet import defer
19
20 from synapse.http.servlet import RestServlet
21 from synapse.types import ThirdPartyEntityKind
22 from ._base import client_v2_patterns
23
24 logger = logging.getLogger(__name__)
25
26
27 class ThirdPartyUserServlet(RestServlet):
28 PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
29 releases=())
30
31 def __init__(self, hs):
32 super(ThirdPartyUserServlet, self).__init__()
33
34 self.auth = hs.get_auth()
35 self.appservice_handler = hs.get_application_service_handler()
36
37 @defer.inlineCallbacks
38 def on_GET(self, request, protocol):
39 yield self.auth.get_user_by_req(request)
40
41 fields = request.args
42 del fields["access_token"]
43
44 results = yield self.appservice_handler.query_3pe(
45 ThirdPartyEntityKind.USER, protocol, fields
46 )
47
48 defer.returnValue((200, results))
49
50
51 class ThirdPartyLocationServlet(RestServlet):
52 PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
53 releases=())
54
55 def __init__(self, hs):
56 super(ThirdPartyLocationServlet, self).__init__()
57
58 self.auth = hs.get_auth()
59 self.appservice_handler = hs.get_application_service_handler()
60
61 @defer.inlineCallbacks
62 def on_GET(self, request, protocol):
63 yield self.auth.get_user_by_req(request)
64
65 fields = request.args
66 del fields["access_token"]
67
68 results = yield self.appservice_handler.query_3pe(
69 ThirdPartyEntityKind.LOCATION, protocol, fields
70 )
71
72 defer.returnValue((200, results))
73
74
75 def register_servlets(hs, http_server):
76 ThirdPartyUserServlet(hs).register(http_server)
77 ThirdPartyLocationServlet(hs).register(http_server)
1414 from synapse.http.server import request_handler, respond_with_json_bytes
1515 from synapse.http.servlet import parse_integer, parse_json_object_from_request
1616 from synapse.api.errors import SynapseError, Codes
17 from synapse.crypto.keyring import KeyLookupError
1718
1819 from twisted.web.resource import Resource
1920 from twisted.web.server import NOT_DONE_YET
209210 yield self.keyring.get_server_verify_key_v2_direct(
210211 server_name, key_ids
211212 )
213 except KeyLookupError as e:
214 logger.info("Failed to fetch key: %s", e)
212215 except:
213216 logger.exception("Failed to get key for %r", server_name)
214 pass
215217 yield self.query_keys(
216218 request, query, query_remote_on_cache_miss=False
217219 )
4444 @request_handler()
4545 @defer.inlineCallbacks
4646 def _async_render_GET(self, request):
47 request.setHeader("Content-Security-Policy", "sandbox")
4748 server_name, media_id, name = parse_media_id(request)
4849 if server_name == self.server_name:
4950 yield self._respond_local_file(request, media_id, name)
2828 from synapse.util.async import ObservableDeferred
2929 from synapse.util.stringutils import is_ascii
3030
31 from copy import deepcopy
32
3331 import os
3432 import re
3533 import fnmatch
3634 import cgi
3735 import ujson as json
3836 import urlparse
37 import itertools
3938
4039 import logging
4140 logger = logging.getLogger(__name__)
162161
163162 logger.debug("got media_info of '%s'" % media_info)
164163
165 if self._is_media(media_info['media_type']):
164 if _is_media(media_info['media_type']):
166165 dims = yield self.media_repo._generate_local_thumbnails(
167166 media_info['filesystem_id'], media_info
168167 )
183182 logger.warn("Couldn't get dims for %s" % url)
184183
185184 # define our OG response for this media
186 elif self._is_html(media_info['media_type']):
185 elif _is_html(media_info['media_type']):
187186 # TODO: somehow stop a big HTML tree from exploding synapse's RAM
188
189 from lxml import etree
190187
191188 file = open(media_info['filename'])
192189 body = file.read()
198195 match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
199196 encoding = match.group(1) if match else "utf-8"
200197
201 try:
202 parser = etree.HTMLParser(recover=True, encoding=encoding)
203 tree = etree.fromstring(body, parser)
204 og = yield self._calc_og(tree, media_info, requester)
205 except UnicodeDecodeError:
206 # blindly try decoding the body as utf-8, which seems to fix
207 # the charset mismatches on https://google.com
208 parser = etree.HTMLParser(recover=True, encoding=encoding)
209 tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
210 og = yield self._calc_og(tree, media_info, requester)
211
198 og = decode_and_calc_og(body, media_info['uri'], encoding)
199
200 # pre-cache the image for posterity
201 # FIXME: it might be cleaner to use the same flow as the main /preview_url
202 # request itself and benefit from the same caching etc. But for now we
203 # just rely on the caching on the master request to speed things up.
204 if 'og:image' in og and og['og:image']:
205 image_info = yield self._download_url(
206 _rebase_url(og['og:image'], media_info['uri']), requester.user
207 )
208
209 if _is_media(image_info['media_type']):
210 # TODO: make sure we don't choke on white-on-transparent images
211 dims = yield self.media_repo._generate_local_thumbnails(
212 image_info['filesystem_id'], image_info
213 )
214 if dims:
215 og["og:image:width"] = dims['width']
216 og["og:image:height"] = dims['height']
217 else:
218 logger.warn("Couldn't get dims for %s" % og["og:image"])
219
220 og["og:image"] = "mxc://%s/%s" % (
221 self.server_name, image_info['filesystem_id']
222 )
223 og["og:image:type"] = image_info['media_type']
224 og["matrix:image:size"] = image_info['media_length']
225 else:
226 del og["og:image"]
212227 else:
213228 logger.warn("Failed to find any OG data in %s", url)
214229 og = {}
230245 )
231246
232247 respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
233
234 @defer.inlineCallbacks
235 def _calc_og(self, tree, media_info, requester):
236 # suck our tree into lxml and define our OG response.
237
238 # if we see any image URLs in the OG response, then spider them
239 # (although the client could choose to do this by asking for previews of those
240 # URLs to avoid DoSing the server)
241
242 # "og:type" : "video",
243 # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
244 # "og:site_name" : "YouTube",
245 # "og:video:type" : "application/x-shockwave-flash",
246 # "og:description" : "Fun stuff happening here",
247 # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
248 # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
249 # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
250 # "og:video:width" : "1280"
251 # "og:video:height" : "720",
252 # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
253
254 og = {}
255 for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
256 if 'content' in tag.attrib:
257 og[tag.attrib['property']] = tag.attrib['content']
258
259 # TODO: grab article: meta tags too, e.g.:
260
261 # "article:publisher" : "https://www.facebook.com/thethudonline" />
262 # "article:author" content="https://www.facebook.com/thethudonline" />
263 # "article:tag" content="baby" />
264 # "article:section" content="Breaking News" />
265 # "article:published_time" content="2016-03-31T19:58:24+00:00" />
266 # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
267
268 if 'og:title' not in og:
269 # do some basic spidering of the HTML
270 title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
271 og['og:title'] = title[0].text.strip() if title else None
272
273 if 'og:image' not in og:
274 # TODO: extract a favicon failing all else
275 meta_image = tree.xpath(
276 "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
277 )
278 if meta_image:
279 og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
280 else:
281 # TODO: consider inlined CSS styles as well as width & height attribs
282 images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
283 images = sorted(images, key=lambda i: (
284 -1 * float(i.attrib['width']) * float(i.attrib['height'])
285 ))
286 if not images:
287 images = tree.xpath("//img[@src]")
288 if images:
289 og['og:image'] = images[0].attrib['src']
290
291 # pre-cache the image for posterity
292 # FIXME: it might be cleaner to use the same flow as the main /preview_url
293 # request itself and benefit from the same caching etc. But for now we
294 # just rely on the caching on the master request to speed things up.
295 if 'og:image' in og and og['og:image']:
296 image_info = yield self._download_url(
297 self._rebase_url(og['og:image'], media_info['uri']), requester.user
298 )
299
300 if self._is_media(image_info['media_type']):
301 # TODO: make sure we don't choke on white-on-transparent images
302 dims = yield self.media_repo._generate_local_thumbnails(
303 image_info['filesystem_id'], image_info
304 )
305 if dims:
306 og["og:image:width"] = dims['width']
307 og["og:image:height"] = dims['height']
308 else:
309 logger.warn("Couldn't get dims for %s" % og["og:image"])
310
311 og["og:image"] = "mxc://%s/%s" % (
312 self.server_name, image_info['filesystem_id']
313 )
314 og["og:image:type"] = image_info['media_type']
315 og["matrix:image:size"] = image_info['media_length']
316 else:
317 del og["og:image"]
318
319 if 'og:description' not in og:
320 meta_description = tree.xpath(
321 "//*/meta"
322 "[translate(@name, 'DESCRIPTION', 'description')='description']"
323 "/@content")
324 if meta_description:
325 og['og:description'] = meta_description[0]
326 else:
327 # grab any text nodes which are inside the <body/> tag...
328 # unless they are within an HTML5 semantic markup tag...
329 # <header/>, <nav/>, <aside/>, <footer/>
330 # ...or if they are within a <script/> or <style/> tag.
331 # This is a very very very coarse approximation to a plain text
332 # render of the page.
333
334 # We don't just use XPATH here as that is slow on some machines.
335
336 # We clone `tree` as we modify it.
337 cloned_tree = deepcopy(tree.find("body"))
338
339 TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
340 for el in cloned_tree.iter(TAGS_TO_REMOVE):
341 el.getparent().remove(el)
342
343 # Split all the text nodes into paragraphs (by splitting on new
344 # lines)
345 text_nodes = (
346 re.sub(r'\s+', '\n', el.text).strip()
347 for el in cloned_tree.iter()
348 if el.text and isinstance(el.tag, basestring) # Removes comments
349 )
350 og['og:description'] = summarize_paragraphs(text_nodes)
351
352 # TODO: delete the url downloads to stop diskfilling,
353 # as we only ever cared about its OG
354 defer.returnValue(og)
355
356 def _rebase_url(self, url, base):
357 base = list(urlparse.urlparse(base))
358 url = list(urlparse.urlparse(url))
359 if not url[0]: # fix up schema
360 url[0] = base[0] or "http"
361 if not url[1]: # fix up hostname
362 url[1] = base[1]
363 if not url[2].startswith('/'):
364 url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
365 return urlparse.urlunparse(url)
366248
367249 @defer.inlineCallbacks
368250 def _download_url(self, url, user):
444326 "etag": headers["ETag"][0] if "ETag" in headers else None,
445327 })
446328
447 def _is_media(self, content_type):
448 if content_type.lower().startswith("image/"):
449 return True
450
451 def _is_html(self, content_type):
452 content_type = content_type.lower()
453 if (
454 content_type.startswith("text/html") or
455 content_type.startswith("application/xhtml")
456 ):
457 return True
329
330 def decode_and_calc_og(body, media_uri, request_encoding=None):
331 from lxml import etree
332
333 try:
334 parser = etree.HTMLParser(recover=True, encoding=request_encoding)
335 tree = etree.fromstring(body, parser)
336 og = _calc_og(tree, media_uri)
337 except UnicodeDecodeError:
338 # blindly try decoding the body as utf-8, which seems to fix
339 # the charset mismatches on https://google.com
340 parser = etree.HTMLParser(recover=True, encoding=request_encoding)
341 tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
342 og = _calc_og(tree, media_uri)
343
344 return og
345
346
347 def _calc_og(tree, media_uri):
348 # suck our tree into lxml and define our OG response.
349
350 # if we see any image URLs in the OG response, then spider them
351 # (although the client could choose to do this by asking for previews of those
352 # URLs to avoid DoSing the server)
353
354 # "og:type" : "video",
355 # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
356 # "og:site_name" : "YouTube",
357 # "og:video:type" : "application/x-shockwave-flash",
358 # "og:description" : "Fun stuff happening here",
359 # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
360 # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
361 # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
362 # "og:video:width" : "1280"
363 # "og:video:height" : "720",
364 # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
365
366 og = {}
367 for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
368 if 'content' in tag.attrib:
369 og[tag.attrib['property']] = tag.attrib['content']
370
371 # TODO: grab article: meta tags too, e.g.:
372
373 # "article:publisher" : "https://www.facebook.com/thethudonline" />
374 # "article:author" content="https://www.facebook.com/thethudonline" />
375 # "article:tag" content="baby" />
376 # "article:section" content="Breaking News" />
377 # "article:published_time" content="2016-03-31T19:58:24+00:00" />
378 # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
379
380 if 'og:title' not in og:
381 # do some basic spidering of the HTML
382 title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
383 og['og:title'] = title[0].text.strip() if title else None
384
385 if 'og:image' not in og:
386 # TODO: extract a favicon failing all else
387 meta_image = tree.xpath(
388 "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
389 )
390 if meta_image:
391 og['og:image'] = _rebase_url(meta_image[0], media_uri)
392 else:
393 # TODO: consider inlined CSS styles as well as width & height attribs
394 images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
395 images = sorted(images, key=lambda i: (
396 -1 * float(i.attrib['width']) * float(i.attrib['height'])
397 ))
398 if not images:
399 images = tree.xpath("//img[@src]")
400 if images:
401 og['og:image'] = images[0].attrib['src']
402
403 if 'og:description' not in og:
404 meta_description = tree.xpath(
405 "//*/meta"
406 "[translate(@name, 'DESCRIPTION', 'description')='description']"
407 "/@content")
408 if meta_description:
409 og['og:description'] = meta_description[0]
410 else:
411 # grab any text nodes which are inside the <body/> tag...
412 # unless they are within an HTML5 semantic markup tag...
413 # <header/>, <nav/>, <aside/>, <footer/>
414 # ...or if they are within a <script/> or <style/> tag.
415 # This is a very very very coarse approximation to a plain text
416 # render of the page.
417
418 # We don't just use XPATH here as that is slow on some machines.
419
420 from lxml import etree
421
422 TAGS_TO_REMOVE = (
423 "header", "nav", "aside", "footer", "script", "style", etree.Comment
424 )
425
426 # Split all the text nodes into paragraphs (by splitting on new
427 # lines)
428 text_nodes = (
429 re.sub(r'\s+', '\n', el).strip()
430 for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
431 )
432 og['og:description'] = summarize_paragraphs(text_nodes)
433
434 # TODO: delete the url downloads to stop diskfilling,
435 # as we only ever cared about its OG
436 return og
437
438
439 def _iterate_over_text(tree, *tags_to_ignore):
440 """Iterate over the tree returning text nodes in a depth first fashion,
441 skipping text nodes inside certain tags.
442 """
443 # This is basically a stack that we extend using itertools.chain.
444 # This will either consist of an element to iterate over *or* a string
445 # to be returned.
446 elements = iter([tree])
447 while True:
448 el = elements.next()
449 if isinstance(el, basestring):
450 yield el
451 elif el is not None and el.tag not in tags_to_ignore:
452 # el.text is the text before the first child, so we can immediately
453 # return it if the text exists.
454 if el.text:
455 yield el.text
456
457 # We add to the stack all the elements children, interspersed with
458 # each child's tail text (if it exists). The tail text of a node
459 # is text that comes *after* the node, so we always include it even
460 # if we ignore the child node.
461 elements = itertools.chain(
462 itertools.chain.from_iterable( # Basically a flatmap
463 [child, child.tail] if child.tail else [child]
464 for child in el.iterchildren()
465 ),
466 elements
467 )
468
469
470 def _rebase_url(url, base):
471 base = list(urlparse.urlparse(base))
472 url = list(urlparse.urlparse(url))
473 if not url[0]: # fix up schema
474 url[0] = base[0] or "http"
475 if not url[1]: # fix up hostname
476 url[1] = base[1]
477 if not url[2].startswith('/'):
478 url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
479 return urlparse.urlunparse(url)
480
481
482 def _is_media(content_type):
483 if content_type.lower().startswith("image/"):
484 return True
485
486
487 def _is_html(content_type):
488 content_type = content_type.lower()
489 if (
490 content_type.startswith("text/html") or
491 content_type.startswith("application/xhtml")
492 ):
493 return True
458494
459495
460496 def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
4040 from synapse.handlers.room import RoomListHandler
4141 from synapse.handlers.sync import SyncHandler
4242 from synapse.handlers.typing import TypingHandler
43 from synapse.handlers.events import EventHandler, EventStreamHandler
4344 from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
4445 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
4546 from synapse.notifier import Notifier
9394 'auth_handler',
9495 'device_handler',
9596 'e2e_keys_handler',
97 'event_handler',
98 'event_stream_handler',
9699 'application_service_api',
97100 'application_service_scheduler',
98101 'application_service_handler',
213216 def build_application_service_handler(self):
214217 return ApplicationServicesHandler(self)
215218
219 def build_event_handler(self):
220 return EventHandler(self)
221
222 def build_event_stream_handler(self):
223 return EventStreamHandler(self)
224
216225 def build_event_sources(self):
217226 return EventSources(self)
218227
0 import synapse.api.auth
01 import synapse.handlers
12 import synapse.handlers.auth
23 import synapse.handlers.device
56 import synapse.state
67
78 class HomeServer(object):
9 def get_auth(self) -> synapse.api.auth.Auth:
10 pass
11
812 def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
913 pass
1014
4949 from .client_ips import ClientIpStore
5050
5151 from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
52 from .engines import PostgresEngine
5253
5354 from synapse.api.constants import PresenceState
5455 from synapse.util.caches.stream_change_cache import StreamChangeCache
122123 extra_tables=[("deleted_pushers", "stream_id")],
123124 )
124125
126 if isinstance(self.database_engine, PostgresEngine):
127 self._cache_id_gen = StreamIdGenerator(
128 db_conn, "cache_invalidation_stream", "stream_id",
129 )
130 else:
131 self._cache_id_gen = None
132
125133 events_max = self._stream_id_gen.get_current_token()
126134 event_cache_prefill, min_event_val = self._get_cache_dict(
127135 db_conn, "events",
1818 from synapse.util.caches.dictionary_cache import DictionaryCache
1919 from synapse.util.caches.descriptors import Cache
2020 from synapse.util.caches import intern_dict
21 from synapse.storage.engines import PostgresEngine
2122 import synapse.metrics
2223
2324
164165 self._txn_perf_counters = PerformanceCounters()
165166 self._get_event_counters = PerformanceCounters()
166167
167 self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
168 self._get_event_cache = Cache("*getEvent*", keylen=3,
168169 max_entries=hs.config.event_cache_size)
169170
170171 self._state_group_cache = DictionaryCache(
304305 func, *args, **kwargs
305306 )
306307
307 with PreserveLoggingContext():
308 result = yield self._db_pool.runWithConnection(
309 inner_func, *args, **kwargs
310 )
311
312 for after_callback, after_args in after_callbacks:
313 after_callback(*after_args)
308 try:
309 with PreserveLoggingContext():
310 result = yield self._db_pool.runWithConnection(
311 inner_func, *args, **kwargs
312 )
313 finally:
314 for after_callback, after_args in after_callbacks:
315 after_callback(*after_args)
314316 defer.returnValue(result)
315317
316318 @defer.inlineCallbacks
859861
860862 return cache, min_val
861863
864 def _invalidate_cache_and_stream(self, txn, cache_func, keys):
865 """Invalidates the cache and adds it to the cache stream so slaves
866 will know to invalidate their caches.
867
868 This should only be used to invalidate caches where slaves won't
869 otherwise know from other replication streams that the cache should
870 be invalidated.
871 """
872 txn.call_after(cache_func.invalidate, keys)
873
874 if isinstance(self.database_engine, PostgresEngine):
875 # get_next() returns a context manager which is designed to wrap
876 # the transaction. However, we want to only get an ID when we want
877 # to use it, here, so we need to call __enter__ manually, and have
878 # __exit__ called after the transaction finishes.
879 ctx = self._cache_id_gen.get_next()
880 stream_id = ctx.__enter__()
881 txn.call_after(ctx.__exit__, None, None, None)
882 txn.call_after(self.hs.get_notifier().on_new_replication_data)
883
884 self._simple_insert_txn(
885 txn,
886 table="cache_invalidation_stream",
887 values={
888 "stream_id": stream_id,
889 "cache_func": cache_func.__name__,
890 "keys": list(keys),
891 "invalidation_ts": self.clock.time_msec(),
892 }
893 )
894
895 def get_all_updated_caches(self, last_id, current_id, limit):
896 if last_id == current_id:
897 return defer.succeed([])
898
899 def get_all_updated_caches_txn(txn):
900 # We purposefully don't bound by the current token, as we want to
901 # send across cache invalidations as quickly as possible. Cache
902 # invalidations are idempotent, so duplicates are fine.
903 sql = (
904 "SELECT stream_id, cache_func, keys, invalidation_ts"
905 " FROM cache_invalidation_stream"
906 " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
907 )
908 txn.execute(sql, (last_id, limit,))
909 return txn.fetchall()
910 return self.runInteraction(
911 "get_all_updated_caches", get_all_updated_caches_txn
912 )
913
914 def get_cache_stream_token(self):
915 if self._cache_id_gen:
916 return self._cache_id_gen.get_current_token()
917 else:
918 return 0
919
862920
863921 class _RollbackButIsFineException(Exception):
864922 """ This exception is used to rollback a transaction without implying
217217 Returns:
218218 AppServiceTransaction: A new transaction.
219219 """
220 def _create_appservice_txn(txn):
221 # work out new txn id (highest txn id for this service += 1)
222 # The highest id may be the last one sent (in which case it is last_txn)
223 # or it may be the highest in the txns list (which are waiting to be/are
224 # being sent)
225 last_txn_id = self._get_last_txn(txn, service.id)
226
227 txn.execute(
228 "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
229 (service.id,)
230 )
231 highest_txn_id = txn.fetchone()[0]
232 if highest_txn_id is None:
233 highest_txn_id = 0
234
235 new_txn_id = max(highest_txn_id, last_txn_id) + 1
236
237 # Insert new txn into txn table
238 event_ids = json.dumps([e.event_id for e in events])
239 txn.execute(
240 "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
241 "VALUES(?,?,?)",
242 (service.id, new_txn_id, event_ids)
243 )
244 return AppServiceTransaction(
245 service=service, id=new_txn_id, events=events
246 )
247
220248 return self.runInteraction(
221249 "create_appservice_txn",
222 self._create_appservice_txn,
223 service, events
224 )
225
226 def _create_appservice_txn(self, txn, service, events):
227 # work out new txn id (highest txn id for this service += 1)
228 # The highest id may be the last one sent (in which case it is last_txn)
229 # or it may be the highest in the txns list (which are waiting to be/are
230 # being sent)
231 last_txn_id = self._get_last_txn(txn, service.id)
232
233 txn.execute(
234 "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
235 (service.id,)
236 )
237 highest_txn_id = txn.fetchone()[0]
238 if highest_txn_id is None:
239 highest_txn_id = 0
240
241 new_txn_id = max(highest_txn_id, last_txn_id) + 1
242
243 # Insert new txn into txn table
244 event_ids = json.dumps([e.event_id for e in events])
245 txn.execute(
246 "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
247 "VALUES(?,?,?)",
248 (service.id, new_txn_id, event_ids)
249 )
250 return AppServiceTransaction(
251 service=service, id=new_txn_id, events=events
250 _create_appservice_txn,
252251 )
253252
254253 def complete_appservice_txn(self, txn_id, service):
262261 A Deferred which resolves if this transaction was stored
263262 successfully.
264263 """
264 txn_id = int(txn_id)
265
266 def _complete_appservice_txn(txn):
267 # Debugging query: Make sure the txn being completed is EXACTLY +1 from
268 # what was there before. If it isn't, we've got problems (e.g. the AS
269 # has probably missed some events), so whine loudly but still continue,
270 # since it shouldn't fail completion of the transaction.
271 last_txn_id = self._get_last_txn(txn, service.id)
272 if (last_txn_id + 1) != txn_id:
273 logger.error(
274 "appservice: Completing a transaction which has an ID > 1 from "
275 "the last ID sent to this AS. We've either dropped events or "
276 "sent it to the AS out of order. FIX ME. last_txn=%s "
277 "completing_txn=%s service_id=%s", last_txn_id, txn_id,
278 service.id
279 )
280
281 # Set current txn_id for AS to 'txn_id'
282 self._simple_upsert_txn(
283 txn, "application_services_state", dict(as_id=service.id),
284 dict(last_txn=txn_id)
285 )
286
287 # Delete txn
288 self._simple_delete_txn(
289 txn, "application_services_txns",
290 dict(txn_id=txn_id, as_id=service.id)
291 )
292
265293 return self.runInteraction(
266294 "complete_appservice_txn",
267 self._complete_appservice_txn,
268 txn_id, service
269 )
270
271 def _complete_appservice_txn(self, txn, txn_id, service):
272 txn_id = int(txn_id)
273
274 # Debugging query: Make sure the txn being completed is EXACTLY +1 from
275 # what was there before. If it isn't, we've got problems (e.g. the AS
276 # has probably missed some events), so whine loudly but still continue,
277 # since it shouldn't fail completion of the transaction.
278 last_txn_id = self._get_last_txn(txn, service.id)
279 if (last_txn_id + 1) != txn_id:
280 logger.error(
281 "appservice: Completing a transaction which has an ID > 1 from "
282 "the last ID sent to this AS. We've either dropped events or "
283 "sent it to the AS out of order. FIX ME. last_txn=%s "
284 "completing_txn=%s service_id=%s", last_txn_id, txn_id,
285 service.id
286 )
287
288 # Set current txn_id for AS to 'txn_id'
289 self._simple_upsert_txn(
290 txn, "application_services_state", dict(as_id=service.id),
291 dict(last_txn=txn_id)
292 )
293
294 # Delete txn
295 self._simple_delete_txn(
296 txn, "application_services_txns",
297 dict(txn_id=txn_id, as_id=service.id)
295 _complete_appservice_txn,
298296 )
299297
300298 @defer.inlineCallbacks
308306 A Deferred which resolves to an AppServiceTransaction or
309307 None.
310308 """
309 def _get_oldest_unsent_txn(txn):
310 # Monotonically increasing txn ids, so just select the smallest
311 # one in the txns table (we delete them when they are sent)
312 txn.execute(
313 "SELECT * FROM application_services_txns WHERE as_id=?"
314 " ORDER BY txn_id ASC LIMIT 1",
315 (service.id,)
316 )
317 rows = self.cursor_to_dict(txn)
318 if not rows:
319 return None
320
321 entry = rows[0]
322
323 return entry
324
311325 entry = yield self.runInteraction(
312326 "get_oldest_unsent_appservice_txn",
313 self._get_oldest_unsent_txn,
314 service
327 _get_oldest_unsent_txn,
315328 )
316329
317330 if not entry:
324337 defer.returnValue(AppServiceTransaction(
325338 service=service, id=entry["txn_id"], events=events
326339 ))
327
328 def _get_oldest_unsent_txn(self, txn, service):
329 # Monotonically increasing txn ids, so just select the smallest
330 # one in the txns table (we delete them when they are sent)
331 txn.execute(
332 "SELECT * FROM application_services_txns WHERE as_id=?"
333 " ORDER BY txn_id ASC LIMIT 1",
334 (service.id,)
335 )
336 rows = self.cursor_to_dict(txn)
337 if not rows:
338 return None
339
340 entry = rows[0]
341
342 return entry
343340
344341 def _get_last_txn(self, txn, service_id):
345342 txn.execute(
351348 return 0
352349 else:
353350 return int(last_txn_id[0]) # select 'last_txn' col
351
352 def set_appservice_last_pos(self, pos):
353 def set_appservice_last_pos_txn(txn):
354 txn.execute(
355 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
356 )
357 return self.runInteraction(
358 "set_appservice_last_pos", set_appservice_last_pos_txn
359 )
360
361 @defer.inlineCallbacks
362 def get_new_events_for_appservice(self, current_id, limit):
363 """Get all new evnets"""
364
365 def get_new_events_for_appservice_txn(txn):
366 sql = (
367 "SELECT e.stream_ordering, e.event_id"
368 " FROM events AS e"
369 " WHERE"
370 " (SELECT stream_ordering FROM appservice_stream_position)"
371 " < e.stream_ordering"
372 " AND e.stream_ordering <= ?"
373 " ORDER BY e.stream_ordering ASC"
374 " LIMIT ?"
375 )
376
377 txn.execute(sql, (current_id, limit))
378 rows = txn.fetchall()
379
380 upper_bound = current_id
381 if len(rows) == limit:
382 upper_bound = rows[-1][0]
383
384 return upper_bound, [row[1] for row in rows]
385
386 upper_bound, event_ids = yield self.runInteraction(
387 "get_new_events_for_appservice", get_new_events_for_appservice_txn,
388 )
389
390 events = yield self._get_events(event_ids)
391
392 defer.returnValue((upper_bound, events))
8181 Returns:
8282 Deferred
8383 """
84 try:
85 yield self._simple_insert(
84 def alias_txn(txn):
85 self._simple_insert_txn(
86 txn,
8687 "room_aliases",
8788 {
8889 "room_alias": room_alias.to_string(),
8990 "room_id": room_id,
9091 "creator": creator,
9192 },
92 desc="create_room_alias_association",
93 )
94
95 self._simple_insert_many_txn(
96 txn,
97 table="room_alias_servers",
98 values=[{
99 "room_alias": room_alias.to_string(),
100 "server": server,
101 } for server in servers],
102 )
103
104 self._invalidate_cache_and_stream(
105 txn, self.get_aliases_for_room, (room_id,)
106 )
107
108 try:
109 ret = yield self.runInteraction(
110 "create_room_alias_association", alias_txn
93111 )
94112 except self.database_engine.module.IntegrityError:
95113 raise SynapseError(
96114 409, "Room alias %s already exists" % room_alias.to_string()
97115 )
98
99 for server in servers:
100 # TODO(erikj): Fix this to bulk insert
101 yield self._simple_insert(
102 "room_alias_servers",
103 {
104 "room_alias": room_alias.to_string(),
105 "server": server,
106 },
107 desc="create_room_alias_association",
108 )
109 self.get_aliases_for_room.invalidate((room_id,))
116 defer.returnValue(ret)
110117
111118 def get_room_alias_creator(self, room_alias):
112119 return self._simple_select_one_onecol(
5555 )
5656 self._simple_insert_many_txn(txn, "event_push_actions", values)
5757
58 @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
58 @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
5959 def get_unread_event_push_actions_by_room_for_user(
6060 self, room_id, user_id, last_read_event_id
6161 ):
337337 defer.returnValue(notifs[:limit])
338338
339339 @defer.inlineCallbacks
340 def get_push_actions_for_user(self, user_id, before=None, limit=50):
341 def f(txn):
342 before_clause = ""
343 if before:
344 before_clause = "AND stream_ordering < ?"
345 args = [user_id, before, limit]
346 else:
347 args = [user_id, limit]
348 sql = (
349 "SELECT epa.event_id, epa.room_id,"
350 " epa.stream_ordering, epa.topological_ordering,"
351 " epa.actions, epa.profile_tag, e.received_ts"
352 " FROM event_push_actions epa, events e"
353 " WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id"
354 " AND epa.user_id = ? %s"
355 " ORDER BY epa.stream_ordering DESC"
356 " LIMIT ?"
357 % (before_clause,)
358 )
359 txn.execute(sql, args)
360 return self.cursor_to_dict(txn)
361
362 push_actions = yield self.runInteraction(
363 "get_push_actions_for_user", f
364 )
365 for pa in push_actions:
366 pa["actions"] = json.loads(pa["actions"])
367 defer.returnValue(push_actions)
368
369 @defer.inlineCallbacks
340370 def get_time_of_last_push_action_before(self, stream_ordering):
341371 def f(txn):
342372 sql = (
1919 from synapse.events.utils import prune_event
2020
2121 from synapse.util.async import ObservableDeferred
22 from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
22 from synapse.util.logcontext import (
23 preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
24 )
2325 from synapse.util.logutils import log_function
26 from synapse.util.metrics import Measure
2427 from synapse.api.constants import EventTypes
2528 from synapse.api.errors import SynapseError
2629
200203
201204 deferreds = []
202205 for room_id, evs_ctxs in partitioned.items():
203 d = self._event_persist_queue.add_to_queue(
206 d = preserve_fn(self._event_persist_queue.add_to_queue)(
204207 room_id, evs_ctxs,
205208 backfilled=backfilled,
206209 current_state=None,
210213 for room_id in partitioned.keys():
211214 self._maybe_start_persisting(room_id)
212215
213 return defer.gatherResults(deferreds, consumeErrors=True)
216 return preserve_context_over_deferred(
217 defer.gatherResults(deferreds, consumeErrors=True)
218 )
214219
215220 @defer.inlineCallbacks
216221 @log_function
223228
224229 self._maybe_start_persisting(event.room_id)
225230
226 yield deferred
231 yield preserve_context_over_deferred(deferred)
227232
228233 max_persisted_id = yield self._stream_id_gen.get_current_token()
229234 defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
599604 "rejections",
600605 "redactions",
601606 "room_memberships",
602 "state_events"
607 "state_events",
608 "topics"
603609 ):
604610 txn.executemany(
605611 "DELETE FROM %s WHERE event_id = ?" % (table,),
10851091 if not allow_rejected:
10861092 rows[:] = [r for r in rows if not r["rejects"]]
10871093
1088 res = yield defer.gatherResults(
1094 res = yield preserve_context_over_deferred(defer.gatherResults(
10891095 [
10901096 preserve_fn(self._get_event_from_row)(
10911097 row["internal_metadata"], row["json"], row["redacts"],
10941100 for row in rows
10951101 ],
10961102 consumeErrors=True
1097 )
1103 ))
10981104
10991105 defer.returnValue({
11001106 e.event.event_id: e
11301136 @defer.inlineCallbacks
11311137 def _get_event_from_row(self, internal_metadata, js, redacted,
11321138 rejected_reason=None):
1133 d = json.loads(js)
1134 internal_metadata = json.loads(internal_metadata)
1135
1136 if rejected_reason:
1137 rejected_reason = yield self._simple_select_one_onecol(
1138 table="rejections",
1139 keyvalues={"event_id": rejected_reason},
1140 retcol="reason",
1141 desc="_get_event_from_row_rejected_reason",
1142 )
1143
1144 original_ev = FrozenEvent(
1145 d,
1146 internal_metadata_dict=internal_metadata,
1147 rejected_reason=rejected_reason,
1148 )
1149
1150 redacted_event = None
1151 if redacted:
1152 redacted_event = prune_event(original_ev)
1153
1154 redaction_id = yield self._simple_select_one_onecol(
1155 table="redactions",
1156 keyvalues={"redacts": redacted_event.event_id},
1157 retcol="event_id",
1158 desc="_get_event_from_row_redactions",
1159 )
1160
1161 redacted_event.unsigned["redacted_by"] = redaction_id
1162 # Get the redaction event.
1163
1164 because = yield self.get_event(
1165 redaction_id,
1166 check_redacted=False,
1167 allow_none=True,
1168 )
1169
1170 if because:
1171 # It's fine to do add the event directly, since get_pdu_json
1172 # will serialise this field correctly
1173 redacted_event.unsigned["redacted_because"] = because
1174
1175 cache_entry = _EventCacheEntry(
1176 event=original_ev,
1177 redacted_event=redacted_event,
1178 )
1179
1180 self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
1139 with Measure(self._clock, "_get_event_from_row"):
1140 d = json.loads(js)
1141 internal_metadata = json.loads(internal_metadata)
1142
1143 if rejected_reason:
1144 rejected_reason = yield self._simple_select_one_onecol(
1145 table="rejections",
1146 keyvalues={"event_id": rejected_reason},
1147 retcol="reason",
1148 desc="_get_event_from_row_rejected_reason",
1149 )
1150
1151 original_ev = FrozenEvent(
1152 d,
1153 internal_metadata_dict=internal_metadata,
1154 rejected_reason=rejected_reason,
1155 )
1156
1157 redacted_event = None
1158 if redacted:
1159 redacted_event = prune_event(original_ev)
1160
1161 redaction_id = yield self._simple_select_one_onecol(
1162 table="redactions",
1163 keyvalues={"redacts": redacted_event.event_id},
1164 retcol="event_id",
1165 desc="_get_event_from_row_redactions",
1166 )
1167
1168 redacted_event.unsigned["redacted_by"] = redaction_id
1169 # Get the redaction event.
1170
1171 because = yield self.get_event(
1172 redaction_id,
1173 check_redacted=False,
1174 allow_none=True,
1175 )
1176
1177 if because:
1178 # It's fine to do add the event directly, since get_pdu_json
1179 # will serialise this field correctly
1180 redacted_event.unsigned["redacted_because"] = because
1181
1182 cache_entry = _EventCacheEntry(
1183 event=original_ev,
1184 redacted_event=redacted_event,
1185 )
1186
1187 self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
11811188
11821189 defer.returnValue(cache_entry)
11831190
2424
2525 # Remember to update this number every time a change is made to database
2626 # schema files, so the users will be informed on server restarts.
27 SCHEMA_VERSION = 33
27 SCHEMA_VERSION = 34
2828
2929 dir_path = os.path.abspath(os.path.dirname(__file__))
3030
188188 desc="add_presence_list_pending",
189189 )
190190
191 @defer.inlineCallbacks
192191 def set_presence_list_accepted(self, observer_localpart, observed_userid):
193 result = yield self._simple_update_one(
194 table="presence_list",
195 keyvalues={"user_id": observer_localpart,
196 "observed_user_id": observed_userid},
197 updatevalues={"accepted": True},
198 desc="set_presence_list_accepted",
199 )
200 self.get_presence_list_accepted.invalidate((observer_localpart,))
201 self.get_presence_list_observers_accepted.invalidate((observed_userid,))
202 defer.returnValue(result)
192 def update_presence_list_txn(txn):
193 result = self._simple_update_one_txn(
194 txn,
195 table="presence_list",
196 keyvalues={
197 "user_id": observer_localpart,
198 "observed_user_id": observed_userid
199 },
200 updatevalues={"accepted": True},
201 )
202
203 self._invalidate_cache_and_stream(
204 txn, self.get_presence_list_accepted, (observer_localpart,)
205 )
206 self._invalidate_cache_and_stream(
207 txn, self.get_presence_list_observers_accepted, (observed_userid,)
208 )
209
210 return result
211
212 return self.runInteraction(
213 "set_presence_list_accepted", update_presence_list_txn,
214 )
203215
204216 def get_presence_list(self, observer_localpart, accepted=None):
205217 if accepted:
1515 from ._base import SQLBaseStore
1616 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
1717 from synapse.push.baserules import list_with_base_rules
18 from synapse.api.constants import EventTypes, Membership
1819 from twisted.internet import defer
1920
2021 import logging
4748
4849
4950 class PushRuleStore(SQLBaseStore):
50 @cachedInlineCallbacks(lru=True)
51 @cachedInlineCallbacks()
5152 def get_push_rules_for_user(self, user_id):
5253 rows = yield self._simple_select_list(
5354 table="push_rules",
7172
7273 defer.returnValue(rules)
7374
74 @cachedInlineCallbacks(lru=True)
75 @cachedInlineCallbacks()
7576 def get_push_rules_enabled_for_user(self, user_id):
7677 results = yield self._simple_select_list(
7778 table="push_rules_enable",
121122 )
122123
123124 defer.returnValue(results)
125
126 def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
127 if not state_group:
128 # If state_group is None it means it has yet to be assigned a
129 # state group, i.e. we need to make sure that calls with a state_group
130 # of None don't hit previous cached calls with a None state_group.
131 # To do this we set the state_group to a new object as object() != object()
132 state_group = object()
133
134 return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
135
136 @cachedInlineCallbacks(num_args=2, cache_context=True)
137 def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
138 cache_context):
139 # We don't use `state_group`, its there so that we can cache based
140 # on it. However, its important that its never None, since two current_state's
141 # with a state_group of None are likely to be different.
142 # See bulk_get_push_rules_for_room for how we work around this.
143 assert state_group is not None
144
145 # We also will want to generate notifs for other people in the room so
146 # their unread countss are correct in the event stream, but to avoid
147 # generating them for bot / AS users etc, we only do so for people who've
148 # sent a read receipt into the room.
149 local_users_in_room = set(
150 e.state_key for e in current_state.values()
151 if e.type == EventTypes.Member and e.membership == Membership.JOIN
152 and self.hs.is_mine_id(e.state_key)
153 )
154
155 # users in the room who have pushers need to get push rules run because
156 # that's how their pushers work
157 if_users_with_pushers = yield self.get_if_users_have_pushers(
158 local_users_in_room, on_invalidate=cache_context.invalidate,
159 )
160 user_ids = set(
161 uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
162 )
163
164 users_with_receipts = yield self.get_users_with_read_receipts_in_room(
165 room_id, on_invalidate=cache_context.invalidate,
166 )
167
168 # any users with pushers must be ours: they have pushers
169 for uid in users_with_receipts:
170 if uid in local_users_in_room:
171 user_ids.add(uid)
172
173 rules_by_user = yield self.bulk_get_push_rules(
174 user_ids, on_invalidate=cache_context.invalidate,
175 )
176
177 rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
178
179 defer.returnValue(rules_by_user)
124180
125181 @cachedList(cached_method_name="get_push_rules_enabled_for_user",
126182 list_name="user_ids", num_args=1, inlineCallbacks=True)
134134 "get_all_updated_pushers", get_all_updated_pushers_txn
135135 )
136136
137 @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
137 @cachedInlineCallbacks(num_args=1, max_entries=15000)
138138 def get_if_user_has_pusher(self, user_id):
139139 result = yield self._simple_select_many_batch(
140140 table='pushers',
9494 defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
9595
9696 @defer.inlineCallbacks
97 def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
98 def f(txn):
99 sql = (
100 "SELECT rl.room_id, rl.event_id,"
101 " e.topological_ordering, e.stream_ordering"
102 " FROM receipts_linearized AS rl"
103 " INNER JOIN events AS e USING (room_id, event_id)"
104 " WHERE rl.room_id = e.room_id"
105 " AND rl.event_id = e.event_id"
106 " AND user_id = ?"
107 )
108 txn.execute(sql, (user_id,))
109 return txn.fetchall()
110 rows = yield self.runInteraction(
111 "get_receipts_for_user_with_orderings", f
112 )
113 defer.returnValue({
114 row[0]: {
115 "event_id": row[1],
116 "topological_ordering": row[2],
117 "stream_ordering": row[3],
118 } for row in rows
119 })
120
121 @defer.inlineCallbacks
97122 def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
98123 """Get receipts for multiple rooms for sending to clients.
99124
119144
120145 defer.returnValue([ev for res in results.values() for ev in res])
121146
122 @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
147 @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
123148 def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
124149 """Get receipts for a single room for sending to clients.
125150
9292 desc="add_refresh_token_to_user",
9393 )
9494
95 @defer.inlineCallbacks
9695 def register(self, user_id, token=None, password_hash=None,
9796 was_guest=False, make_guest=False, appservice_id=None,
9897 create_profile_with_localpart=None, admin=False):
114113 Raises:
115114 StoreError if the user_id could not be registered.
116115 """
117 yield self.runInteraction(
116 return self.runInteraction(
118117 "register",
119118 self._register,
120119 user_id,
126125 create_profile_with_localpart,
127126 admin
128127 )
129 self.get_user_by_id.invalidate((user_id,))
130 self.is_guest.invalidate((user_id,))
131128
132129 def _register(
133130 self,
209206 (create_profile_with_localpart,)
210207 )
211208
209 self._invalidate_cache_and_stream(
210 txn, self.get_user_by_id, (user_id,)
211 )
212 txn.call_after(self.is_guest.invalidate, (user_id,))
213
212214 @cached()
213215 def get_user_by_id(self, user_id):
214216 return self._simple_select_one(
235237
236238 return self.runInteraction("get_users_by_id_case_insensitive", f)
237239
238 @defer.inlineCallbacks
239240 def user_set_password_hash(self, user_id, password_hash):
240241 """
241242 NB. This does *not* evict any cache because the one use for this
242243 removes most of the entries subsequently anyway so it would be
243244 pointless. Use flush_user separately.
244245 """
245 yield self._simple_update_one('users', {
246 'name': user_id
247 }, {
248 'password_hash': password_hash
249 })
250 self.get_user_by_id.invalidate((user_id,))
251
252 @defer.inlineCallbacks
253 def user_delete_access_tokens(self, user_id, except_token_ids=[],
246 def user_set_password_hash_txn(txn):
247 self._simple_update_one_txn(
248 txn,
249 'users', {
250 'name': user_id
251 },
252 {
253 'password_hash': password_hash
254 }
255 )
256 self._invalidate_cache_and_stream(
257 txn, self.get_user_by_id, (user_id,)
258 )
259 return self.runInteraction(
260 "user_set_password_hash", user_set_password_hash_txn
261 )
262
263 @defer.inlineCallbacks
264 def user_delete_access_tokens(self, user_id, except_token_id=None,
254265 device_id=None,
255266 delete_refresh_tokens=False):
256267 """
258269
259270 Args:
260271 user_id (str): ID of user the tokens belong to
261 except_token_ids (list[str]): list of access_tokens which should
272 except_token_id (str): list of access_tokens IDs which should
262273 *not* be deleted
263274 device_id (str|None): ID of device the tokens are associated with.
264275 If None, tokens associated with any device (or no device) will
268279 Returns:
269280 defer.Deferred:
270281 """
271 def f(txn, table, except_tokens, call_after_delete):
272 sql = "SELECT token FROM %s WHERE user_id = ?" % table
273 clauses = [user_id]
274
282 def f(txn):
283 keyvalues = {
284 "user_id": user_id,
285 }
275286 if device_id is not None:
276 sql += " AND device_id = ?"
277 clauses.append(device_id)
278
279 if except_tokens:
280 sql += " AND id NOT IN (%s)" % (
281 ",".join(["?" for _ in except_tokens]),
287 keyvalues["device_id"] = device_id
288
289 if delete_refresh_tokens:
290 self._simple_delete_txn(
291 txn,
292 table="refresh_tokens",
293 keyvalues=keyvalues,
282294 )
283 clauses += except_tokens
284
285 txn.execute(sql, clauses)
286
287 rows = txn.fetchall()
288
289 n = 100
290 chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
291 for chunk in chunks:
292 if call_after_delete:
293 for row in chunk:
294 txn.call_after(call_after_delete, (row[0],))
295
296 txn.execute(
297 "DELETE FROM %s WHERE token in (%s)" % (
298 table,
299 ",".join(["?" for _ in chunk]),
300 ), [r[0] for r in chunk]
295
296 items = keyvalues.items()
297 where_clause = " AND ".join(k + " = ?" for k, _ in items)
298 values = [v for _, v in items]
299 if except_token_id:
300 where_clause += " AND id != ?"
301 values.append(except_token_id)
302
303 txn.execute(
304 "SELECT token FROM access_tokens WHERE %s" % where_clause,
305 values
306 )
307 rows = self.cursor_to_dict(txn)
308
309 for row in rows:
310 self._invalidate_cache_and_stream(
311 txn, self.get_user_by_access_token, (row["token"],)
301312 )
302313
303 # delete refresh tokens first, to stop new access tokens being
304 # allocated while our backs are turned
305 if delete_refresh_tokens:
306 yield self.runInteraction(
307 "user_delete_access_tokens", f,
308 table="refresh_tokens",
309 except_tokens=[],
310 call_after_delete=None,
314 txn.execute(
315 "DELETE FROM access_tokens WHERE %s" % where_clause,
316 values
311317 )
312318
313319 yield self.runInteraction(
314320 "user_delete_access_tokens", f,
315 table="access_tokens",
316 except_tokens=except_token_ids,
317 call_after_delete=self.get_user_by_access_token.invalidate,
318321 )
319322
320323 def delete_access_token(self, access_token):
327330 },
328331 )
329332
330 txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
333 self._invalidate_cache_and_stream(
334 txn, self.get_user_by_access_token, (access_token,)
335 )
331336
332337 return self.runInteraction("delete_access_token", f)
333338
276276 user_id, membership_list=[Membership.JOIN],
277277 )
278278
279 @defer.inlineCallbacks
280279 def forget(self, user_id, room_id):
281280 """Indicate that user_id wishes to discard history for room_id."""
282281 def f(txn):
291290 " room_id = ?"
292291 )
293292 txn.execute(sql, (user_id, room_id))
294 yield self.runInteraction("forget_membership", f)
295 self.was_forgotten_at.invalidate_all()
296 self.who_forgot_in_room.invalidate_all()
297 self.did_forget.invalidate((user_id, room_id))
293
294 txn.call_after(self.was_forgotten_at.invalidate_all)
295 txn.call_after(self.did_forget.invalidate, (user_id, room_id))
296 self._invalidate_cache_and_stream(
297 txn, self.who_forgot_in_room, (room_id,)
298 )
299 return self.runInteraction("forget_membership", f)
298300
299301 @cachedInlineCallbacks(num_args=2)
300302 def did_forget(self, user_id, room_id):
0 /* Copyright 2016 OpenMarket Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 CREATE TABLE IF NOT EXISTS appservice_stream_position(
16 Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
17 stream_ordering BIGINT,
18 CHECK (Lock='X')
19 );
20
21 INSERT INTO appservice_stream_position (stream_ordering)
22 SELECT COALESCE(MAX(stream_ordering), 0) FROM events;
0 # Copyright 2016 OpenMarket Ltd
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 from synapse.storage.prepare_database import get_statements
15 from synapse.storage.engines import PostgresEngine
16
17 import logging
18
19 logger = logging.getLogger(__name__)
20
21
22 # This stream is used to notify replication slaves that some caches have
23 # been invalidated that they cannot infer from the other streams.
24 CREATE_TABLE = """
25 CREATE TABLE cache_invalidation_stream (
26 stream_id BIGINT,
27 cache_func TEXT,
28 keys TEXT[],
29 invalidation_ts BIGINT
30 );
31
32 CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
33 """
34
35
36 def run_create(cur, database_engine, *args, **kwargs):
37 if not isinstance(database_engine, PostgresEngine):
38 return
39
40 for statement in get_statements(CREATE_TABLE.splitlines()):
41 cur.execute(statement)
42
43
44 def run_upgrade(cur, database_engine, *args, **kwargs):
45 pass
0 /* Copyright 2016 OpenMarket Ltd
1 *
2 * Licensed under the Apache License, Version 2.0 (the "License");
3 * you may not use this file except in compliance with the License.
4 * You may obtain a copy of the License at
5 *
6 * http://www.apache.org/licenses/LICENSE-2.0
7 *
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
13 */
14
15 DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name';
16 UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
17
18 DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name';
19 UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name';
0 # Copyright 2016 OpenMarket Ltd
1 #
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
5 #
6 # http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 from synapse.storage.engines import PostgresEngine
15
16 import logging
17
18 logger = logging.getLogger(__name__)
19
20
21 def run_create(cur, database_engine, *args, **kwargs):
22 if isinstance(database_engine, PostgresEngine):
23 cur.execute("TRUNCATE received_transactions")
24 else:
25 cur.execute("DELETE FROM received_transactions")
26
27 cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)")
28
29
30 def run_upgrade(cur, database_engine, *args, **kwargs):
31 pass
2424 class SignatureStore(SQLBaseStore):
2525 """Persistence for event signatures and hashes"""
2626
27 @cached(lru=True)
27 @cached()
2828 def get_event_reference_hash(self, event_id):
2929 return self._get_event_reference_hashes_txn(event_id)
3030
173173 return [r[0] for r in results]
174174 return self.runInteraction("get_current_state_for_key", f)
175175
176 @cached(num_args=2, lru=True, max_entries=1000)
176 @cached(num_args=2, max_entries=1000)
177177 def _get_state_group_from_group(self, group, types):
178178 raise NotImplementedError()
179179
271271 state_map = yield self.get_state_for_events([event_id], types)
272272 defer.returnValue(state_map[event_id])
273273
274 @cached(num_args=2, lru=True, max_entries=10000)
274 @cached(num_args=2, max_entries=10000)
275275 def _get_state_group_for_event(self, room_id, event_id):
276276 return self._simple_select_one_onecol(
277277 table="event_to_state_groups",
3838 from synapse.util.caches.descriptors import cached
3939 from synapse.api.constants import EventTypes
4040 from synapse.types import RoomStreamToken
41 from synapse.util.logcontext import preserve_fn
41 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
4242 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
4343
4444 import logging
233233 results = {}
234234 room_ids = list(room_ids)
235235 for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
236 res = yield defer.gatherResults([
236 res = yield preserve_context_over_deferred(defer.gatherResults([
237237 preserve_fn(self.get_room_events_stream_for_room)(
238238 room_id, from_key, to_key, limit, order=order,
239239 )
240240 for room_id in rm_ids
241 ])
241 ]))
242242 results.update(dict(zip(rm_ids, res)))
243243
244244 defer.returnValue(results)
6161 self.last_transaction = {}
6262
6363 reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
64 hs.get_clock().looping_call(
65 self._persist_in_mem_txns,
66 1000,
67 )
64 self._clock.looping_call(self._persist_in_mem_txns, 1000)
65
66 self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
6867
6968 def get_received_txn_response(self, transaction_id, origin):
7069 """For an incoming transaction from a given origin, check if we have
126125 "origin": origin,
127126 "response_code": code,
128127 "response_json": buffer(encode_canonical_json(response_dict)),
128 "ts": self._clock.time_msec(),
129129 },
130130 or_ignore=True,
131131 desc="set_received_txn_response",
382382 yield self.runInteraction("_persist_in_mem_txns", f)
383383 except:
384384 logger.exception("Failed to persist transactions!")
385
386 def _cleanup_transactions(self):
387 now = self._clock.time_msec()
388 month_ago = now - 30 * 24 * 60 * 60 * 1000
389
390 def _cleanup_transactions_txn(txn):
391 txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
392
393 return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
268268 return "t%d-%d" % (self.topological, self.stream)
269269 else:
270270 return "s%d" % (self.stream,)
271
272
273 # Some arbitrary constants used for internal API enumerations. Don't rely on
274 # exact values; always pass or compare symbolically
275 class ThirdPartyEntityKind(object):
276 USER = 'user'
277 LOCATION = 'location'
145145 except StopIteration:
146146 pass
147147
148 return defer.gatherResults([
148 return preserve_context_over_deferred(defer.gatherResults([
149149 preserve_fn(_concurrently_execute_inner)()
150150 for _ in xrange(limit)
151 ], consumeErrors=True).addErrback(unwrapFirstError)
151 ], consumeErrors=True)).addErrback(unwrapFirstError)
152152
153153
154154 class Linearizer(object):
180180 self.key_to_defer[key] = new_defer
181181
182182 if current_defer:
183 yield preserve_context_over_deferred(current_defer)
183 with PreserveLoggingContext():
184 yield current_defer
184185
185186 @contextmanager
186187 def _ctx_manager():
263264 curr_readers.clear()
264265 self.key_to_current_writer[key] = new_defer
265266
266 yield defer.gatherResults(to_wait_on)
267 yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
267268
268269 @contextmanager
269270 def _ctx_manager():
2424 from . import DEBUG_CACHES, register_cache
2525
2626 from twisted.internet import defer
27
28 from collections import OrderedDict
27 from collections import namedtuple
2928
3029 import os
3130 import functools
5352 "metrics",
5453 )
5554
56 def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
57 if lru:
58 cache_type = TreeCache if tree else dict
59 self.cache = LruCache(
60 max_size=max_entries, keylen=keylen, cache_type=cache_type
61 )
62 self.max_entries = None
63 else:
64 self.cache = OrderedDict()
65 self.max_entries = max_entries
55 def __init__(self, name, max_entries=1000, keylen=1, tree=False):
56 cache_type = TreeCache if tree else dict
57 self.cache = LruCache(
58 max_size=max_entries, keylen=keylen, cache_type=cache_type
59 )
6660
6761 self.name = name
6862 self.keylen = keylen
8074 "Cache objects can only be accessed from the main thread"
8175 )
8276
83 def get(self, key, default=_CacheSentinel):
84 val = self.cache.get(key, _CacheSentinel)
77 def get(self, key, default=_CacheSentinel, callback=None):
78 val = self.cache.get(key, _CacheSentinel, callback=callback)
8579 if val is not _CacheSentinel:
8680 self.metrics.inc_hits()
8781 return val
9387 else:
9488 return default
9589
96 def update(self, sequence, key, value):
90 def update(self, sequence, key, value, callback=None):
9791 self.check_thread()
9892 if self.sequence == sequence:
9993 # Only update the cache if the caches sequence number matches the
10094 # number that the cache had before the SELECT was started (SYN-369)
101 self.prefill(key, value)
102
103 def prefill(self, key, value):
104 if self.max_entries is not None:
105 while len(self.cache) >= self.max_entries:
106 self.cache.popitem(last=False)
107
108 self.cache[key] = value
95 self.prefill(key, value, callback=callback)
96
97 def prefill(self, key, value, callback=None):
98 self.cache.set(key, value, callback=callback)
10999
110100 def invalidate(self, key):
111101 self.check_thread()
150140 The wrapped function has another additional callable, called "prefill",
151141 which can be used to insert values into the cache specifically, without
152142 calling the calculation function.
143
144 Cached functions can be "chained" (i.e. a cached function can call other cached
145 functions and get appropriately invalidated when they called caches are
146 invalidated) by adding a special "cache_context" argument to the function
147 and passing that as a kwarg to all caches called. For example::
148
149 @cachedInlineCallbacks(cache_context=True)
150 def foo(self, key, cache_context):
151 r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
152 r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
153 defer.returnValue(r1 + r2)
154
153155 """
154 def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
155 inlineCallbacks=False):
156 def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
157 inlineCallbacks=False, cache_context=False):
156158 max_entries = int(max_entries * CACHE_SIZE_FACTOR)
157159
158160 self.orig = orig
164166
165167 self.max_entries = max_entries
166168 self.num_args = num_args
167 self.lru = lru
168169 self.tree = tree
169170
170 self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
171 all_args = inspect.getargspec(orig)
172 self.arg_names = all_args.args[1:num_args + 1]
173
174 if "cache_context" in all_args.args:
175 if not cache_context:
176 raise ValueError(
177 "Cannot have a 'cache_context' arg without setting"
178 " cache_context=True"
179 )
180 try:
181 self.arg_names.remove("cache_context")
182 except ValueError:
183 pass
184 elif cache_context:
185 raise ValueError(
186 "Cannot have cache_context=True without having an arg"
187 " named `cache_context`"
188 )
189
190 self.add_cache_context = cache_context
171191
172192 if len(self.arg_names) < self.num_args:
173193 raise Exception(
174194 "Not enough explicit positional arguments to key off of for %r."
175 " (@cached cannot key off of *args or **kwars)"
195 " (@cached cannot key off of *args or **kwargs)"
176196 % (orig.__name__,)
177197 )
178198
181201 name=self.orig.__name__,
182202 max_entries=self.max_entries,
183203 keylen=self.num_args,
184 lru=self.lru,
185204 tree=self.tree,
186205 )
187206
188207 @functools.wraps(self.orig)
189208 def wrapped(*args, **kwargs):
209 # If we're passed a cache_context then we'll want to call its invalidate()
210 # whenever we are invalidated
211 invalidate_callback = kwargs.pop("on_invalidate", None)
212
213 # Add temp cache_context so inspect.getcallargs doesn't explode
214 if self.add_cache_context:
215 kwargs["cache_context"] = None
216
190217 arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
191218 cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
219
220 # Add our own `cache_context` to argument list if the wrapped function
221 # has asked for one
222 if self.add_cache_context:
223 kwargs["cache_context"] = _CacheContext(cache, cache_key)
224
192225 try:
193 cached_result_d = cache.get(cache_key)
226 cached_result_d = cache.get(cache_key, callback=invalidate_callback)
194227
195228 observer = cached_result_d.observe()
196229 if DEBUG_CACHES:
227260 ret.addErrback(onErr)
228261
229262 ret = ObservableDeferred(ret, consumeErrors=True)
230 cache.update(sequence, cache_key, ret)
263 cache.update(sequence, cache_key, ret, callback=invalidate_callback)
231264
232265 return preserve_context_over_deferred(ret.observe())
233266
296329
297330 @functools.wraps(self.orig)
298331 def wrapped(*args, **kwargs):
332 # If we're passed a cache_context then we'll want to call its invalidate()
333 # whenever we are invalidated
334 invalidate_callback = kwargs.pop("on_invalidate", None)
335
299336 arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
300337 keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
301338 list_args = arg_dict[self.list_name]
310347 key[self.list_pos] = arg
311348
312349 try:
313 res = cache.get(tuple(key))
350 res = cache.get(tuple(key), callback=invalidate_callback)
314351 if not res.has_succeeded():
315352 res = res.observe()
316353 res.addCallback(lambda r, arg: (arg, r), arg)
344381
345382 key = list(keyargs)
346383 key[self.list_pos] = arg
347 cache.update(sequence, tuple(key), observer)
384 cache.update(
385 sequence, tuple(key), observer,
386 callback=invalidate_callback
387 )
348388
349389 def invalidate(f, key):
350390 cache.invalidate(key)
375415 return wrapped
376416
377417
378 def cached(max_entries=1000, num_args=1, lru=True, tree=False):
418 class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
419 def invalidate(self):
420 self.cache.invalidate(self.key)
421
422
423 def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
379424 return lambda orig: CacheDescriptor(
380425 orig,
381426 max_entries=max_entries,
382427 num_args=num_args,
383 lru=lru,
384428 tree=tree,
429 cache_context=cache_context,
385430 )
386431
387432
388 def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
433 def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
389434 return lambda orig: CacheDescriptor(
390435 orig,
391436 max_entries=max_entries,
392437 num_args=num_args,
393 lru=lru,
394438 tree=tree,
395439 inlineCallbacks=True,
440 cache_context=cache_context,
396441 )
397442
398443
2929
3030
3131 class _Node(object):
32 __slots__ = ["prev_node", "next_node", "key", "value"]
33
34 def __init__(self, prev_node, next_node, key, value):
32 __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
33
34 def __init__(self, prev_node, next_node, key, value, callbacks=set()):
3535 self.prev_node = prev_node
3636 self.next_node = next_node
3737 self.key = key
3838 self.value = value
39 self.callbacks = callbacks
3940
4041
4142 class LruCache(object):
4344 Least-recently-used cache.
4445 Supports del_multi only if cache_type=TreeCache
4546 If cache_type=TreeCache, all keys must be tuples.
47
48 Can also set callbacks on objects when getting/setting which are fired
49 when that key gets invalidated/evicted.
4650 """
4751 def __init__(self, max_size, keylen=1, cache_type=dict):
4852 cache = cache_type()
6165
6266 return inner
6367
64 def add_node(key, value):
68 def add_node(key, value, callbacks=set()):
6569 prev_node = list_root
6670 next_node = prev_node.next_node
67 node = _Node(prev_node, next_node, key, value)
71 node = _Node(prev_node, next_node, key, value, callbacks)
6872 prev_node.next_node = node
6973 next_node.prev_node = node
7074 cache[key] = node
8791 prev_node.next_node = next_node
8892 next_node.prev_node = prev_node
8993
90 @synchronized
91 def cache_get(key, default=None):
94 for cb in node.callbacks:
95 cb()
96 node.callbacks.clear()
97
98 @synchronized
99 def cache_get(key, default=None, callback=None):
92100 node = cache.get(key, None)
93101 if node is not None:
94102 move_node_to_front(node)
103 if callback:
104 node.callbacks.add(callback)
95105 return node.value
96106 else:
97107 return default
98108
99109 @synchronized
100 def cache_set(key, value):
110 def cache_set(key, value, callback=None):
101111 node = cache.get(key, None)
102112 if node is not None:
113 if value != node.value:
114 for cb in node.callbacks:
115 cb()
116 node.callbacks.clear()
117
118 if callback:
119 node.callbacks.add(callback)
120
103121 move_node_to_front(node)
104122 node.value = value
105123 else:
106 add_node(key, value)
124 if callback:
125 callbacks = set([callback])
126 else:
127 callbacks = set()
128 add_node(key, value, callbacks)
107129 if len(cache) > max_size:
108130 todelete = list_root.prev_node
109131 delete_node(todelete)
147169 def cache_clear():
148170 list_root.next_node = list_root
149171 list_root.prev_node = list_root
172 for node in cache.values():
173 for cb in node.callbacks:
174 cb()
150175 cache.clear()
151176
152177 @synchronized
6363 self.size -= cnt
6464 return popped
6565
66 def values(self):
67 return [e.value for e in self.root.values()]
68
6669 def __len__(self):
6770 return self.size
6871
296296 return res
297297
298298
299 def preserve_context_over_deferred(deferred):
299 def preserve_context_over_deferred(deferred, context=None):
300300 """Given a deferred wrap it such that any callbacks added later to it will
301301 be invoked with the current context.
302302 """
303 current_context = LoggingContext.current_context()
304 d = _PreservingContextDeferred(current_context)
303 if context is None:
304 context = LoggingContext.current_context()
305 d = _PreservingContextDeferred(context)
305306 deferred.chainDeferred(d)
306307 return d
307308
315316
316317 def g(*args, **kwargs):
317318 with PreserveLoggingContext(current):
318 return f(*args, **kwargs)
319
319 res = f(*args, **kwargs)
320 if isinstance(res, defer.Deferred):
321 return preserve_context_over_deferred(
322 res, context=LoggingContext.sentinel
323 )
324 else:
325 return res
320326 return g
321327
322328
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from twisted.internet import defer
1516
1617 from synapse.util.logcontext import LoggingContext
1718 import synapse.metrics
1819
20 from functools import wraps
1921 import logging
2022
2123
4648 )
4749
4850
51 def measure_func(name):
52 def wrapper(func):
53 @wraps(func)
54 @defer.inlineCallbacks
55 def measured_func(self, *args, **kwargs):
56 with Measure(self.clock, name):
57 r = yield func(self, *args, **kwargs)
58 defer.returnValue(r)
59 return measured_func
60 return wrapper
61
62
4963 class Measure(object):
5064 __slots__ = [
5165 "clock", "name", "start_context", "start", "new_context", "ru_utime",
6377 self.start = self.clock.time_msec()
6478 self.start_context = LoggingContext.current_context()
6579 if not self.start_context:
66 logger.warn("Entered Measure without log context: %s", self.name)
6780 self.start_context = LoggingContext("Measure")
6881 self.start_context.__enter__()
6982 self.created_context = True
7386 self.db_txn_duration = self.start_context.db_txn_duration
7487
7588 def __exit__(self, exc_type, exc_val, exc_tb):
76 if exc_type is not None or not self.start_context:
89 if isinstance(exc_type, Exception) or not self.start_context:
7790 return
7891
7992 duration = self.clock.time_msec() - self.start
8497 if context != self.start_context:
8598 logger.warn(
8699 "Context has unexpectedly changed from '%s' to '%s'. (%r)",
87 context, self.start_context, self.name
100 self.start_context, context, self.name
88101 )
89102 return
90103
1616
1717 from synapse.api.constants import Membership, EventTypes
1818
19 from synapse.util.logcontext import preserve_fn
19 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
2020
2121 import logging
2222
5454 given events
5555 events ([synapse.events.EventBase]): list of events to filter
5656 """
57 forgotten = yield defer.gatherResults([
57 forgotten = yield preserve_context_over_deferred(defer.gatherResults([
5858 preserve_fn(store.who_forgot_in_room)(
5959 room_id,
6060 )
6161 for room_id in frozenset(e.room_id for e in events)
62 ], consumeErrors=True)
62 ], consumeErrors=True))
6363
6464 # Set of membership event_ids that have been forgotten
6565 event_id_forgotten = frozenset(
1313 # limitations under the License.
1414 from synapse.appservice import ApplicationService
1515
16 from twisted.internet import defer
17
1618 from mock import Mock
1719 from tests import unittest
1820
4143 type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
4244 )
4345
46 self.store = Mock()
47
48 @defer.inlineCallbacks
4449 def test_regex_user_id_prefix_match(self):
4550 self.service.namespaces[ApplicationService.NS_USERS].append(
4651 _regex("@irc_.*")
4752 )
4853 self.event.sender = "@irc_foobar:matrix.org"
49 self.assertTrue(self.service.is_interested(self.event))
50
54 self.assertTrue((yield self.service.is_interested(self.event)))
55
56 @defer.inlineCallbacks
5157 def test_regex_user_id_prefix_no_match(self):
5258 self.service.namespaces[ApplicationService.NS_USERS].append(
5359 _regex("@irc_.*")
5460 )
5561 self.event.sender = "@someone_else:matrix.org"
56 self.assertFalse(self.service.is_interested(self.event))
57
62 self.assertFalse((yield self.service.is_interested(self.event)))
63
64 @defer.inlineCallbacks
5865 def test_regex_room_member_is_checked(self):
5966 self.service.namespaces[ApplicationService.NS_USERS].append(
6067 _regex("@irc_.*")
6269 self.event.sender = "@someone_else:matrix.org"
6370 self.event.type = "m.room.member"
6471 self.event.state_key = "@irc_foobar:matrix.org"
65 self.assertTrue(self.service.is_interested(self.event))
66
72 self.assertTrue((yield self.service.is_interested(self.event)))
73
74 @defer.inlineCallbacks
6775 def test_regex_room_id_match(self):
6876 self.service.namespaces[ApplicationService.NS_ROOMS].append(
6977 _regex("!some_prefix.*some_suffix:matrix.org")
7078 )
7179 self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
72 self.assertTrue(self.service.is_interested(self.event))
73
80 self.assertTrue((yield self.service.is_interested(self.event)))
81
82 @defer.inlineCallbacks
7483 def test_regex_room_id_no_match(self):
7584 self.service.namespaces[ApplicationService.NS_ROOMS].append(
7685 _regex("!some_prefix.*some_suffix:matrix.org")
7786 )
7887 self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
79 self.assertFalse(self.service.is_interested(self.event))
80
88 self.assertFalse((yield self.service.is_interested(self.event)))
89
90 @defer.inlineCallbacks
8191 def test_regex_alias_match(self):
8292 self.service.namespaces[ApplicationService.NS_ALIASES].append(
8393 _regex("#irc_.*:matrix.org")
8494 )
85 self.assertTrue(self.service.is_interested(
86 self.event,
87 aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
88 ))
95 self.store.get_aliases_for_room.return_value = [
96 "#irc_foobar:matrix.org", "#athing:matrix.org"
97 ]
98 self.store.get_users_in_room.return_value = []
99 self.assertTrue((yield self.service.is_interested(
100 self.event, self.store
101 )))
89102
90103 def test_non_exclusive_alias(self):
91104 self.service.namespaces[ApplicationService.NS_ALIASES].append(
135148 "!irc_foobar:matrix.org"
136149 ))
137150
151 @defer.inlineCallbacks
138152 def test_regex_alias_no_match(self):
139153 self.service.namespaces[ApplicationService.NS_ALIASES].append(
140154 _regex("#irc_.*:matrix.org")
141155 )
142 self.assertFalse(self.service.is_interested(
143 self.event,
144 aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
145 ))
146
156 self.store.get_aliases_for_room.return_value = [
157 "#xmpp_foobar:matrix.org", "#athing:matrix.org"
158 ]
159 self.store.get_users_in_room.return_value = []
160 self.assertFalse((yield self.service.is_interested(
161 self.event, self.store
162 )))
163
164 @defer.inlineCallbacks
147165 def test_regex_multiple_matches(self):
148166 self.service.namespaces[ApplicationService.NS_ALIASES].append(
149167 _regex("#irc_.*:matrix.org")
152170 _regex("@irc_.*")
153171 )
154172 self.event.sender = "@irc_foobar:matrix.org"
155 self.assertTrue(self.service.is_interested(
156 self.event,
157 aliases_for_event=["#irc_barfoo:matrix.org"]
158 ))
159
160 def test_restrict_to_rooms(self):
161 self.service.namespaces[ApplicationService.NS_ROOMS].append(
162 _regex("!flibble_.*:matrix.org")
163 )
164 self.service.namespaces[ApplicationService.NS_USERS].append(
165 _regex("@irc_.*")
166 )
167 self.event.sender = "@irc_foobar:matrix.org"
168 self.event.room_id = "!wibblewoo:matrix.org"
169 self.assertFalse(self.service.is_interested(
170 self.event,
171 restrict_to=ApplicationService.NS_ROOMS
172 ))
173
174 def test_restrict_to_aliases(self):
175 self.service.namespaces[ApplicationService.NS_ALIASES].append(
176 _regex("#xmpp_.*:matrix.org")
177 )
178 self.service.namespaces[ApplicationService.NS_USERS].append(
179 _regex("@irc_.*")
180 )
181 self.event.sender = "@irc_foobar:matrix.org"
182 self.assertFalse(self.service.is_interested(
183 self.event,
184 restrict_to=ApplicationService.NS_ALIASES,
185 aliases_for_event=["#irc_barfoo:matrix.org"]
186 ))
187
188 def test_restrict_to_senders(self):
189 self.service.namespaces[ApplicationService.NS_ALIASES].append(
190 _regex("#xmpp_.*:matrix.org")
191 )
192 self.service.namespaces[ApplicationService.NS_USERS].append(
193 _regex("@irc_.*")
194 )
195 self.event.sender = "@xmpp_foobar:matrix.org"
196 self.assertFalse(self.service.is_interested(
197 self.event,
198 restrict_to=ApplicationService.NS_USERS,
199 aliases_for_event=["#xmpp_barfoo:matrix.org"]
200 ))
201
173 self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
174 self.store.get_users_in_room.return_value = []
175 self.assertTrue((yield self.service.is_interested(
176 self.event, self.store
177 )))
178
179 @defer.inlineCallbacks
202180 def test_interested_in_self(self):
203181 # make sure invites get through
204182 self.service.sender = "@appservice:name"
210188 "membership": "invite"
211189 }
212190 self.event.state_key = self.service.sender
213 self.assertTrue(self.service.is_interested(self.event))
214
191 self.assertTrue((yield self.service.is_interested(self.event)))
192
193 @defer.inlineCallbacks
215194 def test_member_list_match(self):
216195 self.service.namespaces[ApplicationService.NS_USERS].append(
217196 _regex("@irc_.*")
218197 )
219 join_list = [
198 self.store.get_users_in_room.return_value = [
220199 "@alice:here",
221200 "@irc_fo:here", # AS user
222201 "@bob:here",
223202 ]
203 self.store.get_aliases_for_room.return_value = []
224204
225205 self.event.sender = "@xmpp_foobar:matrix.org"
226 self.assertTrue(self.service.is_interested(
227 event=self.event,
228 member_list=join_list
229 ))
206 self.assertTrue((yield self.service.is_interested(
207 event=self.event, store=self.store
208 )))
192192
193193 def setUp(self):
194194 self.txn_ctrl = Mock()
195 self.queuer = _ServiceQueuer(self.txn_ctrl)
195 self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
196196
197197 def test_send_single_event_no_queue(self):
198198 # Expect the event to be sent immediately.
1414
1515 from twisted.internet import defer
1616 from .. import unittest
17 from tests.utils import MockClock
1718
1819 from synapse.handlers.appservice import ApplicationServicesHandler
1920
3132 hs.get_datastore = Mock(return_value=self.mock_store)
3233 hs.get_application_service_api = Mock(return_value=self.mock_as_api)
3334 hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
35 hs.get_clock.return_value = MockClock()
3436 self.handler = ApplicationServicesHandler(hs)
3537
3638 @defer.inlineCallbacks
5052 type="m.room.message",
5153 room_id="!foo:bar"
5254 )
55 self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
5356 self.mock_as_api.push = Mock()
54 yield self.handler.notify_interested_services(event)
57 yield self.handler.notify_interested_services(0)
5558 self.mock_scheduler.submit_event_for_as.assert_called_once_with(
5659 interested_service, event
5760 )
7174 )
7275 self.mock_as_api.push = Mock()
7376 self.mock_as_api.query_user = Mock()
74 yield self.handler.notify_interested_services(event)
77 self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
78 yield self.handler.notify_interested_services(0)
7579 self.mock_as_api.query_user.assert_called_once_with(
7680 services[0], user_id
7781 )
9397 )
9498 self.mock_as_api.push = Mock()
9599 self.mock_as_api.query_user = Mock()
96 yield self.handler.notify_interested_services(event)
100 self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
101 yield self.handler.notify_interested_services(0)
97102 self.assertFalse(
98103 self.mock_as_api.query_user.called,
99104 "query_user called when it shouldn't have been."
107112
108113 room_id = "!alpha:bet"
109114 servers = ["aperture"]
110 interested_service = self._mkservice(is_interested=True)
115 interested_service = self._mkservice_alias(is_interested_in_alias=True)
111116 services = [
112 self._mkservice(is_interested=False),
117 self._mkservice_alias(is_interested_in_alias=False),
113118 interested_service,
114 self._mkservice(is_interested=False)
119 self._mkservice_alias(is_interested_in_alias=False)
115120 ]
116121
117122 self.mock_store.get_app_services = Mock(return_value=services)
134139 service.token = "mock_service_token"
135140 service.url = "mock_service_url"
136141 return service
142
143 def _mkservice_alias(self, is_interested_in_alias):
144 service = Mock()
145 service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
146 service.token = "mock_service_token"
147 service.url = "mock_service_url"
148 return service
1313 # limitations under the License.
1414
1515 import pymacaroons
16 from twisted.internet import defer
1617
18 import synapse
19 import synapse.api.errors
1720 from synapse.handlers.auth import AuthHandler
1821 from tests import unittest
1922 from tests.utils import setup_test_homeserver
20 from twisted.internet import defer
2123
2224
2325 class AuthHandlers(object):
3032 def setUp(self):
3133 self.hs = yield setup_test_homeserver(handlers=None)
3234 self.hs.handlers = AuthHandlers(self.hs)
35 self.auth_handler = self.hs.handlers.auth_handler
3336
3437 def test_token_is_a_macaroon(self):
3538 self.hs.config.macaroon_secret_key = "this key is a huge secret"
3639
37 token = self.hs.handlers.auth_handler.generate_access_token("some_user")
40 token = self.auth_handler.generate_access_token("some_user")
3841 # Check that we can parse the thing with pymacaroons
3942 macaroon = pymacaroons.Macaroon.deserialize(token)
4043 # The most basic of sanity checks
4548 self.hs.config.macaroon_secret_key = "this key is a massive secret"
4649 self.hs.clock.now = 5000
4750
48 token = self.hs.handlers.auth_handler.generate_access_token("a_user")
51 token = self.auth_handler.generate_access_token("a_user")
4952 macaroon = pymacaroons.Macaroon.deserialize(token)
5053
5154 def verify_gen(caveat):
6669 v.satisfy_general(verify_type)
6770 v.satisfy_general(verify_expiry)
6871 v.verify(macaroon, self.hs.config.macaroon_secret_key)
72
73 def test_short_term_login_token_gives_user_id(self):
74 self.hs.clock.now = 1000
75
76 token = self.auth_handler.generate_short_term_login_token(
77 "a_user", 5000
78 )
79
80 self.assertEqual(
81 "a_user",
82 self.auth_handler.validate_short_term_login_token_and_get_user_id(
83 token
84 )
85 )
86
87 # when we advance the clock, the token should be rejected
88 self.hs.clock.now = 6000
89 with self.assertRaises(synapse.api.errors.AuthError):
90 self.auth_handler.validate_short_term_login_token_and_get_user_id(
91 token
92 )
93
94 def test_short_term_login_token_cannot_replace_user_id(self):
95 token = self.auth_handler.generate_short_term_login_token(
96 "a_user", 5000
97 )
98 macaroon = pymacaroons.Macaroon.deserialize(token)
99
100 self.assertEqual(
101 "a_user",
102 self.auth_handler.validate_short_term_login_token_and_get_user_id(
103 macaroon.serialize()
104 )
105 )
106
107 # add another "user_id" caveat, which might allow us to override the
108 # user_id.
109 macaroon.add_first_party_caveat("user_id = b_user")
110
111 with self.assertRaises(synapse.api.errors.AuthError):
112 self.auth_handler.validate_short_term_login_token_and_get_user_id(
113 macaroon.serialize()
114 )
1616 from tests import unittest
1717 from twisted.internet import defer
1818
19 from mock import Mock
20
1921 from synapse.util.async import ObservableDeferred
2022
2123 from synapse.util.caches.descriptors import Cache, cached
7173 cache.get(3)
7274
7375 def test_eviction_lru(self):
74 cache = Cache("test", max_entries=2, lru=True)
76 cache = Cache("test", max_entries=2)
7577
7678 cache.prefill(1, "one")
7779 cache.prefill(2, "two")
198200
199201 self.assertEquals(a.func("foo").result, d.result)
200202 self.assertEquals(callcount[0], 0)
203
204 @defer.inlineCallbacks
205 def test_invalidate_context(self):
206 callcount = [0]
207 callcount2 = [0]
208
209 class A(object):
210 @cached()
211 def func(self, key):
212 callcount[0] += 1
213 return key
214
215 @cached(cache_context=True)
216 def func2(self, key, cache_context):
217 callcount2[0] += 1
218 return self.func(key, on_invalidate=cache_context.invalidate)
219
220 a = A()
221 yield a.func2("foo")
222
223 self.assertEquals(callcount[0], 1)
224 self.assertEquals(callcount2[0], 1)
225
226 a.func.invalidate(("foo",))
227 yield a.func("foo")
228
229 self.assertEquals(callcount[0], 2)
230 self.assertEquals(callcount2[0], 1)
231
232 yield a.func2("foo")
233
234 self.assertEquals(callcount[0], 2)
235 self.assertEquals(callcount2[0], 2)
236
237 @defer.inlineCallbacks
238 def test_eviction_context(self):
239 callcount = [0]
240 callcount2 = [0]
241
242 class A(object):
243 @cached(max_entries=2)
244 def func(self, key):
245 callcount[0] += 1
246 return key
247
248 @cached(cache_context=True)
249 def func2(self, key, cache_context):
250 callcount2[0] += 1
251 return self.func(key, on_invalidate=cache_context.invalidate)
252
253 a = A()
254 yield a.func2("foo")
255 yield a.func2("foo2")
256
257 self.assertEquals(callcount[0], 2)
258 self.assertEquals(callcount2[0], 2)
259
260 yield a.func("foo3")
261
262 self.assertEquals(callcount[0], 3)
263 self.assertEquals(callcount2[0], 2)
264
265 yield a.func2("foo")
266
267 self.assertEquals(callcount[0], 4)
268 self.assertEquals(callcount2[0], 3)
269
270 @defer.inlineCallbacks
271 def test_double_get(self):
272 callcount = [0]
273 callcount2 = [0]
274
275 class A(object):
276 @cached()
277 def func(self, key):
278 callcount[0] += 1
279 return key
280
281 @cached(cache_context=True)
282 def func2(self, key, cache_context):
283 callcount2[0] += 1
284 return self.func(key, on_invalidate=cache_context.invalidate)
285
286 a = A()
287 a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
288
289 yield a.func2("foo")
290
291 self.assertEquals(callcount[0], 1)
292 self.assertEquals(callcount2[0], 1)
293
294 a.func2.invalidate(("foo",))
295 self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
296
297 yield a.func2("foo")
298 a.func2.invalidate(("foo",))
299 self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
300
301 self.assertEquals(callcount[0], 1)
302 self.assertEquals(callcount2[0], 2)
303
304 a.func.invalidate(("foo",))
305 self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
306 yield a.func("foo")
307
308 self.assertEquals(callcount[0], 2)
309 self.assertEquals(callcount2[0], 2)
310
311 yield a.func2("foo")
312
313 self.assertEquals(callcount[0], 2)
314 self.assertEquals(callcount2[0], 3)
1414
1515 from . import unittest
1616
17 from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs
17 from synapse.rest.media.v1.preview_url_resource import (
18 summarize_paragraphs, decode_and_calc_og
19 )
1820
1921
2022 class PreviewTestCase(unittest.TestCase):
136138 " of old wooden houses in Northern Norway, the oldest house dating from"
137139 " 1789. The Arctic Cathedral, a modern church…"
138140 )
141
142
143 class PreviewUrlTestCase(unittest.TestCase):
144 def test_simple(self):
145 html = """
146 <html>
147 <head><title>Foo</title></head>
148 <body>
149 Some text.
150 </body>
151 </html>
152 """
153
154 og = decode_and_calc_og(html, "http://example.com/test.html")
155
156 self.assertEquals(og, {
157 "og:title": "Foo",
158 "og:description": "Some text."
159 })
160
161 def test_comment(self):
162 html = """
163 <html>
164 <head><title>Foo</title></head>
165 <body>
166 <!-- HTML comment -->
167 Some text.
168 </body>
169 </html>
170 """
171
172 og = decode_and_calc_og(html, "http://example.com/test.html")
173
174 self.assertEquals(og, {
175 "og:title": "Foo",
176 "og:description": "Some text."
177 })
178
179 def test_comment2(self):
180 html = """
181 <html>
182 <head><title>Foo</title></head>
183 <body>
184 Some text.
185 <!-- HTML comment -->
186 Some more text.
187 <p>Text</p>
188 More text
189 </body>
190 </html>
191 """
192
193 og = decode_and_calc_og(html, "http://example.com/test.html")
194
195 self.assertEquals(og, {
196 "og:title": "Foo",
197 "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text"
198 })
199
200 def test_script(self):
201 html = """
202 <html>
203 <head><title>Foo</title></head>
204 <body>
205 <script> (function() {})() </script>
206 Some text.
207 </body>
208 </html>
209 """
210
211 og = decode_and_calc_og(html, "http://example.com/test.html")
212
213 self.assertEquals(og, {
214 "og:title": "Foo",
215 "og:description": "Some text."
216 })
1818 from synapse.util.caches.lrucache import LruCache
1919 from synapse.util.caches.treecache import TreeCache
2020
21 from mock import Mock
22
2123
2224 class LruCacheTestCase(unittest.TestCase):
2325
4749 self.assertEquals(cache.get("key"), 1)
4850 self.assertEquals(cache.setdefault("key", 2), 1)
4951 self.assertEquals(cache.get("key"), 1)
52 cache["key"] = 2 # Make sure overriding works.
53 self.assertEquals(cache.get("key"), 2)
5054
5155 def test_pop(self):
5256 cache = LruCache(1)
7882 cache["key"] = 1
7983 cache.clear()
8084 self.assertEquals(len(cache), 0)
85
86
87 class LruCacheCallbacksTestCase(unittest.TestCase):
88 def test_get(self):
89 m = Mock()
90 cache = LruCache(1)
91
92 cache.set("key", "value")
93 self.assertFalse(m.called)
94
95 cache.get("key", callback=m)
96 self.assertFalse(m.called)
97
98 cache.get("key", "value")
99 self.assertFalse(m.called)
100
101 cache.set("key", "value2")
102 self.assertEquals(m.call_count, 1)
103
104 cache.set("key", "value")
105 self.assertEquals(m.call_count, 1)
106
107 def test_multi_get(self):
108 m = Mock()
109 cache = LruCache(1)
110
111 cache.set("key", "value")
112 self.assertFalse(m.called)
113
114 cache.get("key", callback=m)
115 self.assertFalse(m.called)
116
117 cache.get("key", callback=m)
118 self.assertFalse(m.called)
119
120 cache.set("key", "value2")
121 self.assertEquals(m.call_count, 1)
122
123 cache.set("key", "value")
124 self.assertEquals(m.call_count, 1)
125
126 def test_set(self):
127 m = Mock()
128 cache = LruCache(1)
129
130 cache.set("key", "value", m)
131 self.assertFalse(m.called)
132
133 cache.set("key", "value")
134 self.assertFalse(m.called)
135
136 cache.set("key", "value2")
137 self.assertEquals(m.call_count, 1)
138
139 cache.set("key", "value")
140 self.assertEquals(m.call_count, 1)
141
142 def test_pop(self):
143 m = Mock()
144 cache = LruCache(1)
145
146 cache.set("key", "value", m)
147 self.assertFalse(m.called)
148
149 cache.pop("key")
150 self.assertEquals(m.call_count, 1)
151
152 cache.set("key", "value")
153 self.assertEquals(m.call_count, 1)
154
155 cache.pop("key")
156 self.assertEquals(m.call_count, 1)
157
158 def test_del_multi(self):
159 m1 = Mock()
160 m2 = Mock()
161 m3 = Mock()
162 m4 = Mock()
163 cache = LruCache(4, 2, cache_type=TreeCache)
164
165 cache.set(("a", "1"), "value", m1)
166 cache.set(("a", "2"), "value", m2)
167 cache.set(("b", "1"), "value", m3)
168 cache.set(("b", "2"), "value", m4)
169
170 self.assertEquals(m1.call_count, 0)
171 self.assertEquals(m2.call_count, 0)
172 self.assertEquals(m3.call_count, 0)
173 self.assertEquals(m4.call_count, 0)
174
175 cache.del_multi(("a",))
176
177 self.assertEquals(m1.call_count, 1)
178 self.assertEquals(m2.call_count, 1)
179 self.assertEquals(m3.call_count, 0)
180 self.assertEquals(m4.call_count, 0)
181
182 def test_clear(self):
183 m1 = Mock()
184 m2 = Mock()
185 cache = LruCache(5)
186
187 cache.set("key1", "value", m1)
188 cache.set("key2", "value", m2)
189
190 self.assertEquals(m1.call_count, 0)
191 self.assertEquals(m2.call_count, 0)
192
193 cache.clear()
194
195 self.assertEquals(m1.call_count, 1)
196 self.assertEquals(m2.call_count, 1)
197
198 def test_eviction(self):
199 m1 = Mock(name="m1")
200 m2 = Mock(name="m2")
201 m3 = Mock(name="m3")
202 cache = LruCache(2)
203
204 cache.set("key1", "value", m1)
205 cache.set("key2", "value", m2)
206
207 self.assertEquals(m1.call_count, 0)
208 self.assertEquals(m2.call_count, 0)
209 self.assertEquals(m3.call_count, 0)
210
211 cache.set("key3", "value", m3)
212
213 self.assertEquals(m1.call_count, 1)
214 self.assertEquals(m2.call_count, 0)
215 self.assertEquals(m3.call_count, 0)
216
217 cache.set("key3", "value")
218
219 self.assertEquals(m1.call_count, 1)
220 self.assertEquals(m2.call_count, 0)
221 self.assertEquals(m3.call_count, 0)
222
223 cache.get("key2")
224
225 self.assertEquals(m1.call_count, 1)
226 self.assertEquals(m2.call_count, 0)
227 self.assertEquals(m3.call_count, 0)
228
229 cache.set("key1", "value", m1)
230
231 self.assertEquals(m1.call_count, 1)
232 self.assertEquals(m2.call_count, 0)
233 self.assertEquals(m3.call_count, 1)