Imported Upstream version 0.17.1
Erik Johnston
7 years ago
0 | Changes in synapse v0.17.1 (2016-08-24) | |
1 | ======================================= | |
2 | ||
3 | Changes: | |
4 | ||
5 | * Delete old received_transactions rows (PR #1038) | |
6 | * Pass through user-supplied content in /join/$room_id (PR #1039) | |
7 | ||
8 | ||
9 | Bug fixes: | |
10 | ||
11 | * Fix bug with backfill (PR #1040) | |
12 | ||
13 | ||
14 | Changes in synapse v0.17.1-rc1 (2016-08-22) | |
15 | =========================================== | |
16 | ||
17 | Features: | |
18 | ||
19 | * Add notification API (PR #1028) | |
20 | ||
21 | ||
22 | Changes: | |
23 | ||
24 | * Don't print stack traces when failing to get remote keys (PR #996) | |
25 | * Various federation /event/ perf improvements (PR #998) | |
26 | * Only process one local membership event per room at a time (PR #1005) | |
27 | * Move default display name push rule (PR #1011, #1023) | |
28 | * Fix up preview URL API. Add tests. (PR #1015) | |
29 | * Set ``Content-Security-Policy`` on media repo (PR #1021) | |
30 | * Make notify_interested_services faster (PR #1022) | |
31 | * Add usage stats to prometheus monitoring (PR #1037) | |
32 | ||
33 | ||
34 | Bug fixes: | |
35 | ||
36 | * Fix token login (PR #993) | |
37 | * Fix CAS login (PR #994, #995) | |
38 | * Fix /sync to not clobber status_msg (PR #997) | |
39 | * Fix redacted state events to include prev_content (PR #1003) | |
40 | * Fix some bugs in the auth/ldap handler (PR #1007) | |
41 | * Fix backfill request to limit URI length, so that remotes don't reject the | |
42 | requests due to path length limits (PR #1012) | |
43 | * Fix AS push code to not send duplicate events (PR #1025) | |
44 | ||
45 | ||
46 | ||
0 | 47 | Changes in synapse v0.17.0 (2016-08-08) |
1 | 48 | ======================================= |
2 | 49 |
94 | 94 | System requirements: |
95 | 95 | - POSIX-compliant system (tested on Linux & OS X) |
96 | 96 | - Python 2.7 |
97 | - At least 512 MB RAM. | |
97 | - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org | |
98 | 98 | |
99 | 99 | Synapse is written in python but some of the libraries is uses are written in |
100 | 100 | C. So before we can install synapse itself we need a working C compiler and the |
0 | Scaling synapse via workers | |
1 | --------------------------- | |
2 | ||
3 | Synapse has experimental support for splitting out functionality into | |
4 | multiple separate python processes, helping greatly with scalability. These | |
5 | processes are called 'workers', and are (eventually) intended to scale | |
6 | horizontally independently. | |
7 | ||
8 | All processes continue to share the same database instance, and as such, workers | |
9 | only work with postgres based synapse deployments (sharing a single sqlite | |
10 | across multiple processes is a recipe for disaster, plus you should be using | |
11 | postgres anyway if you care about scalability). | |
12 | ||
13 | The workers communicate with the master synapse process via a synapse-specific | |
14 | HTTP protocol called 'replication' - analogous to MySQL or Postgres style | |
15 | database replication; feeding a stream of relevant data to the workers so they | |
16 | can be kept in sync with the main synapse process and database state. | |
17 | ||
18 | To enable workers, you need to add a replication listener to the master synapse, e.g.:: | |
19 | ||
20 | listeners: | |
21 | - port: 9092 | |
22 | bind_address: '127.0.0.1' | |
23 | type: http | |
24 | tls: false | |
25 | x_forwarded: false | |
26 | resources: | |
27 | - names: [replication] | |
28 | compress: false | |
29 | ||
30 | Under **no circumstances** should this replication API listener be exposed to the | |
31 | public internet; it currently implements no authentication whatsoever and is | |
32 | unencrypted HTTP. | |
33 | ||
34 | You then create a set of configs for the various worker processes. These should be | |
35 | worker configuration files should be stored in a dedicated subdirectory, to allow | |
36 | synctl to manipulate them. | |
37 | ||
38 | The current available worker applications are: | |
39 | * synapse.app.pusher - handles sending push notifications to sygnal and email | |
40 | * synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances. | |
41 | * synapse.app.appservice - handles output traffic to Application Services | |
42 | * synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API) | |
43 | * synapse.app.media_repository - handles the media repository. | |
44 | ||
45 | Each worker configuration file inherits the configuration of the main homeserver | |
46 | configuration file. You can then override configuration specific to that worker, | |
47 | e.g. the HTTP listener that it provides (if any); logging configuration; etc. | |
48 | You should minimise the number of overrides though to maintain a usable config. | |
49 | ||
50 | You must specify the type of worker application (worker_app) and the replication | |
51 | endpoint that it's talking to on the main synapse process (worker_replication_url). | |
52 | ||
53 | For instance:: | |
54 | ||
55 | worker_app: synapse.app.synchrotron | |
56 | ||
57 | # The replication listener on the synapse to talk to. | |
58 | worker_replication_url: http://127.0.0.1:9092/_synapse/replication | |
59 | ||
60 | worker_listeners: | |
61 | - type: http | |
62 | port: 8083 | |
63 | resources: | |
64 | - names: | |
65 | - client | |
66 | ||
67 | worker_daemonize: True | |
68 | worker_pid_file: /home/matrix/synapse/synchrotron.pid | |
69 | worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml | |
70 | ||
71 | ...is a full configuration for a synchrotron worker instance, which will expose a | |
72 | plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided | |
73 | by the main synapse. | |
74 | ||
75 | Obviously you should configure your loadbalancer to route the /sync endpoint to | |
76 | the synchrotron instance(s) in this instance. | |
77 | ||
78 | Finally, to actually run your worker-based synapse, you must pass synctl the -a | |
79 | commandline option to tell it to operate on all the worker configurations found | |
80 | in the given directory, e.g.:: | |
81 | ||
82 | synctl -a $CONFIG/workers start | |
83 | ||
84 | Currently one should always restart all workers when restarting or upgrading | |
85 | synapse, unless you explicitly know it's safe not to. For instance, restarting | |
86 | synapse without restarting all the synchrotrons may result in broken typing | |
87 | notifications. | |
88 | ||
89 | To manipulate a specific worker, you pass the -w option to synctl:: | |
90 | ||
91 | synctl -w $CONFIG/workers/synchrotron.yaml restart | |
92 | ||
93 | All of the above is highly experimental and subject to change as Synapse evolves, | |
94 | but documenting it here to help folks needing highly scalable Synapses similar | |
95 | to the one running matrix.org! | |
96 |
13 | 13 | tox -e py27 --notest -v |
14 | 14 | |
15 | 15 | TOX_BIN=$TOX_DIR/py27/bin |
16 | $TOX_BIN/pip install setuptools | |
16 | 17 | python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install |
17 | 18 | $TOX_BIN/pip install lxml |
18 | 19 | $TOX_BIN/pip install psycopg2 |
24 | 24 | tox --notest -e py27 |
25 | 25 | TOX_BIN=$WORKSPACE/.tox/py27/bin |
26 | 26 | python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install |
27 | $TOX_BIN/pip install lxml | |
27 | 28 | |
28 | 29 | tox -e py27 |
15 | 15 | """ This is a reference implementation of a Matrix home server. |
16 | 16 | """ |
17 | 17 | |
18 | __version__ = "0.17.0" | |
18 | __version__ = "0.17.1" |
674 | 674 | try: |
675 | 675 | macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) |
676 | 676 | |
677 | user_prefix = "user_id = " | |
678 | user = None | |
679 | user_id = None | |
680 | guest = False | |
681 | for caveat in macaroon.caveats: | |
682 | if caveat.caveat_id.startswith(user_prefix): | |
683 | user_id = caveat.caveat_id[len(user_prefix):] | |
684 | user = UserID.from_string(user_id) | |
685 | elif caveat.caveat_id == "guest = true": | |
686 | guest = True | |
677 | user_id = self.get_user_id_from_macaroon(macaroon) | |
678 | user = UserID.from_string(user_id) | |
687 | 679 | |
688 | 680 | self.validate_macaroon( |
689 | 681 | macaroon, rights, self.hs.config.expire_access_token, |
690 | 682 | user_id=user_id, |
691 | 683 | ) |
692 | 684 | |
693 | if user is None: | |
694 | raise AuthError( | |
695 | self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", | |
696 | errcode=Codes.UNKNOWN_TOKEN | |
697 | ) | |
685 | guest = False | |
686 | for caveat in macaroon.caveats: | |
687 | if caveat.caveat_id == "guest = true": | |
688 | guest = True | |
698 | 689 | |
699 | 690 | if guest: |
700 | 691 | ret = { |
742 | 733 | errcode=Codes.UNKNOWN_TOKEN |
743 | 734 | ) |
744 | 735 | |
736 | def get_user_id_from_macaroon(self, macaroon): | |
737 | """Retrieve the user_id given by the caveats on the macaroon. | |
738 | ||
739 | Does *not* validate the macaroon. | |
740 | ||
741 | Args: | |
742 | macaroon (pymacaroons.Macaroon): The macaroon to validate | |
743 | ||
744 | Returns: | |
745 | (str) user id | |
746 | ||
747 | Raises: | |
748 | AuthError if there is no user_id caveat in the macaroon | |
749 | """ | |
750 | user_prefix = "user_id = " | |
751 | for caveat in macaroon.caveats: | |
752 | if caveat.caveat_id.startswith(user_prefix): | |
753 | return caveat.caveat_id[len(user_prefix):] | |
754 | raise AuthError( | |
755 | self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", | |
756 | errcode=Codes.UNKNOWN_TOKEN | |
757 | ) | |
758 | ||
745 | 759 | def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): |
746 | 760 | """ |
747 | 761 | validate that a Macaroon is understood by and was signed by this server. |
753 | 767 | verify_expiry(bool): Whether to verify whether the macaroon has expired. |
754 | 768 | This should really always be True, but no clients currently implement |
755 | 769 | token refresh, so we can't enforce expiry yet. |
770 | user_id (str): The user_id required | |
756 | 771 | """ |
757 | 772 | v = pymacaroons.Verifier() |
758 | 773 | v.satisfy_exact("gen = 1") |
0 | #!/usr/bin/env python | |
1 | # -*- coding: utf-8 -*- | |
2 | # Copyright 2016 OpenMarket Ltd | |
3 | # | |
4 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
5 | # you may not use this file except in compliance with the License. | |
6 | # You may obtain a copy of the License at | |
7 | # | |
8 | # http://www.apache.org/licenses/LICENSE-2.0 | |
9 | # | |
10 | # Unless required by applicable law or agreed to in writing, software | |
11 | # distributed under the License is distributed on an "AS IS" BASIS, | |
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
13 | # See the License for the specific language governing permissions and | |
14 | # limitations under the License. | |
15 | ||
16 | import synapse | |
17 | ||
18 | from synapse.server import HomeServer | |
19 | from synapse.config._base import ConfigError | |
20 | from synapse.config.logger import setup_logging | |
21 | from synapse.config.homeserver import HomeServerConfig | |
22 | from synapse.http.site import SynapseSite | |
23 | from synapse.metrics.resource import MetricsResource, METRICS_PREFIX | |
24 | from synapse.replication.slave.storage.directory import DirectoryStore | |
25 | from synapse.replication.slave.storage.events import SlavedEventStore | |
26 | from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore | |
27 | from synapse.replication.slave.storage.registration import SlavedRegistrationStore | |
28 | from synapse.storage.engines import create_engine | |
29 | from synapse.util.async import sleep | |
30 | from synapse.util.httpresourcetree import create_resource_tree | |
31 | from synapse.util.logcontext import LoggingContext | |
32 | from synapse.util.manhole import manhole | |
33 | from synapse.util.rlimit import change_resource_limit | |
34 | from synapse.util.versionstring import get_version_string | |
35 | ||
36 | from twisted.internet import reactor, defer | |
37 | from twisted.web.resource import Resource | |
38 | ||
39 | from daemonize import Daemonize | |
40 | ||
41 | import sys | |
42 | import logging | |
43 | import gc | |
44 | ||
45 | logger = logging.getLogger("synapse.app.appservice") | |
46 | ||
47 | ||
48 | class AppserviceSlaveStore( | |
49 | DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore, | |
50 | SlavedRegistrationStore, | |
51 | ): | |
52 | pass | |
53 | ||
54 | ||
55 | class AppserviceServer(HomeServer): | |
56 | def get_db_conn(self, run_new_connection=True): | |
57 | # Any param beginning with cp_ is a parameter for adbapi, and should | |
58 | # not be passed to the database engine. | |
59 | db_params = { | |
60 | k: v for k, v in self.db_config.get("args", {}).items() | |
61 | if not k.startswith("cp_") | |
62 | } | |
63 | db_conn = self.database_engine.module.connect(**db_params) | |
64 | ||
65 | if run_new_connection: | |
66 | self.database_engine.on_new_connection(db_conn) | |
67 | return db_conn | |
68 | ||
69 | def setup(self): | |
70 | logger.info("Setting up.") | |
71 | self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) | |
72 | logger.info("Finished setting up.") | |
73 | ||
74 | def _listen_http(self, listener_config): | |
75 | port = listener_config["port"] | |
76 | bind_address = listener_config.get("bind_address", "") | |
77 | site_tag = listener_config.get("tag", port) | |
78 | resources = {} | |
79 | for res in listener_config["resources"]: | |
80 | for name in res["names"]: | |
81 | if name == "metrics": | |
82 | resources[METRICS_PREFIX] = MetricsResource(self) | |
83 | ||
84 | root_resource = create_resource_tree(resources, Resource()) | |
85 | reactor.listenTCP( | |
86 | port, | |
87 | SynapseSite( | |
88 | "synapse.access.http.%s" % (site_tag,), | |
89 | site_tag, | |
90 | listener_config, | |
91 | root_resource, | |
92 | ), | |
93 | interface=bind_address | |
94 | ) | |
95 | logger.info("Synapse appservice now listening on port %d", port) | |
96 | ||
97 | def start_listening(self, listeners): | |
98 | for listener in listeners: | |
99 | if listener["type"] == "http": | |
100 | self._listen_http(listener) | |
101 | elif listener["type"] == "manhole": | |
102 | reactor.listenTCP( | |
103 | listener["port"], | |
104 | manhole( | |
105 | username="matrix", | |
106 | password="rabbithole", | |
107 | globals={"hs": self}, | |
108 | ), | |
109 | interface=listener.get("bind_address", '127.0.0.1') | |
110 | ) | |
111 | else: | |
112 | logger.warn("Unrecognized listener type: %s", listener["type"]) | |
113 | ||
114 | @defer.inlineCallbacks | |
115 | def replicate(self): | |
116 | http_client = self.get_simple_http_client() | |
117 | store = self.get_datastore() | |
118 | replication_url = self.config.worker_replication_url | |
119 | appservice_handler = self.get_application_service_handler() | |
120 | ||
121 | @defer.inlineCallbacks | |
122 | def replicate(results): | |
123 | stream = results.get("events") | |
124 | if stream: | |
125 | max_stream_id = stream["position"] | |
126 | yield appservice_handler.notify_interested_services(max_stream_id) | |
127 | ||
128 | while True: | |
129 | try: | |
130 | args = store.stream_positions() | |
131 | args["timeout"] = 30000 | |
132 | result = yield http_client.get_json(replication_url, args=args) | |
133 | yield store.process_replication(result) | |
134 | replicate(result) | |
135 | except: | |
136 | logger.exception("Error replicating from %r", replication_url) | |
137 | yield sleep(30) | |
138 | ||
139 | ||
140 | def start(config_options): | |
141 | try: | |
142 | config = HomeServerConfig.load_config( | |
143 | "Synapse appservice", config_options | |
144 | ) | |
145 | except ConfigError as e: | |
146 | sys.stderr.write("\n" + e.message + "\n") | |
147 | sys.exit(1) | |
148 | ||
149 | assert config.worker_app == "synapse.app.appservice" | |
150 | ||
151 | setup_logging(config.worker_log_config, config.worker_log_file) | |
152 | ||
153 | database_engine = create_engine(config.database_config) | |
154 | ||
155 | if config.notify_appservices: | |
156 | sys.stderr.write( | |
157 | "\nThe appservices must be disabled in the main synapse process" | |
158 | "\nbefore they can be run in a separate worker." | |
159 | "\nPlease add ``notify_appservices: false`` to the main config" | |
160 | "\n" | |
161 | ) | |
162 | sys.exit(1) | |
163 | ||
164 | # Force the pushers to start since they will be disabled in the main config | |
165 | config.notify_appservices = True | |
166 | ||
167 | ps = AppserviceServer( | |
168 | config.server_name, | |
169 | db_config=config.database_config, | |
170 | config=config, | |
171 | version_string="Synapse/" + get_version_string(synapse), | |
172 | database_engine=database_engine, | |
173 | ) | |
174 | ||
175 | ps.setup() | |
176 | ps.start_listening(config.worker_listeners) | |
177 | ||
178 | def run(): | |
179 | with LoggingContext("run"): | |
180 | logger.info("Running") | |
181 | change_resource_limit(config.soft_file_limit) | |
182 | if config.gc_thresholds: | |
183 | gc.set_threshold(*config.gc_thresholds) | |
184 | reactor.run() | |
185 | ||
186 | def start(): | |
187 | ps.replicate() | |
188 | ps.get_datastore().start_profiling() | |
189 | ||
190 | reactor.callWhenRunning(start) | |
191 | ||
192 | if config.worker_daemonize: | |
193 | daemon = Daemonize( | |
194 | app="synapse-appservice", | |
195 | pid=config.worker_pid_file, | |
196 | action=run, | |
197 | auto_close_fds=False, | |
198 | verbose=True, | |
199 | logger=logger, | |
200 | ) | |
201 | daemon.start() | |
202 | else: | |
203 | run() | |
204 | ||
205 | ||
206 | if __name__ == '__main__': | |
207 | with LoggingContext("main"): | |
208 | start(sys.argv[1:]) |
50 | 50 | from synapse.config.homeserver import HomeServerConfig |
51 | 51 | from synapse.crypto import context_factory |
52 | 52 | from synapse.util.logcontext import LoggingContext |
53 | from synapse.metrics import register_memory_metrics | |
53 | from synapse.metrics import register_memory_metrics, get_metrics_for | |
54 | 54 | from synapse.metrics.resource import MetricsResource, METRICS_PREFIX |
55 | 55 | from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX |
56 | 56 | from synapse.federation.transport.server import TransportLayerServer |
384 | 384 | |
385 | 385 | start_time = hs.get_clock().time() |
386 | 386 | |
387 | stats = {} | |
388 | ||
387 | 389 | @defer.inlineCallbacks |
388 | 390 | def phone_stats_home(): |
389 | 391 | logger.info("Gathering stats for reporting") |
392 | 394 | if uptime < 0: |
393 | 395 | uptime = 0 |
394 | 396 | |
395 | stats = {} | |
397 | # If the stats directory is empty then this is the first time we've | |
398 | # reported stats. | |
399 | first_time = not stats | |
400 | ||
396 | 401 | stats["homeserver"] = hs.config.server_name |
397 | 402 | stats["timestamp"] = now |
398 | 403 | stats["uptime_seconds"] = uptime |
405 | 410 | daily_messages = yield hs.get_datastore().count_daily_messages() |
406 | 411 | if daily_messages is not None: |
407 | 412 | stats["daily_messages"] = daily_messages |
413 | else: | |
414 | stats.pop("daily_messages", None) | |
415 | ||
416 | if first_time: | |
417 | # Add callbacks to report the synapse stats as metrics whenever | |
418 | # prometheus requests them, typically every 30s. | |
419 | # As some of the stats are expensive to calculate we only update | |
420 | # them when synapse phones home to matrix.org every 24 hours. | |
421 | metrics = get_metrics_for("synapse.usage") | |
422 | metrics.add_callback("timestamp", lambda: stats["timestamp"]) | |
423 | metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"]) | |
424 | metrics.add_callback("total_users", lambda: stats["total_users"]) | |
425 | metrics.add_callback("total_room_count", lambda: stats["total_room_count"]) | |
426 | metrics.add_callback( | |
427 | "daily_active_users", lambda: stats["daily_active_users"] | |
428 | ) | |
429 | metrics.add_callback( | |
430 | "daily_messages", lambda: stats.get("daily_messages", 0) | |
431 | ) | |
408 | 432 | |
409 | 433 | logger.info("Reporting stats to matrix.org: %s" % (stats,)) |
410 | 434 | try: |
0 | #!/usr/bin/env python | |
1 | # -*- coding: utf-8 -*- | |
2 | # Copyright 2016 OpenMarket Ltd | |
3 | # | |
4 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
5 | # you may not use this file except in compliance with the License. | |
6 | # You may obtain a copy of the License at | |
7 | # | |
8 | # http://www.apache.org/licenses/LICENSE-2.0 | |
9 | # | |
10 | # Unless required by applicable law or agreed to in writing, software | |
11 | # distributed under the License is distributed on an "AS IS" BASIS, | |
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
13 | # See the License for the specific language governing permissions and | |
14 | # limitations under the License. | |
15 | ||
16 | import synapse | |
17 | ||
18 | from synapse.config._base import ConfigError | |
19 | from synapse.config.homeserver import HomeServerConfig | |
20 | from synapse.config.logger import setup_logging | |
21 | from synapse.http.site import SynapseSite | |
22 | from synapse.metrics.resource import MetricsResource, METRICS_PREFIX | |
23 | from synapse.replication.slave.storage._base import BaseSlavedStore | |
24 | from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore | |
25 | from synapse.replication.slave.storage.registration import SlavedRegistrationStore | |
26 | from synapse.rest.media.v0.content_repository import ContentRepoResource | |
27 | from synapse.rest.media.v1.media_repository import MediaRepositoryResource | |
28 | from synapse.server import HomeServer | |
29 | from synapse.storage.client_ips import ClientIpStore | |
30 | from synapse.storage.engines import create_engine | |
31 | from synapse.storage.media_repository import MediaRepositoryStore | |
32 | from synapse.util.async import sleep | |
33 | from synapse.util.httpresourcetree import create_resource_tree | |
34 | from synapse.util.logcontext import LoggingContext | |
35 | from synapse.util.manhole import manhole | |
36 | from synapse.util.rlimit import change_resource_limit | |
37 | from synapse.util.versionstring import get_version_string | |
38 | from synapse.api.urls import ( | |
39 | CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX | |
40 | ) | |
41 | from synapse.crypto import context_factory | |
42 | ||
43 | ||
44 | from twisted.internet import reactor, defer | |
45 | from twisted.web.resource import Resource | |
46 | ||
47 | from daemonize import Daemonize | |
48 | ||
49 | import sys | |
50 | import logging | |
51 | import gc | |
52 | ||
53 | logger = logging.getLogger("synapse.app.media_repository") | |
54 | ||
55 | ||
56 | class MediaRepositorySlavedStore( | |
57 | SlavedApplicationServiceStore, | |
58 | SlavedRegistrationStore, | |
59 | BaseSlavedStore, | |
60 | MediaRepositoryStore, | |
61 | ClientIpStore, | |
62 | ): | |
63 | pass | |
64 | ||
65 | ||
66 | class MediaRepositoryServer(HomeServer): | |
67 | def get_db_conn(self, run_new_connection=True): | |
68 | # Any param beginning with cp_ is a parameter for adbapi, and should | |
69 | # not be passed to the database engine. | |
70 | db_params = { | |
71 | k: v for k, v in self.db_config.get("args", {}).items() | |
72 | if not k.startswith("cp_") | |
73 | } | |
74 | db_conn = self.database_engine.module.connect(**db_params) | |
75 | ||
76 | if run_new_connection: | |
77 | self.database_engine.on_new_connection(db_conn) | |
78 | return db_conn | |
79 | ||
80 | def setup(self): | |
81 | logger.info("Setting up.") | |
82 | self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) | |
83 | logger.info("Finished setting up.") | |
84 | ||
85 | def _listen_http(self, listener_config): | |
86 | port = listener_config["port"] | |
87 | bind_address = listener_config.get("bind_address", "") | |
88 | site_tag = listener_config.get("tag", port) | |
89 | resources = {} | |
90 | for res in listener_config["resources"]: | |
91 | for name in res["names"]: | |
92 | if name == "metrics": | |
93 | resources[METRICS_PREFIX] = MetricsResource(self) | |
94 | elif name == "media": | |
95 | media_repo = MediaRepositoryResource(self) | |
96 | resources.update({ | |
97 | MEDIA_PREFIX: media_repo, | |
98 | LEGACY_MEDIA_PREFIX: media_repo, | |
99 | CONTENT_REPO_PREFIX: ContentRepoResource( | |
100 | self, self.config.uploads_path | |
101 | ), | |
102 | }) | |
103 | ||
104 | root_resource = create_resource_tree(resources, Resource()) | |
105 | reactor.listenTCP( | |
106 | port, | |
107 | SynapseSite( | |
108 | "synapse.access.http.%s" % (site_tag,), | |
109 | site_tag, | |
110 | listener_config, | |
111 | root_resource, | |
112 | ), | |
113 | interface=bind_address | |
114 | ) | |
115 | logger.info("Synapse media repository now listening on port %d", port) | |
116 | ||
117 | def start_listening(self, listeners): | |
118 | for listener in listeners: | |
119 | if listener["type"] == "http": | |
120 | self._listen_http(listener) | |
121 | elif listener["type"] == "manhole": | |
122 | reactor.listenTCP( | |
123 | listener["port"], | |
124 | manhole( | |
125 | username="matrix", | |
126 | password="rabbithole", | |
127 | globals={"hs": self}, | |
128 | ), | |
129 | interface=listener.get("bind_address", '127.0.0.1') | |
130 | ) | |
131 | else: | |
132 | logger.warn("Unrecognized listener type: %s", listener["type"]) | |
133 | ||
134 | @defer.inlineCallbacks | |
135 | def replicate(self): | |
136 | http_client = self.get_simple_http_client() | |
137 | store = self.get_datastore() | |
138 | replication_url = self.config.worker_replication_url | |
139 | ||
140 | while True: | |
141 | try: | |
142 | args = store.stream_positions() | |
143 | args["timeout"] = 30000 | |
144 | result = yield http_client.get_json(replication_url, args=args) | |
145 | yield store.process_replication(result) | |
146 | except: | |
147 | logger.exception("Error replicating from %r", replication_url) | |
148 | yield sleep(5) | |
149 | ||
150 | ||
151 | def start(config_options): | |
152 | try: | |
153 | config = HomeServerConfig.load_config( | |
154 | "Synapse media repository", config_options | |
155 | ) | |
156 | except ConfigError as e: | |
157 | sys.stderr.write("\n" + e.message + "\n") | |
158 | sys.exit(1) | |
159 | ||
160 | assert config.worker_app == "synapse.app.media_repository" | |
161 | ||
162 | setup_logging(config.worker_log_config, config.worker_log_file) | |
163 | ||
164 | database_engine = create_engine(config.database_config) | |
165 | ||
166 | tls_server_context_factory = context_factory.ServerContextFactory(config) | |
167 | ||
168 | ss = MediaRepositoryServer( | |
169 | config.server_name, | |
170 | db_config=config.database_config, | |
171 | tls_server_context_factory=tls_server_context_factory, | |
172 | config=config, | |
173 | version_string="Synapse/" + get_version_string(synapse), | |
174 | database_engine=database_engine, | |
175 | ) | |
176 | ||
177 | ss.setup() | |
178 | ss.get_handlers() | |
179 | ss.start_listening(config.worker_listeners) | |
180 | ||
181 | def run(): | |
182 | with LoggingContext("run"): | |
183 | logger.info("Running") | |
184 | change_resource_limit(config.soft_file_limit) | |
185 | if config.gc_thresholds: | |
186 | gc.set_threshold(*config.gc_thresholds) | |
187 | reactor.run() | |
188 | ||
189 | def start(): | |
190 | ss.get_datastore().start_profiling() | |
191 | ss.replicate() | |
192 | ||
193 | reactor.callWhenRunning(start) | |
194 | ||
195 | if config.worker_daemonize: | |
196 | daemon = Daemonize( | |
197 | app="synapse-media-repository", | |
198 | pid=config.worker_pid_file, | |
199 | action=run, | |
200 | auto_close_fds=False, | |
201 | verbose=True, | |
202 | logger=logger, | |
203 | ) | |
204 | daemon.start() | |
205 | else: | |
206 | run() | |
207 | ||
208 | ||
209 | if __name__ == '__main__': | |
210 | with LoggingContext("main"): | |
211 | start(sys.argv[1:]) |
79 | 79 | DataStore.get_profile_displayname.__func__ |
80 | 80 | ) |
81 | 81 | |
82 | # XXX: This is a bit broken because we don't persist forgotten rooms | |
83 | # in a way that they can be streamed. This means that we don't have a | |
84 | # way to invalidate the forgotten rooms cache correctly. | |
85 | # For now we expire the cache every 10 minutes. | |
86 | BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000 | |
87 | 82 | who_forgot_in_room = ( |
88 | 83 | RoomMemberStore.__dict__["who_forgot_in_room"] |
89 | 84 | ) |
167 | 162 | store = self.get_datastore() |
168 | 163 | replication_url = self.config.worker_replication_url |
169 | 164 | pusher_pool = self.get_pusherpool() |
170 | clock = self.get_clock() | |
171 | 165 | |
172 | 166 | def stop_pusher(user_id, app_id, pushkey): |
173 | 167 | key = "%s:%s" % (app_id, pushkey) |
219 | 213 | min_stream_id, max_stream_id, affected_room_ids |
220 | 214 | ) |
221 | 215 | |
222 | def expire_broken_caches(): | |
223 | store.who_forgot_in_room.invalidate_all() | |
224 | ||
225 | next_expire_broken_caches_ms = 0 | |
226 | 216 | while True: |
227 | 217 | try: |
228 | 218 | args = store.stream_positions() |
229 | 219 | args["timeout"] = 30000 |
230 | 220 | result = yield http_client.get_json(replication_url, args=args) |
231 | now_ms = clock.time_msec() | |
232 | if now_ms > next_expire_broken_caches_ms: | |
233 | expire_broken_caches() | |
234 | next_expire_broken_caches_ms = ( | |
235 | now_ms + store.BROKEN_CACHE_EXPIRY_MS | |
236 | ) | |
237 | 221 | yield store.process_replication(result) |
238 | 222 | poke_pushers(result) |
239 | 223 | except: |
25 | 25 | from synapse.http.server import JsonResource |
26 | 26 | from synapse.metrics.resource import MetricsResource, METRICS_PREFIX |
27 | 27 | from synapse.rest.client.v2_alpha import sync |
28 | from synapse.rest.client.v1 import events | |
28 | 29 | from synapse.replication.slave.storage._base import BaseSlavedStore |
29 | 30 | from synapse.replication.slave.storage.events import SlavedEventStore |
30 | 31 | from synapse.replication.slave.storage.receipts import SlavedReceiptsStore |
73 | 74 | BaseSlavedStore, |
74 | 75 | ClientIpStore, # After BaseSlavedStore because the constructor is different |
75 | 76 | ): |
76 | # XXX: This is a bit broken because we don't persist forgotten rooms | |
77 | # in a way that they can be streamed. This means that we don't have a | |
78 | # way to invalidate the forgotten rooms cache correctly. | |
79 | # For now we expire the cache every 10 minutes. | |
80 | BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000 | |
81 | 77 | who_forgot_in_room = ( |
82 | 78 | RoomMemberStore.__dict__["who_forgot_in_room"] |
83 | 79 | ) |
88 | 84 | get_presence_list_accepted = PresenceStore.__dict__[ |
89 | 85 | "get_presence_list_accepted" |
90 | 86 | ] |
87 | get_presence_list_observers_accepted = PresenceStore.__dict__[ | |
88 | "get_presence_list_observers_accepted" | |
89 | ] | |
90 | ||
91 | 91 | |
92 | 92 | UPDATE_SYNCING_USERS_MS = 10 * 1000 |
93 | 93 | |
94 | 94 | |
95 | 95 | class SynchrotronPresence(object): |
96 | 96 | def __init__(self, hs): |
97 | self.is_mine_id = hs.is_mine_id | |
97 | 98 | self.http_client = hs.get_simple_http_client() |
98 | 99 | self.store = hs.get_datastore() |
99 | 100 | self.user_to_num_current_syncs = {} |
100 | 101 | self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" |
101 | 102 | self.clock = hs.get_clock() |
103 | self.notifier = hs.get_notifier() | |
102 | 104 | |
103 | 105 | active_presence = self.store.take_presence_startup_info() |
104 | 106 | self.user_to_current_state = { |
118 | 120 | |
119 | 121 | reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) |
120 | 122 | |
121 | def set_state(self, user, state): | |
123 | def set_state(self, user, state, ignore_status_msg=False): | |
122 | 124 | # TODO Hows this supposed to work? |
123 | 125 | pass |
124 | 126 | |
125 | 127 | get_states = PresenceHandler.get_states.__func__ |
128 | get_state = PresenceHandler.get_state.__func__ | |
129 | _get_interested_parties = PresenceHandler._get_interested_parties.__func__ | |
126 | 130 | current_state_for_users = PresenceHandler.current_state_for_users.__func__ |
127 | 131 | |
128 | 132 | @defer.inlineCallbacks |
193 | 197 | self._need_to_send_sync = False |
194 | 198 | yield self._send_syncing_users_now() |
195 | 199 | |
200 | @defer.inlineCallbacks | |
201 | def notify_from_replication(self, states, stream_id): | |
202 | parties = yield self._get_interested_parties( | |
203 | states, calculate_remote_hosts=False | |
204 | ) | |
205 | room_ids_to_states, users_to_states, _ = parties | |
206 | ||
207 | self.notifier.on_new_event( | |
208 | "presence_key", stream_id, rooms=room_ids_to_states.keys(), | |
209 | users=users_to_states.keys() | |
210 | ) | |
211 | ||
212 | @defer.inlineCallbacks | |
196 | 213 | def process_replication(self, result): |
197 | 214 | stream = result.get("presence", {"rows": []}) |
215 | states = [] | |
198 | 216 | for row in stream["rows"]: |
199 | 217 | ( |
200 | 218 | position, user_id, state, last_active_ts, |
201 | 219 | last_federation_update_ts, last_user_sync_ts, status_msg, |
202 | 220 | currently_active |
203 | 221 | ) = row |
204 | self.user_to_current_state[user_id] = UserPresenceState( | |
222 | state = UserPresenceState( | |
205 | 223 | user_id, state, last_active_ts, |
206 | 224 | last_federation_update_ts, last_user_sync_ts, status_msg, |
207 | 225 | currently_active |
208 | 226 | ) |
227 | self.user_to_current_state[user_id] = state | |
228 | states.append(state) | |
229 | ||
230 | if states and "position" in stream: | |
231 | stream_id = int(stream["position"]) | |
232 | yield self.notify_from_replication(states, stream_id) | |
209 | 233 | |
210 | 234 | |
211 | 235 | class SynchrotronTyping(object): |
265 | 289 | elif name == "client": |
266 | 290 | resource = JsonResource(self, canonical_json=False) |
267 | 291 | sync.register_servlets(self, resource) |
292 | events.register_servlets(self, resource) | |
268 | 293 | resources.update({ |
269 | 294 | "/_matrix/client/r0": resource, |
270 | 295 | "/_matrix/client/unstable": resource, |
271 | 296 | "/_matrix/client/v2_alpha": resource, |
297 | "/_matrix/client/api/v1": resource, | |
272 | 298 | }) |
273 | 299 | |
274 | 300 | root_resource = create_resource_tree(resources, Resource()) |
306 | 332 | http_client = self.get_simple_http_client() |
307 | 333 | store = self.get_datastore() |
308 | 334 | replication_url = self.config.worker_replication_url |
309 | clock = self.get_clock() | |
310 | 335 | notifier = self.get_notifier() |
311 | 336 | presence_handler = self.get_presence_handler() |
312 | 337 | typing_handler = self.get_typing_handler() |
313 | ||
314 | def expire_broken_caches(): | |
315 | store.who_forgot_in_room.invalidate_all() | |
316 | store.get_presence_list_accepted.invalidate_all() | |
317 | 338 | |
318 | 339 | def notify_from_stream( |
319 | 340 | result, stream_name, stream_key, room=None, user=None |
376 | 397 | result, "typing", "typing_key", room="room_id" |
377 | 398 | ) |
378 | 399 | |
379 | next_expire_broken_caches_ms = 0 | |
380 | 400 | while True: |
381 | 401 | try: |
382 | 402 | args = store.stream_positions() |
383 | 403 | args.update(typing_handler.stream_positions()) |
384 | 404 | args["timeout"] = 30000 |
385 | 405 | result = yield http_client.get_json(replication_url, args=args) |
386 | now_ms = clock.time_msec() | |
387 | if now_ms > next_expire_broken_caches_ms: | |
388 | expire_broken_caches() | |
389 | next_expire_broken_caches_ms = ( | |
390 | now_ms + store.BROKEN_CACHE_EXPIRY_MS | |
391 | ) | |
392 | 406 | yield store.process_replication(result) |
393 | 407 | typing_handler.process_replication(result) |
394 | presence_handler.process_replication(result) | |
408 | yield presence_handler.process_replication(result) | |
395 | 409 | notify(result) |
396 | 410 | except: |
397 | 411 | logger.exception("Error replicating from %r", replication_url) |
13 | 13 | # limitations under the License. |
14 | 14 | from synapse.api.constants import EventTypes |
15 | 15 | |
16 | from twisted.internet import defer | |
17 | ||
16 | 18 | import logging |
17 | 19 | import re |
18 | 20 | |
78 | 80 | NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] |
79 | 81 | |
80 | 82 | def __init__(self, token, url=None, namespaces=None, hs_token=None, |
81 | sender=None, id=None): | |
83 | sender=None, id=None, protocols=None): | |
82 | 84 | self.token = token |
83 | 85 | self.url = url |
84 | 86 | self.hs_token = hs_token |
85 | 87 | self.sender = sender |
86 | 88 | self.namespaces = self._check_namespaces(namespaces) |
87 | 89 | self.id = id |
90 | if protocols: | |
91 | self.protocols = set(protocols) | |
92 | else: | |
93 | self.protocols = set() | |
88 | 94 | |
89 | 95 | def _check_namespaces(self, namespaces): |
90 | 96 | # Sanity check that it is of the form: |
137 | 143 | return regex_obj["exclusive"] |
138 | 144 | return False |
139 | 145 | |
140 | def _matches_user(self, event, member_list): | |
141 | if (hasattr(event, "sender") and | |
142 | self.is_interested_in_user(event.sender)): | |
143 | return True | |
146 | @defer.inlineCallbacks | |
147 | def _matches_user(self, event, store): | |
148 | if not event: | |
149 | defer.returnValue(False) | |
150 | ||
151 | if self.is_interested_in_user(event.sender): | |
152 | defer.returnValue(True) | |
144 | 153 | # also check m.room.member state key |
145 | if (hasattr(event, "type") and event.type == EventTypes.Member | |
146 | and hasattr(event, "state_key") | |
147 | and self.is_interested_in_user(event.state_key)): | |
148 | return True | |
154 | if (event.type == EventTypes.Member and | |
155 | self.is_interested_in_user(event.state_key)): | |
156 | defer.returnValue(True) | |
157 | ||
158 | if not store: | |
159 | defer.returnValue(False) | |
160 | ||
161 | member_list = yield store.get_users_in_room(event.room_id) | |
162 | ||
149 | 163 | # check joined member events |
150 | 164 | for user_id in member_list: |
151 | 165 | if self.is_interested_in_user(user_id): |
152 | return True | |
153 | return False | |
166 | defer.returnValue(True) | |
167 | defer.returnValue(False) | |
154 | 168 | |
155 | 169 | def _matches_room_id(self, event): |
156 | 170 | if hasattr(event, "room_id"): |
157 | 171 | return self.is_interested_in_room(event.room_id) |
158 | 172 | return False |
159 | 173 | |
160 | def _matches_aliases(self, event, alias_list): | |
174 | @defer.inlineCallbacks | |
175 | def _matches_aliases(self, event, store): | |
176 | if not store or not event: | |
177 | defer.returnValue(False) | |
178 | ||
179 | alias_list = yield store.get_aliases_for_room(event.room_id) | |
161 | 180 | for alias in alias_list: |
162 | 181 | if self.is_interested_in_alias(alias): |
163 | return True | |
164 | return False | |
165 | ||
166 | def is_interested(self, event, restrict_to=None, aliases_for_event=None, | |
167 | member_list=None): | |
182 | defer.returnValue(True) | |
183 | defer.returnValue(False) | |
184 | ||
185 | @defer.inlineCallbacks | |
186 | def is_interested(self, event, store=None): | |
168 | 187 | """Check if this service is interested in this event. |
169 | 188 | |
170 | 189 | Args: |
171 | 190 | event(Event): The event to check. |
172 | restrict_to(str): The namespace to restrict regex tests to. | |
173 | aliases_for_event(list): A list of all the known room aliases for | |
174 | this event. | |
175 | member_list(list): A list of all joined user_ids in this room. | |
191 | store(DataStore) | |
176 | 192 | Returns: |
177 | 193 | bool: True if this service would like to know about this event. |
178 | 194 | """ |
179 | if aliases_for_event is None: | |
180 | aliases_for_event = [] | |
181 | if member_list is None: | |
182 | member_list = [] | |
183 | ||
184 | if restrict_to and restrict_to not in ApplicationService.NS_LIST: | |
185 | # this is a programming error, so fail early and raise a general | |
186 | # exception | |
187 | raise Exception("Unexpected restrict_to value: %s". restrict_to) | |
188 | ||
189 | if not restrict_to: | |
190 | return (self._matches_user(event, member_list) | |
191 | or self._matches_aliases(event, aliases_for_event) | |
192 | or self._matches_room_id(event)) | |
193 | elif restrict_to == ApplicationService.NS_ALIASES: | |
194 | return self._matches_aliases(event, aliases_for_event) | |
195 | elif restrict_to == ApplicationService.NS_ROOMS: | |
196 | return self._matches_room_id(event) | |
197 | elif restrict_to == ApplicationService.NS_USERS: | |
198 | return self._matches_user(event, member_list) | |
195 | # Do cheap checks first | |
196 | if self._matches_room_id(event): | |
197 | defer.returnValue(True) | |
198 | ||
199 | if (yield self._matches_aliases(event, store)): | |
200 | defer.returnValue(True) | |
201 | ||
202 | if (yield self._matches_user(event, store)): | |
203 | defer.returnValue(True) | |
204 | ||
205 | defer.returnValue(False) | |
199 | 206 | |
200 | 207 | def is_interested_in_user(self, user_id): |
201 | 208 | return ( |
215 | 222 | or user_id == self.sender |
216 | 223 | ) |
217 | 224 | |
225 | def is_interested_in_protocol(self, protocol): | |
226 | return protocol in self.protocols | |
227 | ||
218 | 228 | def is_exclusive_alias(self, alias): |
219 | 229 | return self._is_exclusive(ApplicationService.NS_ALIASES, alias) |
220 | 230 |
16 | 16 | from synapse.api.errors import CodeMessageException |
17 | 17 | from synapse.http.client import SimpleHttpClient |
18 | 18 | from synapse.events.utils import serialize_event |
19 | from synapse.types import ThirdPartyEntityKind | |
19 | 20 | |
20 | 21 | import logging |
21 | 22 | import urllib |
22 | 23 | |
23 | 24 | logger = logging.getLogger(__name__) |
25 | ||
26 | ||
27 | def _is_valid_3pe_result(r, field): | |
28 | if not isinstance(r, dict): | |
29 | return False | |
30 | ||
31 | for k in (field, "protocol"): | |
32 | if k not in r: | |
33 | return False | |
34 | if not isinstance(r[k], str): | |
35 | return False | |
36 | ||
37 | if "fields" not in r: | |
38 | return False | |
39 | fields = r["fields"] | |
40 | if not isinstance(fields, dict): | |
41 | return False | |
42 | for k in fields.keys(): | |
43 | if not isinstance(fields[k], str): | |
44 | return False | |
45 | ||
46 | return True | |
24 | 47 | |
25 | 48 | |
26 | 49 | class ApplicationServiceApi(SimpleHttpClient): |
71 | 94 | defer.returnValue(False) |
72 | 95 | |
73 | 96 | @defer.inlineCallbacks |
97 | def query_3pe(self, service, kind, protocol, fields): | |
98 | if kind == ThirdPartyEntityKind.USER: | |
99 | uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) | |
100 | required_field = "userid" | |
101 | elif kind == ThirdPartyEntityKind.LOCATION: | |
102 | uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) | |
103 | required_field = "alias" | |
104 | else: | |
105 | raise ValueError( | |
106 | "Unrecognised 'kind' argument %r to query_3pe()", kind | |
107 | ) | |
108 | ||
109 | try: | |
110 | response = yield self.get_json(uri, fields) | |
111 | if not isinstance(response, list): | |
112 | logger.warning( | |
113 | "query_3pe to %s returned an invalid response %r", | |
114 | uri, response | |
115 | ) | |
116 | defer.returnValue([]) | |
117 | ||
118 | ret = [] | |
119 | for r in response: | |
120 | if _is_valid_3pe_result(r, field=required_field): | |
121 | ret.append(r) | |
122 | else: | |
123 | logger.warning( | |
124 | "query_3pe to %s returned an invalid result %r", | |
125 | uri, r | |
126 | ) | |
127 | ||
128 | defer.returnValue(ret) | |
129 | except Exception as ex: | |
130 | logger.warning("query_3pe to %s threw exception %s", uri, ex) | |
131 | defer.returnValue([]) | |
132 | ||
133 | @defer.inlineCallbacks | |
74 | 134 | def push_bulk(self, service, events, txn_id=None): |
75 | 135 | events = self._serialize(events) |
76 | 136 |
47 | 47 | This is all tied together by the AppServiceScheduler which DIs the required |
48 | 48 | components. |
49 | 49 | """ |
50 | from twisted.internet import defer | |
50 | 51 | |
51 | 52 | from synapse.appservice import ApplicationServiceState |
52 | from twisted.internet import defer | |
53 | from synapse.util.logcontext import preserve_fn | |
54 | from synapse.util.metrics import Measure | |
55 | ||
53 | 56 | import logging |
54 | 57 | |
55 | 58 | logger = logging.getLogger(__name__) |
72 | 75 | self.txn_ctrl = _TransactionController( |
73 | 76 | self.clock, self.store, self.as_api, create_recoverer |
74 | 77 | ) |
75 | self.queuer = _ServiceQueuer(self.txn_ctrl) | |
78 | self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) | |
76 | 79 | |
77 | 80 | @defer.inlineCallbacks |
78 | 81 | def start(self): |
93 | 96 | this schedules any other events in the queue to run. |
94 | 97 | """ |
95 | 98 | |
96 | def __init__(self, txn_ctrl): | |
99 | def __init__(self, txn_ctrl, clock): | |
97 | 100 | self.queued_events = {} # dict of {service_id: [events]} |
98 | self.pending_requests = {} # dict of {service_id: Deferred} | |
101 | self.requests_in_flight = set() | |
99 | 102 | self.txn_ctrl = txn_ctrl |
103 | self.clock = clock | |
100 | 104 | |
101 | 105 | def enqueue(self, service, event): |
102 | 106 | # if this service isn't being sent something |
103 | if not self.pending_requests.get(service.id): | |
104 | self._send_request(service, [event]) | |
105 | else: | |
106 | # add to queue for this service | |
107 | if service.id not in self.queued_events: | |
108 | self.queued_events[service.id] = [] | |
109 | self.queued_events[service.id].append(event) | |
110 | ||
111 | def _send_request(self, service, events): | |
112 | # send request and add callbacks | |
113 | d = self.txn_ctrl.send(service, events) | |
114 | d.addBoth(self._on_request_finish) | |
115 | d.addErrback(self._on_request_fail) | |
116 | self.pending_requests[service.id] = d | |
117 | ||
118 | def _on_request_finish(self, service): | |
119 | self.pending_requests[service.id] = None | |
120 | # if there are queued events, then send them. | |
121 | if (service.id in self.queued_events | |
122 | and len(self.queued_events[service.id]) > 0): | |
123 | self._send_request(service, self.queued_events[service.id]) | |
124 | self.queued_events[service.id] = [] | |
125 | ||
126 | def _on_request_fail(self, err): | |
127 | logger.error("AS request failed: %s", err) | |
107 | self.queued_events.setdefault(service.id, []).append(event) | |
108 | preserve_fn(self._send_request)(service) | |
109 | ||
110 | @defer.inlineCallbacks | |
111 | def _send_request(self, service): | |
112 | if service.id in self.requests_in_flight: | |
113 | return | |
114 | ||
115 | self.requests_in_flight.add(service.id) | |
116 | try: | |
117 | while True: | |
118 | events = self.queued_events.pop(service.id, []) | |
119 | if not events: | |
120 | return | |
121 | ||
122 | with Measure(self.clock, "servicequeuer.send"): | |
123 | try: | |
124 | yield self.txn_ctrl.send(service, events) | |
125 | except: | |
126 | logger.exception("AS request failed") | |
127 | finally: | |
128 | self.requests_in_flight.discard(service.id) | |
128 | 129 | |
129 | 130 | |
130 | 131 | class _TransactionController(object): |
148 | 149 | if service_is_up: |
149 | 150 | sent = yield txn.send(self.as_api) |
150 | 151 | if sent: |
151 | txn.complete(self.store) | |
152 | yield txn.complete(self.store) | |
152 | 153 | else: |
153 | self._start_recoverer(service) | |
154 | preserve_fn(self._start_recoverer)(service) | |
154 | 155 | except Exception as e: |
155 | 156 | logger.exception(e) |
156 | self._start_recoverer(service) | |
157 | # request has finished | |
158 | defer.returnValue(service) | |
157 | preserve_fn(self._start_recoverer)(service) | |
159 | 158 | |
160 | 159 | @defer.inlineCallbacks |
161 | 160 | def on_recovered(self, recoverer): |
27 | 27 | |
28 | 28 | def read_config(self, config): |
29 | 29 | self.app_service_config_files = config.get("app_service_config_files", []) |
30 | self.notify_appservices = config.get("notify_appservices", True) | |
30 | 31 | |
31 | 32 | def default_config(cls, **kwargs): |
32 | 33 | return """\ |
121 | 122 | raise ValueError( |
122 | 123 | "Missing/bad type 'exclusive' key in %s", regex_obj |
123 | 124 | ) |
125 | # protocols check | |
126 | protocols = as_info.get("protocols") | |
127 | if protocols: | |
128 | # Because strings are lists in python | |
129 | if isinstance(protocols, str) or not isinstance(protocols, list): | |
130 | raise KeyError("Optional 'protocols' must be a list if present.") | |
131 | for p in protocols: | |
132 | if not isinstance(p, str): | |
133 | raise KeyError("Bad value for 'protocols' item") | |
124 | 134 | return ApplicationService( |
125 | 135 | token=as_info["as_token"], |
126 | 136 | url=as_info["url"], |
128 | 138 | hs_token=as_info["hs_token"], |
129 | 139 | sender=user_id, |
130 | 140 | id=as_info["id"], |
141 | protocols=protocols, | |
131 | 142 | ) |
21 | 21 | preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, |
22 | 22 | preserve_fn |
23 | 23 | ) |
24 | from synapse.util.metrics import Measure | |
24 | 25 | |
25 | 26 | from twisted.internet import defer |
26 | 27 | |
58 | 59 | A deferred (server_name, key_id, verify_key) tuple that resolves when |
59 | 60 | a verify key has been fetched |
60 | 61 | """ |
62 | ||
63 | ||
64 | class KeyLookupError(ValueError): | |
65 | pass | |
61 | 66 | |
62 | 67 | |
63 | 68 | class Keyring(object): |
238 | 243 | |
239 | 244 | @defer.inlineCallbacks |
240 | 245 | def do_iterations(): |
241 | merged_results = {} | |
242 | ||
243 | missing_keys = {} | |
244 | for verify_request in verify_requests: | |
245 | missing_keys.setdefault(verify_request.server_name, set()).update( | |
246 | verify_request.key_ids | |
247 | ) | |
248 | ||
249 | for fn in key_fetch_fns: | |
250 | results = yield fn(missing_keys.items()) | |
251 | merged_results.update(results) | |
252 | ||
253 | # We now need to figure out which verify requests we have keys | |
254 | # for and which we don't | |
246 | with Measure(self.clock, "get_server_verify_keys"): | |
247 | merged_results = {} | |
248 | ||
255 | 249 | missing_keys = {} |
256 | requests_missing_keys = [] | |
257 | 250 | for verify_request in verify_requests: |
258 | server_name = verify_request.server_name | |
259 | result_keys = merged_results[server_name] | |
260 | ||
261 | if verify_request.deferred.called: | |
262 | # We've already called this deferred, which probably | |
263 | # means that we've already found a key for it. | |
264 | continue | |
265 | ||
266 | for key_id in verify_request.key_ids: | |
267 | if key_id in result_keys: | |
268 | with PreserveLoggingContext(): | |
269 | verify_request.deferred.callback(( | |
270 | server_name, | |
271 | key_id, | |
272 | result_keys[key_id], | |
273 | )) | |
274 | break | |
275 | else: | |
276 | # The else block is only reached if the loop above | |
277 | # doesn't break. | |
278 | missing_keys.setdefault(server_name, set()).update( | |
279 | verify_request.key_ids | |
280 | ) | |
281 | requests_missing_keys.append(verify_request) | |
282 | ||
283 | if not missing_keys: | |
284 | break | |
285 | ||
286 | for verify_request in requests_missing_keys.values(): | |
287 | verify_request.deferred.errback(SynapseError( | |
288 | 401, | |
289 | "No key for %s with id %s" % ( | |
290 | verify_request.server_name, verify_request.key_ids, | |
291 | ), | |
292 | Codes.UNAUTHORIZED, | |
293 | )) | |
251 | missing_keys.setdefault(verify_request.server_name, set()).update( | |
252 | verify_request.key_ids | |
253 | ) | |
254 | ||
255 | for fn in key_fetch_fns: | |
256 | results = yield fn(missing_keys.items()) | |
257 | merged_results.update(results) | |
258 | ||
259 | # We now need to figure out which verify requests we have keys | |
260 | # for and which we don't | |
261 | missing_keys = {} | |
262 | requests_missing_keys = [] | |
263 | for verify_request in verify_requests: | |
264 | server_name = verify_request.server_name | |
265 | result_keys = merged_results[server_name] | |
266 | ||
267 | if verify_request.deferred.called: | |
268 | # We've already called this deferred, which probably | |
269 | # means that we've already found a key for it. | |
270 | continue | |
271 | ||
272 | for key_id in verify_request.key_ids: | |
273 | if key_id in result_keys: | |
274 | with PreserveLoggingContext(): | |
275 | verify_request.deferred.callback(( | |
276 | server_name, | |
277 | key_id, | |
278 | result_keys[key_id], | |
279 | )) | |
280 | break | |
281 | else: | |
282 | # The else block is only reached if the loop above | |
283 | # doesn't break. | |
284 | missing_keys.setdefault(server_name, set()).update( | |
285 | verify_request.key_ids | |
286 | ) | |
287 | requests_missing_keys.append(verify_request) | |
288 | ||
289 | if not missing_keys: | |
290 | break | |
291 | ||
292 | for verify_request in requests_missing_keys.values(): | |
293 | verify_request.deferred.errback(SynapseError( | |
294 | 401, | |
295 | "No key for %s with id %s" % ( | |
296 | verify_request.server_name, verify_request.key_ids, | |
297 | ), | |
298 | Codes.UNAUTHORIZED, | |
299 | )) | |
294 | 300 | |
295 | 301 | def on_err(err): |
296 | 302 | for verify_request in verify_requests: |
301 | 307 | |
302 | 308 | @defer.inlineCallbacks |
303 | 309 | def get_keys_from_store(self, server_name_and_key_ids): |
304 | res = yield defer.gatherResults( | |
310 | res = yield preserve_context_over_deferred(defer.gatherResults( | |
305 | 311 | [ |
306 | self.store.get_server_verify_keys( | |
312 | preserve_fn(self.store.get_server_verify_keys)( | |
307 | 313 | server_name, key_ids |
308 | 314 | ).addCallback(lambda ks, server: (server, ks), server_name) |
309 | 315 | for server_name, key_ids in server_name_and_key_ids |
310 | 316 | ], |
311 | 317 | consumeErrors=True, |
312 | ).addErrback(unwrapFirstError) | |
318 | )).addErrback(unwrapFirstError) | |
313 | 319 | |
314 | 320 | defer.returnValue(dict(res)) |
315 | 321 | |
330 | 336 | ) |
331 | 337 | defer.returnValue({}) |
332 | 338 | |
333 | results = yield defer.gatherResults( | |
339 | results = yield preserve_context_over_deferred(defer.gatherResults( | |
334 | 340 | [ |
335 | get_key(p_name, p_keys) | |
341 | preserve_fn(get_key)(p_name, p_keys) | |
336 | 342 | for p_name, p_keys in self.perspective_servers.items() |
337 | 343 | ], |
338 | 344 | consumeErrors=True, |
339 | ).addErrback(unwrapFirstError) | |
345 | )).addErrback(unwrapFirstError) | |
340 | 346 | |
341 | 347 | union_of_keys = {} |
342 | 348 | for result in results: |
362 | 368 | ) |
363 | 369 | except Exception as e: |
364 | 370 | logger.info( |
365 | "Unable to getting key %r for %r directly: %s %s", | |
371 | "Unable to get key %r for %r directly: %s %s", | |
366 | 372 | key_ids, server_name, |
367 | 373 | type(e).__name__, str(e.message), |
368 | 374 | ) |
376 | 382 | |
377 | 383 | defer.returnValue(keys) |
378 | 384 | |
379 | results = yield defer.gatherResults( | |
385 | results = yield preserve_context_over_deferred(defer.gatherResults( | |
380 | 386 | [ |
381 | get_key(server_name, key_ids) | |
387 | preserve_fn(get_key)(server_name, key_ids) | |
382 | 388 | for server_name, key_ids in server_name_and_key_ids |
383 | 389 | ], |
384 | 390 | consumeErrors=True, |
385 | ).addErrback(unwrapFirstError) | |
391 | )).addErrback(unwrapFirstError) | |
386 | 392 | |
387 | 393 | merged = {} |
388 | 394 | for result in results: |
424 | 430 | for response in responses: |
425 | 431 | if (u"signatures" not in response |
426 | 432 | or perspective_name not in response[u"signatures"]): |
427 | raise ValueError( | |
433 | raise KeyLookupError( | |
428 | 434 | "Key response not signed by perspective server" |
429 | 435 | " %r" % (perspective_name,) |
430 | 436 | ) |
447 | 453 | list(response[u"signatures"][perspective_name]), |
448 | 454 | list(perspective_keys) |
449 | 455 | ) |
450 | raise ValueError( | |
456 | raise KeyLookupError( | |
451 | 457 | "Response not signed with a known key for perspective" |
452 | 458 | " server %r" % (perspective_name,) |
453 | 459 | ) |
459 | 465 | for server_name, response_keys in processed_response.items(): |
460 | 466 | keys.setdefault(server_name, {}).update(response_keys) |
461 | 467 | |
462 | yield defer.gatherResults( | |
468 | yield preserve_context_over_deferred(defer.gatherResults( | |
463 | 469 | [ |
464 | self.store_keys( | |
470 | preserve_fn(self.store_keys)( | |
465 | 471 | server_name=server_name, |
466 | 472 | from_server=perspective_name, |
467 | 473 | verify_keys=response_keys, |
469 | 475 | for server_name, response_keys in keys.items() |
470 | 476 | ], |
471 | 477 | consumeErrors=True |
472 | ).addErrback(unwrapFirstError) | |
478 | )).addErrback(unwrapFirstError) | |
473 | 479 | |
474 | 480 | defer.returnValue(keys) |
475 | 481 | |
490 | 496 | |
491 | 497 | if (u"signatures" not in response |
492 | 498 | or server_name not in response[u"signatures"]): |
493 | raise ValueError("Key response not signed by remote server") | |
499 | raise KeyLookupError("Key response not signed by remote server") | |
494 | 500 | |
495 | 501 | if "tls_fingerprints" not in response: |
496 | raise ValueError("Key response missing TLS fingerprints") | |
502 | raise KeyLookupError("Key response missing TLS fingerprints") | |
497 | 503 | |
498 | 504 | certificate_bytes = crypto.dump_certificate( |
499 | 505 | crypto.FILETYPE_ASN1, tls_certificate |
507 | 513 | response_sha256_fingerprints.add(fingerprint[u"sha256"]) |
508 | 514 | |
509 | 515 | if sha256_fingerprint_b64 not in response_sha256_fingerprints: |
510 | raise ValueError("TLS certificate not allowed by fingerprints") | |
516 | raise KeyLookupError("TLS certificate not allowed by fingerprints") | |
511 | 517 | |
512 | 518 | response_keys = yield self.process_v2_response( |
513 | 519 | from_server=server_name, |
517 | 523 | |
518 | 524 | keys.update(response_keys) |
519 | 525 | |
520 | yield defer.gatherResults( | |
526 | yield preserve_context_over_deferred(defer.gatherResults( | |
521 | 527 | [ |
522 | 528 | preserve_fn(self.store_keys)( |
523 | 529 | server_name=key_server_name, |
527 | 533 | for key_server_name, verify_keys in keys.items() |
528 | 534 | ], |
529 | 535 | consumeErrors=True |
530 | ).addErrback(unwrapFirstError) | |
536 | )).addErrback(unwrapFirstError) | |
531 | 537 | |
532 | 538 | defer.returnValue(keys) |
533 | 539 | |
559 | 565 | server_name = response_json["server_name"] |
560 | 566 | if only_from_server: |
561 | 567 | if server_name != from_server: |
562 | raise ValueError( | |
568 | raise KeyLookupError( | |
563 | 569 | "Expected a response for server %r not %r" % ( |
564 | 570 | from_server, server_name |
565 | 571 | ) |
566 | 572 | ) |
567 | 573 | for key_id in response_json["signatures"].get(server_name, {}): |
568 | 574 | if key_id not in response_json["verify_keys"]: |
569 | raise ValueError( | |
575 | raise KeyLookupError( | |
570 | 576 | "Key response must include verification keys for all" |
571 | 577 | " signatures" |
572 | 578 | ) |
593 | 599 | response_keys.update(verify_keys) |
594 | 600 | response_keys.update(old_verify_keys) |
595 | 601 | |
596 | yield defer.gatherResults( | |
602 | yield preserve_context_over_deferred(defer.gatherResults( | |
597 | 603 | [ |
598 | 604 | preserve_fn(self.store.store_server_keys_json)( |
599 | 605 | server_name=server_name, |
606 | 612 | for key_id in updated_key_ids |
607 | 613 | ], |
608 | 614 | consumeErrors=True, |
609 | ).addErrback(unwrapFirstError) | |
615 | )).addErrback(unwrapFirstError) | |
610 | 616 | |
611 | 617 | results[server_name] = response_keys |
612 | 618 | |
634 | 640 | |
635 | 641 | if ("signatures" not in response |
636 | 642 | or server_name not in response["signatures"]): |
637 | raise ValueError("Key response not signed by remote server") | |
643 | raise KeyLookupError("Key response not signed by remote server") | |
638 | 644 | |
639 | 645 | if "tls_certificate" not in response: |
640 | raise ValueError("Key response missing TLS certificate") | |
646 | raise KeyLookupError("Key response missing TLS certificate") | |
641 | 647 | |
642 | 648 | tls_certificate_b64 = response["tls_certificate"] |
643 | 649 | |
644 | 650 | if encode_base64(x509_certificate_bytes) != tls_certificate_b64: |
645 | raise ValueError("TLS certificate doesn't match") | |
651 | raise KeyLookupError("TLS certificate doesn't match") | |
646 | 652 | |
647 | 653 | # Cache the result in the datastore. |
648 | 654 | |
658 | 664 | |
659 | 665 | for key_id in response["signatures"][server_name]: |
660 | 666 | if key_id not in response["verify_keys"]: |
661 | raise ValueError( | |
667 | raise KeyLookupError( | |
662 | 668 | "Key response must include verification keys for all" |
663 | 669 | " signatures" |
664 | 670 | ) |
695 | 701 | A deferred that completes when the keys are stored. |
696 | 702 | """ |
697 | 703 | # TODO(markjh): Store whether the keys have expired. |
698 | yield defer.gatherResults( | |
704 | yield preserve_context_over_deferred(defer.gatherResults( | |
699 | 705 | [ |
700 | 706 | preserve_fn(self.store.store_server_verify_key)( |
701 | 707 | server_name, server_name, key.time_added, key |
703 | 709 | for key_id, key in verify_keys.items() |
704 | 710 | ], |
705 | 711 | consumeErrors=True, |
706 | ).addErrback(unwrapFirstError) | |
712 | )).addErrback(unwrapFirstError) |
87 | 87 | |
88 | 88 | if "age_ts" in event.unsigned: |
89 | 89 | allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] |
90 | if "replaces_state" in event.unsigned: | |
91 | allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"] | |
90 | 92 | |
91 | 93 | return type(event)( |
92 | 94 | allowed_fields, |
22 | 22 | from synapse.api.errors import SynapseError |
23 | 23 | |
24 | 24 | from synapse.util import unwrapFirstError |
25 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |
25 | 26 | |
26 | 27 | import logging |
27 | 28 | |
101 | 102 | warn, pdu |
102 | 103 | ) |
103 | 104 | |
104 | valid_pdus = yield defer.gatherResults( | |
105 | valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( | |
105 | 106 | deferreds, |
106 | 107 | consumeErrors=True |
107 | ).addErrback(unwrapFirstError) | |
108 | )).addErrback(unwrapFirstError) | |
108 | 109 | |
109 | 110 | if include_none: |
110 | 111 | defer.returnValue(valid_pdus) |
128 | 129 | for pdu in pdus |
129 | 130 | ] |
130 | 131 | |
131 | deferreds = self.keyring.verify_json_objects_for_server([ | |
132 | deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ | |
132 | 133 | (p.origin, p.get_pdu_json()) |
133 | 134 | for p in redacted_pdus |
134 | 135 | ]) |
26 | 26 | from synapse.util.async import concurrently_execute |
27 | 27 | from synapse.util.caches.expiringcache import ExpiringCache |
28 | 28 | from synapse.util.logutils import log_function |
29 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |
29 | 30 | from synapse.events import FrozenEvent |
30 | 31 | import synapse.metrics |
31 | 32 | |
50 | 51 | sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) |
51 | 52 | |
52 | 53 | |
54 | PDU_RETRY_TIME_MS = 1 * 60 * 1000 | |
55 | ||
56 | ||
53 | 57 | class FederationClient(FederationBase): |
54 | 58 | def __init__(self, hs): |
55 | 59 | super(FederationClient, self).__init__(hs) |
60 | ||
61 | self.pdu_destination_tried = {} | |
62 | self._clock.looping_call( | |
63 | self._clear_tried_cache, 60 * 1000, | |
64 | ) | |
65 | ||
66 | def _clear_tried_cache(self): | |
67 | """Clear pdu_destination_tried cache""" | |
68 | now = self._clock.time_msec() | |
69 | ||
70 | old_dict = self.pdu_destination_tried | |
71 | self.pdu_destination_tried = {} | |
72 | ||
73 | for event_id, destination_dict in old_dict.items(): | |
74 | destination_dict = { | |
75 | dest: time | |
76 | for dest, time in destination_dict.items() | |
77 | if time + PDU_RETRY_TIME_MS > now | |
78 | } | |
79 | if destination_dict: | |
80 | self.pdu_destination_tried[event_id] = destination_dict | |
56 | 81 | |
57 | 82 | def start_get_pdu_cache(self): |
58 | 83 | self._get_pdu_cache = ExpiringCache( |
200 | 225 | ] |
201 | 226 | |
202 | 227 | # FIXME: We should handle signature failures more gracefully. |
203 | pdus[:] = yield defer.gatherResults( | |
228 | pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( | |
204 | 229 | self._check_sigs_and_hashes(pdus), |
205 | 230 | consumeErrors=True, |
206 | ).addErrback(unwrapFirstError) | |
231 | )).addErrback(unwrapFirstError) | |
207 | 232 | |
208 | 233 | defer.returnValue(pdus) |
209 | 234 | |
239 | 264 | if ev: |
240 | 265 | defer.returnValue(ev) |
241 | 266 | |
267 | pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) | |
268 | ||
242 | 269 | pdu = None |
243 | 270 | for destination in destinations: |
271 | now = self._clock.time_msec() | |
272 | last_attempt = pdu_attempts.get(destination, 0) | |
273 | if last_attempt + PDU_RETRY_TIME_MS > now: | |
274 | continue | |
275 | ||
244 | 276 | try: |
245 | 277 | limiter = yield get_retry_limiter( |
246 | 278 | destination, |
268 | 300 | |
269 | 301 | break |
270 | 302 | |
303 | pdu_attempts[destination] = now | |
304 | ||
271 | 305 | except SynapseError as e: |
272 | 306 | logger.info( |
273 | 307 | "Failed to get PDU %s from %s because %s", |
274 | 308 | event_id, destination, e, |
275 | 309 | ) |
276 | continue | |
277 | except CodeMessageException as e: | |
278 | if 400 <= e.code < 500: | |
279 | raise | |
280 | ||
281 | logger.info( | |
282 | "Failed to get PDU %s from %s because %s", | |
283 | event_id, destination, e, | |
284 | ) | |
285 | continue | |
286 | 310 | except NotRetryingDestination as e: |
287 | 311 | logger.info(e.message) |
288 | 312 | continue |
289 | 313 | except Exception as e: |
314 | pdu_attempts[destination] = now | |
315 | ||
290 | 316 | logger.info( |
291 | 317 | "Failed to get PDU %s from %s because %s", |
292 | 318 | event_id, destination, e, |
405 | 431 | events and the second is a list of event ids that we failed to fetch. |
406 | 432 | """ |
407 | 433 | if return_local: |
408 | seen_events = yield self.store.get_events(event_ids) | |
434 | seen_events = yield self.store.get_events(event_ids, allow_rejected=True) | |
409 | 435 | signed_events = seen_events.values() |
410 | 436 | else: |
411 | 437 | seen_events = yield self.store.have_events(event_ids) |
431 | 457 | batch = set(missing_events[i:i + batch_size]) |
432 | 458 | |
433 | 459 | deferreds = [ |
434 | self.get_pdu( | |
460 | preserve_fn(self.get_pdu)( | |
435 | 461 | destinations=random_server_list(), |
436 | 462 | event_id=e_id, |
437 | 463 | ) |
438 | 464 | for e_id in batch |
439 | 465 | ] |
440 | 466 | |
441 | res = yield defer.DeferredList(deferreds, consumeErrors=True) | |
467 | res = yield preserve_context_over_deferred( | |
468 | defer.DeferredList(deferreds, consumeErrors=True) | |
469 | ) | |
442 | 470 | for success, result in res: |
443 | 471 | if success: |
444 | 472 | signed_events.append(result) |
827 | 855 | return srvs |
828 | 856 | |
829 | 857 | deferreds = [ |
830 | self.get_pdu( | |
858 | preserve_fn(self.get_pdu)( | |
831 | 859 | destinations=random_server_list(), |
832 | 860 | event_id=e_id, |
833 | 861 | ) |
834 | 862 | for e_id, depth in ordered_missing[:limit - len(signed_events)] |
835 | 863 | ] |
836 | 864 | |
837 | res = yield defer.DeferredList(deferreds, consumeErrors=True) | |
865 | res = yield preserve_context_over_deferred( | |
866 | defer.DeferredList(deferreds, consumeErrors=True) | |
867 | ) | |
838 | 868 | for (result, val), (e_id, _) in zip(res, ordered_missing): |
839 | 869 | if result and val: |
840 | 870 | signed_events.append(val) |
20 | 20 | |
21 | 21 | from synapse.api.errors import HttpResponseException |
22 | 22 | from synapse.util.async import run_on_reactor |
23 | from synapse.util.logutils import log_function | |
24 | from synapse.util.logcontext import PreserveLoggingContext | |
23 | from synapse.util.logcontext import preserve_context_over_fn | |
25 | 24 | from synapse.util.retryutils import ( |
26 | 25 | get_retry_limiter, NotRetryingDestination, |
27 | 26 | ) |
27 | from synapse.util.metrics import measure_func | |
28 | 28 | import synapse.metrics |
29 | 29 | |
30 | 30 | import logging |
50 | 50 | |
51 | 51 | self.transport_layer = transport_layer |
52 | 52 | |
53 | self._clock = hs.get_clock() | |
53 | self.clock = hs.get_clock() | |
54 | 54 | |
55 | 55 | # Is a mapping from destinations -> deferreds. Used to keep track |
56 | 56 | # of which destinations have transactions in flight and when they are |
81 | 81 | self.pending_failures_by_dest = {} |
82 | 82 | |
83 | 83 | # HACK to get unique tx id |
84 | self._next_txn_id = int(self._clock.time_msec()) | |
84 | self._next_txn_id = int(self.clock.time_msec()) | |
85 | 85 | |
86 | 86 | def can_send_to(self, destination): |
87 | 87 | """Can we send messages to the given server? |
118 | 118 | if not destinations: |
119 | 119 | return |
120 | 120 | |
121 | deferreds = [] | |
122 | ||
123 | 121 | for destination in destinations: |
124 | deferred = defer.Deferred() | |
125 | 122 | self.pending_pdus_by_dest.setdefault(destination, []).append( |
126 | (pdu, deferred, order) | |
123 | (pdu, order) | |
127 | 124 | ) |
128 | 125 | |
129 | def chain(failure): | |
130 | if not deferred.called: | |
131 | deferred.errback(failure) | |
132 | ||
133 | def log_failure(f): | |
134 | logger.warn("Failed to send pdu to %s: %s", destination, f.value) | |
135 | ||
136 | deferred.addErrback(log_failure) | |
137 | ||
138 | with PreserveLoggingContext(): | |
139 | self._attempt_new_transaction(destination).addErrback(chain) | |
140 | ||
141 | deferreds.append(deferred) | |
142 | ||
143 | # NO inlineCallbacks | |
126 | preserve_context_over_fn( | |
127 | self._attempt_new_transaction, destination | |
128 | ) | |
129 | ||
144 | 130 | def enqueue_edu(self, edu): |
145 | 131 | destination = edu.destination |
146 | 132 | |
147 | 133 | if not self.can_send_to(destination): |
148 | 134 | return |
149 | 135 | |
150 | deferred = defer.Deferred() | |
151 | self.pending_edus_by_dest.setdefault(destination, []).append( | |
152 | (edu, deferred) | |
153 | ) | |
154 | ||
155 | def chain(failure): | |
156 | if not deferred.called: | |
157 | deferred.errback(failure) | |
158 | ||
159 | def log_failure(f): | |
160 | logger.warn("Failed to send edu to %s: %s", destination, f.value) | |
161 | ||
162 | deferred.addErrback(log_failure) | |
163 | ||
164 | with PreserveLoggingContext(): | |
165 | self._attempt_new_transaction(destination).addErrback(chain) | |
166 | ||
167 | return deferred | |
168 | ||
169 | @defer.inlineCallbacks | |
136 | self.pending_edus_by_dest.setdefault(destination, []).append(edu) | |
137 | ||
138 | preserve_context_over_fn( | |
139 | self._attempt_new_transaction, destination | |
140 | ) | |
141 | ||
170 | 142 | def enqueue_failure(self, failure, destination): |
171 | 143 | if destination == self.server_name or destination == "localhost": |
172 | 144 | return |
173 | 145 | |
174 | deferred = defer.Deferred() | |
175 | ||
176 | 146 | if not self.can_send_to(destination): |
177 | 147 | return |
178 | 148 | |
179 | 149 | self.pending_failures_by_dest.setdefault( |
180 | 150 | destination, [] |
181 | ).append( | |
182 | (failure, deferred) | |
183 | ) | |
184 | ||
185 | def chain(f): | |
186 | if not deferred.called: | |
187 | deferred.errback(f) | |
188 | ||
189 | def log_failure(f): | |
190 | logger.warn("Failed to send failure to %s: %s", destination, f.value) | |
191 | ||
192 | deferred.addErrback(log_failure) | |
193 | ||
194 | with PreserveLoggingContext(): | |
195 | self._attempt_new_transaction(destination).addErrback(chain) | |
196 | ||
197 | yield deferred | |
151 | ).append(failure) | |
152 | ||
153 | preserve_context_over_fn( | |
154 | self._attempt_new_transaction, destination | |
155 | ) | |
198 | 156 | |
199 | 157 | @defer.inlineCallbacks |
200 | @log_function | |
201 | 158 | def _attempt_new_transaction(self, destination): |
202 | 159 | yield run_on_reactor() |
203 | ||
204 | # list of (pending_pdu, deferred, order) | |
205 | if destination in self.pending_transactions: | |
206 | # XXX: pending_transactions can get stuck on by a never-ending | |
207 | # request at which point pending_pdus_by_dest just keeps growing. | |
208 | # we need application-layer timeouts of some flavour of these | |
209 | # requests | |
210 | logger.debug( | |
211 | "TX [%s] Transaction already in progress", | |
212 | destination | |
160 | while True: | |
161 | # list of (pending_pdu, deferred, order) | |
162 | if destination in self.pending_transactions: | |
163 | # XXX: pending_transactions can get stuck on by a never-ending | |
164 | # request at which point pending_pdus_by_dest just keeps growing. | |
165 | # we need application-layer timeouts of some flavour of these | |
166 | # requests | |
167 | logger.debug( | |
168 | "TX [%s] Transaction already in progress", | |
169 | destination | |
170 | ) | |
171 | return | |
172 | ||
173 | pending_pdus = self.pending_pdus_by_dest.pop(destination, []) | |
174 | pending_edus = self.pending_edus_by_dest.pop(destination, []) | |
175 | pending_failures = self.pending_failures_by_dest.pop(destination, []) | |
176 | ||
177 | if pending_pdus: | |
178 | logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", | |
179 | destination, len(pending_pdus)) | |
180 | ||
181 | if not pending_pdus and not pending_edus and not pending_failures: | |
182 | logger.debug("TX [%s] Nothing to send", destination) | |
183 | return | |
184 | ||
185 | yield self._send_new_transaction( | |
186 | destination, pending_pdus, pending_edus, pending_failures | |
213 | 187 | ) |
214 | return | |
215 | ||
216 | pending_pdus = self.pending_pdus_by_dest.pop(destination, []) | |
217 | pending_edus = self.pending_edus_by_dest.pop(destination, []) | |
218 | pending_failures = self.pending_failures_by_dest.pop(destination, []) | |
219 | ||
220 | if pending_pdus: | |
221 | logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", | |
222 | destination, len(pending_pdus)) | |
223 | ||
224 | if not pending_pdus and not pending_edus and not pending_failures: | |
225 | logger.debug("TX [%s] Nothing to send", destination) | |
226 | return | |
227 | ||
228 | try: | |
229 | self.pending_transactions[destination] = 1 | |
230 | ||
231 | logger.debug("TX [%s] _attempt_new_transaction", destination) | |
188 | ||
189 | @measure_func("_send_new_transaction") | |
190 | @defer.inlineCallbacks | |
191 | def _send_new_transaction(self, destination, pending_pdus, pending_edus, | |
192 | pending_failures): | |
232 | 193 | |
233 | 194 | # Sort based on the order field |
234 | pending_pdus.sort(key=lambda t: t[2]) | |
235 | ||
195 | pending_pdus.sort(key=lambda t: t[1]) | |
236 | 196 | pdus = [x[0] for x in pending_pdus] |
237 | edus = [x[0] for x in pending_edus] | |
238 | failures = [x[0].get_dict() for x in pending_failures] | |
239 | deferreds = [ | |
240 | x[1] | |
241 | for x in pending_pdus + pending_edus + pending_failures | |
242 | ] | |
243 | ||
244 | txn_id = str(self._next_txn_id) | |
245 | ||
246 | limiter = yield get_retry_limiter( | |
247 | destination, | |
248 | self._clock, | |
249 | self.store, | |
250 | ) | |
251 | ||
252 | logger.debug( | |
253 | "TX [%s] {%s} Attempting new transaction" | |
254 | " (pdus: %d, edus: %d, failures: %d)", | |
255 | destination, txn_id, | |
256 | len(pending_pdus), | |
257 | len(pending_edus), | |
258 | len(pending_failures) | |
259 | ) | |
260 | ||
261 | logger.debug("TX [%s] Persisting transaction...", destination) | |
262 | ||
263 | transaction = Transaction.create_new( | |
264 | origin_server_ts=int(self._clock.time_msec()), | |
265 | transaction_id=txn_id, | |
266 | origin=self.server_name, | |
267 | destination=destination, | |
268 | pdus=pdus, | |
269 | edus=edus, | |
270 | pdu_failures=failures, | |
271 | ) | |
272 | ||
273 | self._next_txn_id += 1 | |
274 | ||
275 | yield self.transaction_actions.prepare_to_send(transaction) | |
276 | ||
277 | logger.debug("TX [%s] Persisted transaction", destination) | |
278 | logger.info( | |
279 | "TX [%s] {%s} Sending transaction [%s]," | |
280 | " (PDUs: %d, EDUs: %d, failures: %d)", | |
281 | destination, txn_id, | |
282 | transaction.transaction_id, | |
283 | len(pending_pdus), | |
284 | len(pending_edus), | |
285 | len(pending_failures), | |
286 | ) | |
287 | ||
288 | with limiter: | |
289 | # Actually send the transaction | |
290 | ||
291 | # FIXME (erikj): This is a bit of a hack to make the Pdu age | |
292 | # keys work | |
293 | def json_data_cb(): | |
294 | data = transaction.get_dict() | |
295 | now = int(self._clock.time_msec()) | |
296 | if "pdus" in data: | |
297 | for p in data["pdus"]: | |
298 | if "age_ts" in p: | |
299 | unsigned = p.setdefault("unsigned", {}) | |
300 | unsigned["age"] = now - int(p["age_ts"]) | |
301 | del p["age_ts"] | |
302 | return data | |
303 | ||
304 | try: | |
305 | response = yield self.transport_layer.send_transaction( | |
306 | transaction, json_data_cb | |
197 | edus = pending_edus | |
198 | failures = [x.get_dict() for x in pending_failures] | |
199 | ||
200 | try: | |
201 | self.pending_transactions[destination] = 1 | |
202 | ||
203 | logger.debug("TX [%s] _attempt_new_transaction", destination) | |
204 | ||
205 | txn_id = str(self._next_txn_id) | |
206 | ||
207 | limiter = yield get_retry_limiter( | |
208 | destination, | |
209 | self.clock, | |
210 | self.store, | |
211 | ) | |
212 | ||
213 | logger.debug( | |
214 | "TX [%s] {%s} Attempting new transaction" | |
215 | " (pdus: %d, edus: %d, failures: %d)", | |
216 | destination, txn_id, | |
217 | len(pending_pdus), | |
218 | len(pending_edus), | |
219 | len(pending_failures) | |
220 | ) | |
221 | ||
222 | logger.debug("TX [%s] Persisting transaction...", destination) | |
223 | ||
224 | transaction = Transaction.create_new( | |
225 | origin_server_ts=int(self.clock.time_msec()), | |
226 | transaction_id=txn_id, | |
227 | origin=self.server_name, | |
228 | destination=destination, | |
229 | pdus=pdus, | |
230 | edus=edus, | |
231 | pdu_failures=failures, | |
232 | ) | |
233 | ||
234 | self._next_txn_id += 1 | |
235 | ||
236 | yield self.transaction_actions.prepare_to_send(transaction) | |
237 | ||
238 | logger.debug("TX [%s] Persisted transaction", destination) | |
239 | logger.info( | |
240 | "TX [%s] {%s} Sending transaction [%s]," | |
241 | " (PDUs: %d, EDUs: %d, failures: %d)", | |
242 | destination, txn_id, | |
243 | transaction.transaction_id, | |
244 | len(pending_pdus), | |
245 | len(pending_edus), | |
246 | len(pending_failures), | |
247 | ) | |
248 | ||
249 | with limiter: | |
250 | # Actually send the transaction | |
251 | ||
252 | # FIXME (erikj): This is a bit of a hack to make the Pdu age | |
253 | # keys work | |
254 | def json_data_cb(): | |
255 | data = transaction.get_dict() | |
256 | now = int(self.clock.time_msec()) | |
257 | if "pdus" in data: | |
258 | for p in data["pdus"]: | |
259 | if "age_ts" in p: | |
260 | unsigned = p.setdefault("unsigned", {}) | |
261 | unsigned["age"] = now - int(p["age_ts"]) | |
262 | del p["age_ts"] | |
263 | return data | |
264 | ||
265 | try: | |
266 | response = yield self.transport_layer.send_transaction( | |
267 | transaction, json_data_cb | |
268 | ) | |
269 | code = 200 | |
270 | ||
271 | if response: | |
272 | for e_id, r in response.get("pdus", {}).items(): | |
273 | if "error" in r: | |
274 | logger.warn( | |
275 | "Transaction returned error for %s: %s", | |
276 | e_id, r, | |
277 | ) | |
278 | except HttpResponseException as e: | |
279 | code = e.code | |
280 | response = e.response | |
281 | ||
282 | logger.info( | |
283 | "TX [%s] {%s} got %d response", | |
284 | destination, txn_id, code | |
307 | 285 | ) |
308 | code = 200 | |
309 | ||
310 | if response: | |
311 | for e_id, r in response.get("pdus", {}).items(): | |
312 | if "error" in r: | |
313 | logger.warn( | |
314 | "Transaction returned error for %s: %s", | |
315 | e_id, r, | |
316 | ) | |
317 | except HttpResponseException as e: | |
318 | code = e.code | |
319 | response = e.response | |
320 | ||
286 | ||
287 | logger.debug("TX [%s] Sent transaction", destination) | |
288 | logger.debug("TX [%s] Marking as delivered...", destination) | |
289 | ||
290 | yield self.transaction_actions.delivered( | |
291 | transaction, code, response | |
292 | ) | |
293 | ||
294 | logger.debug("TX [%s] Marked as delivered", destination) | |
295 | ||
296 | if code != 200: | |
297 | for p in pdus: | |
298 | logger.info( | |
299 | "Failed to send event %s to %s", p.event_id, destination | |
300 | ) | |
301 | except NotRetryingDestination: | |
321 | 302 | logger.info( |
322 | "TX [%s] {%s} got %d response", | |
323 | destination, txn_id, code | |
324 | ) | |
325 | ||
326 | logger.debug("TX [%s] Sent transaction", destination) | |
327 | logger.debug("TX [%s] Marking as delivered...", destination) | |
328 | ||
329 | yield self.transaction_actions.delivered( | |
330 | transaction, code, response | |
331 | ) | |
332 | ||
333 | logger.debug("TX [%s] Marked as delivered", destination) | |
334 | ||
335 | logger.debug("TX [%s] Yielding to callbacks...", destination) | |
336 | ||
337 | for deferred in deferreds: | |
338 | if code == 200: | |
339 | deferred.callback(None) | |
340 | else: | |
341 | deferred.errback(RuntimeError("Got status %d" % code)) | |
342 | ||
343 | # Ensures we don't continue until all callbacks on that | |
344 | # deferred have fired | |
345 | try: | |
346 | yield deferred | |
347 | except: | |
348 | pass | |
349 | ||
350 | logger.debug("TX [%s] Yielded to callbacks", destination) | |
351 | except NotRetryingDestination: | |
352 | logger.info( | |
353 | "TX [%s] not ready for retry yet - " | |
354 | "dropping transaction for now", | |
355 | destination, | |
356 | ) | |
357 | except RuntimeError as e: | |
358 | # We capture this here as there as nothing actually listens | |
359 | # for this finishing functions deferred. | |
360 | logger.warn( | |
361 | "TX [%s] Problem in _attempt_transaction: %s", | |
362 | destination, | |
363 | e, | |
364 | ) | |
365 | except Exception as e: | |
366 | # We capture this here as there as nothing actually listens | |
367 | # for this finishing functions deferred. | |
368 | logger.warn( | |
369 | "TX [%s] Problem in _attempt_transaction: %s", | |
370 | destination, | |
371 | e, | |
372 | ) | |
373 | ||
374 | for deferred in deferreds: | |
375 | if not deferred.called: | |
376 | deferred.errback(e) | |
377 | ||
378 | finally: | |
379 | # We want to be *very* sure we delete this after we stop processing | |
380 | self.pending_transactions.pop(destination, None) | |
381 | ||
382 | # Check to see if there is anything else to send. | |
383 | self._attempt_new_transaction(destination) | |
303 | "TX [%s] not ready for retry yet - " | |
304 | "dropping transaction for now", | |
305 | destination, | |
306 | ) | |
307 | except RuntimeError as e: | |
308 | # We capture this here as there as nothing actually listens | |
309 | # for this finishing functions deferred. | |
310 | logger.warn( | |
311 | "TX [%s] Problem in _attempt_transaction: %s", | |
312 | destination, | |
313 | e, | |
314 | ) | |
315 | ||
316 | for p in pdus: | |
317 | logger.info("Failed to send event %s to %s", p.event_id, destination) | |
318 | except Exception as e: | |
319 | # We capture this here as there as nothing actually listens | |
320 | # for this finishing functions deferred. | |
321 | logger.warn( | |
322 | "TX [%s] Problem in _attempt_transaction: %s", | |
323 | destination, | |
324 | e, | |
325 | ) | |
326 | ||
327 | for p in pdus: | |
328 | logger.info("Failed to send event %s to %s", p.event_id, destination) | |
329 | ||
330 | finally: | |
331 | # We want to be *very* sure we delete this after we stop processing | |
332 | self.pending_transactions.pop(destination, None) |
18 | 18 | ) |
19 | 19 | from .room_member import RoomMemberHandler |
20 | 20 | from .message import MessageHandler |
21 | from .events import EventStreamHandler, EventHandler | |
22 | 21 | from .federation import FederationHandler |
23 | 22 | from .profile import ProfileHandler |
24 | 23 | from .directory import DirectoryHandler |
52 | 51 | self.message_handler = MessageHandler(hs) |
53 | 52 | self.room_creation_handler = RoomCreationHandler(hs) |
54 | 53 | self.room_member_handler = RoomMemberHandler(hs) |
55 | self.event_stream_handler = EventStreamHandler(hs) | |
56 | self.event_handler = EventHandler(hs) | |
57 | 54 | self.federation_handler = FederationHandler(hs) |
58 | 55 | self.profile_handler = ProfileHandler(hs) |
59 | 56 | self.directory_handler = DirectoryHandler(hs) |
15 | 15 | from twisted.internet import defer |
16 | 16 | |
17 | 17 | from synapse.api.constants import EventTypes |
18 | from synapse.appservice import ApplicationService | |
18 | from synapse.util.metrics import Measure | |
19 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |
19 | 20 | |
20 | 21 | import logging |
21 | 22 | |
41 | 42 | self.appservice_api = hs.get_application_service_api() |
42 | 43 | self.scheduler = hs.get_application_service_scheduler() |
43 | 44 | self.started_scheduler = False |
44 | ||
45 | @defer.inlineCallbacks | |
46 | def notify_interested_services(self, event): | |
45 | self.clock = hs.get_clock() | |
46 | self.notify_appservices = hs.config.notify_appservices | |
47 | ||
48 | self.current_max = 0 | |
49 | self.is_processing = False | |
50 | ||
51 | @defer.inlineCallbacks | |
52 | def notify_interested_services(self, current_id): | |
47 | 53 | """Notifies (pushes) all application services interested in this event. |
48 | 54 | |
49 | 55 | Pushing is done asynchronously, so this method won't block for any |
50 | 56 | prolonged length of time. |
51 | 57 | |
52 | 58 | Args: |
53 | event(Event): The event to push out to interested services. | |
54 | """ | |
55 | # Gather interested services | |
56 | services = yield self._get_services_for_event(event) | |
57 | if len(services) == 0: | |
58 | return # no services need notifying | |
59 | ||
60 | # Do we know this user exists? If not, poke the user query API for | |
61 | # all services which match that user regex. This needs to block as these | |
62 | # user queries need to be made BEFORE pushing the event. | |
63 | yield self._check_user_exists(event.sender) | |
64 | if event.type == EventTypes.Member: | |
65 | yield self._check_user_exists(event.state_key) | |
66 | ||
67 | if not self.started_scheduler: | |
68 | self.scheduler.start().addErrback(log_failure) | |
69 | self.started_scheduler = True | |
70 | ||
71 | # Fork off pushes to these services | |
72 | for service in services: | |
73 | self.scheduler.submit_event_for_as(service, event) | |
59 | current_id(int): The current maximum ID. | |
60 | """ | |
61 | services = yield self.store.get_app_services() | |
62 | if not services or not self.notify_appservices: | |
63 | return | |
64 | ||
65 | self.current_max = max(self.current_max, current_id) | |
66 | if self.is_processing: | |
67 | return | |
68 | ||
69 | with Measure(self.clock, "notify_interested_services"): | |
70 | self.is_processing = True | |
71 | try: | |
72 | upper_bound = self.current_max | |
73 | limit = 100 | |
74 | while True: | |
75 | upper_bound, events = yield self.store.get_new_events_for_appservice( | |
76 | upper_bound, limit | |
77 | ) | |
78 | ||
79 | if not events: | |
80 | break | |
81 | ||
82 | for event in events: | |
83 | # Gather interested services | |
84 | services = yield self._get_services_for_event(event) | |
85 | if len(services) == 0: | |
86 | continue # no services need notifying | |
87 | ||
88 | # Do we know this user exists? If not, poke the user | |
89 | # query API for all services which match that user regex. | |
90 | # This needs to block as these user queries need to be | |
91 | # made BEFORE pushing the event. | |
92 | yield self._check_user_exists(event.sender) | |
93 | if event.type == EventTypes.Member: | |
94 | yield self._check_user_exists(event.state_key) | |
95 | ||
96 | if not self.started_scheduler: | |
97 | self.scheduler.start().addErrback(log_failure) | |
98 | self.started_scheduler = True | |
99 | ||
100 | # Fork off pushes to these services | |
101 | for service in services: | |
102 | preserve_fn(self.scheduler.submit_event_for_as)( | |
103 | service, event | |
104 | ) | |
105 | ||
106 | yield self.store.set_appservice_last_pos(upper_bound) | |
107 | ||
108 | if len(events) < limit: | |
109 | break | |
110 | finally: | |
111 | self.is_processing = False | |
74 | 112 | |
75 | 113 | @defer.inlineCallbacks |
76 | 114 | def query_user_exists(self, user_id): |
103 | 141 | association can be found. |
104 | 142 | """ |
105 | 143 | room_alias_str = room_alias.to_string() |
106 | alias_query_services = yield self._get_services_for_event( | |
107 | event=None, | |
108 | restrict_to=ApplicationService.NS_ALIASES, | |
109 | alias_list=[room_alias_str] | |
110 | ) | |
144 | services = yield self.store.get_app_services() | |
145 | alias_query_services = [ | |
146 | s for s in services if ( | |
147 | s.is_interested_in_alias(room_alias_str) | |
148 | ) | |
149 | ] | |
111 | 150 | for alias_service in alias_query_services: |
112 | 151 | is_known_alias = yield self.appservice_api.query_alias( |
113 | 152 | alias_service, room_alias_str |
120 | 159 | defer.returnValue(result) |
121 | 160 | |
122 | 161 | @defer.inlineCallbacks |
123 | def _get_services_for_event(self, event, restrict_to="", alias_list=None): | |
162 | def query_3pe(self, kind, protocol, fields): | |
163 | services = yield self._get_services_for_3pn(protocol) | |
164 | ||
165 | results = yield preserve_context_over_deferred(defer.DeferredList([ | |
166 | preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) | |
167 | for service in services | |
168 | ], consumeErrors=True)) | |
169 | ||
170 | ret = [] | |
171 | for (success, result) in results: | |
172 | if success: | |
173 | ret.extend(result) | |
174 | ||
175 | defer.returnValue(ret) | |
176 | ||
177 | @defer.inlineCallbacks | |
178 | def _get_services_for_event(self, event): | |
124 | 179 | """Retrieve a list of application services interested in this event. |
125 | 180 | |
126 | 181 | Args: |
127 | 182 | event(Event): The event to check. Can be None if alias_list is not. |
128 | restrict_to(str): The namespace to restrict regex tests to. | |
129 | alias_list: A list of aliases to get services for. If None, this | |
130 | list is obtained from the database. | |
131 | 183 | Returns: |
132 | 184 | list<ApplicationService>: A list of services interested in this |
133 | 185 | event based on the service regex. |
134 | 186 | """ |
135 | member_list = None | |
136 | if hasattr(event, "room_id"): | |
137 | # We need to know the aliases associated with this event.room_id, | |
138 | # if any. | |
139 | if not alias_list: | |
140 | alias_list = yield self.store.get_aliases_for_room( | |
141 | event.room_id | |
142 | ) | |
143 | # We need to know the members associated with this event.room_id, | |
144 | # if any. | |
145 | member_list = yield self.store.get_users_in_room(event.room_id) | |
146 | ||
147 | 187 | services = yield self.store.get_app_services() |
148 | 188 | interested_list = [ |
149 | 189 | s for s in services if ( |
150 | s.is_interested(event, restrict_to, alias_list, member_list) | |
190 | yield s.is_interested(event, self.store) | |
151 | 191 | ) |
152 | 192 | ] |
153 | 193 | defer.returnValue(interested_list) |
159 | 199 | s for s in services if ( |
160 | 200 | s.is_interested_in_user(user_id) |
161 | 201 | ) |
202 | ] | |
203 | defer.returnValue(interested_list) | |
204 | ||
205 | @defer.inlineCallbacks | |
206 | def _get_services_for_3pn(self, protocol): | |
207 | services = yield self.store.get_app_services() | |
208 | interested_list = [ | |
209 | s for s in services if s.is_interested_in_protocol(protocol) | |
162 | 210 | ] |
163 | 211 | defer.returnValue(interested_list) |
164 | 212 |
69 | 69 | self.ldap_uri = hs.config.ldap_uri |
70 | 70 | self.ldap_start_tls = hs.config.ldap_start_tls |
71 | 71 | self.ldap_base = hs.config.ldap_base |
72 | self.ldap_filter = hs.config.ldap_filter | |
73 | 72 | self.ldap_attributes = hs.config.ldap_attributes |
74 | 73 | if self.ldap_mode == LDAPMode.SEARCH: |
75 | 74 | self.ldap_bind_dn = hs.config.ldap_bind_dn |
76 | 75 | self.ldap_bind_password = hs.config.ldap_bind_password |
76 | self.ldap_filter = hs.config.ldap_filter | |
77 | 77 | |
78 | 78 | self.hs = hs # FIXME better possibility to access registrationHandler later? |
79 | 79 | self.device_handler = hs.get_device_handler() |
659 | 659 | else: |
660 | 660 | logger.warn( |
661 | 661 | "ldap registration failed: unexpected (%d!=1) amount of results", |
662 | len(result) | |
662 | len(conn.response) | |
663 | 663 | ) |
664 | 664 | defer.returnValue(False) |
665 | 665 | |
718 | 718 | return macaroon.serialize() |
719 | 719 | |
720 | 720 | def validate_short_term_login_token_and_get_user_id(self, login_token): |
721 | auth_api = self.hs.get_auth() | |
721 | 722 | try: |
722 | 723 | macaroon = pymacaroons.Macaroon.deserialize(login_token) |
723 | auth_api = self.hs.get_auth() | |
724 | auth_api.validate_macaroon(macaroon, "login", True) | |
725 | return self.get_user_from_macaroon(macaroon) | |
726 | except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): | |
727 | raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) | |
724 | user_id = auth_api.get_user_id_from_macaroon(macaroon) | |
725 | auth_api.validate_macaroon(macaroon, "login", True, user_id) | |
726 | return user_id | |
727 | except Exception: | |
728 | raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) | |
728 | 729 | |
729 | 730 | def _generate_base_macaroon(self, user_id): |
730 | 731 | macaroon = pymacaroons.Macaroon( |
735 | 736 | macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) |
736 | 737 | return macaroon |
737 | 738 | |
738 | def get_user_from_macaroon(self, macaroon): | |
739 | user_prefix = "user_id = " | |
740 | for caveat in macaroon.caveats: | |
741 | if caveat.caveat_id.startswith(user_prefix): | |
742 | return caveat.caveat_id[len(user_prefix):] | |
743 | raise AuthError( | |
744 | self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", | |
745 | errcode=Codes.UNKNOWN_TOKEN | |
746 | ) | |
747 | ||
748 | 739 | @defer.inlineCallbacks |
749 | 740 | def set_password(self, user_id, newpassword, requester=None): |
750 | 741 | password_hash = self.hash(newpassword) |
751 | 742 | |
752 | except_access_token_ids = [requester.access_token_id] if requester else [] | |
743 | except_access_token_id = requester.access_token_id if requester else None | |
753 | 744 | |
754 | 745 | try: |
755 | 746 | yield self.store.user_set_password_hash(user_id, password_hash) |
758 | 749 | raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) |
759 | 750 | raise e |
760 | 751 | yield self.store.user_delete_access_tokens( |
761 | user_id, except_access_token_ids | |
752 | user_id, except_access_token_id | |
762 | 753 | ) |
763 | 754 | yield self.hs.get_pusherpool().remove_pushers_by_user( |
764 | user_id, except_access_token_ids | |
755 | user_id, except_access_token_id | |
765 | 756 | ) |
766 | 757 | |
767 | 758 | @defer.inlineCallbacks |
25 | 25 | from synapse.api.constants import EventTypes, Membership, RejectedReason |
26 | 26 | from synapse.events.validator import EventValidator |
27 | 27 | from synapse.util import unwrapFirstError |
28 | from synapse.util.logcontext import PreserveLoggingContext, preserve_fn | |
28 | from synapse.util.logcontext import ( | |
29 | PreserveLoggingContext, preserve_fn, preserve_context_over_deferred | |
30 | ) | |
29 | 31 | from synapse.util.logutils import log_function |
30 | 32 | from synapse.util.async import run_on_reactor |
31 | 33 | from synapse.util.frozenutils import unfreeze |
248 | 250 | if ev.type != EventTypes.Member: |
249 | 251 | continue |
250 | 252 | try: |
251 | domain = UserID.from_string(ev.state_key).domain | |
253 | domain = get_domain_from_id(ev.state_key) | |
252 | 254 | except: |
253 | 255 | continue |
254 | 256 | |
273 | 275 | |
274 | 276 | @log_function |
275 | 277 | @defer.inlineCallbacks |
276 | def backfill(self, dest, room_id, limit, extremities=[]): | |
278 | def backfill(self, dest, room_id, limit, extremities): | |
277 | 279 | """ Trigger a backfill request to `dest` for the given `room_id` |
278 | 280 | |
279 | 281 | This will attempt to get more events from the remote. This may return |
282 | 284 | """ |
283 | 285 | if dest == self.server_name: |
284 | 286 | raise SynapseError(400, "Can't backfill from self.") |
285 | ||
286 | if not extremities: | |
287 | extremities = yield self.store.get_oldest_events_in_room(room_id) | |
288 | 287 | |
289 | 288 | events = yield self.replication_layer.backfill( |
290 | 289 | dest, |
363 | 362 | missing_auth - failed_to_fetch |
364 | 363 | ) |
365 | 364 | |
366 | results = yield defer.gatherResults( | |
365 | results = yield preserve_context_over_deferred(defer.gatherResults( | |
367 | 366 | [ |
368 | self.replication_layer.get_pdu( | |
367 | preserve_fn(self.replication_layer.get_pdu)( | |
369 | 368 | [dest], |
370 | 369 | event_id, |
371 | 370 | outlier=True, |
374 | 373 | for event_id in missing_auth - failed_to_fetch |
375 | 374 | ], |
376 | 375 | consumeErrors=True |
377 | ).addErrback(unwrapFirstError) | |
378 | auth_events.update({a.event_id: a for a in results}) | |
376 | )).addErrback(unwrapFirstError) | |
377 | auth_events.update({a.event_id: a for a in results if a}) | |
379 | 378 | required_auth.update( |
380 | a_id for event in results for a_id, _ in event.auth_events | |
379 | a_id for event in results for a_id, _ in event.auth_events if event | |
381 | 380 | ) |
382 | 381 | missing_auth = required_auth - set(auth_events) |
383 | 382 | |
453 | 452 | key=lambda e: -int(e[1]) |
454 | 453 | ) |
455 | 454 | max_depth = sorted_extremeties_tuple[0][1] |
455 | ||
456 | # We don't want to specify too many extremities as it causes the backfill | |
457 | # request URI to be too long. | |
458 | extremities = dict(sorted_extremeties_tuple[:5]) | |
456 | 459 | |
457 | 460 | if current_depth > max_depth: |
458 | 461 | logger.debug( |
550 | 553 | |
551 | 554 | event_ids = list(extremities.keys()) |
552 | 555 | |
553 | states = yield defer.gatherResults([ | |
554 | self.state_handler.resolve_state_groups(room_id, [e]) | |
556 | states = yield preserve_context_over_deferred(defer.gatherResults([ | |
557 | preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) | |
555 | 558 | for e in event_ids |
556 | ]) | |
559 | ])) | |
557 | 560 | states = dict(zip(event_ids, [s[1] for s in states])) |
558 | 561 | |
559 | 562 | for e_id, _ in sorted_extremeties_tuple: |
1092 | 1095 | ) |
1093 | 1096 | |
1094 | 1097 | if event: |
1095 | # FIXME: This is a temporary work around where we occasionally | |
1096 | # return events slightly differently than when they were | |
1097 | # originally signed | |
1098 | event.signatures.update( | |
1099 | compute_event_signature( | |
1100 | event, | |
1101 | self.hs.hostname, | |
1102 | self.hs.config.signing_key[0] | |
1103 | ) | |
1104 | ) | |
1098 | if self.hs.is_mine_id(event.event_id): | |
1099 | # FIXME: This is a temporary work around where we occasionally | |
1100 | # return events slightly differently than when they were | |
1101 | # originally signed | |
1102 | event.signatures.update( | |
1103 | compute_event_signature( | |
1104 | event, | |
1105 | self.hs.hostname, | |
1106 | self.hs.config.signing_key[0] | |
1107 | ) | |
1108 | ) | |
1105 | 1109 | |
1106 | 1110 | if do_auth: |
1107 | 1111 | in_room = yield self.auth.check_host_in_room( |
1110 | 1114 | ) |
1111 | 1115 | if not in_room: |
1112 | 1116 | raise AuthError(403, "Host not in room.") |
1117 | ||
1118 | events = yield self._filter_events_for_server( | |
1119 | origin, event.room_id, [event] | |
1120 | ) | |
1121 | ||
1122 | event = events[0] | |
1113 | 1123 | |
1114 | 1124 | defer.returnValue(event) |
1115 | 1125 | else: |
1157 | 1167 | a bunch of outliers, but not a chunk of individual events that depend |
1158 | 1168 | on each other for state calculations. |
1159 | 1169 | """ |
1160 | contexts = yield defer.gatherResults( | |
1170 | contexts = yield preserve_context_over_deferred(defer.gatherResults( | |
1161 | 1171 | [ |
1162 | self._prep_event( | |
1172 | preserve_fn(self._prep_event)( | |
1163 | 1173 | origin, |
1164 | 1174 | ev_info["event"], |
1165 | 1175 | state=ev_info.get("state"), |
1167 | 1177 | ) |
1168 | 1178 | for ev_info in event_infos |
1169 | 1179 | ] |
1170 | ) | |
1180 | )) | |
1171 | 1181 | |
1172 | 1182 | yield self.store.persist_events( |
1173 | 1183 | [ |
1451 | 1461 | # Do auth conflict res. |
1452 | 1462 | logger.info("Different auth: %s", different_auth) |
1453 | 1463 | |
1454 | different_events = yield defer.gatherResults( | |
1464 | different_events = yield preserve_context_over_deferred(defer.gatherResults( | |
1455 | 1465 | [ |
1456 | self.store.get_event( | |
1466 | preserve_fn(self.store.get_event)( | |
1457 | 1467 | d, |
1458 | 1468 | allow_none=True, |
1459 | 1469 | allow_rejected=False, |
1462 | 1472 | if d in have_events and not have_events[d] |
1463 | 1473 | ], |
1464 | 1474 | consumeErrors=True |
1465 | ).addErrback(unwrapFirstError) | |
1475 | )).addErrback(unwrapFirstError) | |
1466 | 1476 | |
1467 | 1477 | if different_events: |
1468 | 1478 | local_view = dict(auth_events) |
27 | 27 | from synapse.util import unwrapFirstError |
28 | 28 | from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock |
29 | 29 | from synapse.util.caches.snapshot_cache import SnapshotCache |
30 | from synapse.util.logcontext import preserve_fn | |
30 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |
31 | from synapse.util.metrics import measure_func | |
31 | 32 | from synapse.visibility import filter_events_for_client |
32 | 33 | |
33 | 34 | from ._base import BaseHandler |
501 | 502 | lambda states: states[event.event_id] |
502 | 503 | ) |
503 | 504 | |
504 | (messages, token), current_state = yield defer.gatherResults( | |
505 | [ | |
506 | self.store.get_recent_events_for_room( | |
507 | event.room_id, | |
508 | limit=limit, | |
509 | end_token=room_end_token, | |
510 | ), | |
511 | deferred_room_state, | |
512 | ] | |
505 | (messages, token), current_state = yield preserve_context_over_deferred( | |
506 | defer.gatherResults( | |
507 | [ | |
508 | preserve_fn(self.store.get_recent_events_for_room)( | |
509 | event.room_id, | |
510 | limit=limit, | |
511 | end_token=room_end_token, | |
512 | ), | |
513 | deferred_room_state, | |
514 | ] | |
515 | ) | |
513 | 516 | ).addErrback(unwrapFirstError) |
514 | 517 | |
515 | 518 | messages = yield filter_events_for_client( |
718 | 721 | |
719 | 722 | presence, receipts, (messages, token) = yield defer.gatherResults( |
720 | 723 | [ |
721 | get_presence(), | |
722 | get_receipts(), | |
723 | self.store.get_recent_events_for_room( | |
724 | preserve_fn(get_presence)(), | |
725 | preserve_fn(get_receipts)(), | |
726 | preserve_fn(self.store.get_recent_events_for_room)( | |
724 | 727 | room_id, |
725 | 728 | limit=limit, |
726 | 729 | end_token=now_token.room_key, |
754 | 757 | |
755 | 758 | defer.returnValue(ret) |
756 | 759 | |
760 | @measure_func("_create_new_client_event") | |
757 | 761 | @defer.inlineCallbacks |
758 | 762 | def _create_new_client_event(self, builder, prev_event_ids=None): |
759 | 763 | if prev_event_ids: |
805 | 809 | (event, context,) |
806 | 810 | ) |
807 | 811 | |
812 | @measure_func("handle_new_client_event") | |
808 | 813 | @defer.inlineCallbacks |
809 | 814 | def handle_new_client_event( |
810 | 815 | self, |
933 | 938 | @defer.inlineCallbacks |
934 | 939 | def _notify(): |
935 | 940 | yield run_on_reactor() |
936 | self.notifier.on_new_room_event( | |
941 | yield self.notifier.on_new_room_event( | |
937 | 942 | event, event_stream_id, max_stream_id, |
938 | 943 | extra_users=extra_users |
939 | 944 | ) |
943 | 948 | # If invite, remove room_state from unsigned before sending. |
944 | 949 | event.unsigned.pop("invite_room_state", None) |
945 | 950 | |
946 | federation_handler.handle_new_event( | |
951 | preserve_fn(federation_handler.handle_new_event)( | |
947 | 952 | event, destinations=destinations, |
948 | 953 | ) |
502 | 502 | defer.returnValue(states) |
503 | 503 | |
504 | 504 | @defer.inlineCallbacks |
505 | def _get_interested_parties(self, states): | |
505 | def _get_interested_parties(self, states, calculate_remote_hosts=True): | |
506 | 506 | """Given a list of states return which entities (rooms, users, servers) |
507 | 507 | are interested in the given states. |
508 | 508 | |
525 | 525 | users_to_states.setdefault(state.user_id, []).append(state) |
526 | 526 | |
527 | 527 | hosts_to_states = {} |
528 | for room_id, states in room_ids_to_states.items(): | |
529 | local_states = filter(lambda s: self.is_mine_id(s.user_id), states) | |
530 | if not local_states: | |
531 | continue | |
532 | ||
533 | hosts = yield self.store.get_joined_hosts_for_room(room_id) | |
534 | for host in hosts: | |
535 | hosts_to_states.setdefault(host, []).extend(local_states) | |
528 | if calculate_remote_hosts: | |
529 | for room_id, states in room_ids_to_states.items(): | |
530 | local_states = filter(lambda s: self.is_mine_id(s.user_id), states) | |
531 | if not local_states: | |
532 | continue | |
533 | ||
534 | hosts = yield self.store.get_joined_hosts_for_room(room_id) | |
535 | for host in hosts: | |
536 | hosts_to_states.setdefault(host, []).extend(local_states) | |
536 | 537 | |
537 | 538 | for user_id, states in users_to_states.items(): |
538 | 539 | local_states = filter(lambda s: self.is_mine_id(s.user_id), states) |
563 | 564 | ) |
564 | 565 | |
565 | 566 | self._push_to_remotes(hosts_to_states) |
567 | ||
568 | @defer.inlineCallbacks | |
569 | def notify_for_states(self, state, stream_id): | |
570 | parties = yield self._get_interested_parties([state]) | |
571 | room_ids_to_states, users_to_states, hosts_to_states = parties | |
572 | ||
573 | self.notifier.on_new_event( | |
574 | "presence_key", stream_id, rooms=room_ids_to_states.keys(), | |
575 | users=[UserID.from_string(u) for u in users_to_states.keys()] | |
576 | ) | |
566 | 577 | |
567 | 578 | def _push_to_remotes(self, hosts_to_states): |
568 | 579 | """Sends state updates to remote servers. |
671 | 682 | ]) |
672 | 683 | |
673 | 684 | @defer.inlineCallbacks |
674 | def set_state(self, target_user, state): | |
685 | def set_state(self, target_user, state, ignore_status_msg=False): | |
675 | 686 | """Set the presence state of the user. |
676 | 687 | """ |
677 | 688 | status_msg = state.get("status_msg", None) |
688 | 699 | prev_state = yield self.current_state_for_user(user_id) |
689 | 700 | |
690 | 701 | new_fields = { |
691 | "state": presence, | |
692 | "status_msg": status_msg if presence != PresenceState.OFFLINE else None | |
702 | "state": presence | |
693 | 703 | } |
704 | ||
705 | if not ignore_status_msg: | |
706 | msg = status_msg if presence != PresenceState.OFFLINE else None | |
707 | new_fields["status_msg"] = msg | |
694 | 708 | |
695 | 709 | if presence == PresenceState.ONLINE: |
696 | 710 | new_fields["last_active_ts"] = self.clock.time_msec() |
58 | 58 | prev_event_ids, |
59 | 59 | txn_id=None, |
60 | 60 | ratelimit=True, |
61 | content=None, | |
61 | 62 | ): |
63 | if content is None: | |
64 | content = {} | |
62 | 65 | msg_handler = self.hs.get_handlers().message_handler |
63 | 66 | |
64 | content = {"membership": membership} | |
67 | content["membership"] = membership | |
65 | 68 | if requester.is_guest: |
66 | 69 | content["kind"] = "guest" |
67 | 70 | |
139 | 142 | remote_room_hosts=None, |
140 | 143 | third_party_signed=None, |
141 | 144 | ratelimit=True, |
145 | content=None, | |
142 | 146 | ): |
143 | key = (target, room_id,) | |
147 | key = (room_id,) | |
144 | 148 | |
145 | 149 | with (yield self.member_linearizer.queue(key)): |
146 | 150 | result = yield self._update_membership( |
152 | 156 | remote_room_hosts=remote_room_hosts, |
153 | 157 | third_party_signed=third_party_signed, |
154 | 158 | ratelimit=ratelimit, |
159 | content=content, | |
155 | 160 | ) |
156 | 161 | |
157 | 162 | defer.returnValue(result) |
167 | 172 | remote_room_hosts=None, |
168 | 173 | third_party_signed=None, |
169 | 174 | ratelimit=True, |
175 | content=None, | |
170 | 176 | ): |
177 | if content is None: | |
178 | content = {} | |
179 | ||
171 | 180 | effective_membership_state = action |
172 | 181 | if action in ["kick", "unban"]: |
173 | 182 | effective_membership_state = "leave" |
217 | 226 | if inviter and not self.hs.is_mine(inviter): |
218 | 227 | remote_room_hosts.append(inviter.domain) |
219 | 228 | |
220 | content = {"membership": Membership.JOIN} | |
229 | content["membership"] = Membership.JOIN | |
221 | 230 | |
222 | 231 | profile = self.hs.get_handlers().profile_handler |
223 | 232 | content["displayname"] = yield profile.get_displayname(target) |
271 | 280 | txn_id=txn_id, |
272 | 281 | ratelimit=ratelimit, |
273 | 282 | prev_event_ids=latest_event_ids, |
283 | content=content, | |
274 | 284 | ) |
275 | 285 | |
276 | 286 | @defer.inlineCallbacks |
463 | 463 | else: |
464 | 464 | state = {} |
465 | 465 | |
466 | defer.returnValue({ | |
467 | (e.type, e.state_key): e | |
468 | for e in sync_config.filter_collection.filter_room_state(state.values()) | |
469 | }) | |
466 | defer.returnValue({ | |
467 | (e.type, e.state_key): e | |
468 | for e in sync_config.filter_collection.filter_room_state(state.values()) | |
469 | }) | |
470 | 470 | |
471 | 471 | @defer.inlineCallbacks |
472 | 472 | def unread_notifs_for_room_id(self, room_id, sync_config): |
484 | 484 | ) |
485 | 485 | defer.returnValue(notifs) |
486 | 486 | |
487 | # There is no new information in this period, so your notification | |
488 | # count is whatever it was last time. | |
489 | defer.returnValue(None) | |
487 | # There is no new information in this period, so your notification | |
488 | # count is whatever it was last time. | |
489 | defer.returnValue(None) | |
490 | 490 | |
491 | 491 | @defer.inlineCallbacks |
492 | 492 | def generate_sync_result(self, sync_config, since_token=None, full_state=False): |
15 | 15 | from twisted.internet import defer |
16 | 16 | |
17 | 17 | from synapse.api.errors import SynapseError, AuthError |
18 | from synapse.util.logcontext import PreserveLoggingContext | |
18 | from synapse.util.logcontext import ( | |
19 | PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, | |
20 | ) | |
19 | 21 | from synapse.util.metrics import Measure |
20 | 22 | from synapse.types import UserID |
21 | 23 | |
168 | 170 | deferreds = [] |
169 | 171 | for domain in domains: |
170 | 172 | if domain == self.server_name: |
171 | self._push_update_local( | |
173 | preserve_fn(self._push_update_local)( | |
172 | 174 | room_id=room_id, |
173 | 175 | user_id=user_id, |
174 | 176 | typing=typing |
175 | 177 | ) |
176 | 178 | else: |
177 | deferreds.append(self.federation.send_edu( | |
179 | deferreds.append(preserve_fn(self.federation.send_edu)( | |
178 | 180 | destination=domain, |
179 | 181 | edu_type="m.typing", |
180 | 182 | content={ |
184 | 186 | }, |
185 | 187 | )) |
186 | 188 | |
187 | yield defer.DeferredList(deferreds, consumeErrors=True) | |
189 | yield preserve_context_over_deferred( | |
190 | defer.DeferredList(deferreds, consumeErrors=True) | |
191 | ) | |
188 | 192 | |
189 | 193 | @defer.inlineCallbacks |
190 | 194 | def _recv_edu(self, origin, content): |
154 | 154 | time_out=timeout / 1000. if timeout else 60, |
155 | 155 | ) |
156 | 156 | |
157 | response = yield preserve_context_over_fn( | |
158 | send_request, | |
159 | ) | |
157 | response = yield preserve_context_over_fn(send_request) | |
160 | 158 | |
161 | 159 | log_result = "%d %s" % (response.code, response.phrase,) |
162 | 160 | break |
18 | 18 | ) |
19 | 19 | from synapse.util.logcontext import LoggingContext, PreserveLoggingContext |
20 | 20 | from synapse.util.caches import intern_dict |
21 | from synapse.util.metrics import Measure | |
21 | 22 | import synapse.metrics |
22 | 23 | import synapse.events |
23 | 24 | |
73 | 74 | _next_request_id = 0 |
74 | 75 | |
75 | 76 | |
76 | def request_handler(report_metrics=True): | |
77 | def request_handler(include_metrics=False): | |
77 | 78 | """Decorator for ``wrap_request_handler``""" |
78 | return lambda request_handler: wrap_request_handler(request_handler, report_metrics) | |
79 | ||
80 | ||
81 | def wrap_request_handler(request_handler, report_metrics): | |
79 | return lambda request_handler: wrap_request_handler(request_handler, include_metrics) | |
80 | ||
81 | ||
82 | def wrap_request_handler(request_handler, include_metrics=False): | |
82 | 83 | """Wraps a method that acts as a request handler with the necessary logging |
83 | 84 | and exception handling. |
84 | 85 | |
102 | 103 | _next_request_id += 1 |
103 | 104 | |
104 | 105 | with LoggingContext(request_id) as request_context: |
105 | if report_metrics: | |
106 | with Measure(self.clock, "wrapped_request_handler"): | |
106 | 107 | request_metrics = RequestMetrics() |
107 | request_metrics.start(self.clock) | |
108 | ||
109 | request_context.request = request_id | |
110 | with request.processing(): | |
111 | try: | |
112 | with PreserveLoggingContext(request_context): | |
113 | yield request_handler(self, request) | |
114 | except CodeMessageException as e: | |
115 | code = e.code | |
116 | if isinstance(e, SynapseError): | |
117 | logger.info( | |
118 | "%s SynapseError: %s - %s", request, code, e.msg | |
108 | request_metrics.start(self.clock, name=self.__class__.__name__) | |
109 | ||
110 | request_context.request = request_id | |
111 | with request.processing(): | |
112 | try: | |
113 | with PreserveLoggingContext(request_context): | |
114 | if include_metrics: | |
115 | yield request_handler(self, request, request_metrics) | |
116 | else: | |
117 | yield request_handler(self, request) | |
118 | except CodeMessageException as e: | |
119 | code = e.code | |
120 | if isinstance(e, SynapseError): | |
121 | logger.info( | |
122 | "%s SynapseError: %s - %s", request, code, e.msg | |
123 | ) | |
124 | else: | |
125 | logger.exception(e) | |
126 | outgoing_responses_counter.inc(request.method, str(code)) | |
127 | respond_with_json( | |
128 | request, code, cs_exception(e), send_cors=True, | |
129 | pretty_print=_request_user_agent_is_curl(request), | |
130 | version_string=self.version_string, | |
119 | 131 | ) |
120 | else: | |
121 | logger.exception(e) | |
122 | outgoing_responses_counter.inc(request.method, str(code)) | |
123 | respond_with_json( | |
124 | request, code, cs_exception(e), send_cors=True, | |
125 | pretty_print=_request_user_agent_is_curl(request), | |
126 | version_string=self.version_string, | |
127 | ) | |
128 | except: | |
129 | logger.exception( | |
130 | "Failed handle request %s.%s on %r: %r", | |
131 | request_handler.__module__, | |
132 | request_handler.__name__, | |
133 | self, | |
134 | request | |
135 | ) | |
136 | respond_with_json( | |
137 | request, | |
138 | 500, | |
139 | { | |
140 | "error": "Internal server error", | |
141 | "errcode": Codes.UNKNOWN, | |
142 | }, | |
143 | send_cors=True | |
144 | ) | |
145 | finally: | |
146 | try: | |
147 | if report_metrics: | |
132 | except: | |
133 | logger.exception( | |
134 | "Failed handle request %s.%s on %r: %r", | |
135 | request_handler.__module__, | |
136 | request_handler.__name__, | |
137 | self, | |
138 | request | |
139 | ) | |
140 | respond_with_json( | |
141 | request, | |
142 | 500, | |
143 | { | |
144 | "error": "Internal server error", | |
145 | "errcode": Codes.UNKNOWN, | |
146 | }, | |
147 | send_cors=True | |
148 | ) | |
149 | finally: | |
150 | try: | |
148 | 151 | request_metrics.stop( |
149 | self.clock, request, self.__class__.__name__ | |
152 | self.clock, request | |
150 | 153 | ) |
151 | except: | |
152 | pass | |
154 | except Exception as e: | |
155 | logger.warn("Failed to stop metrics: %r", e) | |
153 | 156 | return wrapped_request_handler |
154 | 157 | |
155 | 158 | |
219 | 222 | # It does its own metric reporting because _async_render dispatches to |
220 | 223 | # a callback and it's the class name of that callback we want to report |
221 | 224 | # against rather than the JsonResource itself. |
222 | @request_handler(report_metrics=False) | |
225 | @request_handler(include_metrics=True) | |
223 | 226 | @defer.inlineCallbacks |
224 | def _async_render(self, request): | |
227 | def _async_render(self, request, request_metrics): | |
225 | 228 | """ This gets called from render() every time someone sends us a request. |
226 | 229 | This checks if anyone has registered a callback for that method and |
227 | 230 | path. |
230 | 233 | self._send_response(request, 200, {}) |
231 | 234 | return |
232 | 235 | |
233 | request_metrics = RequestMetrics() | |
234 | request_metrics.start(self.clock) | |
235 | ||
236 | 236 | # Loop through all the registered callbacks to check if the method |
237 | 237 | # and path regex match |
238 | 238 | for path_entry in self.path_regexs.get(request.method, []): |
246 | 246 | |
247 | 247 | callback = path_entry.callback |
248 | 248 | |
249 | kwargs = intern_dict({ | |
250 | name: urllib.unquote(value).decode("UTF-8") if value else value | |
251 | for name, value in m.groupdict().items() | |
252 | }) | |
253 | ||
254 | callback_return = yield callback(request, **kwargs) | |
255 | if callback_return is not None: | |
256 | code, response = callback_return | |
257 | self._send_response(request, code, response) | |
258 | ||
249 | 259 | servlet_instance = getattr(callback, "__self__", None) |
250 | 260 | if servlet_instance is not None: |
251 | 261 | servlet_classname = servlet_instance.__class__.__name__ |
252 | 262 | else: |
253 | 263 | servlet_classname = "%r" % callback |
254 | 264 | |
255 | kwargs = intern_dict({ | |
256 | name: urllib.unquote(value).decode("UTF-8") if value else value | |
257 | for name, value in m.groupdict().items() | |
258 | }) | |
259 | ||
260 | callback_return = yield callback(request, **kwargs) | |
261 | if callback_return is not None: | |
262 | code, response = callback_return | |
263 | self._send_response(request, code, response) | |
264 | ||
265 | try: | |
266 | request_metrics.stop(self.clock, request, servlet_classname) | |
267 | except: | |
268 | pass | |
265 | request_metrics.name = servlet_classname | |
269 | 266 | |
270 | 267 | return |
271 | 268 | |
297 | 294 | |
298 | 295 | |
299 | 296 | class RequestMetrics(object): |
300 | def start(self, clock): | |
297 | def start(self, clock, name): | |
301 | 298 | self.start = clock.time_msec() |
302 | 299 | self.start_context = LoggingContext.current_context() |
303 | ||
304 | def stop(self, clock, request, servlet_classname): | |
300 | self.name = name | |
301 | ||
302 | def stop(self, clock, request): | |
305 | 303 | context = LoggingContext.current_context() |
306 | 304 | |
307 | 305 | tag = "" |
315 | 313 | ) |
316 | 314 | return |
317 | 315 | |
318 | incoming_requests_counter.inc(request.method, servlet_classname, tag) | |
316 | incoming_requests_counter.inc(request.method, self.name, tag) | |
319 | 317 | |
320 | 318 | response_timer.inc_by( |
321 | 319 | clock.time_msec() - self.start, request.method, |
322 | servlet_classname, tag | |
320 | self.name, tag | |
323 | 321 | ) |
324 | 322 | |
325 | 323 | ru_utime, ru_stime = context.get_resource_usage() |
326 | 324 | |
327 | 325 | response_ru_utime.inc_by( |
328 | ru_utime, request.method, servlet_classname, tag | |
326 | ru_utime, request.method, self.name, tag | |
329 | 327 | ) |
330 | 328 | response_ru_stime.inc_by( |
331 | ru_stime, request.method, servlet_classname, tag | |
329 | ru_stime, request.method, self.name, tag | |
332 | 330 | ) |
333 | 331 | response_db_txn_count.inc_by( |
334 | context.db_txn_count, request.method, servlet_classname, tag | |
332 | context.db_txn_count, request.method, self.name, tag | |
335 | 333 | ) |
336 | 334 | response_db_txn_duration.inc_by( |
337 | context.db_txn_duration, request.method, servlet_classname, tag | |
335 | context.db_txn_duration, request.method, self.name, tag | |
338 | 336 | ) |
339 | 337 | |
340 | 338 |
18 | 18 | |
19 | 19 | from synapse.util.logutils import log_function |
20 | 20 | from synapse.util.async import ObservableDeferred |
21 | from synapse.util.logcontext import PreserveLoggingContext | |
21 | from synapse.util.logcontext import PreserveLoggingContext, preserve_fn | |
22 | from synapse.util.metrics import Measure | |
22 | 23 | from synapse.types import StreamToken |
23 | 24 | from synapse.visibility import filter_events_for_client |
24 | 25 | import synapse.metrics |
66 | 67 | so that it can remove itself from the indexes in the Notifier class. |
67 | 68 | """ |
68 | 69 | |
69 | def __init__(self, user_id, rooms, current_token, time_now_ms, | |
70 | appservice=None): | |
70 | def __init__(self, user_id, rooms, current_token, time_now_ms): | |
71 | 71 | self.user_id = user_id |
72 | self.appservice = appservice | |
73 | 72 | self.rooms = set(rooms) |
74 | 73 | self.current_token = current_token |
75 | 74 | self.last_notified_ms = time_now_ms |
106 | 105 | |
107 | 106 | notifier.user_to_user_stream.pop(self.user_id) |
108 | 107 | |
109 | if self.appservice: | |
110 | notifier.appservice_to_user_streams.get( | |
111 | self.appservice, set() | |
112 | ).discard(self) | |
113 | ||
114 | 108 | def count_listeners(self): |
115 | 109 | return len(self.notify_deferred.observers()) |
116 | 110 | |
141 | 135 | def __init__(self, hs): |
142 | 136 | self.user_to_user_stream = {} |
143 | 137 | self.room_to_user_streams = {} |
144 | self.appservice_to_user_streams = {} | |
145 | 138 | |
146 | 139 | self.event_sources = hs.get_event_sources() |
147 | 140 | self.store = hs.get_datastore() |
167 | 160 | all_user_streams |= x |
168 | 161 | for x in self.user_to_user_stream.values(): |
169 | 162 | all_user_streams.add(x) |
170 | for x in self.appservice_to_user_streams.values(): | |
171 | all_user_streams |= x | |
172 | 163 | |
173 | 164 | return sum(stream.count_listeners() for stream in all_user_streams) |
174 | 165 | metrics.register_callback("listeners", count_listeners) |
181 | 172 | "users", |
182 | 173 | lambda: len(self.user_to_user_stream), |
183 | 174 | ) |
184 | metrics.register_callback( | |
185 | "appservices", | |
186 | lambda: count(bool, self.appservice_to_user_streams.values()), | |
187 | ) | |
188 | ||
175 | ||
176 | @preserve_fn | |
189 | 177 | def on_new_room_event(self, event, room_stream_id, max_room_stream_id, |
190 | 178 | extra_users=[]): |
191 | 179 | """ Used by handlers to inform the notifier something has happened |
207 | 195 | |
208 | 196 | self.notify_replication() |
209 | 197 | |
198 | @preserve_fn | |
210 | 199 | def _notify_pending_new_room_events(self, max_room_stream_id): |
211 | 200 | """Notify for the room events that were queued waiting for a previous |
212 | 201 | event to be persisted. |
224 | 213 | else: |
225 | 214 | self._on_new_room_event(event, room_stream_id, extra_users) |
226 | 215 | |
216 | @preserve_fn | |
227 | 217 | def _on_new_room_event(self, event, room_stream_id, extra_users=[]): |
228 | 218 | """Notify any user streams that are interested in this room event""" |
229 | 219 | # poke any interested application service. |
230 | self.appservice_handler.notify_interested_services(event) | |
231 | ||
232 | app_streams = set() | |
233 | ||
234 | for appservice in self.appservice_to_user_streams: | |
235 | # TODO (kegan): Redundant appservice listener checks? | |
236 | # App services will already be in the room_to_user_streams set, but | |
237 | # that isn't enough. They need to be checked here in order to | |
238 | # receive *invites* for users they are interested in. Does this | |
239 | # make the room_to_user_streams check somewhat obselete? | |
240 | if appservice.is_interested(event): | |
241 | app_user_streams = self.appservice_to_user_streams.get( | |
242 | appservice, set() | |
243 | ) | |
244 | app_streams |= app_user_streams | |
220 | self.appservice_handler.notify_interested_services(room_stream_id) | |
245 | 221 | |
246 | 222 | if event.type == EventTypes.Member and event.membership == Membership.JOIN: |
247 | 223 | self._user_joined_room(event.state_key, event.room_id) |
250 | 226 | "room_key", room_stream_id, |
251 | 227 | users=extra_users, |
252 | 228 | rooms=[event.room_id], |
253 | extra_streams=app_streams, | |
254 | ) | |
255 | ||
256 | def on_new_event(self, stream_key, new_token, users=[], rooms=[], | |
257 | extra_streams=set()): | |
229 | ) | |
230 | ||
231 | @preserve_fn | |
232 | def on_new_event(self, stream_key, new_token, users=[], rooms=[]): | |
258 | 233 | """ Used to inform listeners that something has happend event wise. |
259 | 234 | |
260 | 235 | Will wake up all listeners for the given users and rooms. |
261 | 236 | """ |
262 | 237 | with PreserveLoggingContext(): |
263 | user_streams = set() | |
264 | ||
265 | for user in users: | |
266 | user_stream = self.user_to_user_stream.get(str(user)) | |
267 | if user_stream is not None: | |
268 | user_streams.add(user_stream) | |
269 | ||
270 | for room in rooms: | |
271 | user_streams |= self.room_to_user_streams.get(room, set()) | |
272 | ||
273 | time_now_ms = self.clock.time_msec() | |
274 | for user_stream in user_streams: | |
275 | try: | |
276 | user_stream.notify(stream_key, new_token, time_now_ms) | |
277 | except: | |
278 | logger.exception("Failed to notify listener") | |
279 | ||
280 | self.notify_replication() | |
281 | ||
238 | with Measure(self.clock, "on_new_event"): | |
239 | user_streams = set() | |
240 | ||
241 | for user in users: | |
242 | user_stream = self.user_to_user_stream.get(str(user)) | |
243 | if user_stream is not None: | |
244 | user_streams.add(user_stream) | |
245 | ||
246 | for room in rooms: | |
247 | user_streams |= self.room_to_user_streams.get(room, set()) | |
248 | ||
249 | time_now_ms = self.clock.time_msec() | |
250 | for user_stream in user_streams: | |
251 | try: | |
252 | user_stream.notify(stream_key, new_token, time_now_ms) | |
253 | except: | |
254 | logger.exception("Failed to notify listener") | |
255 | ||
256 | self.notify_replication() | |
257 | ||
258 | @preserve_fn | |
282 | 259 | def on_new_replication_data(self): |
283 | 260 | """Used to inform replication listeners that something has happend |
284 | 261 | without waking up any of the normal user event streams""" |
293 | 270 | """ |
294 | 271 | user_stream = self.user_to_user_stream.get(user_id) |
295 | 272 | if user_stream is None: |
296 | appservice = yield self.store.get_app_service_by_user_id(user_id) | |
297 | 273 | current_token = yield self.event_sources.get_current_token() |
298 | 274 | if room_ids is None: |
299 | 275 | rooms = yield self.store.get_rooms_for_user(user_id) |
301 | 277 | user_stream = _NotifierUserStream( |
302 | 278 | user_id=user_id, |
303 | 279 | rooms=room_ids, |
304 | appservice=appservice, | |
305 | 280 | current_token=current_token, |
306 | 281 | time_now_ms=self.clock.time_msec(), |
307 | 282 | ) |
476 | 451 | s = self.room_to_user_streams.setdefault(room, set()) |
477 | 452 | s.add(user_stream) |
478 | 453 | |
479 | if user_stream.appservice: | |
480 | self.appservice_to_user_stream.setdefault( | |
481 | user_stream.appservice, set() | |
482 | ).add(user_stream) | |
483 | ||
484 | 454 | def _user_joined_room(self, user_id, room_id): |
485 | 455 | new_user_stream = self.user_to_user_stream.get(user_id) |
486 | 456 | if new_user_stream is not None: |
37 | 37 | |
38 | 38 | @defer.inlineCallbacks |
39 | 39 | def handle_push_actions_for_event(self, event, context): |
40 | with Measure(self.clock, "handle_push_actions_for_event"): | |
40 | with Measure(self.clock, "evaluator_for_event"): | |
41 | 41 | bulk_evaluator = yield evaluator_for_event( |
42 | event, self.hs, self.store, context.current_state | |
42 | event, self.hs, self.store, context.state_group, context.current_state | |
43 | 43 | ) |
44 | 44 | |
45 | with Measure(self.clock, "action_for_event_by_user"): | |
45 | 46 | actions_by_user = yield bulk_evaluator.action_for_event_by_user( |
46 | 47 | event, context.current_state |
47 | 48 | ) |
48 | 49 | |
49 | context.push_actions = [ | |
50 | (uid, actions) for uid, actions in actions_by_user.items() | |
51 | ] | |
50 | context.push_actions = [ | |
51 | (uid, actions) for uid, actions in actions_by_user.items() | |
52 | ] |
216 | 216 | 'dont_notify' |
217 | 217 | ] |
218 | 218 | }, |
219 | # This was changed from underride to override so it's closer in priority | |
220 | # to the content rules where the user name highlight rule lives. This | |
221 | # way a room rule is lower priority than both but a custom override rule | |
222 | # is higher priority than both. | |
223 | { | |
224 | 'rule_id': 'global/override/.m.rule.contains_display_name', | |
225 | 'conditions': [ | |
226 | { | |
227 | 'kind': 'contains_display_name' | |
228 | } | |
229 | ], | |
230 | 'actions': [ | |
231 | 'notify', | |
232 | { | |
233 | 'set_tweak': 'sound', | |
234 | 'value': 'default' | |
235 | }, { | |
236 | 'set_tweak': 'highlight' | |
237 | } | |
238 | ] | |
239 | }, | |
219 | 240 | ] |
220 | 241 | |
221 | 242 | |
238 | 259 | }, { |
239 | 260 | 'set_tweak': 'highlight', |
240 | 261 | 'value': False |
241 | } | |
242 | ] | |
243 | }, | |
244 | { | |
245 | 'rule_id': 'global/underride/.m.rule.contains_display_name', | |
246 | 'conditions': [ | |
247 | { | |
248 | 'kind': 'contains_display_name' | |
249 | } | |
250 | ], | |
251 | 'actions': [ | |
252 | 'notify', | |
253 | { | |
254 | 'set_tweak': 'sound', | |
255 | 'value': 'default' | |
256 | }, { | |
257 | 'set_tweak': 'highlight' | |
258 | 262 | } |
259 | 263 | ] |
260 | 264 | }, |
35 | 35 | |
36 | 36 | |
37 | 37 | @defer.inlineCallbacks |
38 | def evaluator_for_event(event, hs, store, current_state): | |
39 | room_id = event.room_id | |
40 | # We also will want to generate notifs for other people in the room so | |
41 | # their unread countss are correct in the event stream, but to avoid | |
42 | # generating them for bot / AS users etc, we only do so for people who've | |
43 | # sent a read receipt into the room. | |
44 | ||
45 | local_users_in_room = set( | |
46 | e.state_key for e in current_state.values() | |
47 | if e.type == EventTypes.Member and e.membership == Membership.JOIN | |
48 | and hs.is_mine_id(e.state_key) | |
38 | def evaluator_for_event(event, hs, store, state_group, current_state): | |
39 | rules_by_user = yield store.bulk_get_push_rules_for_room( | |
40 | event.room_id, state_group, current_state | |
49 | 41 | ) |
50 | ||
51 | # users in the room who have pushers need to get push rules run because | |
52 | # that's how their pushers work | |
53 | if_users_with_pushers = yield store.get_if_users_have_pushers( | |
54 | local_users_in_room | |
55 | ) | |
56 | user_ids = set( | |
57 | uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher | |
58 | ) | |
59 | ||
60 | users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id) | |
61 | ||
62 | # any users with pushers must be ours: they have pushers | |
63 | for uid in users_with_receipts: | |
64 | if uid in local_users_in_room: | |
65 | user_ids.add(uid) | |
66 | 42 | |
67 | 43 | # if this event is an invite event, we may need to run rules for the user |
68 | 44 | # who's been invited, otherwise they won't get told they've been invited |
71 | 47 | if invited_user and hs.is_mine_id(invited_user): |
72 | 48 | has_pusher = yield store.user_has_pusher(invited_user) |
73 | 49 | if has_pusher: |
74 | user_ids.add(invited_user) | |
75 | ||
76 | rules_by_user = yield _get_rules(room_id, user_ids, store) | |
50 | rules_by_user[invited_user] = yield store.get_push_rules_for_user( | |
51 | invited_user | |
52 | ) | |
77 | 53 | |
78 | 54 | defer.returnValue(BulkPushRuleEvaluator( |
79 | room_id, rules_by_user, user_ids, store | |
55 | event.room_id, rules_by_user, store | |
80 | 56 | )) |
81 | 57 | |
82 | 58 | |
89 | 65 | the same logic to run the actual rules, but could be optimised further |
90 | 66 | (see https://matrix.org/jira/browse/SYN-562) |
91 | 67 | """ |
92 | def __init__(self, room_id, rules_by_user, users_in_room, store): | |
68 | def __init__(self, room_id, rules_by_user, store): | |
93 | 69 | self.room_id = room_id |
94 | 70 | self.rules_by_user = rules_by_user |
95 | self.users_in_room = users_in_room | |
96 | 71 | self.store = store |
97 | 72 | |
98 | 73 | @defer.inlineCallbacks |
16 | 16 | from synapse.util.presentable_names import ( |
17 | 17 | calculate_room_name, name_from_member_event |
18 | 18 | ) |
19 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |
19 | 20 | |
20 | 21 | |
21 | 22 | @defer.inlineCallbacks |
22 | 23 | def get_badge_count(store, user_id): |
23 | invites, joins = yield defer.gatherResults([ | |
24 | store.get_invited_rooms_for_user(user_id), | |
25 | store.get_rooms_for_user(user_id), | |
26 | ], consumeErrors=True) | |
24 | invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ | |
25 | preserve_fn(store.get_invited_rooms_for_user)(user_id), | |
26 | preserve_fn(store.get_rooms_for_user)(user_id), | |
27 | ], consumeErrors=True)) | |
27 | 28 | |
28 | 29 | my_receipts_by_room = yield store.get_receipts_for_user( |
29 | 30 | user_id, "m.read", |
16 | 16 | from twisted.internet import defer |
17 | 17 | |
18 | 18 | import pusher |
19 | from synapse.util.logcontext import preserve_fn | |
19 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |
20 | 20 | from synapse.util.async import run_on_reactor |
21 | 21 | |
22 | 22 | import logging |
101 | 101 | yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) |
102 | 102 | |
103 | 103 | @defer.inlineCallbacks |
104 | def remove_pushers_by_user(self, user_id, except_token_ids=[]): | |
104 | def remove_pushers_by_user(self, user_id, except_access_token_id=None): | |
105 | 105 | all = yield self.store.get_all_pushers() |
106 | 106 | logger.info( |
107 | "Removing all pushers for user %s except access tokens ids %r", | |
108 | user_id, except_token_ids | |
107 | "Removing all pushers for user %s except access tokens id %r", | |
108 | user_id, except_access_token_id | |
109 | 109 | ) |
110 | 110 | for p in all: |
111 | if p['user_name'] == user_id and p['access_token'] not in except_token_ids: | |
111 | if p['user_name'] == user_id and p['access_token'] != except_access_token_id: | |
112 | 112 | logger.info( |
113 | 113 | "Removing pusher for app id %s, pushkey %s, user %s", |
114 | 114 | p['app_id'], p['pushkey'], p['user_name'] |
129 | 129 | if u in self.pushers: |
130 | 130 | for p in self.pushers[u].values(): |
131 | 131 | deferreds.append( |
132 | p.on_new_notifications(min_stream_id, max_stream_id) | |
132 | preserve_fn(p.on_new_notifications)( | |
133 | min_stream_id, max_stream_id | |
134 | ) | |
133 | 135 | ) |
134 | 136 | |
135 | yield defer.gatherResults(deferreds) | |
137 | yield preserve_context_over_deferred(defer.gatherResults(deferreds)) | |
136 | 138 | except: |
137 | 139 | logger.exception("Exception in pusher on_new_notifications") |
138 | 140 | |
154 | 156 | if u in self.pushers: |
155 | 157 | for p in self.pushers[u].values(): |
156 | 158 | deferreds.append( |
157 | p.on_new_receipts(min_stream_id, max_stream_id) | |
159 | preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) | |
158 | 160 | ) |
159 | 161 | |
160 | yield defer.gatherResults(deferreds) | |
162 | yield preserve_context_over_deferred(defer.gatherResults(deferreds)) | |
161 | 163 | except: |
162 | 164 | logger.exception("Exception in pusher on_new_receipts") |
163 | 165 |
40 | 40 | ("push_rules",), |
41 | 41 | ("pushers",), |
42 | 42 | ("state",), |
43 | ("caches",), | |
43 | 44 | ) |
44 | 45 | |
45 | 46 | |
69 | 70 | * "backfill": Old events that have been backfilled from other servers. |
70 | 71 | * "push_rules": Per user changes to push rules. |
71 | 72 | * "pushers": Per user changes to their pushers. |
73 | * "caches": Cache invalidations. | |
72 | 74 | |
73 | 75 | The API takes two additional query parameters: |
74 | 76 | |
128 | 130 | push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() |
129 | 131 | pushers_token = self.store.get_pushers_stream_token() |
130 | 132 | state_token = self.store.get_state_stream_token() |
133 | caches_token = self.store.get_cache_stream_token() | |
131 | 134 | |
132 | 135 | defer.returnValue(_ReplicationToken( |
133 | 136 | room_stream_token, |
139 | 142 | push_rules_token, |
140 | 143 | pushers_token, |
141 | 144 | state_token, |
145 | caches_token, | |
142 | 146 | )) |
143 | 147 | |
144 | 148 | @request_handler() |
187 | 191 | yield self.push_rules(writer, current_token, limit, request_streams) |
188 | 192 | yield self.pushers(writer, current_token, limit, request_streams) |
189 | 193 | yield self.state(writer, current_token, limit, request_streams) |
194 | yield self.caches(writer, current_token, limit, request_streams) | |
190 | 195 | self.streams(writer, current_token, request_streams) |
191 | 196 | |
192 | 197 | logger.info("Replicated %d rows", writer.total) |
378 | 383 | "position", "type", "state_key", "event_id" |
379 | 384 | )) |
380 | 385 | |
386 | @defer.inlineCallbacks | |
387 | def caches(self, writer, current_token, limit, request_streams): | |
388 | current_position = current_token.caches | |
389 | ||
390 | caches = request_streams.get("caches") | |
391 | ||
392 | if caches is not None: | |
393 | updated_caches = yield self.store.get_all_updated_caches( | |
394 | caches, current_position, limit | |
395 | ) | |
396 | writer.write_header_and_rows("caches", updated_caches, ( | |
397 | "position", "cache_func", "keys", "invalidation_ts" | |
398 | )) | |
399 | ||
381 | 400 | |
382 | 401 | class _Writer(object): |
383 | 402 | """Writes the streams as a JSON object as the response to the request""" |
406 | 425 | |
407 | 426 | class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( |
408 | 427 | "events", "presence", "typing", "receipts", "account_data", "backfill", |
409 | "push_rules", "pushers", "state" | |
428 | "push_rules", "pushers", "state", "caches", | |
410 | 429 | ))): |
411 | 430 | __slots__ = [] |
412 | 431 |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | from synapse.storage._base import SQLBaseStore |
16 | from synapse.storage.engines import PostgresEngine | |
16 | 17 | from twisted.internet import defer |
18 | ||
19 | from ._slaved_id_tracker import SlavedIdTracker | |
20 | ||
21 | import logging | |
22 | ||
23 | logger = logging.getLogger(__name__) | |
17 | 24 | |
18 | 25 | |
19 | 26 | class BaseSlavedStore(SQLBaseStore): |
20 | 27 | def __init__(self, db_conn, hs): |
21 | 28 | super(BaseSlavedStore, self).__init__(hs) |
29 | if isinstance(self.database_engine, PostgresEngine): | |
30 | self._cache_id_gen = SlavedIdTracker( | |
31 | db_conn, "cache_invalidation_stream", "stream_id", | |
32 | ) | |
33 | else: | |
34 | self._cache_id_gen = None | |
22 | 35 | |
23 | 36 | def stream_positions(self): |
24 | return {} | |
37 | pos = {} | |
38 | if self._cache_id_gen: | |
39 | pos["caches"] = self._cache_id_gen.get_current_token() | |
40 | return pos | |
25 | 41 | |
26 | 42 | def process_replication(self, result): |
43 | stream = result.get("caches") | |
44 | if stream: | |
45 | for row in stream["rows"]: | |
46 | ( | |
47 | position, cache_func, keys, invalidation_ts, | |
48 | ) = row | |
49 | ||
50 | try: | |
51 | getattr(self, cache_func).invalidate(tuple(keys)) | |
52 | except AttributeError: | |
53 | logger.info("Got unexpected cache_func: %r", cache_func) | |
54 | self._cache_id_gen.advance(int(stream["position"])) | |
27 | 55 | return defer.succeed(None) |
27 | 27 | |
28 | 28 | get_app_service_by_token = DataStore.get_app_service_by_token.__func__ |
29 | 29 | get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ |
30 | get_app_services = DataStore.get_app_services.__func__ | |
31 | get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__ | |
32 | create_appservice_txn = DataStore.create_appservice_txn.__func__ | |
33 | get_appservices_by_state = DataStore.get_appservices_by_state.__func__ | |
34 | get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__ | |
35 | _get_last_txn = DataStore._get_last_txn.__func__ | |
36 | complete_appservice_txn = DataStore.complete_appservice_txn.__func__ | |
37 | get_appservice_state = DataStore.get_appservice_state.__func__ | |
38 | set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__ | |
39 | set_appservice_state = DataStore.set_appservice_state.__func__ |
19 | 19 | class DirectoryStore(BaseSlavedStore): |
20 | 20 | get_aliases_for_room = DirectoryStore.__dict__[ |
21 | 21 | "get_aliases_for_room" |
22 | ].orig | |
22 | ] |
24 | 24 | # TODO: use the cached version and invalidate deleted tokens |
25 | 25 | get_user_by_access_token = RegistrationStore.__dict__[ |
26 | 26 | "get_user_by_access_token" |
27 | ].orig | |
27 | ] | |
28 | 28 | |
29 | 29 | _query_for_auth = DataStore._query_for_auth.__func__ |
30 | get_user_by_id = RegistrationStore.__dict__[ | |
31 | "get_user_by_id" | |
32 | ] |
45 | 45 | account_data, |
46 | 46 | report_event, |
47 | 47 | openid, |
48 | notifications, | |
48 | 49 | devices, |
50 | thirdparty, | |
49 | 51 | ) |
50 | 52 | |
51 | 53 | from synapse.http.server import JsonResource |
90 | 92 | account_data.register_servlets(hs, client_resource) |
91 | 93 | report_event.register_servlets(hs, client_resource) |
92 | 94 | openid.register_servlets(hs, client_resource) |
95 | notifications.register_servlets(hs, client_resource) | |
93 | 96 | devices.register_servlets(hs, client_resource) |
97 | thirdparty.register_servlets(hs, client_resource) |
26 | 26 | |
27 | 27 | class WhoisRestServlet(ClientV1RestServlet): |
28 | 28 | PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") |
29 | ||
30 | def __init__(self, hs): | |
31 | super(WhoisRestServlet, self).__init__(hs) | |
32 | self.handlers = hs.get_handlers() | |
29 | 33 | |
30 | 34 | @defer.inlineCallbacks |
31 | 35 | def on_GET(self, request, user_id): |
81 | 85 | "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" |
82 | 86 | ) |
83 | 87 | |
88 | def __init__(self, hs): | |
89 | super(PurgeHistoryRestServlet, self).__init__(hs) | |
90 | self.handlers = hs.get_handlers() | |
91 | ||
84 | 92 | @defer.inlineCallbacks |
85 | 93 | def on_POST(self, request, room_id, event_id): |
86 | 94 | requester = yield self.auth.get_user_by_req(request) |
56 | 56 | hs (synapse.server.HomeServer): |
57 | 57 | """ |
58 | 58 | self.hs = hs |
59 | self.handlers = hs.get_handlers() | |
60 | 59 | self.builder_factory = hs.get_event_builder_factory() |
61 | 60 | self.auth = hs.get_v1auth() |
62 | 61 | self.txns = HttpTransactionStore() |
34 | 34 | |
35 | 35 | class ClientDirectoryServer(ClientV1RestServlet): |
36 | 36 | PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") |
37 | ||
38 | def __init__(self, hs): | |
39 | super(ClientDirectoryServer, self).__init__(hs) | |
40 | self.handlers = hs.get_handlers() | |
37 | 41 | |
38 | 42 | @defer.inlineCallbacks |
39 | 43 | def on_GET(self, request, room_alias): |
145 | 149 | def __init__(self, hs): |
146 | 150 | super(ClientDirectoryListServer, self).__init__(hs) |
147 | 151 | self.store = hs.get_datastore() |
152 | self.handlers = hs.get_handlers() | |
148 | 153 | |
149 | 154 | @defer.inlineCallbacks |
150 | 155 | def on_GET(self, request, room_id): |
31 | 31 | |
32 | 32 | DEFAULT_LONGPOLL_TIME_MS = 30000 |
33 | 33 | |
34 | def __init__(self, hs): | |
35 | super(EventStreamRestServlet, self).__init__(hs) | |
36 | self.event_stream_handler = hs.get_event_stream_handler() | |
37 | ||
34 | 38 | @defer.inlineCallbacks |
35 | 39 | def on_GET(self, request): |
36 | 40 | requester = yield self.auth.get_user_by_req( |
45 | 49 | if "room_id" in request.args: |
46 | 50 | room_id = request.args["room_id"][0] |
47 | 51 | |
48 | handler = self.handlers.event_stream_handler | |
49 | 52 | pagin_config = PaginationConfig.from_request(request) |
50 | 53 | timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS |
51 | 54 | if "timeout" in request.args: |
56 | 59 | |
57 | 60 | as_client_event = "raw" not in request.args |
58 | 61 | |
59 | chunk = yield handler.get_stream( | |
62 | chunk = yield self.event_stream_handler.get_stream( | |
60 | 63 | requester.user.to_string(), |
61 | 64 | pagin_config, |
62 | 65 | timeout=timeout, |
79 | 82 | def __init__(self, hs): |
80 | 83 | super(EventRestServlet, self).__init__(hs) |
81 | 84 | self.clock = hs.get_clock() |
85 | self.event_handler = hs.get_event_handler() | |
82 | 86 | |
83 | 87 | @defer.inlineCallbacks |
84 | 88 | def on_GET(self, request, event_id): |
85 | 89 | requester = yield self.auth.get_user_by_req(request) |
86 | handler = self.handlers.event_handler | |
87 | event = yield handler.get_event(requester.user, event_id) | |
90 | event = yield self.event_handler.get_event(requester.user, event_id) | |
88 | 91 | |
89 | 92 | time_now = self.clock.time_msec() |
90 | 93 | if event: |
22 | 22 | class InitialSyncRestServlet(ClientV1RestServlet): |
23 | 23 | PATTERNS = client_path_patterns("/initialSync$") |
24 | 24 | |
25 | def __init__(self, hs): | |
26 | super(InitialSyncRestServlet, self).__init__(hs) | |
27 | self.handlers = hs.get_handlers() | |
28 | ||
25 | 29 | @defer.inlineCallbacks |
26 | 30 | def on_GET(self, request): |
27 | 31 | requester = yield self.auth.get_user_by_req(request) |
53 | 53 | self.jwt_secret = hs.config.jwt_secret |
54 | 54 | self.jwt_algorithm = hs.config.jwt_algorithm |
55 | 55 | self.cas_enabled = hs.config.cas_enabled |
56 | self.cas_server_url = hs.config.cas_server_url | |
57 | self.cas_required_attributes = hs.config.cas_required_attributes | |
58 | self.servername = hs.config.server_name | |
59 | self.http_client = hs.get_simple_http_client() | |
60 | 56 | self.auth_handler = self.hs.get_auth_handler() |
61 | 57 | self.device_handler = self.hs.get_device_handler() |
58 | self.handlers = hs.get_handlers() | |
62 | 59 | |
63 | 60 | def on_GET(self, request): |
64 | 61 | flows = [] |
109 | 106 | LoginRestServlet.JWT_TYPE): |
110 | 107 | result = yield self.do_jwt_login(login_submission) |
111 | 108 | defer.returnValue(result) |
112 | # TODO Delete this after all CAS clients switch to token login instead | |
113 | elif self.cas_enabled and (login_submission["type"] == | |
114 | LoginRestServlet.CAS_TYPE): | |
115 | uri = "%s/proxyValidate" % (self.cas_server_url,) | |
116 | args = { | |
117 | "ticket": login_submission["ticket"], | |
118 | "service": login_submission["service"] | |
119 | } | |
120 | body = yield self.http_client.get_raw(uri, args) | |
121 | result = yield self.do_cas_login(body) | |
122 | defer.returnValue(result) | |
123 | 109 | elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: |
124 | 110 | result = yield self.do_token_login(login_submission) |
125 | 111 | defer.returnValue(result) |
187 | 173 | "home_server": self.hs.hostname, |
188 | 174 | "device_id": device_id, |
189 | 175 | } |
190 | ||
191 | defer.returnValue((200, result)) | |
192 | ||
193 | # TODO Delete this after all CAS clients switch to token login instead | |
194 | @defer.inlineCallbacks | |
195 | def do_cas_login(self, cas_response_body): | |
196 | user, attributes = self.parse_cas_response(cas_response_body) | |
197 | ||
198 | for required_attribute, required_value in self.cas_required_attributes.items(): | |
199 | # If required attribute was not in CAS Response - Forbidden | |
200 | if required_attribute not in attributes: | |
201 | raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) | |
202 | ||
203 | # Also need to check value | |
204 | if required_value is not None: | |
205 | actual_value = attributes[required_attribute] | |
206 | # If required attribute value does not match expected - Forbidden | |
207 | if required_value != actual_value: | |
208 | raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) | |
209 | ||
210 | user_id = UserID.create(user, self.hs.hostname).to_string() | |
211 | auth_handler = self.auth_handler | |
212 | registered_user_id = yield auth_handler.check_user_exists(user_id) | |
213 | if registered_user_id: | |
214 | access_token, refresh_token = ( | |
215 | yield auth_handler.get_login_tuple_for_user_id( | |
216 | registered_user_id | |
217 | ) | |
218 | ) | |
219 | result = { | |
220 | "user_id": registered_user_id, # may have changed | |
221 | "access_token": access_token, | |
222 | "refresh_token": refresh_token, | |
223 | "home_server": self.hs.hostname, | |
224 | } | |
225 | ||
226 | else: | |
227 | user_id, access_token = ( | |
228 | yield self.handlers.registration_handler.register(localpart=user) | |
229 | ) | |
230 | result = { | |
231 | "user_id": user_id, # may have changed | |
232 | "access_token": access_token, | |
233 | "home_server": self.hs.hostname, | |
234 | } | |
235 | 176 | |
236 | 177 | defer.returnValue((200, result)) |
237 | 178 | |
291 | 232 | } |
292 | 233 | |
293 | 234 | defer.returnValue((200, result)) |
294 | ||
295 | # TODO Delete this after all CAS clients switch to token login instead | |
296 | def parse_cas_response(self, cas_response_body): | |
297 | root = ET.fromstring(cas_response_body) | |
298 | if not root.tag.endswith("serviceResponse"): | |
299 | raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) | |
300 | if not root[0].tag.endswith("authenticationSuccess"): | |
301 | raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) | |
302 | for child in root[0]: | |
303 | if child.tag.endswith("user"): | |
304 | user = child.text | |
305 | if child.tag.endswith("attributes"): | |
306 | attributes = {} | |
307 | for attribute in child: | |
308 | # ElementTree library expands the namespace in attribute tags | |
309 | # to the full URL of the namespace. | |
310 | # See (https://docs.python.org/2/library/xml.etree.elementtree.html) | |
311 | # We don't care about namespace here and it will always be encased in | |
312 | # curly braces, so we remove them. | |
313 | if "}" in attribute.tag: | |
314 | attributes[attribute.tag.split("}")[1]] = attribute.text | |
315 | else: | |
316 | attributes[attribute.tag] = attribute.text | |
317 | if user is None or attributes is None: | |
318 | raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) | |
319 | ||
320 | return (user, attributes) | |
321 | 235 | |
322 | 236 | def _register_device(self, user_id, login_submission): |
323 | 237 | """Register a device for a user. |
346 | 260 | def __init__(self, hs): |
347 | 261 | super(SAML2RestServlet, self).__init__(hs) |
348 | 262 | self.sp_config = hs.config.saml2_config_path |
263 | self.handlers = hs.get_handlers() | |
349 | 264 | |
350 | 265 | @defer.inlineCallbacks |
351 | 266 | def on_POST(self, request): |
383 | 298 | defer.returnValue((200, {"status": "not_authenticated"})) |
384 | 299 | |
385 | 300 | |
386 | # TODO Delete this after all CAS clients switch to token login instead | |
387 | class CasRestServlet(ClientV1RestServlet): | |
388 | PATTERNS = client_path_patterns("/login/cas", releases=()) | |
389 | ||
390 | def __init__(self, hs): | |
391 | super(CasRestServlet, self).__init__(hs) | |
392 | self.cas_server_url = hs.config.cas_server_url | |
393 | ||
394 | def on_GET(self, request): | |
395 | return (200, {"serverUrl": self.cas_server_url}) | |
396 | ||
397 | ||
398 | 301 | class CasRedirectServlet(ClientV1RestServlet): |
399 | 302 | PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) |
400 | 303 | |
426 | 329 | self.cas_server_url = hs.config.cas_server_url |
427 | 330 | self.cas_service_url = hs.config.cas_service_url |
428 | 331 | self.cas_required_attributes = hs.config.cas_required_attributes |
332 | self.auth_handler = hs.get_auth_handler() | |
333 | self.handlers = hs.get_handlers() | |
429 | 334 | |
430 | 335 | @defer.inlineCallbacks |
431 | 336 | def on_GET(self, request): |
478 | 383 | return urlparse.urlunparse(url_parts) |
479 | 384 | |
480 | 385 | def parse_cas_response(self, cas_response_body): |
481 | root = ET.fromstring(cas_response_body) | |
482 | if not root.tag.endswith("serviceResponse"): | |
483 | raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) | |
484 | if not root[0].tag.endswith("authenticationSuccess"): | |
485 | raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) | |
486 | for child in root[0]: | |
487 | if child.tag.endswith("user"): | |
488 | user = child.text | |
489 | if child.tag.endswith("attributes"): | |
490 | attributes = {} | |
491 | for attribute in child: | |
492 | # ElementTree library expands the namespace in attribute tags | |
493 | # to the full URL of the namespace. | |
494 | # See (https://docs.python.org/2/library/xml.etree.elementtree.html) | |
495 | # We don't care about namespace here and it will always be encased in | |
496 | # curly braces, so we remove them. | |
497 | if "}" in attribute.tag: | |
498 | attributes[attribute.tag.split("}")[1]] = attribute.text | |
499 | else: | |
500 | attributes[attribute.tag] = attribute.text | |
501 | if user is None or attributes is None: | |
502 | raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) | |
503 | ||
504 | return (user, attributes) | |
386 | user = None | |
387 | attributes = None | |
388 | try: | |
389 | root = ET.fromstring(cas_response_body) | |
390 | if not root.tag.endswith("serviceResponse"): | |
391 | raise Exception("root of CAS response is not serviceResponse") | |
392 | success = (root[0].tag.endswith("authenticationSuccess")) | |
393 | for child in root[0]: | |
394 | if child.tag.endswith("user"): | |
395 | user = child.text | |
396 | if child.tag.endswith("attributes"): | |
397 | attributes = {} | |
398 | for attribute in child: | |
399 | # ElementTree library expands the namespace in | |
400 | # attribute tags to the full URL of the namespace. | |
401 | # We don't care about namespace here and it will always | |
402 | # be encased in curly braces, so we remove them. | |
403 | tag = attribute.tag | |
404 | if "}" in tag: | |
405 | tag = tag.split("}")[1] | |
406 | attributes[tag] = attribute.text | |
407 | if user is None: | |
408 | raise Exception("CAS response does not contain user") | |
409 | if attributes is None: | |
410 | raise Exception("CAS response does not contain attributes") | |
411 | except Exception: | |
412 | logger.error("Error parsing CAS response", exc_info=1) | |
413 | raise LoginError(401, "Invalid CAS response", | |
414 | errcode=Codes.UNAUTHORIZED) | |
415 | if not success: | |
416 | raise LoginError(401, "Unsuccessful CAS response", | |
417 | errcode=Codes.UNAUTHORIZED) | |
418 | return user, attributes | |
505 | 419 | |
506 | 420 | |
507 | 421 | def register_servlets(hs, http_server): |
511 | 425 | if hs.config.cas_enabled: |
512 | 426 | CasRedirectServlet(hs).register(http_server) |
513 | 427 | CasTicketServlet(hs).register(http_server) |
514 | CasRestServlet(hs).register(http_server) | |
515 | # TODO PasswordResetRestServlet(hs).register(http_server) |
22 | 22 | |
23 | 23 | class ProfileDisplaynameRestServlet(ClientV1RestServlet): |
24 | 24 | PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") |
25 | ||
26 | def __init__(self, hs): | |
27 | super(ProfileDisplaynameRestServlet, self).__init__(hs) | |
28 | self.handlers = hs.get_handlers() | |
25 | 29 | |
26 | 30 | @defer.inlineCallbacks |
27 | 31 | def on_GET(self, request, user_id): |
61 | 65 | class ProfileAvatarURLRestServlet(ClientV1RestServlet): |
62 | 66 | PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") |
63 | 67 | |
68 | def __init__(self, hs): | |
69 | super(ProfileAvatarURLRestServlet, self).__init__(hs) | |
70 | self.handlers = hs.get_handlers() | |
71 | ||
64 | 72 | @defer.inlineCallbacks |
65 | 73 | def on_GET(self, request, user_id): |
66 | 74 | user = UserID.from_string(user_id) |
98 | 106 | class ProfileRestServlet(ClientV1RestServlet): |
99 | 107 | PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") |
100 | 108 | |
109 | def __init__(self, hs): | |
110 | super(ProfileRestServlet, self).__init__(hs) | |
111 | self.handlers = hs.get_handlers() | |
112 | ||
101 | 113 | @defer.inlineCallbacks |
102 | 114 | def on_GET(self, request, user_id): |
103 | 115 | user = UserID.from_string(user_id) |
64 | 64 | self.sessions = {} |
65 | 65 | self.enable_registration = hs.config.enable_registration |
66 | 66 | self.auth_handler = hs.get_auth_handler() |
67 | self.handlers = hs.get_handlers() | |
67 | 68 | |
68 | 69 | def on_GET(self, request): |
69 | 70 | if self.hs.config.enable_registration_captcha: |
382 | 383 | super(CreateUserRestServlet, self).__init__(hs) |
383 | 384 | self.store = hs.get_datastore() |
384 | 385 | self.direct_user_creation_max_duration = hs.config.user_creation_max_duration |
386 | self.handlers = hs.get_handlers() | |
385 | 387 | |
386 | 388 | @defer.inlineCallbacks |
387 | 389 | def on_POST(self, request): |
34 | 34 | class RoomCreateRestServlet(ClientV1RestServlet): |
35 | 35 | # No PATTERN; we have custom dispatch rules here |
36 | 36 | |
37 | def __init__(self, hs): | |
38 | super(RoomCreateRestServlet, self).__init__(hs) | |
39 | self.handlers = hs.get_handlers() | |
40 | ||
37 | 41 | def register(self, http_server): |
38 | 42 | PATTERNS = "/createRoom" |
39 | 43 | register_txn_path(self, PATTERNS, http_server) |
81 | 85 | |
82 | 86 | # TODO: Needs unit testing for generic events |
83 | 87 | class RoomStateEventRestServlet(ClientV1RestServlet): |
88 | def __init__(self, hs): | |
89 | super(RoomStateEventRestServlet, self).__init__(hs) | |
90 | self.handlers = hs.get_handlers() | |
91 | ||
84 | 92 | def register(self, http_server): |
85 | 93 | # /room/$roomid/state/$eventtype |
86 | 94 | no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" |
165 | 173 | # TODO: Needs unit testing for generic events + feedback |
166 | 174 | class RoomSendEventRestServlet(ClientV1RestServlet): |
167 | 175 | |
176 | def __init__(self, hs): | |
177 | super(RoomSendEventRestServlet, self).__init__(hs) | |
178 | self.handlers = hs.get_handlers() | |
179 | ||
168 | 180 | def register(self, http_server): |
169 | 181 | # /rooms/$roomid/send/$event_type[/$txn_id] |
170 | 182 | PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") |
209 | 221 | |
210 | 222 | # TODO: Needs unit testing for room ID + alias joins |
211 | 223 | class JoinRoomAliasServlet(ClientV1RestServlet): |
224 | def __init__(self, hs): | |
225 | super(JoinRoomAliasServlet, self).__init__(hs) | |
226 | self.handlers = hs.get_handlers() | |
212 | 227 | |
213 | 228 | def register(self, http_server): |
214 | 229 | # /join/$room_identifier[/$txn_id] |
252 | 267 | action="join", |
253 | 268 | txn_id=txn_id, |
254 | 269 | remote_room_hosts=remote_room_hosts, |
270 | content=content, | |
255 | 271 | third_party_signed=content.get("third_party_signed", None), |
256 | 272 | ) |
257 | 273 | |
294 | 310 | # TODO: Needs unit testing |
295 | 311 | class RoomMemberListRestServlet(ClientV1RestServlet): |
296 | 312 | PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") |
313 | ||
314 | def __init__(self, hs): | |
315 | super(RoomMemberListRestServlet, self).__init__(hs) | |
316 | self.handlers = hs.get_handlers() | |
297 | 317 | |
298 | 318 | @defer.inlineCallbacks |
299 | 319 | def on_GET(self, request, room_id): |
320 | 340 | # TODO: Needs better unit testing |
321 | 341 | class RoomMessageListRestServlet(ClientV1RestServlet): |
322 | 342 | PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") |
343 | ||
344 | def __init__(self, hs): | |
345 | super(RoomMessageListRestServlet, self).__init__(hs) | |
346 | self.handlers = hs.get_handlers() | |
323 | 347 | |
324 | 348 | @defer.inlineCallbacks |
325 | 349 | def on_GET(self, request, room_id): |
350 | 374 | class RoomStateRestServlet(ClientV1RestServlet): |
351 | 375 | PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") |
352 | 376 | |
377 | def __init__(self, hs): | |
378 | super(RoomStateRestServlet, self).__init__(hs) | |
379 | self.handlers = hs.get_handlers() | |
380 | ||
353 | 381 | @defer.inlineCallbacks |
354 | 382 | def on_GET(self, request, room_id): |
355 | 383 | requester = yield self.auth.get_user_by_req(request, allow_guest=True) |
367 | 395 | class RoomInitialSyncRestServlet(ClientV1RestServlet): |
368 | 396 | PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") |
369 | 397 | |
398 | def __init__(self, hs): | |
399 | super(RoomInitialSyncRestServlet, self).__init__(hs) | |
400 | self.handlers = hs.get_handlers() | |
401 | ||
370 | 402 | @defer.inlineCallbacks |
371 | 403 | def on_GET(self, request, room_id): |
372 | 404 | requester = yield self.auth.get_user_by_req(request, allow_guest=True) |
387 | 419 | def __init__(self, hs): |
388 | 420 | super(RoomEventContext, self).__init__(hs) |
389 | 421 | self.clock = hs.get_clock() |
422 | self.handlers = hs.get_handlers() | |
390 | 423 | |
391 | 424 | @defer.inlineCallbacks |
392 | 425 | def on_GET(self, request, room_id, event_id): |
423 | 456 | |
424 | 457 | |
425 | 458 | class RoomForgetRestServlet(ClientV1RestServlet): |
459 | def __init__(self, hs): | |
460 | super(RoomForgetRestServlet, self).__init__(hs) | |
461 | self.handlers = hs.get_handlers() | |
462 | ||
426 | 463 | def register(self, http_server): |
427 | 464 | PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") |
428 | 465 | register_txn_path(self, PATTERNS, http_server) |
460 | 497 | |
461 | 498 | # TODO: Needs unit testing |
462 | 499 | class RoomMembershipRestServlet(ClientV1RestServlet): |
500 | ||
501 | def __init__(self, hs): | |
502 | super(RoomMembershipRestServlet, self).__init__(hs) | |
503 | self.handlers = hs.get_handlers() | |
463 | 504 | |
464 | 505 | def register(self, http_server): |
465 | 506 | # /rooms/$roomid/[invite|join|leave] |
541 | 582 | |
542 | 583 | |
543 | 584 | class RoomRedactEventRestServlet(ClientV1RestServlet): |
585 | def __init__(self, hs): | |
586 | super(RoomRedactEventRestServlet, self).__init__(hs) | |
587 | self.handlers = hs.get_handlers() | |
588 | ||
544 | 589 | def register(self, http_server): |
545 | 590 | PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") |
546 | 591 | register_txn_path(self, PATTERNS, http_server) |
622 | 667 | PATTERNS = client_path_patterns( |
623 | 668 | "/search$" |
624 | 669 | ) |
670 | ||
671 | def __init__(self, hs): | |
672 | super(SearchRestServlet, self).__init__(hs) | |
673 | self.handlers = hs.get_handlers() | |
625 | 674 | |
626 | 675 | @defer.inlineCallbacks |
627 | 676 | def on_POST(self, request): |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from twisted.internet import defer | |
16 | ||
17 | from synapse.http.servlet import ( | |
18 | RestServlet, parse_string, parse_integer | |
19 | ) | |
20 | from synapse.events.utils import ( | |
21 | serialize_event, format_event_for_client_v2_without_room_id, | |
22 | ) | |
23 | ||
24 | from ._base import client_v2_patterns | |
25 | ||
26 | import logging | |
27 | ||
28 | logger = logging.getLogger(__name__) | |
29 | ||
30 | ||
31 | class NotificationsServlet(RestServlet): | |
32 | PATTERNS = client_v2_patterns("/notifications$", releases=()) | |
33 | ||
34 | def __init__(self, hs): | |
35 | super(NotificationsServlet, self).__init__() | |
36 | self.store = hs.get_datastore() | |
37 | self.auth = hs.get_auth() | |
38 | self.clock = hs.get_clock() | |
39 | ||
40 | @defer.inlineCallbacks | |
41 | def on_GET(self, request): | |
42 | requester = yield self.auth.get_user_by_req(request) | |
43 | user_id = requester.user.to_string() | |
44 | ||
45 | from_token = parse_string(request, "from", required=False) | |
46 | limit = parse_integer(request, "limit", default=50) | |
47 | ||
48 | limit = min(limit, 500) | |
49 | ||
50 | push_actions = yield self.store.get_push_actions_for_user( | |
51 | user_id, from_token, limit | |
52 | ) | |
53 | ||
54 | receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( | |
55 | user_id, 'm.read' | |
56 | ) | |
57 | ||
58 | notif_event_ids = [pa["event_id"] for pa in push_actions] | |
59 | notif_events = yield self.store.get_events(notif_event_ids) | |
60 | ||
61 | returned_push_actions = [] | |
62 | ||
63 | next_token = None | |
64 | ||
65 | for pa in push_actions: | |
66 | returned_pa = { | |
67 | "room_id": pa["room_id"], | |
68 | "profile_tag": pa["profile_tag"], | |
69 | "actions": pa["actions"], | |
70 | "ts": pa["received_ts"], | |
71 | "event": serialize_event( | |
72 | notif_events[pa["event_id"]], | |
73 | self.clock.time_msec(), | |
74 | event_format=format_event_for_client_v2_without_room_id, | |
75 | ), | |
76 | } | |
77 | ||
78 | if pa["room_id"] not in receipts_by_room: | |
79 | returned_pa["read"] = False | |
80 | else: | |
81 | receipt = receipts_by_room[pa["room_id"]] | |
82 | ||
83 | returned_pa["read"] = ( | |
84 | receipt["topological_ordering"], receipt["stream_ordering"] | |
85 | ) >= ( | |
86 | pa["topological_ordering"], pa["stream_ordering"] | |
87 | ) | |
88 | returned_push_actions.append(returned_pa) | |
89 | next_token = pa["stream_ordering"] | |
90 | ||
91 | defer.returnValue((200, { | |
92 | "notifications": returned_push_actions, | |
93 | "next_token": next_token, | |
94 | })) | |
95 | ||
96 | ||
97 | def register_servlets(hs, http_server): | |
98 | NotificationsServlet(hs).register(http_server) |
402 | 402 | # register the user's device |
403 | 403 | device_id = params.get("device_id") |
404 | 404 | initial_display_name = params.get("initial_device_display_name") |
405 | device_id = self.device_handler.check_device_registered( | |
405 | return self.device_handler.check_device_registered( | |
406 | 406 | user_id, device_id, initial_display_name |
407 | 407 | ) |
408 | return device_id | |
409 | 408 | |
410 | 409 | @defer.inlineCallbacks |
411 | 410 | def _do_guest_registration(self): |
145 | 145 | affect_presence = set_presence != PresenceState.OFFLINE |
146 | 146 | |
147 | 147 | if affect_presence: |
148 | yield self.presence_handler.set_state(user, {"presence": set_presence}) | |
148 | yield self.presence_handler.set_state(user, {"presence": set_presence}, True) | |
149 | 149 | |
150 | 150 | context = yield self.presence_handler.user_syncing( |
151 | 151 | user.to_string(), affect_presence=affect_presence, |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2015, 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | ||
16 | import logging | |
17 | ||
18 | from twisted.internet import defer | |
19 | ||
20 | from synapse.http.servlet import RestServlet | |
21 | from synapse.types import ThirdPartyEntityKind | |
22 | from ._base import client_v2_patterns | |
23 | ||
24 | logger = logging.getLogger(__name__) | |
25 | ||
26 | ||
27 | class ThirdPartyUserServlet(RestServlet): | |
28 | PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$", | |
29 | releases=()) | |
30 | ||
31 | def __init__(self, hs): | |
32 | super(ThirdPartyUserServlet, self).__init__() | |
33 | ||
34 | self.auth = hs.get_auth() | |
35 | self.appservice_handler = hs.get_application_service_handler() | |
36 | ||
37 | @defer.inlineCallbacks | |
38 | def on_GET(self, request, protocol): | |
39 | yield self.auth.get_user_by_req(request) | |
40 | ||
41 | fields = request.args | |
42 | del fields["access_token"] | |
43 | ||
44 | results = yield self.appservice_handler.query_3pe( | |
45 | ThirdPartyEntityKind.USER, protocol, fields | |
46 | ) | |
47 | ||
48 | defer.returnValue((200, results)) | |
49 | ||
50 | ||
51 | class ThirdPartyLocationServlet(RestServlet): | |
52 | PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$", | |
53 | releases=()) | |
54 | ||
55 | def __init__(self, hs): | |
56 | super(ThirdPartyLocationServlet, self).__init__() | |
57 | ||
58 | self.auth = hs.get_auth() | |
59 | self.appservice_handler = hs.get_application_service_handler() | |
60 | ||
61 | @defer.inlineCallbacks | |
62 | def on_GET(self, request, protocol): | |
63 | yield self.auth.get_user_by_req(request) | |
64 | ||
65 | fields = request.args | |
66 | del fields["access_token"] | |
67 | ||
68 | results = yield self.appservice_handler.query_3pe( | |
69 | ThirdPartyEntityKind.LOCATION, protocol, fields | |
70 | ) | |
71 | ||
72 | defer.returnValue((200, results)) | |
73 | ||
74 | ||
75 | def register_servlets(hs, http_server): | |
76 | ThirdPartyUserServlet(hs).register(http_server) | |
77 | ThirdPartyLocationServlet(hs).register(http_server) |
14 | 14 | from synapse.http.server import request_handler, respond_with_json_bytes |
15 | 15 | from synapse.http.servlet import parse_integer, parse_json_object_from_request |
16 | 16 | from synapse.api.errors import SynapseError, Codes |
17 | from synapse.crypto.keyring import KeyLookupError | |
17 | 18 | |
18 | 19 | from twisted.web.resource import Resource |
19 | 20 | from twisted.web.server import NOT_DONE_YET |
209 | 210 | yield self.keyring.get_server_verify_key_v2_direct( |
210 | 211 | server_name, key_ids |
211 | 212 | ) |
213 | except KeyLookupError as e: | |
214 | logger.info("Failed to fetch key: %s", e) | |
212 | 215 | except: |
213 | 216 | logger.exception("Failed to get key for %r", server_name) |
214 | pass | |
215 | 217 | yield self.query_keys( |
216 | 218 | request, query, query_remote_on_cache_miss=False |
217 | 219 | ) |
44 | 44 | @request_handler() |
45 | 45 | @defer.inlineCallbacks |
46 | 46 | def _async_render_GET(self, request): |
47 | request.setHeader("Content-Security-Policy", "sandbox") | |
47 | 48 | server_name, media_id, name = parse_media_id(request) |
48 | 49 | if server_name == self.server_name: |
49 | 50 | yield self._respond_local_file(request, media_id, name) |
28 | 28 | from synapse.util.async import ObservableDeferred |
29 | 29 | from synapse.util.stringutils import is_ascii |
30 | 30 | |
31 | from copy import deepcopy | |
32 | ||
33 | 31 | import os |
34 | 32 | import re |
35 | 33 | import fnmatch |
36 | 34 | import cgi |
37 | 35 | import ujson as json |
38 | 36 | import urlparse |
37 | import itertools | |
39 | 38 | |
40 | 39 | import logging |
41 | 40 | logger = logging.getLogger(__name__) |
162 | 161 | |
163 | 162 | logger.debug("got media_info of '%s'" % media_info) |
164 | 163 | |
165 | if self._is_media(media_info['media_type']): | |
164 | if _is_media(media_info['media_type']): | |
166 | 165 | dims = yield self.media_repo._generate_local_thumbnails( |
167 | 166 | media_info['filesystem_id'], media_info |
168 | 167 | ) |
183 | 182 | logger.warn("Couldn't get dims for %s" % url) |
184 | 183 | |
185 | 184 | # define our OG response for this media |
186 | elif self._is_html(media_info['media_type']): | |
185 | elif _is_html(media_info['media_type']): | |
187 | 186 | # TODO: somehow stop a big HTML tree from exploding synapse's RAM |
188 | ||
189 | from lxml import etree | |
190 | 187 | |
191 | 188 | file = open(media_info['filename']) |
192 | 189 | body = file.read() |
198 | 195 | match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) |
199 | 196 | encoding = match.group(1) if match else "utf-8" |
200 | 197 | |
201 | try: | |
202 | parser = etree.HTMLParser(recover=True, encoding=encoding) | |
203 | tree = etree.fromstring(body, parser) | |
204 | og = yield self._calc_og(tree, media_info, requester) | |
205 | except UnicodeDecodeError: | |
206 | # blindly try decoding the body as utf-8, which seems to fix | |
207 | # the charset mismatches on https://google.com | |
208 | parser = etree.HTMLParser(recover=True, encoding=encoding) | |
209 | tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) | |
210 | og = yield self._calc_og(tree, media_info, requester) | |
211 | ||
198 | og = decode_and_calc_og(body, media_info['uri'], encoding) | |
199 | ||
200 | # pre-cache the image for posterity | |
201 | # FIXME: it might be cleaner to use the same flow as the main /preview_url | |
202 | # request itself and benefit from the same caching etc. But for now we | |
203 | # just rely on the caching on the master request to speed things up. | |
204 | if 'og:image' in og and og['og:image']: | |
205 | image_info = yield self._download_url( | |
206 | _rebase_url(og['og:image'], media_info['uri']), requester.user | |
207 | ) | |
208 | ||
209 | if _is_media(image_info['media_type']): | |
210 | # TODO: make sure we don't choke on white-on-transparent images | |
211 | dims = yield self.media_repo._generate_local_thumbnails( | |
212 | image_info['filesystem_id'], image_info | |
213 | ) | |
214 | if dims: | |
215 | og["og:image:width"] = dims['width'] | |
216 | og["og:image:height"] = dims['height'] | |
217 | else: | |
218 | logger.warn("Couldn't get dims for %s" % og["og:image"]) | |
219 | ||
220 | og["og:image"] = "mxc://%s/%s" % ( | |
221 | self.server_name, image_info['filesystem_id'] | |
222 | ) | |
223 | og["og:image:type"] = image_info['media_type'] | |
224 | og["matrix:image:size"] = image_info['media_length'] | |
225 | else: | |
226 | del og["og:image"] | |
212 | 227 | else: |
213 | 228 | logger.warn("Failed to find any OG data in %s", url) |
214 | 229 | og = {} |
230 | 245 | ) |
231 | 246 | |
232 | 247 | respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) |
233 | ||
234 | @defer.inlineCallbacks | |
235 | def _calc_og(self, tree, media_info, requester): | |
236 | # suck our tree into lxml and define our OG response. | |
237 | ||
238 | # if we see any image URLs in the OG response, then spider them | |
239 | # (although the client could choose to do this by asking for previews of those | |
240 | # URLs to avoid DoSing the server) | |
241 | ||
242 | # "og:type" : "video", | |
243 | # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", | |
244 | # "og:site_name" : "YouTube", | |
245 | # "og:video:type" : "application/x-shockwave-flash", | |
246 | # "og:description" : "Fun stuff happening here", | |
247 | # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", | |
248 | # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", | |
249 | # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", | |
250 | # "og:video:width" : "1280" | |
251 | # "og:video:height" : "720", | |
252 | # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", | |
253 | ||
254 | og = {} | |
255 | for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): | |
256 | if 'content' in tag.attrib: | |
257 | og[tag.attrib['property']] = tag.attrib['content'] | |
258 | ||
259 | # TODO: grab article: meta tags too, e.g.: | |
260 | ||
261 | # "article:publisher" : "https://www.facebook.com/thethudonline" /> | |
262 | # "article:author" content="https://www.facebook.com/thethudonline" /> | |
263 | # "article:tag" content="baby" /> | |
264 | # "article:section" content="Breaking News" /> | |
265 | # "article:published_time" content="2016-03-31T19:58:24+00:00" /> | |
266 | # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> | |
267 | ||
268 | if 'og:title' not in og: | |
269 | # do some basic spidering of the HTML | |
270 | title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") | |
271 | og['og:title'] = title[0].text.strip() if title else None | |
272 | ||
273 | if 'og:image' not in og: | |
274 | # TODO: extract a favicon failing all else | |
275 | meta_image = tree.xpath( | |
276 | "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" | |
277 | ) | |
278 | if meta_image: | |
279 | og['og:image'] = self._rebase_url(meta_image[0], media_info['uri']) | |
280 | else: | |
281 | # TODO: consider inlined CSS styles as well as width & height attribs | |
282 | images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") | |
283 | images = sorted(images, key=lambda i: ( | |
284 | -1 * float(i.attrib['width']) * float(i.attrib['height']) | |
285 | )) | |
286 | if not images: | |
287 | images = tree.xpath("//img[@src]") | |
288 | if images: | |
289 | og['og:image'] = images[0].attrib['src'] | |
290 | ||
291 | # pre-cache the image for posterity | |
292 | # FIXME: it might be cleaner to use the same flow as the main /preview_url | |
293 | # request itself and benefit from the same caching etc. But for now we | |
294 | # just rely on the caching on the master request to speed things up. | |
295 | if 'og:image' in og and og['og:image']: | |
296 | image_info = yield self._download_url( | |
297 | self._rebase_url(og['og:image'], media_info['uri']), requester.user | |
298 | ) | |
299 | ||
300 | if self._is_media(image_info['media_type']): | |
301 | # TODO: make sure we don't choke on white-on-transparent images | |
302 | dims = yield self.media_repo._generate_local_thumbnails( | |
303 | image_info['filesystem_id'], image_info | |
304 | ) | |
305 | if dims: | |
306 | og["og:image:width"] = dims['width'] | |
307 | og["og:image:height"] = dims['height'] | |
308 | else: | |
309 | logger.warn("Couldn't get dims for %s" % og["og:image"]) | |
310 | ||
311 | og["og:image"] = "mxc://%s/%s" % ( | |
312 | self.server_name, image_info['filesystem_id'] | |
313 | ) | |
314 | og["og:image:type"] = image_info['media_type'] | |
315 | og["matrix:image:size"] = image_info['media_length'] | |
316 | else: | |
317 | del og["og:image"] | |
318 | ||
319 | if 'og:description' not in og: | |
320 | meta_description = tree.xpath( | |
321 | "//*/meta" | |
322 | "[translate(@name, 'DESCRIPTION', 'description')='description']" | |
323 | "/@content") | |
324 | if meta_description: | |
325 | og['og:description'] = meta_description[0] | |
326 | else: | |
327 | # grab any text nodes which are inside the <body/> tag... | |
328 | # unless they are within an HTML5 semantic markup tag... | |
329 | # <header/>, <nav/>, <aside/>, <footer/> | |
330 | # ...or if they are within a <script/> or <style/> tag. | |
331 | # This is a very very very coarse approximation to a plain text | |
332 | # render of the page. | |
333 | ||
334 | # We don't just use XPATH here as that is slow on some machines. | |
335 | ||
336 | # We clone `tree` as we modify it. | |
337 | cloned_tree = deepcopy(tree.find("body")) | |
338 | ||
339 | TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",) | |
340 | for el in cloned_tree.iter(TAGS_TO_REMOVE): | |
341 | el.getparent().remove(el) | |
342 | ||
343 | # Split all the text nodes into paragraphs (by splitting on new | |
344 | # lines) | |
345 | text_nodes = ( | |
346 | re.sub(r'\s+', '\n', el.text).strip() | |
347 | for el in cloned_tree.iter() | |
348 | if el.text and isinstance(el.tag, basestring) # Removes comments | |
349 | ) | |
350 | og['og:description'] = summarize_paragraphs(text_nodes) | |
351 | ||
352 | # TODO: delete the url downloads to stop diskfilling, | |
353 | # as we only ever cared about its OG | |
354 | defer.returnValue(og) | |
355 | ||
356 | def _rebase_url(self, url, base): | |
357 | base = list(urlparse.urlparse(base)) | |
358 | url = list(urlparse.urlparse(url)) | |
359 | if not url[0]: # fix up schema | |
360 | url[0] = base[0] or "http" | |
361 | if not url[1]: # fix up hostname | |
362 | url[1] = base[1] | |
363 | if not url[2].startswith('/'): | |
364 | url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] | |
365 | return urlparse.urlunparse(url) | |
366 | 248 | |
367 | 249 | @defer.inlineCallbacks |
368 | 250 | def _download_url(self, url, user): |
444 | 326 | "etag": headers["ETag"][0] if "ETag" in headers else None, |
445 | 327 | }) |
446 | 328 | |
447 | def _is_media(self, content_type): | |
448 | if content_type.lower().startswith("image/"): | |
449 | return True | |
450 | ||
451 | def _is_html(self, content_type): | |
452 | content_type = content_type.lower() | |
453 | if ( | |
454 | content_type.startswith("text/html") or | |
455 | content_type.startswith("application/xhtml") | |
456 | ): | |
457 | return True | |
329 | ||
330 | def decode_and_calc_og(body, media_uri, request_encoding=None): | |
331 | from lxml import etree | |
332 | ||
333 | try: | |
334 | parser = etree.HTMLParser(recover=True, encoding=request_encoding) | |
335 | tree = etree.fromstring(body, parser) | |
336 | og = _calc_og(tree, media_uri) | |
337 | except UnicodeDecodeError: | |
338 | # blindly try decoding the body as utf-8, which seems to fix | |
339 | # the charset mismatches on https://google.com | |
340 | parser = etree.HTMLParser(recover=True, encoding=request_encoding) | |
341 | tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) | |
342 | og = _calc_og(tree, media_uri) | |
343 | ||
344 | return og | |
345 | ||
346 | ||
347 | def _calc_og(tree, media_uri): | |
348 | # suck our tree into lxml and define our OG response. | |
349 | ||
350 | # if we see any image URLs in the OG response, then spider them | |
351 | # (although the client could choose to do this by asking for previews of those | |
352 | # URLs to avoid DoSing the server) | |
353 | ||
354 | # "og:type" : "video", | |
355 | # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", | |
356 | # "og:site_name" : "YouTube", | |
357 | # "og:video:type" : "application/x-shockwave-flash", | |
358 | # "og:description" : "Fun stuff happening here", | |
359 | # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", | |
360 | # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", | |
361 | # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", | |
362 | # "og:video:width" : "1280" | |
363 | # "og:video:height" : "720", | |
364 | # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", | |
365 | ||
366 | og = {} | |
367 | for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): | |
368 | if 'content' in tag.attrib: | |
369 | og[tag.attrib['property']] = tag.attrib['content'] | |
370 | ||
371 | # TODO: grab article: meta tags too, e.g.: | |
372 | ||
373 | # "article:publisher" : "https://www.facebook.com/thethudonline" /> | |
374 | # "article:author" content="https://www.facebook.com/thethudonline" /> | |
375 | # "article:tag" content="baby" /> | |
376 | # "article:section" content="Breaking News" /> | |
377 | # "article:published_time" content="2016-03-31T19:58:24+00:00" /> | |
378 | # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> | |
379 | ||
380 | if 'og:title' not in og: | |
381 | # do some basic spidering of the HTML | |
382 | title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") | |
383 | og['og:title'] = title[0].text.strip() if title else None | |
384 | ||
385 | if 'og:image' not in og: | |
386 | # TODO: extract a favicon failing all else | |
387 | meta_image = tree.xpath( | |
388 | "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" | |
389 | ) | |
390 | if meta_image: | |
391 | og['og:image'] = _rebase_url(meta_image[0], media_uri) | |
392 | else: | |
393 | # TODO: consider inlined CSS styles as well as width & height attribs | |
394 | images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") | |
395 | images = sorted(images, key=lambda i: ( | |
396 | -1 * float(i.attrib['width']) * float(i.attrib['height']) | |
397 | )) | |
398 | if not images: | |
399 | images = tree.xpath("//img[@src]") | |
400 | if images: | |
401 | og['og:image'] = images[0].attrib['src'] | |
402 | ||
403 | if 'og:description' not in og: | |
404 | meta_description = tree.xpath( | |
405 | "//*/meta" | |
406 | "[translate(@name, 'DESCRIPTION', 'description')='description']" | |
407 | "/@content") | |
408 | if meta_description: | |
409 | og['og:description'] = meta_description[0] | |
410 | else: | |
411 | # grab any text nodes which are inside the <body/> tag... | |
412 | # unless they are within an HTML5 semantic markup tag... | |
413 | # <header/>, <nav/>, <aside/>, <footer/> | |
414 | # ...or if they are within a <script/> or <style/> tag. | |
415 | # This is a very very very coarse approximation to a plain text | |
416 | # render of the page. | |
417 | ||
418 | # We don't just use XPATH here as that is slow on some machines. | |
419 | ||
420 | from lxml import etree | |
421 | ||
422 | TAGS_TO_REMOVE = ( | |
423 | "header", "nav", "aside", "footer", "script", "style", etree.Comment | |
424 | ) | |
425 | ||
426 | # Split all the text nodes into paragraphs (by splitting on new | |
427 | # lines) | |
428 | text_nodes = ( | |
429 | re.sub(r'\s+', '\n', el).strip() | |
430 | for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) | |
431 | ) | |
432 | og['og:description'] = summarize_paragraphs(text_nodes) | |
433 | ||
434 | # TODO: delete the url downloads to stop diskfilling, | |
435 | # as we only ever cared about its OG | |
436 | return og | |
437 | ||
438 | ||
439 | def _iterate_over_text(tree, *tags_to_ignore): | |
440 | """Iterate over the tree returning text nodes in a depth first fashion, | |
441 | skipping text nodes inside certain tags. | |
442 | """ | |
443 | # This is basically a stack that we extend using itertools.chain. | |
444 | # This will either consist of an element to iterate over *or* a string | |
445 | # to be returned. | |
446 | elements = iter([tree]) | |
447 | while True: | |
448 | el = elements.next() | |
449 | if isinstance(el, basestring): | |
450 | yield el | |
451 | elif el is not None and el.tag not in tags_to_ignore: | |
452 | # el.text is the text before the first child, so we can immediately | |
453 | # return it if the text exists. | |
454 | if el.text: | |
455 | yield el.text | |
456 | ||
457 | # We add to the stack all the elements children, interspersed with | |
458 | # each child's tail text (if it exists). The tail text of a node | |
459 | # is text that comes *after* the node, so we always include it even | |
460 | # if we ignore the child node. | |
461 | elements = itertools.chain( | |
462 | itertools.chain.from_iterable( # Basically a flatmap | |
463 | [child, child.tail] if child.tail else [child] | |
464 | for child in el.iterchildren() | |
465 | ), | |
466 | elements | |
467 | ) | |
468 | ||
469 | ||
470 | def _rebase_url(url, base): | |
471 | base = list(urlparse.urlparse(base)) | |
472 | url = list(urlparse.urlparse(url)) | |
473 | if not url[0]: # fix up schema | |
474 | url[0] = base[0] or "http" | |
475 | if not url[1]: # fix up hostname | |
476 | url[1] = base[1] | |
477 | if not url[2].startswith('/'): | |
478 | url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] | |
479 | return urlparse.urlunparse(url) | |
480 | ||
481 | ||
482 | def _is_media(content_type): | |
483 | if content_type.lower().startswith("image/"): | |
484 | return True | |
485 | ||
486 | ||
487 | def _is_html(content_type): | |
488 | content_type = content_type.lower() | |
489 | if ( | |
490 | content_type.startswith("text/html") or | |
491 | content_type.startswith("application/xhtml") | |
492 | ): | |
493 | return True | |
458 | 494 | |
459 | 495 | |
460 | 496 | def summarize_paragraphs(text_nodes, min_size=200, max_size=500): |
40 | 40 | from synapse.handlers.room import RoomListHandler |
41 | 41 | from synapse.handlers.sync import SyncHandler |
42 | 42 | from synapse.handlers.typing import TypingHandler |
43 | from synapse.handlers.events import EventHandler, EventStreamHandler | |
43 | 44 | from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory |
44 | 45 | from synapse.http.matrixfederationclient import MatrixFederationHttpClient |
45 | 46 | from synapse.notifier import Notifier |
93 | 94 | 'auth_handler', |
94 | 95 | 'device_handler', |
95 | 96 | 'e2e_keys_handler', |
97 | 'event_handler', | |
98 | 'event_stream_handler', | |
96 | 99 | 'application_service_api', |
97 | 100 | 'application_service_scheduler', |
98 | 101 | 'application_service_handler', |
213 | 216 | def build_application_service_handler(self): |
214 | 217 | return ApplicationServicesHandler(self) |
215 | 218 | |
219 | def build_event_handler(self): | |
220 | return EventHandler(self) | |
221 | ||
222 | def build_event_stream_handler(self): | |
223 | return EventStreamHandler(self) | |
224 | ||
216 | 225 | def build_event_sources(self): |
217 | 226 | return EventSources(self) |
218 | 227 |
0 | import synapse.api.auth | |
0 | 1 | import synapse.handlers |
1 | 2 | import synapse.handlers.auth |
2 | 3 | import synapse.handlers.device |
5 | 6 | import synapse.state |
6 | 7 | |
7 | 8 | class HomeServer(object): |
9 | def get_auth(self) -> synapse.api.auth.Auth: | |
10 | pass | |
11 | ||
8 | 12 | def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: |
9 | 13 | pass |
10 | 14 |
49 | 49 | from .client_ips import ClientIpStore |
50 | 50 | |
51 | 51 | from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator |
52 | from .engines import PostgresEngine | |
52 | 53 | |
53 | 54 | from synapse.api.constants import PresenceState |
54 | 55 | from synapse.util.caches.stream_change_cache import StreamChangeCache |
122 | 123 | extra_tables=[("deleted_pushers", "stream_id")], |
123 | 124 | ) |
124 | 125 | |
126 | if isinstance(self.database_engine, PostgresEngine): | |
127 | self._cache_id_gen = StreamIdGenerator( | |
128 | db_conn, "cache_invalidation_stream", "stream_id", | |
129 | ) | |
130 | else: | |
131 | self._cache_id_gen = None | |
132 | ||
125 | 133 | events_max = self._stream_id_gen.get_current_token() |
126 | 134 | event_cache_prefill, min_event_val = self._get_cache_dict( |
127 | 135 | db_conn, "events", |
18 | 18 | from synapse.util.caches.dictionary_cache import DictionaryCache |
19 | 19 | from synapse.util.caches.descriptors import Cache |
20 | 20 | from synapse.util.caches import intern_dict |
21 | from synapse.storage.engines import PostgresEngine | |
21 | 22 | import synapse.metrics |
22 | 23 | |
23 | 24 | |
164 | 165 | self._txn_perf_counters = PerformanceCounters() |
165 | 166 | self._get_event_counters = PerformanceCounters() |
166 | 167 | |
167 | self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, | |
168 | self._get_event_cache = Cache("*getEvent*", keylen=3, | |
168 | 169 | max_entries=hs.config.event_cache_size) |
169 | 170 | |
170 | 171 | self._state_group_cache = DictionaryCache( |
304 | 305 | func, *args, **kwargs |
305 | 306 | ) |
306 | 307 | |
307 | with PreserveLoggingContext(): | |
308 | result = yield self._db_pool.runWithConnection( | |
309 | inner_func, *args, **kwargs | |
310 | ) | |
311 | ||
312 | for after_callback, after_args in after_callbacks: | |
313 | after_callback(*after_args) | |
308 | try: | |
309 | with PreserveLoggingContext(): | |
310 | result = yield self._db_pool.runWithConnection( | |
311 | inner_func, *args, **kwargs | |
312 | ) | |
313 | finally: | |
314 | for after_callback, after_args in after_callbacks: | |
315 | after_callback(*after_args) | |
314 | 316 | defer.returnValue(result) |
315 | 317 | |
316 | 318 | @defer.inlineCallbacks |
859 | 861 | |
860 | 862 | return cache, min_val |
861 | 863 | |
864 | def _invalidate_cache_and_stream(self, txn, cache_func, keys): | |
865 | """Invalidates the cache and adds it to the cache stream so slaves | |
866 | will know to invalidate their caches. | |
867 | ||
868 | This should only be used to invalidate caches where slaves won't | |
869 | otherwise know from other replication streams that the cache should | |
870 | be invalidated. | |
871 | """ | |
872 | txn.call_after(cache_func.invalidate, keys) | |
873 | ||
874 | if isinstance(self.database_engine, PostgresEngine): | |
875 | # get_next() returns a context manager which is designed to wrap | |
876 | # the transaction. However, we want to only get an ID when we want | |
877 | # to use it, here, so we need to call __enter__ manually, and have | |
878 | # __exit__ called after the transaction finishes. | |
879 | ctx = self._cache_id_gen.get_next() | |
880 | stream_id = ctx.__enter__() | |
881 | txn.call_after(ctx.__exit__, None, None, None) | |
882 | txn.call_after(self.hs.get_notifier().on_new_replication_data) | |
883 | ||
884 | self._simple_insert_txn( | |
885 | txn, | |
886 | table="cache_invalidation_stream", | |
887 | values={ | |
888 | "stream_id": stream_id, | |
889 | "cache_func": cache_func.__name__, | |
890 | "keys": list(keys), | |
891 | "invalidation_ts": self.clock.time_msec(), | |
892 | } | |
893 | ) | |
894 | ||
895 | def get_all_updated_caches(self, last_id, current_id, limit): | |
896 | if last_id == current_id: | |
897 | return defer.succeed([]) | |
898 | ||
899 | def get_all_updated_caches_txn(txn): | |
900 | # We purposefully don't bound by the current token, as we want to | |
901 | # send across cache invalidations as quickly as possible. Cache | |
902 | # invalidations are idempotent, so duplicates are fine. | |
903 | sql = ( | |
904 | "SELECT stream_id, cache_func, keys, invalidation_ts" | |
905 | " FROM cache_invalidation_stream" | |
906 | " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" | |
907 | ) | |
908 | txn.execute(sql, (last_id, limit,)) | |
909 | return txn.fetchall() | |
910 | return self.runInteraction( | |
911 | "get_all_updated_caches", get_all_updated_caches_txn | |
912 | ) | |
913 | ||
914 | def get_cache_stream_token(self): | |
915 | if self._cache_id_gen: | |
916 | return self._cache_id_gen.get_current_token() | |
917 | else: | |
918 | return 0 | |
919 | ||
862 | 920 | |
863 | 921 | class _RollbackButIsFineException(Exception): |
864 | 922 | """ This exception is used to rollback a transaction without implying |
217 | 217 | Returns: |
218 | 218 | AppServiceTransaction: A new transaction. |
219 | 219 | """ |
220 | def _create_appservice_txn(txn): | |
221 | # work out new txn id (highest txn id for this service += 1) | |
222 | # The highest id may be the last one sent (in which case it is last_txn) | |
223 | # or it may be the highest in the txns list (which are waiting to be/are | |
224 | # being sent) | |
225 | last_txn_id = self._get_last_txn(txn, service.id) | |
226 | ||
227 | txn.execute( | |
228 | "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", | |
229 | (service.id,) | |
230 | ) | |
231 | highest_txn_id = txn.fetchone()[0] | |
232 | if highest_txn_id is None: | |
233 | highest_txn_id = 0 | |
234 | ||
235 | new_txn_id = max(highest_txn_id, last_txn_id) + 1 | |
236 | ||
237 | # Insert new txn into txn table | |
238 | event_ids = json.dumps([e.event_id for e in events]) | |
239 | txn.execute( | |
240 | "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " | |
241 | "VALUES(?,?,?)", | |
242 | (service.id, new_txn_id, event_ids) | |
243 | ) | |
244 | return AppServiceTransaction( | |
245 | service=service, id=new_txn_id, events=events | |
246 | ) | |
247 | ||
220 | 248 | return self.runInteraction( |
221 | 249 | "create_appservice_txn", |
222 | self._create_appservice_txn, | |
223 | service, events | |
224 | ) | |
225 | ||
226 | def _create_appservice_txn(self, txn, service, events): | |
227 | # work out new txn id (highest txn id for this service += 1) | |
228 | # The highest id may be the last one sent (in which case it is last_txn) | |
229 | # or it may be the highest in the txns list (which are waiting to be/are | |
230 | # being sent) | |
231 | last_txn_id = self._get_last_txn(txn, service.id) | |
232 | ||
233 | txn.execute( | |
234 | "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", | |
235 | (service.id,) | |
236 | ) | |
237 | highest_txn_id = txn.fetchone()[0] | |
238 | if highest_txn_id is None: | |
239 | highest_txn_id = 0 | |
240 | ||
241 | new_txn_id = max(highest_txn_id, last_txn_id) + 1 | |
242 | ||
243 | # Insert new txn into txn table | |
244 | event_ids = json.dumps([e.event_id for e in events]) | |
245 | txn.execute( | |
246 | "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " | |
247 | "VALUES(?,?,?)", | |
248 | (service.id, new_txn_id, event_ids) | |
249 | ) | |
250 | return AppServiceTransaction( | |
251 | service=service, id=new_txn_id, events=events | |
250 | _create_appservice_txn, | |
252 | 251 | ) |
253 | 252 | |
254 | 253 | def complete_appservice_txn(self, txn_id, service): |
262 | 261 | A Deferred which resolves if this transaction was stored |
263 | 262 | successfully. |
264 | 263 | """ |
264 | txn_id = int(txn_id) | |
265 | ||
266 | def _complete_appservice_txn(txn): | |
267 | # Debugging query: Make sure the txn being completed is EXACTLY +1 from | |
268 | # what was there before. If it isn't, we've got problems (e.g. the AS | |
269 | # has probably missed some events), so whine loudly but still continue, | |
270 | # since it shouldn't fail completion of the transaction. | |
271 | last_txn_id = self._get_last_txn(txn, service.id) | |
272 | if (last_txn_id + 1) != txn_id: | |
273 | logger.error( | |
274 | "appservice: Completing a transaction which has an ID > 1 from " | |
275 | "the last ID sent to this AS. We've either dropped events or " | |
276 | "sent it to the AS out of order. FIX ME. last_txn=%s " | |
277 | "completing_txn=%s service_id=%s", last_txn_id, txn_id, | |
278 | service.id | |
279 | ) | |
280 | ||
281 | # Set current txn_id for AS to 'txn_id' | |
282 | self._simple_upsert_txn( | |
283 | txn, "application_services_state", dict(as_id=service.id), | |
284 | dict(last_txn=txn_id) | |
285 | ) | |
286 | ||
287 | # Delete txn | |
288 | self._simple_delete_txn( | |
289 | txn, "application_services_txns", | |
290 | dict(txn_id=txn_id, as_id=service.id) | |
291 | ) | |
292 | ||
265 | 293 | return self.runInteraction( |
266 | 294 | "complete_appservice_txn", |
267 | self._complete_appservice_txn, | |
268 | txn_id, service | |
269 | ) | |
270 | ||
271 | def _complete_appservice_txn(self, txn, txn_id, service): | |
272 | txn_id = int(txn_id) | |
273 | ||
274 | # Debugging query: Make sure the txn being completed is EXACTLY +1 from | |
275 | # what was there before. If it isn't, we've got problems (e.g. the AS | |
276 | # has probably missed some events), so whine loudly but still continue, | |
277 | # since it shouldn't fail completion of the transaction. | |
278 | last_txn_id = self._get_last_txn(txn, service.id) | |
279 | if (last_txn_id + 1) != txn_id: | |
280 | logger.error( | |
281 | "appservice: Completing a transaction which has an ID > 1 from " | |
282 | "the last ID sent to this AS. We've either dropped events or " | |
283 | "sent it to the AS out of order. FIX ME. last_txn=%s " | |
284 | "completing_txn=%s service_id=%s", last_txn_id, txn_id, | |
285 | service.id | |
286 | ) | |
287 | ||
288 | # Set current txn_id for AS to 'txn_id' | |
289 | self._simple_upsert_txn( | |
290 | txn, "application_services_state", dict(as_id=service.id), | |
291 | dict(last_txn=txn_id) | |
292 | ) | |
293 | ||
294 | # Delete txn | |
295 | self._simple_delete_txn( | |
296 | txn, "application_services_txns", | |
297 | dict(txn_id=txn_id, as_id=service.id) | |
295 | _complete_appservice_txn, | |
298 | 296 | ) |
299 | 297 | |
300 | 298 | @defer.inlineCallbacks |
308 | 306 | A Deferred which resolves to an AppServiceTransaction or |
309 | 307 | None. |
310 | 308 | """ |
309 | def _get_oldest_unsent_txn(txn): | |
310 | # Monotonically increasing txn ids, so just select the smallest | |
311 | # one in the txns table (we delete them when they are sent) | |
312 | txn.execute( | |
313 | "SELECT * FROM application_services_txns WHERE as_id=?" | |
314 | " ORDER BY txn_id ASC LIMIT 1", | |
315 | (service.id,) | |
316 | ) | |
317 | rows = self.cursor_to_dict(txn) | |
318 | if not rows: | |
319 | return None | |
320 | ||
321 | entry = rows[0] | |
322 | ||
323 | return entry | |
324 | ||
311 | 325 | entry = yield self.runInteraction( |
312 | 326 | "get_oldest_unsent_appservice_txn", |
313 | self._get_oldest_unsent_txn, | |
314 | service | |
327 | _get_oldest_unsent_txn, | |
315 | 328 | ) |
316 | 329 | |
317 | 330 | if not entry: |
324 | 337 | defer.returnValue(AppServiceTransaction( |
325 | 338 | service=service, id=entry["txn_id"], events=events |
326 | 339 | )) |
327 | ||
328 | def _get_oldest_unsent_txn(self, txn, service): | |
329 | # Monotonically increasing txn ids, so just select the smallest | |
330 | # one in the txns table (we delete them when they are sent) | |
331 | txn.execute( | |
332 | "SELECT * FROM application_services_txns WHERE as_id=?" | |
333 | " ORDER BY txn_id ASC LIMIT 1", | |
334 | (service.id,) | |
335 | ) | |
336 | rows = self.cursor_to_dict(txn) | |
337 | if not rows: | |
338 | return None | |
339 | ||
340 | entry = rows[0] | |
341 | ||
342 | return entry | |
343 | 340 | |
344 | 341 | def _get_last_txn(self, txn, service_id): |
345 | 342 | txn.execute( |
351 | 348 | return 0 |
352 | 349 | else: |
353 | 350 | return int(last_txn_id[0]) # select 'last_txn' col |
351 | ||
352 | def set_appservice_last_pos(self, pos): | |
353 | def set_appservice_last_pos_txn(txn): | |
354 | txn.execute( | |
355 | "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) | |
356 | ) | |
357 | return self.runInteraction( | |
358 | "set_appservice_last_pos", set_appservice_last_pos_txn | |
359 | ) | |
360 | ||
361 | @defer.inlineCallbacks | |
362 | def get_new_events_for_appservice(self, current_id, limit): | |
363 | """Get all new evnets""" | |
364 | ||
365 | def get_new_events_for_appservice_txn(txn): | |
366 | sql = ( | |
367 | "SELECT e.stream_ordering, e.event_id" | |
368 | " FROM events AS e" | |
369 | " WHERE" | |
370 | " (SELECT stream_ordering FROM appservice_stream_position)" | |
371 | " < e.stream_ordering" | |
372 | " AND e.stream_ordering <= ?" | |
373 | " ORDER BY e.stream_ordering ASC" | |
374 | " LIMIT ?" | |
375 | ) | |
376 | ||
377 | txn.execute(sql, (current_id, limit)) | |
378 | rows = txn.fetchall() | |
379 | ||
380 | upper_bound = current_id | |
381 | if len(rows) == limit: | |
382 | upper_bound = rows[-1][0] | |
383 | ||
384 | return upper_bound, [row[1] for row in rows] | |
385 | ||
386 | upper_bound, event_ids = yield self.runInteraction( | |
387 | "get_new_events_for_appservice", get_new_events_for_appservice_txn, | |
388 | ) | |
389 | ||
390 | events = yield self._get_events(event_ids) | |
391 | ||
392 | defer.returnValue((upper_bound, events)) |
81 | 81 | Returns: |
82 | 82 | Deferred |
83 | 83 | """ |
84 | try: | |
85 | yield self._simple_insert( | |
84 | def alias_txn(txn): | |
85 | self._simple_insert_txn( | |
86 | txn, | |
86 | 87 | "room_aliases", |
87 | 88 | { |
88 | 89 | "room_alias": room_alias.to_string(), |
89 | 90 | "room_id": room_id, |
90 | 91 | "creator": creator, |
91 | 92 | }, |
92 | desc="create_room_alias_association", | |
93 | ) | |
94 | ||
95 | self._simple_insert_many_txn( | |
96 | txn, | |
97 | table="room_alias_servers", | |
98 | values=[{ | |
99 | "room_alias": room_alias.to_string(), | |
100 | "server": server, | |
101 | } for server in servers], | |
102 | ) | |
103 | ||
104 | self._invalidate_cache_and_stream( | |
105 | txn, self.get_aliases_for_room, (room_id,) | |
106 | ) | |
107 | ||
108 | try: | |
109 | ret = yield self.runInteraction( | |
110 | "create_room_alias_association", alias_txn | |
93 | 111 | ) |
94 | 112 | except self.database_engine.module.IntegrityError: |
95 | 113 | raise SynapseError( |
96 | 114 | 409, "Room alias %s already exists" % room_alias.to_string() |
97 | 115 | ) |
98 | ||
99 | for server in servers: | |
100 | # TODO(erikj): Fix this to bulk insert | |
101 | yield self._simple_insert( | |
102 | "room_alias_servers", | |
103 | { | |
104 | "room_alias": room_alias.to_string(), | |
105 | "server": server, | |
106 | }, | |
107 | desc="create_room_alias_association", | |
108 | ) | |
109 | self.get_aliases_for_room.invalidate((room_id,)) | |
116 | defer.returnValue(ret) | |
110 | 117 | |
111 | 118 | def get_room_alias_creator(self, room_alias): |
112 | 119 | return self._simple_select_one_onecol( |
55 | 55 | ) |
56 | 56 | self._simple_insert_many_txn(txn, "event_push_actions", values) |
57 | 57 | |
58 | @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000) | |
58 | @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) | |
59 | 59 | def get_unread_event_push_actions_by_room_for_user( |
60 | 60 | self, room_id, user_id, last_read_event_id |
61 | 61 | ): |
337 | 337 | defer.returnValue(notifs[:limit]) |
338 | 338 | |
339 | 339 | @defer.inlineCallbacks |
340 | def get_push_actions_for_user(self, user_id, before=None, limit=50): | |
341 | def f(txn): | |
342 | before_clause = "" | |
343 | if before: | |
344 | before_clause = "AND stream_ordering < ?" | |
345 | args = [user_id, before, limit] | |
346 | else: | |
347 | args = [user_id, limit] | |
348 | sql = ( | |
349 | "SELECT epa.event_id, epa.room_id," | |
350 | " epa.stream_ordering, epa.topological_ordering," | |
351 | " epa.actions, epa.profile_tag, e.received_ts" | |
352 | " FROM event_push_actions epa, events e" | |
353 | " WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id" | |
354 | " AND epa.user_id = ? %s" | |
355 | " ORDER BY epa.stream_ordering DESC" | |
356 | " LIMIT ?" | |
357 | % (before_clause,) | |
358 | ) | |
359 | txn.execute(sql, args) | |
360 | return self.cursor_to_dict(txn) | |
361 | ||
362 | push_actions = yield self.runInteraction( | |
363 | "get_push_actions_for_user", f | |
364 | ) | |
365 | for pa in push_actions: | |
366 | pa["actions"] = json.loads(pa["actions"]) | |
367 | defer.returnValue(push_actions) | |
368 | ||
369 | @defer.inlineCallbacks | |
340 | 370 | def get_time_of_last_push_action_before(self, stream_ordering): |
341 | 371 | def f(txn): |
342 | 372 | sql = ( |
19 | 19 | from synapse.events.utils import prune_event |
20 | 20 | |
21 | 21 | from synapse.util.async import ObservableDeferred |
22 | from synapse.util.logcontext import preserve_fn, PreserveLoggingContext | |
22 | from synapse.util.logcontext import ( | |
23 | preserve_fn, PreserveLoggingContext, preserve_context_over_deferred | |
24 | ) | |
23 | 25 | from synapse.util.logutils import log_function |
26 | from synapse.util.metrics import Measure | |
24 | 27 | from synapse.api.constants import EventTypes |
25 | 28 | from synapse.api.errors import SynapseError |
26 | 29 | |
200 | 203 | |
201 | 204 | deferreds = [] |
202 | 205 | for room_id, evs_ctxs in partitioned.items(): |
203 | d = self._event_persist_queue.add_to_queue( | |
206 | d = preserve_fn(self._event_persist_queue.add_to_queue)( | |
204 | 207 | room_id, evs_ctxs, |
205 | 208 | backfilled=backfilled, |
206 | 209 | current_state=None, |
210 | 213 | for room_id in partitioned.keys(): |
211 | 214 | self._maybe_start_persisting(room_id) |
212 | 215 | |
213 | return defer.gatherResults(deferreds, consumeErrors=True) | |
216 | return preserve_context_over_deferred( | |
217 | defer.gatherResults(deferreds, consumeErrors=True) | |
218 | ) | |
214 | 219 | |
215 | 220 | @defer.inlineCallbacks |
216 | 221 | @log_function |
223 | 228 | |
224 | 229 | self._maybe_start_persisting(event.room_id) |
225 | 230 | |
226 | yield deferred | |
231 | yield preserve_context_over_deferred(deferred) | |
227 | 232 | |
228 | 233 | max_persisted_id = yield self._stream_id_gen.get_current_token() |
229 | 234 | defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) |
599 | 604 | "rejections", |
600 | 605 | "redactions", |
601 | 606 | "room_memberships", |
602 | "state_events" | |
607 | "state_events", | |
608 | "topics" | |
603 | 609 | ): |
604 | 610 | txn.executemany( |
605 | 611 | "DELETE FROM %s WHERE event_id = ?" % (table,), |
1085 | 1091 | if not allow_rejected: |
1086 | 1092 | rows[:] = [r for r in rows if not r["rejects"]] |
1087 | 1093 | |
1088 | res = yield defer.gatherResults( | |
1094 | res = yield preserve_context_over_deferred(defer.gatherResults( | |
1089 | 1095 | [ |
1090 | 1096 | preserve_fn(self._get_event_from_row)( |
1091 | 1097 | row["internal_metadata"], row["json"], row["redacts"], |
1094 | 1100 | for row in rows |
1095 | 1101 | ], |
1096 | 1102 | consumeErrors=True |
1097 | ) | |
1103 | )) | |
1098 | 1104 | |
1099 | 1105 | defer.returnValue({ |
1100 | 1106 | e.event.event_id: e |
1130 | 1136 | @defer.inlineCallbacks |
1131 | 1137 | def _get_event_from_row(self, internal_metadata, js, redacted, |
1132 | 1138 | rejected_reason=None): |
1133 | d = json.loads(js) | |
1134 | internal_metadata = json.loads(internal_metadata) | |
1135 | ||
1136 | if rejected_reason: | |
1137 | rejected_reason = yield self._simple_select_one_onecol( | |
1138 | table="rejections", | |
1139 | keyvalues={"event_id": rejected_reason}, | |
1140 | retcol="reason", | |
1141 | desc="_get_event_from_row_rejected_reason", | |
1142 | ) | |
1143 | ||
1144 | original_ev = FrozenEvent( | |
1145 | d, | |
1146 | internal_metadata_dict=internal_metadata, | |
1147 | rejected_reason=rejected_reason, | |
1148 | ) | |
1149 | ||
1150 | redacted_event = None | |
1151 | if redacted: | |
1152 | redacted_event = prune_event(original_ev) | |
1153 | ||
1154 | redaction_id = yield self._simple_select_one_onecol( | |
1155 | table="redactions", | |
1156 | keyvalues={"redacts": redacted_event.event_id}, | |
1157 | retcol="event_id", | |
1158 | desc="_get_event_from_row_redactions", | |
1159 | ) | |
1160 | ||
1161 | redacted_event.unsigned["redacted_by"] = redaction_id | |
1162 | # Get the redaction event. | |
1163 | ||
1164 | because = yield self.get_event( | |
1165 | redaction_id, | |
1166 | check_redacted=False, | |
1167 | allow_none=True, | |
1168 | ) | |
1169 | ||
1170 | if because: | |
1171 | # It's fine to do add the event directly, since get_pdu_json | |
1172 | # will serialise this field correctly | |
1173 | redacted_event.unsigned["redacted_because"] = because | |
1174 | ||
1175 | cache_entry = _EventCacheEntry( | |
1176 | event=original_ev, | |
1177 | redacted_event=redacted_event, | |
1178 | ) | |
1179 | ||
1180 | self._get_event_cache.prefill((original_ev.event_id,), cache_entry) | |
1139 | with Measure(self._clock, "_get_event_from_row"): | |
1140 | d = json.loads(js) | |
1141 | internal_metadata = json.loads(internal_metadata) | |
1142 | ||
1143 | if rejected_reason: | |
1144 | rejected_reason = yield self._simple_select_one_onecol( | |
1145 | table="rejections", | |
1146 | keyvalues={"event_id": rejected_reason}, | |
1147 | retcol="reason", | |
1148 | desc="_get_event_from_row_rejected_reason", | |
1149 | ) | |
1150 | ||
1151 | original_ev = FrozenEvent( | |
1152 | d, | |
1153 | internal_metadata_dict=internal_metadata, | |
1154 | rejected_reason=rejected_reason, | |
1155 | ) | |
1156 | ||
1157 | redacted_event = None | |
1158 | if redacted: | |
1159 | redacted_event = prune_event(original_ev) | |
1160 | ||
1161 | redaction_id = yield self._simple_select_one_onecol( | |
1162 | table="redactions", | |
1163 | keyvalues={"redacts": redacted_event.event_id}, | |
1164 | retcol="event_id", | |
1165 | desc="_get_event_from_row_redactions", | |
1166 | ) | |
1167 | ||
1168 | redacted_event.unsigned["redacted_by"] = redaction_id | |
1169 | # Get the redaction event. | |
1170 | ||
1171 | because = yield self.get_event( | |
1172 | redaction_id, | |
1173 | check_redacted=False, | |
1174 | allow_none=True, | |
1175 | ) | |
1176 | ||
1177 | if because: | |
1178 | # It's fine to do add the event directly, since get_pdu_json | |
1179 | # will serialise this field correctly | |
1180 | redacted_event.unsigned["redacted_because"] = because | |
1181 | ||
1182 | cache_entry = _EventCacheEntry( | |
1183 | event=original_ev, | |
1184 | redacted_event=redacted_event, | |
1185 | ) | |
1186 | ||
1187 | self._get_event_cache.prefill((original_ev.event_id,), cache_entry) | |
1181 | 1188 | |
1182 | 1189 | defer.returnValue(cache_entry) |
1183 | 1190 |
24 | 24 | |
25 | 25 | # Remember to update this number every time a change is made to database |
26 | 26 | # schema files, so the users will be informed on server restarts. |
27 | SCHEMA_VERSION = 33 | |
27 | SCHEMA_VERSION = 34 | |
28 | 28 | |
29 | 29 | dir_path = os.path.abspath(os.path.dirname(__file__)) |
30 | 30 |
188 | 188 | desc="add_presence_list_pending", |
189 | 189 | ) |
190 | 190 | |
191 | @defer.inlineCallbacks | |
192 | 191 | def set_presence_list_accepted(self, observer_localpart, observed_userid): |
193 | result = yield self._simple_update_one( | |
194 | table="presence_list", | |
195 | keyvalues={"user_id": observer_localpart, | |
196 | "observed_user_id": observed_userid}, | |
197 | updatevalues={"accepted": True}, | |
198 | desc="set_presence_list_accepted", | |
199 | ) | |
200 | self.get_presence_list_accepted.invalidate((observer_localpart,)) | |
201 | self.get_presence_list_observers_accepted.invalidate((observed_userid,)) | |
202 | defer.returnValue(result) | |
192 | def update_presence_list_txn(txn): | |
193 | result = self._simple_update_one_txn( | |
194 | txn, | |
195 | table="presence_list", | |
196 | keyvalues={ | |
197 | "user_id": observer_localpart, | |
198 | "observed_user_id": observed_userid | |
199 | }, | |
200 | updatevalues={"accepted": True}, | |
201 | ) | |
202 | ||
203 | self._invalidate_cache_and_stream( | |
204 | txn, self.get_presence_list_accepted, (observer_localpart,) | |
205 | ) | |
206 | self._invalidate_cache_and_stream( | |
207 | txn, self.get_presence_list_observers_accepted, (observed_userid,) | |
208 | ) | |
209 | ||
210 | return result | |
211 | ||
212 | return self.runInteraction( | |
213 | "set_presence_list_accepted", update_presence_list_txn, | |
214 | ) | |
203 | 215 | |
204 | 216 | def get_presence_list(self, observer_localpart, accepted=None): |
205 | 217 | if accepted: |
15 | 15 | from ._base import SQLBaseStore |
16 | 16 | from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList |
17 | 17 | from synapse.push.baserules import list_with_base_rules |
18 | from synapse.api.constants import EventTypes, Membership | |
18 | 19 | from twisted.internet import defer |
19 | 20 | |
20 | 21 | import logging |
47 | 48 | |
48 | 49 | |
49 | 50 | class PushRuleStore(SQLBaseStore): |
50 | @cachedInlineCallbacks(lru=True) | |
51 | @cachedInlineCallbacks() | |
51 | 52 | def get_push_rules_for_user(self, user_id): |
52 | 53 | rows = yield self._simple_select_list( |
53 | 54 | table="push_rules", |
71 | 72 | |
72 | 73 | defer.returnValue(rules) |
73 | 74 | |
74 | @cachedInlineCallbacks(lru=True) | |
75 | @cachedInlineCallbacks() | |
75 | 76 | def get_push_rules_enabled_for_user(self, user_id): |
76 | 77 | results = yield self._simple_select_list( |
77 | 78 | table="push_rules_enable", |
121 | 122 | ) |
122 | 123 | |
123 | 124 | defer.returnValue(results) |
125 | ||
126 | def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): | |
127 | if not state_group: | |
128 | # If state_group is None it means it has yet to be assigned a | |
129 | # state group, i.e. we need to make sure that calls with a state_group | |
130 | # of None don't hit previous cached calls with a None state_group. | |
131 | # To do this we set the state_group to a new object as object() != object() | |
132 | state_group = object() | |
133 | ||
134 | return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) | |
135 | ||
136 | @cachedInlineCallbacks(num_args=2, cache_context=True) | |
137 | def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, | |
138 | cache_context): | |
139 | # We don't use `state_group`, its there so that we can cache based | |
140 | # on it. However, its important that its never None, since two current_state's | |
141 | # with a state_group of None are likely to be different. | |
142 | # See bulk_get_push_rules_for_room for how we work around this. | |
143 | assert state_group is not None | |
144 | ||
145 | # We also will want to generate notifs for other people in the room so | |
146 | # their unread countss are correct in the event stream, but to avoid | |
147 | # generating them for bot / AS users etc, we only do so for people who've | |
148 | # sent a read receipt into the room. | |
149 | local_users_in_room = set( | |
150 | e.state_key for e in current_state.values() | |
151 | if e.type == EventTypes.Member and e.membership == Membership.JOIN | |
152 | and self.hs.is_mine_id(e.state_key) | |
153 | ) | |
154 | ||
155 | # users in the room who have pushers need to get push rules run because | |
156 | # that's how their pushers work | |
157 | if_users_with_pushers = yield self.get_if_users_have_pushers( | |
158 | local_users_in_room, on_invalidate=cache_context.invalidate, | |
159 | ) | |
160 | user_ids = set( | |
161 | uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher | |
162 | ) | |
163 | ||
164 | users_with_receipts = yield self.get_users_with_read_receipts_in_room( | |
165 | room_id, on_invalidate=cache_context.invalidate, | |
166 | ) | |
167 | ||
168 | # any users with pushers must be ours: they have pushers | |
169 | for uid in users_with_receipts: | |
170 | if uid in local_users_in_room: | |
171 | user_ids.add(uid) | |
172 | ||
173 | rules_by_user = yield self.bulk_get_push_rules( | |
174 | user_ids, on_invalidate=cache_context.invalidate, | |
175 | ) | |
176 | ||
177 | rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} | |
178 | ||
179 | defer.returnValue(rules_by_user) | |
124 | 180 | |
125 | 181 | @cachedList(cached_method_name="get_push_rules_enabled_for_user", |
126 | 182 | list_name="user_ids", num_args=1, inlineCallbacks=True) |
134 | 134 | "get_all_updated_pushers", get_all_updated_pushers_txn |
135 | 135 | ) |
136 | 136 | |
137 | @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000) | |
137 | @cachedInlineCallbacks(num_args=1, max_entries=15000) | |
138 | 138 | def get_if_user_has_pusher(self, user_id): |
139 | 139 | result = yield self._simple_select_many_batch( |
140 | 140 | table='pushers', |
94 | 94 | defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) |
95 | 95 | |
96 | 96 | @defer.inlineCallbacks |
97 | def get_receipts_for_user_with_orderings(self, user_id, receipt_type): | |
98 | def f(txn): | |
99 | sql = ( | |
100 | "SELECT rl.room_id, rl.event_id," | |
101 | " e.topological_ordering, e.stream_ordering" | |
102 | " FROM receipts_linearized AS rl" | |
103 | " INNER JOIN events AS e USING (room_id, event_id)" | |
104 | " WHERE rl.room_id = e.room_id" | |
105 | " AND rl.event_id = e.event_id" | |
106 | " AND user_id = ?" | |
107 | ) | |
108 | txn.execute(sql, (user_id,)) | |
109 | return txn.fetchall() | |
110 | rows = yield self.runInteraction( | |
111 | "get_receipts_for_user_with_orderings", f | |
112 | ) | |
113 | defer.returnValue({ | |
114 | row[0]: { | |
115 | "event_id": row[1], | |
116 | "topological_ordering": row[2], | |
117 | "stream_ordering": row[3], | |
118 | } for row in rows | |
119 | }) | |
120 | ||
121 | @defer.inlineCallbacks | |
97 | 122 | def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): |
98 | 123 | """Get receipts for multiple rooms for sending to clients. |
99 | 124 | |
119 | 144 | |
120 | 145 | defer.returnValue([ev for res in results.values() for ev in res]) |
121 | 146 | |
122 | @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True) | |
147 | @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) | |
123 | 148 | def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): |
124 | 149 | """Get receipts for a single room for sending to clients. |
125 | 150 |
92 | 92 | desc="add_refresh_token_to_user", |
93 | 93 | ) |
94 | 94 | |
95 | @defer.inlineCallbacks | |
96 | 95 | def register(self, user_id, token=None, password_hash=None, |
97 | 96 | was_guest=False, make_guest=False, appservice_id=None, |
98 | 97 | create_profile_with_localpart=None, admin=False): |
114 | 113 | Raises: |
115 | 114 | StoreError if the user_id could not be registered. |
116 | 115 | """ |
117 | yield self.runInteraction( | |
116 | return self.runInteraction( | |
118 | 117 | "register", |
119 | 118 | self._register, |
120 | 119 | user_id, |
126 | 125 | create_profile_with_localpart, |
127 | 126 | admin |
128 | 127 | ) |
129 | self.get_user_by_id.invalidate((user_id,)) | |
130 | self.is_guest.invalidate((user_id,)) | |
131 | 128 | |
132 | 129 | def _register( |
133 | 130 | self, |
209 | 206 | (create_profile_with_localpart,) |
210 | 207 | ) |
211 | 208 | |
209 | self._invalidate_cache_and_stream( | |
210 | txn, self.get_user_by_id, (user_id,) | |
211 | ) | |
212 | txn.call_after(self.is_guest.invalidate, (user_id,)) | |
213 | ||
212 | 214 | @cached() |
213 | 215 | def get_user_by_id(self, user_id): |
214 | 216 | return self._simple_select_one( |
235 | 237 | |
236 | 238 | return self.runInteraction("get_users_by_id_case_insensitive", f) |
237 | 239 | |
238 | @defer.inlineCallbacks | |
239 | 240 | def user_set_password_hash(self, user_id, password_hash): |
240 | 241 | """ |
241 | 242 | NB. This does *not* evict any cache because the one use for this |
242 | 243 | removes most of the entries subsequently anyway so it would be |
243 | 244 | pointless. Use flush_user separately. |
244 | 245 | """ |
245 | yield self._simple_update_one('users', { | |
246 | 'name': user_id | |
247 | }, { | |
248 | 'password_hash': password_hash | |
249 | }) | |
250 | self.get_user_by_id.invalidate((user_id,)) | |
251 | ||
252 | @defer.inlineCallbacks | |
253 | def user_delete_access_tokens(self, user_id, except_token_ids=[], | |
246 | def user_set_password_hash_txn(txn): | |
247 | self._simple_update_one_txn( | |
248 | txn, | |
249 | 'users', { | |
250 | 'name': user_id | |
251 | }, | |
252 | { | |
253 | 'password_hash': password_hash | |
254 | } | |
255 | ) | |
256 | self._invalidate_cache_and_stream( | |
257 | txn, self.get_user_by_id, (user_id,) | |
258 | ) | |
259 | return self.runInteraction( | |
260 | "user_set_password_hash", user_set_password_hash_txn | |
261 | ) | |
262 | ||
263 | @defer.inlineCallbacks | |
264 | def user_delete_access_tokens(self, user_id, except_token_id=None, | |
254 | 265 | device_id=None, |
255 | 266 | delete_refresh_tokens=False): |
256 | 267 | """ |
258 | 269 | |
259 | 270 | Args: |
260 | 271 | user_id (str): ID of user the tokens belong to |
261 | except_token_ids (list[str]): list of access_tokens which should | |
272 | except_token_id (str): list of access_tokens IDs which should | |
262 | 273 | *not* be deleted |
263 | 274 | device_id (str|None): ID of device the tokens are associated with. |
264 | 275 | If None, tokens associated with any device (or no device) will |
268 | 279 | Returns: |
269 | 280 | defer.Deferred: |
270 | 281 | """ |
271 | def f(txn, table, except_tokens, call_after_delete): | |
272 | sql = "SELECT token FROM %s WHERE user_id = ?" % table | |
273 | clauses = [user_id] | |
274 | ||
282 | def f(txn): | |
283 | keyvalues = { | |
284 | "user_id": user_id, | |
285 | } | |
275 | 286 | if device_id is not None: |
276 | sql += " AND device_id = ?" | |
277 | clauses.append(device_id) | |
278 | ||
279 | if except_tokens: | |
280 | sql += " AND id NOT IN (%s)" % ( | |
281 | ",".join(["?" for _ in except_tokens]), | |
287 | keyvalues["device_id"] = device_id | |
288 | ||
289 | if delete_refresh_tokens: | |
290 | self._simple_delete_txn( | |
291 | txn, | |
292 | table="refresh_tokens", | |
293 | keyvalues=keyvalues, | |
282 | 294 | ) |
283 | clauses += except_tokens | |
284 | ||
285 | txn.execute(sql, clauses) | |
286 | ||
287 | rows = txn.fetchall() | |
288 | ||
289 | n = 100 | |
290 | chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] | |
291 | for chunk in chunks: | |
292 | if call_after_delete: | |
293 | for row in chunk: | |
294 | txn.call_after(call_after_delete, (row[0],)) | |
295 | ||
296 | txn.execute( | |
297 | "DELETE FROM %s WHERE token in (%s)" % ( | |
298 | table, | |
299 | ",".join(["?" for _ in chunk]), | |
300 | ), [r[0] for r in chunk] | |
295 | ||
296 | items = keyvalues.items() | |
297 | where_clause = " AND ".join(k + " = ?" for k, _ in items) | |
298 | values = [v for _, v in items] | |
299 | if except_token_id: | |
300 | where_clause += " AND id != ?" | |
301 | values.append(except_token_id) | |
302 | ||
303 | txn.execute( | |
304 | "SELECT token FROM access_tokens WHERE %s" % where_clause, | |
305 | values | |
306 | ) | |
307 | rows = self.cursor_to_dict(txn) | |
308 | ||
309 | for row in rows: | |
310 | self._invalidate_cache_and_stream( | |
311 | txn, self.get_user_by_access_token, (row["token"],) | |
301 | 312 | ) |
302 | 313 | |
303 | # delete refresh tokens first, to stop new access tokens being | |
304 | # allocated while our backs are turned | |
305 | if delete_refresh_tokens: | |
306 | yield self.runInteraction( | |
307 | "user_delete_access_tokens", f, | |
308 | table="refresh_tokens", | |
309 | except_tokens=[], | |
310 | call_after_delete=None, | |
314 | txn.execute( | |
315 | "DELETE FROM access_tokens WHERE %s" % where_clause, | |
316 | values | |
311 | 317 | ) |
312 | 318 | |
313 | 319 | yield self.runInteraction( |
314 | 320 | "user_delete_access_tokens", f, |
315 | table="access_tokens", | |
316 | except_tokens=except_token_ids, | |
317 | call_after_delete=self.get_user_by_access_token.invalidate, | |
318 | 321 | ) |
319 | 322 | |
320 | 323 | def delete_access_token(self, access_token): |
327 | 330 | }, |
328 | 331 | ) |
329 | 332 | |
330 | txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) | |
333 | self._invalidate_cache_and_stream( | |
334 | txn, self.get_user_by_access_token, (access_token,) | |
335 | ) | |
331 | 336 | |
332 | 337 | return self.runInteraction("delete_access_token", f) |
333 | 338 |
276 | 276 | user_id, membership_list=[Membership.JOIN], |
277 | 277 | ) |
278 | 278 | |
279 | @defer.inlineCallbacks | |
280 | 279 | def forget(self, user_id, room_id): |
281 | 280 | """Indicate that user_id wishes to discard history for room_id.""" |
282 | 281 | def f(txn): |
291 | 290 | " room_id = ?" |
292 | 291 | ) |
293 | 292 | txn.execute(sql, (user_id, room_id)) |
294 | yield self.runInteraction("forget_membership", f) | |
295 | self.was_forgotten_at.invalidate_all() | |
296 | self.who_forgot_in_room.invalidate_all() | |
297 | self.did_forget.invalidate((user_id, room_id)) | |
293 | ||
294 | txn.call_after(self.was_forgotten_at.invalidate_all) | |
295 | txn.call_after(self.did_forget.invalidate, (user_id, room_id)) | |
296 | self._invalidate_cache_and_stream( | |
297 | txn, self.who_forgot_in_room, (room_id,) | |
298 | ) | |
299 | return self.runInteraction("forget_membership", f) | |
298 | 300 | |
299 | 301 | @cachedInlineCallbacks(num_args=2) |
300 | 302 | def did_forget(self, user_id, room_id): |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | CREATE TABLE IF NOT EXISTS appservice_stream_position( | |
16 | Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. | |
17 | stream_ordering BIGINT, | |
18 | CHECK (Lock='X') | |
19 | ); | |
20 | ||
21 | INSERT INTO appservice_stream_position (stream_ordering) | |
22 | SELECT COALESCE(MAX(stream_ordering), 0) FROM events; |
0 | # Copyright 2016 OpenMarket Ltd | |
1 | # | |
2 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | # you may not use this file except in compliance with the License. | |
4 | # You may obtain a copy of the License at | |
5 | # | |
6 | # http://www.apache.org/licenses/LICENSE-2.0 | |
7 | # | |
8 | # Unless required by applicable law or agreed to in writing, software | |
9 | # distributed under the License is distributed on an "AS IS" BASIS, | |
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | # See the License for the specific language governing permissions and | |
12 | # limitations under the License. | |
13 | ||
14 | from synapse.storage.prepare_database import get_statements | |
15 | from synapse.storage.engines import PostgresEngine | |
16 | ||
17 | import logging | |
18 | ||
19 | logger = logging.getLogger(__name__) | |
20 | ||
21 | ||
22 | # This stream is used to notify replication slaves that some caches have | |
23 | # been invalidated that they cannot infer from the other streams. | |
24 | CREATE_TABLE = """ | |
25 | CREATE TABLE cache_invalidation_stream ( | |
26 | stream_id BIGINT, | |
27 | cache_func TEXT, | |
28 | keys TEXT[], | |
29 | invalidation_ts BIGINT | |
30 | ); | |
31 | ||
32 | CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id); | |
33 | """ | |
34 | ||
35 | ||
36 | def run_create(cur, database_engine, *args, **kwargs): | |
37 | if not isinstance(database_engine, PostgresEngine): | |
38 | return | |
39 | ||
40 | for statement in get_statements(CREATE_TABLE.splitlines()): | |
41 | cur.execute(statement) | |
42 | ||
43 | ||
44 | def run_upgrade(cur, database_engine, *args, **kwargs): | |
45 | pass |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | DELETE FROM push_rules WHERE rule_id = 'global/override/.m.rule.contains_display_name'; | |
16 | UPDATE push_rules SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; | |
17 | ||
18 | DELETE FROM push_rules_enable WHERE rule_id = 'global/override/.m.rule.contains_display_name'; | |
19 | UPDATE push_rules_enable SET rule_id = 'global/override/.m.rule.contains_display_name' WHERE rule_id = 'global/underride/.m.rule.contains_display_name'; |
0 | # Copyright 2016 OpenMarket Ltd | |
1 | # | |
2 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | # you may not use this file except in compliance with the License. | |
4 | # You may obtain a copy of the License at | |
5 | # | |
6 | # http://www.apache.org/licenses/LICENSE-2.0 | |
7 | # | |
8 | # Unless required by applicable law or agreed to in writing, software | |
9 | # distributed under the License is distributed on an "AS IS" BASIS, | |
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | # See the License for the specific language governing permissions and | |
12 | # limitations under the License. | |
13 | ||
14 | from synapse.storage.engines import PostgresEngine | |
15 | ||
16 | import logging | |
17 | ||
18 | logger = logging.getLogger(__name__) | |
19 | ||
20 | ||
21 | def run_create(cur, database_engine, *args, **kwargs): | |
22 | if isinstance(database_engine, PostgresEngine): | |
23 | cur.execute("TRUNCATE received_transactions") | |
24 | else: | |
25 | cur.execute("DELETE FROM received_transactions") | |
26 | ||
27 | cur.execute("CREATE INDEX received_transactions_ts ON received_transactions(ts)") | |
28 | ||
29 | ||
30 | def run_upgrade(cur, database_engine, *args, **kwargs): | |
31 | pass |
24 | 24 | class SignatureStore(SQLBaseStore): |
25 | 25 | """Persistence for event signatures and hashes""" |
26 | 26 | |
27 | @cached(lru=True) | |
27 | @cached() | |
28 | 28 | def get_event_reference_hash(self, event_id): |
29 | 29 | return self._get_event_reference_hashes_txn(event_id) |
30 | 30 |
173 | 173 | return [r[0] for r in results] |
174 | 174 | return self.runInteraction("get_current_state_for_key", f) |
175 | 175 | |
176 | @cached(num_args=2, lru=True, max_entries=1000) | |
176 | @cached(num_args=2, max_entries=1000) | |
177 | 177 | def _get_state_group_from_group(self, group, types): |
178 | 178 | raise NotImplementedError() |
179 | 179 | |
271 | 271 | state_map = yield self.get_state_for_events([event_id], types) |
272 | 272 | defer.returnValue(state_map[event_id]) |
273 | 273 | |
274 | @cached(num_args=2, lru=True, max_entries=10000) | |
274 | @cached(num_args=2, max_entries=10000) | |
275 | 275 | def _get_state_group_for_event(self, room_id, event_id): |
276 | 276 | return self._simple_select_one_onecol( |
277 | 277 | table="event_to_state_groups", |
38 | 38 | from synapse.util.caches.descriptors import cached |
39 | 39 | from synapse.api.constants import EventTypes |
40 | 40 | from synapse.types import RoomStreamToken |
41 | from synapse.util.logcontext import preserve_fn | |
41 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |
42 | 42 | from synapse.storage.engines import PostgresEngine, Sqlite3Engine |
43 | 43 | |
44 | 44 | import logging |
233 | 233 | results = {} |
234 | 234 | room_ids = list(room_ids) |
235 | 235 | for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): |
236 | res = yield defer.gatherResults([ | |
236 | res = yield preserve_context_over_deferred(defer.gatherResults([ | |
237 | 237 | preserve_fn(self.get_room_events_stream_for_room)( |
238 | 238 | room_id, from_key, to_key, limit, order=order, |
239 | 239 | ) |
240 | 240 | for room_id in rm_ids |
241 | ]) | |
241 | ])) | |
242 | 242 | results.update(dict(zip(rm_ids, res))) |
243 | 243 | |
244 | 244 | defer.returnValue(results) |
61 | 61 | self.last_transaction = {} |
62 | 62 | |
63 | 63 | reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns) |
64 | hs.get_clock().looping_call( | |
65 | self._persist_in_mem_txns, | |
66 | 1000, | |
67 | ) | |
64 | self._clock.looping_call(self._persist_in_mem_txns, 1000) | |
65 | ||
66 | self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) | |
68 | 67 | |
69 | 68 | def get_received_txn_response(self, transaction_id, origin): |
70 | 69 | """For an incoming transaction from a given origin, check if we have |
126 | 125 | "origin": origin, |
127 | 126 | "response_code": code, |
128 | 127 | "response_json": buffer(encode_canonical_json(response_dict)), |
128 | "ts": self._clock.time_msec(), | |
129 | 129 | }, |
130 | 130 | or_ignore=True, |
131 | 131 | desc="set_received_txn_response", |
382 | 382 | yield self.runInteraction("_persist_in_mem_txns", f) |
383 | 383 | except: |
384 | 384 | logger.exception("Failed to persist transactions!") |
385 | ||
386 | def _cleanup_transactions(self): | |
387 | now = self._clock.time_msec() | |
388 | month_ago = now - 30 * 24 * 60 * 60 * 1000 | |
389 | ||
390 | def _cleanup_transactions_txn(txn): | |
391 | txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) | |
392 | ||
393 | return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn) |
268 | 268 | return "t%d-%d" % (self.topological, self.stream) |
269 | 269 | else: |
270 | 270 | return "s%d" % (self.stream,) |
271 | ||
272 | ||
273 | # Some arbitrary constants used for internal API enumerations. Don't rely on | |
274 | # exact values; always pass or compare symbolically | |
275 | class ThirdPartyEntityKind(object): | |
276 | USER = 'user' | |
277 | LOCATION = 'location' |
145 | 145 | except StopIteration: |
146 | 146 | pass |
147 | 147 | |
148 | return defer.gatherResults([ | |
148 | return preserve_context_over_deferred(defer.gatherResults([ | |
149 | 149 | preserve_fn(_concurrently_execute_inner)() |
150 | 150 | for _ in xrange(limit) |
151 | ], consumeErrors=True).addErrback(unwrapFirstError) | |
151 | ], consumeErrors=True)).addErrback(unwrapFirstError) | |
152 | 152 | |
153 | 153 | |
154 | 154 | class Linearizer(object): |
180 | 180 | self.key_to_defer[key] = new_defer |
181 | 181 | |
182 | 182 | if current_defer: |
183 | yield preserve_context_over_deferred(current_defer) | |
183 | with PreserveLoggingContext(): | |
184 | yield current_defer | |
184 | 185 | |
185 | 186 | @contextmanager |
186 | 187 | def _ctx_manager(): |
263 | 264 | curr_readers.clear() |
264 | 265 | self.key_to_current_writer[key] = new_defer |
265 | 266 | |
266 | yield defer.gatherResults(to_wait_on) | |
267 | yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) | |
267 | 268 | |
268 | 269 | @contextmanager |
269 | 270 | def _ctx_manager(): |
24 | 24 | from . import DEBUG_CACHES, register_cache |
25 | 25 | |
26 | 26 | from twisted.internet import defer |
27 | ||
28 | from collections import OrderedDict | |
27 | from collections import namedtuple | |
29 | 28 | |
30 | 29 | import os |
31 | 30 | import functools |
53 | 52 | "metrics", |
54 | 53 | ) |
55 | 54 | |
56 | def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): | |
57 | if lru: | |
58 | cache_type = TreeCache if tree else dict | |
59 | self.cache = LruCache( | |
60 | max_size=max_entries, keylen=keylen, cache_type=cache_type | |
61 | ) | |
62 | self.max_entries = None | |
63 | else: | |
64 | self.cache = OrderedDict() | |
65 | self.max_entries = max_entries | |
55 | def __init__(self, name, max_entries=1000, keylen=1, tree=False): | |
56 | cache_type = TreeCache if tree else dict | |
57 | self.cache = LruCache( | |
58 | max_size=max_entries, keylen=keylen, cache_type=cache_type | |
59 | ) | |
66 | 60 | |
67 | 61 | self.name = name |
68 | 62 | self.keylen = keylen |
80 | 74 | "Cache objects can only be accessed from the main thread" |
81 | 75 | ) |
82 | 76 | |
83 | def get(self, key, default=_CacheSentinel): | |
84 | val = self.cache.get(key, _CacheSentinel) | |
77 | def get(self, key, default=_CacheSentinel, callback=None): | |
78 | val = self.cache.get(key, _CacheSentinel, callback=callback) | |
85 | 79 | if val is not _CacheSentinel: |
86 | 80 | self.metrics.inc_hits() |
87 | 81 | return val |
93 | 87 | else: |
94 | 88 | return default |
95 | 89 | |
96 | def update(self, sequence, key, value): | |
90 | def update(self, sequence, key, value, callback=None): | |
97 | 91 | self.check_thread() |
98 | 92 | if self.sequence == sequence: |
99 | 93 | # Only update the cache if the caches sequence number matches the |
100 | 94 | # number that the cache had before the SELECT was started (SYN-369) |
101 | self.prefill(key, value) | |
102 | ||
103 | def prefill(self, key, value): | |
104 | if self.max_entries is not None: | |
105 | while len(self.cache) >= self.max_entries: | |
106 | self.cache.popitem(last=False) | |
107 | ||
108 | self.cache[key] = value | |
95 | self.prefill(key, value, callback=callback) | |
96 | ||
97 | def prefill(self, key, value, callback=None): | |
98 | self.cache.set(key, value, callback=callback) | |
109 | 99 | |
110 | 100 | def invalidate(self, key): |
111 | 101 | self.check_thread() |
150 | 140 | The wrapped function has another additional callable, called "prefill", |
151 | 141 | which can be used to insert values into the cache specifically, without |
152 | 142 | calling the calculation function. |
143 | ||
144 | Cached functions can be "chained" (i.e. a cached function can call other cached | |
145 | functions and get appropriately invalidated when they called caches are | |
146 | invalidated) by adding a special "cache_context" argument to the function | |
147 | and passing that as a kwarg to all caches called. For example:: | |
148 | ||
149 | @cachedInlineCallbacks(cache_context=True) | |
150 | def foo(self, key, cache_context): | |
151 | r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) | |
152 | r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) | |
153 | defer.returnValue(r1 + r2) | |
154 | ||
153 | 155 | """ |
154 | def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, | |
155 | inlineCallbacks=False): | |
156 | def __init__(self, orig, max_entries=1000, num_args=1, tree=False, | |
157 | inlineCallbacks=False, cache_context=False): | |
156 | 158 | max_entries = int(max_entries * CACHE_SIZE_FACTOR) |
157 | 159 | |
158 | 160 | self.orig = orig |
164 | 166 | |
165 | 167 | self.max_entries = max_entries |
166 | 168 | self.num_args = num_args |
167 | self.lru = lru | |
168 | 169 | self.tree = tree |
169 | 170 | |
170 | self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] | |
171 | all_args = inspect.getargspec(orig) | |
172 | self.arg_names = all_args.args[1:num_args + 1] | |
173 | ||
174 | if "cache_context" in all_args.args: | |
175 | if not cache_context: | |
176 | raise ValueError( | |
177 | "Cannot have a 'cache_context' arg without setting" | |
178 | " cache_context=True" | |
179 | ) | |
180 | try: | |
181 | self.arg_names.remove("cache_context") | |
182 | except ValueError: | |
183 | pass | |
184 | elif cache_context: | |
185 | raise ValueError( | |
186 | "Cannot have cache_context=True without having an arg" | |
187 | " named `cache_context`" | |
188 | ) | |
189 | ||
190 | self.add_cache_context = cache_context | |
171 | 191 | |
172 | 192 | if len(self.arg_names) < self.num_args: |
173 | 193 | raise Exception( |
174 | 194 | "Not enough explicit positional arguments to key off of for %r." |
175 | " (@cached cannot key off of *args or **kwars)" | |
195 | " (@cached cannot key off of *args or **kwargs)" | |
176 | 196 | % (orig.__name__,) |
177 | 197 | ) |
178 | 198 | |
181 | 201 | name=self.orig.__name__, |
182 | 202 | max_entries=self.max_entries, |
183 | 203 | keylen=self.num_args, |
184 | lru=self.lru, | |
185 | 204 | tree=self.tree, |
186 | 205 | ) |
187 | 206 | |
188 | 207 | @functools.wraps(self.orig) |
189 | 208 | def wrapped(*args, **kwargs): |
209 | # If we're passed a cache_context then we'll want to call its invalidate() | |
210 | # whenever we are invalidated | |
211 | invalidate_callback = kwargs.pop("on_invalidate", None) | |
212 | ||
213 | # Add temp cache_context so inspect.getcallargs doesn't explode | |
214 | if self.add_cache_context: | |
215 | kwargs["cache_context"] = None | |
216 | ||
190 | 217 | arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) |
191 | 218 | cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) |
219 | ||
220 | # Add our own `cache_context` to argument list if the wrapped function | |
221 | # has asked for one | |
222 | if self.add_cache_context: | |
223 | kwargs["cache_context"] = _CacheContext(cache, cache_key) | |
224 | ||
192 | 225 | try: |
193 | cached_result_d = cache.get(cache_key) | |
226 | cached_result_d = cache.get(cache_key, callback=invalidate_callback) | |
194 | 227 | |
195 | 228 | observer = cached_result_d.observe() |
196 | 229 | if DEBUG_CACHES: |
227 | 260 | ret.addErrback(onErr) |
228 | 261 | |
229 | 262 | ret = ObservableDeferred(ret, consumeErrors=True) |
230 | cache.update(sequence, cache_key, ret) | |
263 | cache.update(sequence, cache_key, ret, callback=invalidate_callback) | |
231 | 264 | |
232 | 265 | return preserve_context_over_deferred(ret.observe()) |
233 | 266 | |
296 | 329 | |
297 | 330 | @functools.wraps(self.orig) |
298 | 331 | def wrapped(*args, **kwargs): |
332 | # If we're passed a cache_context then we'll want to call its invalidate() | |
333 | # whenever we are invalidated | |
334 | invalidate_callback = kwargs.pop("on_invalidate", None) | |
335 | ||
299 | 336 | arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) |
300 | 337 | keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] |
301 | 338 | list_args = arg_dict[self.list_name] |
310 | 347 | key[self.list_pos] = arg |
311 | 348 | |
312 | 349 | try: |
313 | res = cache.get(tuple(key)) | |
350 | res = cache.get(tuple(key), callback=invalidate_callback) | |
314 | 351 | if not res.has_succeeded(): |
315 | 352 | res = res.observe() |
316 | 353 | res.addCallback(lambda r, arg: (arg, r), arg) |
344 | 381 | |
345 | 382 | key = list(keyargs) |
346 | 383 | key[self.list_pos] = arg |
347 | cache.update(sequence, tuple(key), observer) | |
384 | cache.update( | |
385 | sequence, tuple(key), observer, | |
386 | callback=invalidate_callback | |
387 | ) | |
348 | 388 | |
349 | 389 | def invalidate(f, key): |
350 | 390 | cache.invalidate(key) |
375 | 415 | return wrapped |
376 | 416 | |
377 | 417 | |
378 | def cached(max_entries=1000, num_args=1, lru=True, tree=False): | |
418 | class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): | |
419 | def invalidate(self): | |
420 | self.cache.invalidate(self.key) | |
421 | ||
422 | ||
423 | def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): | |
379 | 424 | return lambda orig: CacheDescriptor( |
380 | 425 | orig, |
381 | 426 | max_entries=max_entries, |
382 | 427 | num_args=num_args, |
383 | lru=lru, | |
384 | 428 | tree=tree, |
429 | cache_context=cache_context, | |
385 | 430 | ) |
386 | 431 | |
387 | 432 | |
388 | def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): | |
433 | def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): | |
389 | 434 | return lambda orig: CacheDescriptor( |
390 | 435 | orig, |
391 | 436 | max_entries=max_entries, |
392 | 437 | num_args=num_args, |
393 | lru=lru, | |
394 | 438 | tree=tree, |
395 | 439 | inlineCallbacks=True, |
440 | cache_context=cache_context, | |
396 | 441 | ) |
397 | 442 | |
398 | 443 |
29 | 29 | |
30 | 30 | |
31 | 31 | class _Node(object): |
32 | __slots__ = ["prev_node", "next_node", "key", "value"] | |
33 | ||
34 | def __init__(self, prev_node, next_node, key, value): | |
32 | __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] | |
33 | ||
34 | def __init__(self, prev_node, next_node, key, value, callbacks=set()): | |
35 | 35 | self.prev_node = prev_node |
36 | 36 | self.next_node = next_node |
37 | 37 | self.key = key |
38 | 38 | self.value = value |
39 | self.callbacks = callbacks | |
39 | 40 | |
40 | 41 | |
41 | 42 | class LruCache(object): |
43 | 44 | Least-recently-used cache. |
44 | 45 | Supports del_multi only if cache_type=TreeCache |
45 | 46 | If cache_type=TreeCache, all keys must be tuples. |
47 | ||
48 | Can also set callbacks on objects when getting/setting which are fired | |
49 | when that key gets invalidated/evicted. | |
46 | 50 | """ |
47 | 51 | def __init__(self, max_size, keylen=1, cache_type=dict): |
48 | 52 | cache = cache_type() |
61 | 65 | |
62 | 66 | return inner |
63 | 67 | |
64 | def add_node(key, value): | |
68 | def add_node(key, value, callbacks=set()): | |
65 | 69 | prev_node = list_root |
66 | 70 | next_node = prev_node.next_node |
67 | node = _Node(prev_node, next_node, key, value) | |
71 | node = _Node(prev_node, next_node, key, value, callbacks) | |
68 | 72 | prev_node.next_node = node |
69 | 73 | next_node.prev_node = node |
70 | 74 | cache[key] = node |
87 | 91 | prev_node.next_node = next_node |
88 | 92 | next_node.prev_node = prev_node |
89 | 93 | |
90 | @synchronized | |
91 | def cache_get(key, default=None): | |
94 | for cb in node.callbacks: | |
95 | cb() | |
96 | node.callbacks.clear() | |
97 | ||
98 | @synchronized | |
99 | def cache_get(key, default=None, callback=None): | |
92 | 100 | node = cache.get(key, None) |
93 | 101 | if node is not None: |
94 | 102 | move_node_to_front(node) |
103 | if callback: | |
104 | node.callbacks.add(callback) | |
95 | 105 | return node.value |
96 | 106 | else: |
97 | 107 | return default |
98 | 108 | |
99 | 109 | @synchronized |
100 | def cache_set(key, value): | |
110 | def cache_set(key, value, callback=None): | |
101 | 111 | node = cache.get(key, None) |
102 | 112 | if node is not None: |
113 | if value != node.value: | |
114 | for cb in node.callbacks: | |
115 | cb() | |
116 | node.callbacks.clear() | |
117 | ||
118 | if callback: | |
119 | node.callbacks.add(callback) | |
120 | ||
103 | 121 | move_node_to_front(node) |
104 | 122 | node.value = value |
105 | 123 | else: |
106 | add_node(key, value) | |
124 | if callback: | |
125 | callbacks = set([callback]) | |
126 | else: | |
127 | callbacks = set() | |
128 | add_node(key, value, callbacks) | |
107 | 129 | if len(cache) > max_size: |
108 | 130 | todelete = list_root.prev_node |
109 | 131 | delete_node(todelete) |
147 | 169 | def cache_clear(): |
148 | 170 | list_root.next_node = list_root |
149 | 171 | list_root.prev_node = list_root |
172 | for node in cache.values(): | |
173 | for cb in node.callbacks: | |
174 | cb() | |
150 | 175 | cache.clear() |
151 | 176 | |
152 | 177 | @synchronized |
63 | 63 | self.size -= cnt |
64 | 64 | return popped |
65 | 65 | |
66 | def values(self): | |
67 | return [e.value for e in self.root.values()] | |
68 | ||
66 | 69 | def __len__(self): |
67 | 70 | return self.size |
68 | 71 |
296 | 296 | return res |
297 | 297 | |
298 | 298 | |
299 | def preserve_context_over_deferred(deferred): | |
299 | def preserve_context_over_deferred(deferred, context=None): | |
300 | 300 | """Given a deferred wrap it such that any callbacks added later to it will |
301 | 301 | be invoked with the current context. |
302 | 302 | """ |
303 | current_context = LoggingContext.current_context() | |
304 | d = _PreservingContextDeferred(current_context) | |
303 | if context is None: | |
304 | context = LoggingContext.current_context() | |
305 | d = _PreservingContextDeferred(context) | |
305 | 306 | deferred.chainDeferred(d) |
306 | 307 | return d |
307 | 308 | |
315 | 316 | |
316 | 317 | def g(*args, **kwargs): |
317 | 318 | with PreserveLoggingContext(current): |
318 | return f(*args, **kwargs) | |
319 | ||
319 | res = f(*args, **kwargs) | |
320 | if isinstance(res, defer.Deferred): | |
321 | return preserve_context_over_deferred( | |
322 | res, context=LoggingContext.sentinel | |
323 | ) | |
324 | else: | |
325 | return res | |
320 | 326 | return g |
321 | 327 | |
322 | 328 |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | from twisted.internet import defer | |
15 | 16 | |
16 | 17 | from synapse.util.logcontext import LoggingContext |
17 | 18 | import synapse.metrics |
18 | 19 | |
20 | from functools import wraps | |
19 | 21 | import logging |
20 | 22 | |
21 | 23 | |
46 | 48 | ) |
47 | 49 | |
48 | 50 | |
51 | def measure_func(name): | |
52 | def wrapper(func): | |
53 | @wraps(func) | |
54 | @defer.inlineCallbacks | |
55 | def measured_func(self, *args, **kwargs): | |
56 | with Measure(self.clock, name): | |
57 | r = yield func(self, *args, **kwargs) | |
58 | defer.returnValue(r) | |
59 | return measured_func | |
60 | return wrapper | |
61 | ||
62 | ||
49 | 63 | class Measure(object): |
50 | 64 | __slots__ = [ |
51 | 65 | "clock", "name", "start_context", "start", "new_context", "ru_utime", |
63 | 77 | self.start = self.clock.time_msec() |
64 | 78 | self.start_context = LoggingContext.current_context() |
65 | 79 | if not self.start_context: |
66 | logger.warn("Entered Measure without log context: %s", self.name) | |
67 | 80 | self.start_context = LoggingContext("Measure") |
68 | 81 | self.start_context.__enter__() |
69 | 82 | self.created_context = True |
73 | 86 | self.db_txn_duration = self.start_context.db_txn_duration |
74 | 87 | |
75 | 88 | def __exit__(self, exc_type, exc_val, exc_tb): |
76 | if exc_type is not None or not self.start_context: | |
89 | if isinstance(exc_type, Exception) or not self.start_context: | |
77 | 90 | return |
78 | 91 | |
79 | 92 | duration = self.clock.time_msec() - self.start |
84 | 97 | if context != self.start_context: |
85 | 98 | logger.warn( |
86 | 99 | "Context has unexpectedly changed from '%s' to '%s'. (%r)", |
87 | context, self.start_context, self.name | |
100 | self.start_context, context, self.name | |
88 | 101 | ) |
89 | 102 | return |
90 | 103 |
16 | 16 | |
17 | 17 | from synapse.api.constants import Membership, EventTypes |
18 | 18 | |
19 | from synapse.util.logcontext import preserve_fn | |
19 | from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred | |
20 | 20 | |
21 | 21 | import logging |
22 | 22 | |
54 | 54 | given events |
55 | 55 | events ([synapse.events.EventBase]): list of events to filter |
56 | 56 | """ |
57 | forgotten = yield defer.gatherResults([ | |
57 | forgotten = yield preserve_context_over_deferred(defer.gatherResults([ | |
58 | 58 | preserve_fn(store.who_forgot_in_room)( |
59 | 59 | room_id, |
60 | 60 | ) |
61 | 61 | for room_id in frozenset(e.room_id for e in events) |
62 | ], consumeErrors=True) | |
62 | ], consumeErrors=True)) | |
63 | 63 | |
64 | 64 | # Set of membership event_ids that have been forgotten |
65 | 65 | event_id_forgotten = frozenset( |
13 | 13 | # limitations under the License. |
14 | 14 | from synapse.appservice import ApplicationService |
15 | 15 | |
16 | from twisted.internet import defer | |
17 | ||
16 | 18 | from mock import Mock |
17 | 19 | from tests import unittest |
18 | 20 | |
41 | 43 | type="m.something", room_id="!foo:bar", sender="@someone:somewhere" |
42 | 44 | ) |
43 | 45 | |
46 | self.store = Mock() | |
47 | ||
48 | @defer.inlineCallbacks | |
44 | 49 | def test_regex_user_id_prefix_match(self): |
45 | 50 | self.service.namespaces[ApplicationService.NS_USERS].append( |
46 | 51 | _regex("@irc_.*") |
47 | 52 | ) |
48 | 53 | self.event.sender = "@irc_foobar:matrix.org" |
49 | self.assertTrue(self.service.is_interested(self.event)) | |
50 | ||
54 | self.assertTrue((yield self.service.is_interested(self.event))) | |
55 | ||
56 | @defer.inlineCallbacks | |
51 | 57 | def test_regex_user_id_prefix_no_match(self): |
52 | 58 | self.service.namespaces[ApplicationService.NS_USERS].append( |
53 | 59 | _regex("@irc_.*") |
54 | 60 | ) |
55 | 61 | self.event.sender = "@someone_else:matrix.org" |
56 | self.assertFalse(self.service.is_interested(self.event)) | |
57 | ||
62 | self.assertFalse((yield self.service.is_interested(self.event))) | |
63 | ||
64 | @defer.inlineCallbacks | |
58 | 65 | def test_regex_room_member_is_checked(self): |
59 | 66 | self.service.namespaces[ApplicationService.NS_USERS].append( |
60 | 67 | _regex("@irc_.*") |
62 | 69 | self.event.sender = "@someone_else:matrix.org" |
63 | 70 | self.event.type = "m.room.member" |
64 | 71 | self.event.state_key = "@irc_foobar:matrix.org" |
65 | self.assertTrue(self.service.is_interested(self.event)) | |
66 | ||
72 | self.assertTrue((yield self.service.is_interested(self.event))) | |
73 | ||
74 | @defer.inlineCallbacks | |
67 | 75 | def test_regex_room_id_match(self): |
68 | 76 | self.service.namespaces[ApplicationService.NS_ROOMS].append( |
69 | 77 | _regex("!some_prefix.*some_suffix:matrix.org") |
70 | 78 | ) |
71 | 79 | self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" |
72 | self.assertTrue(self.service.is_interested(self.event)) | |
73 | ||
80 | self.assertTrue((yield self.service.is_interested(self.event))) | |
81 | ||
82 | @defer.inlineCallbacks | |
74 | 83 | def test_regex_room_id_no_match(self): |
75 | 84 | self.service.namespaces[ApplicationService.NS_ROOMS].append( |
76 | 85 | _regex("!some_prefix.*some_suffix:matrix.org") |
77 | 86 | ) |
78 | 87 | self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" |
79 | self.assertFalse(self.service.is_interested(self.event)) | |
80 | ||
88 | self.assertFalse((yield self.service.is_interested(self.event))) | |
89 | ||
90 | @defer.inlineCallbacks | |
81 | 91 | def test_regex_alias_match(self): |
82 | 92 | self.service.namespaces[ApplicationService.NS_ALIASES].append( |
83 | 93 | _regex("#irc_.*:matrix.org") |
84 | 94 | ) |
85 | self.assertTrue(self.service.is_interested( | |
86 | self.event, | |
87 | aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"] | |
88 | )) | |
95 | self.store.get_aliases_for_room.return_value = [ | |
96 | "#irc_foobar:matrix.org", "#athing:matrix.org" | |
97 | ] | |
98 | self.store.get_users_in_room.return_value = [] | |
99 | self.assertTrue((yield self.service.is_interested( | |
100 | self.event, self.store | |
101 | ))) | |
89 | 102 | |
90 | 103 | def test_non_exclusive_alias(self): |
91 | 104 | self.service.namespaces[ApplicationService.NS_ALIASES].append( |
135 | 148 | "!irc_foobar:matrix.org" |
136 | 149 | )) |
137 | 150 | |
151 | @defer.inlineCallbacks | |
138 | 152 | def test_regex_alias_no_match(self): |
139 | 153 | self.service.namespaces[ApplicationService.NS_ALIASES].append( |
140 | 154 | _regex("#irc_.*:matrix.org") |
141 | 155 | ) |
142 | self.assertFalse(self.service.is_interested( | |
143 | self.event, | |
144 | aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"] | |
145 | )) | |
146 | ||
156 | self.store.get_aliases_for_room.return_value = [ | |
157 | "#xmpp_foobar:matrix.org", "#athing:matrix.org" | |
158 | ] | |
159 | self.store.get_users_in_room.return_value = [] | |
160 | self.assertFalse((yield self.service.is_interested( | |
161 | self.event, self.store | |
162 | ))) | |
163 | ||
164 | @defer.inlineCallbacks | |
147 | 165 | def test_regex_multiple_matches(self): |
148 | 166 | self.service.namespaces[ApplicationService.NS_ALIASES].append( |
149 | 167 | _regex("#irc_.*:matrix.org") |
152 | 170 | _regex("@irc_.*") |
153 | 171 | ) |
154 | 172 | self.event.sender = "@irc_foobar:matrix.org" |
155 | self.assertTrue(self.service.is_interested( | |
156 | self.event, | |
157 | aliases_for_event=["#irc_barfoo:matrix.org"] | |
158 | )) | |
159 | ||
160 | def test_restrict_to_rooms(self): | |
161 | self.service.namespaces[ApplicationService.NS_ROOMS].append( | |
162 | _regex("!flibble_.*:matrix.org") | |
163 | ) | |
164 | self.service.namespaces[ApplicationService.NS_USERS].append( | |
165 | _regex("@irc_.*") | |
166 | ) | |
167 | self.event.sender = "@irc_foobar:matrix.org" | |
168 | self.event.room_id = "!wibblewoo:matrix.org" | |
169 | self.assertFalse(self.service.is_interested( | |
170 | self.event, | |
171 | restrict_to=ApplicationService.NS_ROOMS | |
172 | )) | |
173 | ||
174 | def test_restrict_to_aliases(self): | |
175 | self.service.namespaces[ApplicationService.NS_ALIASES].append( | |
176 | _regex("#xmpp_.*:matrix.org") | |
177 | ) | |
178 | self.service.namespaces[ApplicationService.NS_USERS].append( | |
179 | _regex("@irc_.*") | |
180 | ) | |
181 | self.event.sender = "@irc_foobar:matrix.org" | |
182 | self.assertFalse(self.service.is_interested( | |
183 | self.event, | |
184 | restrict_to=ApplicationService.NS_ALIASES, | |
185 | aliases_for_event=["#irc_barfoo:matrix.org"] | |
186 | )) | |
187 | ||
188 | def test_restrict_to_senders(self): | |
189 | self.service.namespaces[ApplicationService.NS_ALIASES].append( | |
190 | _regex("#xmpp_.*:matrix.org") | |
191 | ) | |
192 | self.service.namespaces[ApplicationService.NS_USERS].append( | |
193 | _regex("@irc_.*") | |
194 | ) | |
195 | self.event.sender = "@xmpp_foobar:matrix.org" | |
196 | self.assertFalse(self.service.is_interested( | |
197 | self.event, | |
198 | restrict_to=ApplicationService.NS_USERS, | |
199 | aliases_for_event=["#xmpp_barfoo:matrix.org"] | |
200 | )) | |
201 | ||
173 | self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] | |
174 | self.store.get_users_in_room.return_value = [] | |
175 | self.assertTrue((yield self.service.is_interested( | |
176 | self.event, self.store | |
177 | ))) | |
178 | ||
179 | @defer.inlineCallbacks | |
202 | 180 | def test_interested_in_self(self): |
203 | 181 | # make sure invites get through |
204 | 182 | self.service.sender = "@appservice:name" |
210 | 188 | "membership": "invite" |
211 | 189 | } |
212 | 190 | self.event.state_key = self.service.sender |
213 | self.assertTrue(self.service.is_interested(self.event)) | |
214 | ||
191 | self.assertTrue((yield self.service.is_interested(self.event))) | |
192 | ||
193 | @defer.inlineCallbacks | |
215 | 194 | def test_member_list_match(self): |
216 | 195 | self.service.namespaces[ApplicationService.NS_USERS].append( |
217 | 196 | _regex("@irc_.*") |
218 | 197 | ) |
219 | join_list = [ | |
198 | self.store.get_users_in_room.return_value = [ | |
220 | 199 | "@alice:here", |
221 | 200 | "@irc_fo:here", # AS user |
222 | 201 | "@bob:here", |
223 | 202 | ] |
203 | self.store.get_aliases_for_room.return_value = [] | |
224 | 204 | |
225 | 205 | self.event.sender = "@xmpp_foobar:matrix.org" |
226 | self.assertTrue(self.service.is_interested( | |
227 | event=self.event, | |
228 | member_list=join_list | |
229 | )) | |
206 | self.assertTrue((yield self.service.is_interested( | |
207 | event=self.event, store=self.store | |
208 | ))) |
192 | 192 | |
193 | 193 | def setUp(self): |
194 | 194 | self.txn_ctrl = Mock() |
195 | self.queuer = _ServiceQueuer(self.txn_ctrl) | |
195 | self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) | |
196 | 196 | |
197 | 197 | def test_send_single_event_no_queue(self): |
198 | 198 | # Expect the event to be sent immediately. |
14 | 14 | |
15 | 15 | from twisted.internet import defer |
16 | 16 | from .. import unittest |
17 | from tests.utils import MockClock | |
17 | 18 | |
18 | 19 | from synapse.handlers.appservice import ApplicationServicesHandler |
19 | 20 | |
31 | 32 | hs.get_datastore = Mock(return_value=self.mock_store) |
32 | 33 | hs.get_application_service_api = Mock(return_value=self.mock_as_api) |
33 | 34 | hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) |
35 | hs.get_clock.return_value = MockClock() | |
34 | 36 | self.handler = ApplicationServicesHandler(hs) |
35 | 37 | |
36 | 38 | @defer.inlineCallbacks |
50 | 52 | type="m.room.message", |
51 | 53 | room_id="!foo:bar" |
52 | 54 | ) |
55 | self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) | |
53 | 56 | self.mock_as_api.push = Mock() |
54 | yield self.handler.notify_interested_services(event) | |
57 | yield self.handler.notify_interested_services(0) | |
55 | 58 | self.mock_scheduler.submit_event_for_as.assert_called_once_with( |
56 | 59 | interested_service, event |
57 | 60 | ) |
71 | 74 | ) |
72 | 75 | self.mock_as_api.push = Mock() |
73 | 76 | self.mock_as_api.query_user = Mock() |
74 | yield self.handler.notify_interested_services(event) | |
77 | self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) | |
78 | yield self.handler.notify_interested_services(0) | |
75 | 79 | self.mock_as_api.query_user.assert_called_once_with( |
76 | 80 | services[0], user_id |
77 | 81 | ) |
93 | 97 | ) |
94 | 98 | self.mock_as_api.push = Mock() |
95 | 99 | self.mock_as_api.query_user = Mock() |
96 | yield self.handler.notify_interested_services(event) | |
100 | self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) | |
101 | yield self.handler.notify_interested_services(0) | |
97 | 102 | self.assertFalse( |
98 | 103 | self.mock_as_api.query_user.called, |
99 | 104 | "query_user called when it shouldn't have been." |
107 | 112 | |
108 | 113 | room_id = "!alpha:bet" |
109 | 114 | servers = ["aperture"] |
110 | interested_service = self._mkservice(is_interested=True) | |
115 | interested_service = self._mkservice_alias(is_interested_in_alias=True) | |
111 | 116 | services = [ |
112 | self._mkservice(is_interested=False), | |
117 | self._mkservice_alias(is_interested_in_alias=False), | |
113 | 118 | interested_service, |
114 | self._mkservice(is_interested=False) | |
119 | self._mkservice_alias(is_interested_in_alias=False) | |
115 | 120 | ] |
116 | 121 | |
117 | 122 | self.mock_store.get_app_services = Mock(return_value=services) |
134 | 139 | service.token = "mock_service_token" |
135 | 140 | service.url = "mock_service_url" |
136 | 141 | return service |
142 | ||
143 | def _mkservice_alias(self, is_interested_in_alias): | |
144 | service = Mock() | |
145 | service.is_interested_in_alias = Mock(return_value=is_interested_in_alias) | |
146 | service.token = "mock_service_token" | |
147 | service.url = "mock_service_url" | |
148 | return service |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | import pymacaroons |
16 | from twisted.internet import defer | |
16 | 17 | |
18 | import synapse | |
19 | import synapse.api.errors | |
17 | 20 | from synapse.handlers.auth import AuthHandler |
18 | 21 | from tests import unittest |
19 | 22 | from tests.utils import setup_test_homeserver |
20 | from twisted.internet import defer | |
21 | 23 | |
22 | 24 | |
23 | 25 | class AuthHandlers(object): |
30 | 32 | def setUp(self): |
31 | 33 | self.hs = yield setup_test_homeserver(handlers=None) |
32 | 34 | self.hs.handlers = AuthHandlers(self.hs) |
35 | self.auth_handler = self.hs.handlers.auth_handler | |
33 | 36 | |
34 | 37 | def test_token_is_a_macaroon(self): |
35 | 38 | self.hs.config.macaroon_secret_key = "this key is a huge secret" |
36 | 39 | |
37 | token = self.hs.handlers.auth_handler.generate_access_token("some_user") | |
40 | token = self.auth_handler.generate_access_token("some_user") | |
38 | 41 | # Check that we can parse the thing with pymacaroons |
39 | 42 | macaroon = pymacaroons.Macaroon.deserialize(token) |
40 | 43 | # The most basic of sanity checks |
45 | 48 | self.hs.config.macaroon_secret_key = "this key is a massive secret" |
46 | 49 | self.hs.clock.now = 5000 |
47 | 50 | |
48 | token = self.hs.handlers.auth_handler.generate_access_token("a_user") | |
51 | token = self.auth_handler.generate_access_token("a_user") | |
49 | 52 | macaroon = pymacaroons.Macaroon.deserialize(token) |
50 | 53 | |
51 | 54 | def verify_gen(caveat): |
66 | 69 | v.satisfy_general(verify_type) |
67 | 70 | v.satisfy_general(verify_expiry) |
68 | 71 | v.verify(macaroon, self.hs.config.macaroon_secret_key) |
72 | ||
73 | def test_short_term_login_token_gives_user_id(self): | |
74 | self.hs.clock.now = 1000 | |
75 | ||
76 | token = self.auth_handler.generate_short_term_login_token( | |
77 | "a_user", 5000 | |
78 | ) | |
79 | ||
80 | self.assertEqual( | |
81 | "a_user", | |
82 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
83 | token | |
84 | ) | |
85 | ) | |
86 | ||
87 | # when we advance the clock, the token should be rejected | |
88 | self.hs.clock.now = 6000 | |
89 | with self.assertRaises(synapse.api.errors.AuthError): | |
90 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
91 | token | |
92 | ) | |
93 | ||
94 | def test_short_term_login_token_cannot_replace_user_id(self): | |
95 | token = self.auth_handler.generate_short_term_login_token( | |
96 | "a_user", 5000 | |
97 | ) | |
98 | macaroon = pymacaroons.Macaroon.deserialize(token) | |
99 | ||
100 | self.assertEqual( | |
101 | "a_user", | |
102 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
103 | macaroon.serialize() | |
104 | ) | |
105 | ) | |
106 | ||
107 | # add another "user_id" caveat, which might allow us to override the | |
108 | # user_id. | |
109 | macaroon.add_first_party_caveat("user_id = b_user") | |
110 | ||
111 | with self.assertRaises(synapse.api.errors.AuthError): | |
112 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
113 | macaroon.serialize() | |
114 | ) |
16 | 16 | from tests import unittest |
17 | 17 | from twisted.internet import defer |
18 | 18 | |
19 | from mock import Mock | |
20 | ||
19 | 21 | from synapse.util.async import ObservableDeferred |
20 | 22 | |
21 | 23 | from synapse.util.caches.descriptors import Cache, cached |
71 | 73 | cache.get(3) |
72 | 74 | |
73 | 75 | def test_eviction_lru(self): |
74 | cache = Cache("test", max_entries=2, lru=True) | |
76 | cache = Cache("test", max_entries=2) | |
75 | 77 | |
76 | 78 | cache.prefill(1, "one") |
77 | 79 | cache.prefill(2, "two") |
198 | 200 | |
199 | 201 | self.assertEquals(a.func("foo").result, d.result) |
200 | 202 | self.assertEquals(callcount[0], 0) |
203 | ||
204 | @defer.inlineCallbacks | |
205 | def test_invalidate_context(self): | |
206 | callcount = [0] | |
207 | callcount2 = [0] | |
208 | ||
209 | class A(object): | |
210 | @cached() | |
211 | def func(self, key): | |
212 | callcount[0] += 1 | |
213 | return key | |
214 | ||
215 | @cached(cache_context=True) | |
216 | def func2(self, key, cache_context): | |
217 | callcount2[0] += 1 | |
218 | return self.func(key, on_invalidate=cache_context.invalidate) | |
219 | ||
220 | a = A() | |
221 | yield a.func2("foo") | |
222 | ||
223 | self.assertEquals(callcount[0], 1) | |
224 | self.assertEquals(callcount2[0], 1) | |
225 | ||
226 | a.func.invalidate(("foo",)) | |
227 | yield a.func("foo") | |
228 | ||
229 | self.assertEquals(callcount[0], 2) | |
230 | self.assertEquals(callcount2[0], 1) | |
231 | ||
232 | yield a.func2("foo") | |
233 | ||
234 | self.assertEquals(callcount[0], 2) | |
235 | self.assertEquals(callcount2[0], 2) | |
236 | ||
237 | @defer.inlineCallbacks | |
238 | def test_eviction_context(self): | |
239 | callcount = [0] | |
240 | callcount2 = [0] | |
241 | ||
242 | class A(object): | |
243 | @cached(max_entries=2) | |
244 | def func(self, key): | |
245 | callcount[0] += 1 | |
246 | return key | |
247 | ||
248 | @cached(cache_context=True) | |
249 | def func2(self, key, cache_context): | |
250 | callcount2[0] += 1 | |
251 | return self.func(key, on_invalidate=cache_context.invalidate) | |
252 | ||
253 | a = A() | |
254 | yield a.func2("foo") | |
255 | yield a.func2("foo2") | |
256 | ||
257 | self.assertEquals(callcount[0], 2) | |
258 | self.assertEquals(callcount2[0], 2) | |
259 | ||
260 | yield a.func("foo3") | |
261 | ||
262 | self.assertEquals(callcount[0], 3) | |
263 | self.assertEquals(callcount2[0], 2) | |
264 | ||
265 | yield a.func2("foo") | |
266 | ||
267 | self.assertEquals(callcount[0], 4) | |
268 | self.assertEquals(callcount2[0], 3) | |
269 | ||
270 | @defer.inlineCallbacks | |
271 | def test_double_get(self): | |
272 | callcount = [0] | |
273 | callcount2 = [0] | |
274 | ||
275 | class A(object): | |
276 | @cached() | |
277 | def func(self, key): | |
278 | callcount[0] += 1 | |
279 | return key | |
280 | ||
281 | @cached(cache_context=True) | |
282 | def func2(self, key, cache_context): | |
283 | callcount2[0] += 1 | |
284 | return self.func(key, on_invalidate=cache_context.invalidate) | |
285 | ||
286 | a = A() | |
287 | a.func2.cache.cache = Mock(wraps=a.func2.cache.cache) | |
288 | ||
289 | yield a.func2("foo") | |
290 | ||
291 | self.assertEquals(callcount[0], 1) | |
292 | self.assertEquals(callcount2[0], 1) | |
293 | ||
294 | a.func2.invalidate(("foo",)) | |
295 | self.assertEquals(a.func2.cache.cache.pop.call_count, 1) | |
296 | ||
297 | yield a.func2("foo") | |
298 | a.func2.invalidate(("foo",)) | |
299 | self.assertEquals(a.func2.cache.cache.pop.call_count, 2) | |
300 | ||
301 | self.assertEquals(callcount[0], 1) | |
302 | self.assertEquals(callcount2[0], 2) | |
303 | ||
304 | a.func.invalidate(("foo",)) | |
305 | self.assertEquals(a.func2.cache.cache.pop.call_count, 3) | |
306 | yield a.func("foo") | |
307 | ||
308 | self.assertEquals(callcount[0], 2) | |
309 | self.assertEquals(callcount2[0], 2) | |
310 | ||
311 | yield a.func2("foo") | |
312 | ||
313 | self.assertEquals(callcount[0], 2) | |
314 | self.assertEquals(callcount2[0], 3) |
14 | 14 | |
15 | 15 | from . import unittest |
16 | 16 | |
17 | from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs | |
17 | from synapse.rest.media.v1.preview_url_resource import ( | |
18 | summarize_paragraphs, decode_and_calc_og | |
19 | ) | |
18 | 20 | |
19 | 21 | |
20 | 22 | class PreviewTestCase(unittest.TestCase): |
136 | 138 | " of old wooden houses in Northern Norway, the oldest house dating from" |
137 | 139 | " 1789. The Arctic Cathedral, a modern church…" |
138 | 140 | ) |
141 | ||
142 | ||
143 | class PreviewUrlTestCase(unittest.TestCase): | |
144 | def test_simple(self): | |
145 | html = """ | |
146 | <html> | |
147 | <head><title>Foo</title></head> | |
148 | <body> | |
149 | Some text. | |
150 | </body> | |
151 | </html> | |
152 | """ | |
153 | ||
154 | og = decode_and_calc_og(html, "http://example.com/test.html") | |
155 | ||
156 | self.assertEquals(og, { | |
157 | "og:title": "Foo", | |
158 | "og:description": "Some text." | |
159 | }) | |
160 | ||
161 | def test_comment(self): | |
162 | html = """ | |
163 | <html> | |
164 | <head><title>Foo</title></head> | |
165 | <body> | |
166 | <!-- HTML comment --> | |
167 | Some text. | |
168 | </body> | |
169 | </html> | |
170 | """ | |
171 | ||
172 | og = decode_and_calc_og(html, "http://example.com/test.html") | |
173 | ||
174 | self.assertEquals(og, { | |
175 | "og:title": "Foo", | |
176 | "og:description": "Some text." | |
177 | }) | |
178 | ||
179 | def test_comment2(self): | |
180 | html = """ | |
181 | <html> | |
182 | <head><title>Foo</title></head> | |
183 | <body> | |
184 | Some text. | |
185 | <!-- HTML comment --> | |
186 | Some more text. | |
187 | <p>Text</p> | |
188 | More text | |
189 | </body> | |
190 | </html> | |
191 | """ | |
192 | ||
193 | og = decode_and_calc_og(html, "http://example.com/test.html") | |
194 | ||
195 | self.assertEquals(og, { | |
196 | "og:title": "Foo", | |
197 | "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text" | |
198 | }) | |
199 | ||
200 | def test_script(self): | |
201 | html = """ | |
202 | <html> | |
203 | <head><title>Foo</title></head> | |
204 | <body> | |
205 | <script> (function() {})() </script> | |
206 | Some text. | |
207 | </body> | |
208 | </html> | |
209 | """ | |
210 | ||
211 | og = decode_and_calc_og(html, "http://example.com/test.html") | |
212 | ||
213 | self.assertEquals(og, { | |
214 | "og:title": "Foo", | |
215 | "og:description": "Some text." | |
216 | }) |
18 | 18 | from synapse.util.caches.lrucache import LruCache |
19 | 19 | from synapse.util.caches.treecache import TreeCache |
20 | 20 | |
21 | from mock import Mock | |
22 | ||
21 | 23 | |
22 | 24 | class LruCacheTestCase(unittest.TestCase): |
23 | 25 | |
47 | 49 | self.assertEquals(cache.get("key"), 1) |
48 | 50 | self.assertEquals(cache.setdefault("key", 2), 1) |
49 | 51 | self.assertEquals(cache.get("key"), 1) |
52 | cache["key"] = 2 # Make sure overriding works. | |
53 | self.assertEquals(cache.get("key"), 2) | |
50 | 54 | |
51 | 55 | def test_pop(self): |
52 | 56 | cache = LruCache(1) |
78 | 82 | cache["key"] = 1 |
79 | 83 | cache.clear() |
80 | 84 | self.assertEquals(len(cache), 0) |
85 | ||
86 | ||
87 | class LruCacheCallbacksTestCase(unittest.TestCase): | |
88 | def test_get(self): | |
89 | m = Mock() | |
90 | cache = LruCache(1) | |
91 | ||
92 | cache.set("key", "value") | |
93 | self.assertFalse(m.called) | |
94 | ||
95 | cache.get("key", callback=m) | |
96 | self.assertFalse(m.called) | |
97 | ||
98 | cache.get("key", "value") | |
99 | self.assertFalse(m.called) | |
100 | ||
101 | cache.set("key", "value2") | |
102 | self.assertEquals(m.call_count, 1) | |
103 | ||
104 | cache.set("key", "value") | |
105 | self.assertEquals(m.call_count, 1) | |
106 | ||
107 | def test_multi_get(self): | |
108 | m = Mock() | |
109 | cache = LruCache(1) | |
110 | ||
111 | cache.set("key", "value") | |
112 | self.assertFalse(m.called) | |
113 | ||
114 | cache.get("key", callback=m) | |
115 | self.assertFalse(m.called) | |
116 | ||
117 | cache.get("key", callback=m) | |
118 | self.assertFalse(m.called) | |
119 | ||
120 | cache.set("key", "value2") | |
121 | self.assertEquals(m.call_count, 1) | |
122 | ||
123 | cache.set("key", "value") | |
124 | self.assertEquals(m.call_count, 1) | |
125 | ||
126 | def test_set(self): | |
127 | m = Mock() | |
128 | cache = LruCache(1) | |
129 | ||
130 | cache.set("key", "value", m) | |
131 | self.assertFalse(m.called) | |
132 | ||
133 | cache.set("key", "value") | |
134 | self.assertFalse(m.called) | |
135 | ||
136 | cache.set("key", "value2") | |
137 | self.assertEquals(m.call_count, 1) | |
138 | ||
139 | cache.set("key", "value") | |
140 | self.assertEquals(m.call_count, 1) | |
141 | ||
142 | def test_pop(self): | |
143 | m = Mock() | |
144 | cache = LruCache(1) | |
145 | ||
146 | cache.set("key", "value", m) | |
147 | self.assertFalse(m.called) | |
148 | ||
149 | cache.pop("key") | |
150 | self.assertEquals(m.call_count, 1) | |
151 | ||
152 | cache.set("key", "value") | |
153 | self.assertEquals(m.call_count, 1) | |
154 | ||
155 | cache.pop("key") | |
156 | self.assertEquals(m.call_count, 1) | |
157 | ||
158 | def test_del_multi(self): | |
159 | m1 = Mock() | |
160 | m2 = Mock() | |
161 | m3 = Mock() | |
162 | m4 = Mock() | |
163 | cache = LruCache(4, 2, cache_type=TreeCache) | |
164 | ||
165 | cache.set(("a", "1"), "value", m1) | |
166 | cache.set(("a", "2"), "value", m2) | |
167 | cache.set(("b", "1"), "value", m3) | |
168 | cache.set(("b", "2"), "value", m4) | |
169 | ||
170 | self.assertEquals(m1.call_count, 0) | |
171 | self.assertEquals(m2.call_count, 0) | |
172 | self.assertEquals(m3.call_count, 0) | |
173 | self.assertEquals(m4.call_count, 0) | |
174 | ||
175 | cache.del_multi(("a",)) | |
176 | ||
177 | self.assertEquals(m1.call_count, 1) | |
178 | self.assertEquals(m2.call_count, 1) | |
179 | self.assertEquals(m3.call_count, 0) | |
180 | self.assertEquals(m4.call_count, 0) | |
181 | ||
182 | def test_clear(self): | |
183 | m1 = Mock() | |
184 | m2 = Mock() | |
185 | cache = LruCache(5) | |
186 | ||
187 | cache.set("key1", "value", m1) | |
188 | cache.set("key2", "value", m2) | |
189 | ||
190 | self.assertEquals(m1.call_count, 0) | |
191 | self.assertEquals(m2.call_count, 0) | |
192 | ||
193 | cache.clear() | |
194 | ||
195 | self.assertEquals(m1.call_count, 1) | |
196 | self.assertEquals(m2.call_count, 1) | |
197 | ||
198 | def test_eviction(self): | |
199 | m1 = Mock(name="m1") | |
200 | m2 = Mock(name="m2") | |
201 | m3 = Mock(name="m3") | |
202 | cache = LruCache(2) | |
203 | ||
204 | cache.set("key1", "value", m1) | |
205 | cache.set("key2", "value", m2) | |
206 | ||
207 | self.assertEquals(m1.call_count, 0) | |
208 | self.assertEquals(m2.call_count, 0) | |
209 | self.assertEquals(m3.call_count, 0) | |
210 | ||
211 | cache.set("key3", "value", m3) | |
212 | ||
213 | self.assertEquals(m1.call_count, 1) | |
214 | self.assertEquals(m2.call_count, 0) | |
215 | self.assertEquals(m3.call_count, 0) | |
216 | ||
217 | cache.set("key3", "value") | |
218 | ||
219 | self.assertEquals(m1.call_count, 1) | |
220 | self.assertEquals(m2.call_count, 0) | |
221 | self.assertEquals(m3.call_count, 0) | |
222 | ||
223 | cache.get("key2") | |
224 | ||
225 | self.assertEquals(m1.call_count, 1) | |
226 | self.assertEquals(m2.call_count, 0) | |
227 | self.assertEquals(m3.call_count, 0) | |
228 | ||
229 | cache.set("key1", "value", m1) | |
230 | ||
231 | self.assertEquals(m1.call_count, 1) | |
232 | self.assertEquals(m2.call_count, 0) | |
233 | self.assertEquals(m3.call_count, 1) |